Time series generative model#

This experiment follows the Section 4.1 of Deep Signature Transforms paper.

import equinox as eqx
import jax
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
from nets import create_generative_net
from utils.dataloader import DataLoader
from utils.signature_normalization import normalize_signature
from jax import numpy as jnp
from jax import random as jrandom
from signax.module import SignatureTransform
from utils.brownian_motion import get_bm_noise
from utils.ornstein_uhlenbeck import get_ou_signal

Parameter setup#

random_seed = 1234
random_key = jrandom.PRNGKey(random_seed)

train_batch_size = 2**10
val_batch_size = 2**10
epochs = 300
n_points = 100

train_key, eval_key, signal_key, model_key = jrandom.split(random_key, 4)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Create data#

train_dataset = get_bm_noise(
    n_points=n_points, num_samples=train_batch_size, random_key=train_key
)
eval_dataset = get_bm_noise(
    n_points=n_points, num_samples=val_batch_size, random_key=eval_key
)
signals = get_ou_signal(signal_key, train_batch_size, n_points)

Plot Ornstein-Uhlenbeck data

for sig_path in signals[:100]:
    plt.plot(*sig_path, "orange", alpha=0.1)
../_images/fb6714710156194f33a1484fe29f8adb0db912bd02971d0f3fd59e6a1be4ab0a.png

Create model#

signature_depth = 4
model = create_generative_net(2, key=model_key)

optim = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=1e-1),
)
opt_state = optim.init(eqx.filter(model, eqx.is_array))
@eqx.filter_jit
def batch_normalize(batch_generator, batch_data):
    """Normalize signature from generated data"""
    return jax.vmap(
        lambda _batch: normalize_signature(
            batch_generator(_batch),
            signature_depth,
        )
    )(batch_data)
signature_transform = SignatureTransform(signature_depth)
normalized_signal_sigs = batch_normalize(
    signature_transform, signals.transpose(0, 2, 1)
)
@eqx.filter_jit
def predict(model_to_predict, path):
    path = model_to_predict(path).squeeze()
    timeline = jnp.linspace(0, 1, path.shape[0] + 1)
    path = jnp.stack([timeline, jnp.concatenate([jnp.array([0]), path])])
    path = path.transpose((1, 0))
    sig = signature_transform(path)

    return sig
def kernel(sigs1, sigs2):
    """Kernel function between two signatures.
    This will be used in computing maximum mean discrepancy (MMD)
    """
    return jnp.mean(jnp.matmul(sigs1, sigs2.transpose()))


t1 = kernel(normalized_signal_sigs, normalized_signal_sigs)
@eqx.filter_value_and_grad
def loss(model_to_train, paths):
    generated_sigs = batch_normalize(
        lambda path: predict(model_to_train, path),
        paths,
    )

    t2 = kernel(normalized_signal_sigs, generated_sigs)
    t3 = kernel(generated_sigs, generated_sigs)

    return jnp.log(t1 - 2 * t2 + t3)


@eqx.filter_jit
def make_step(model_to_train, paths, optimizer_state):
    loss_item, grads = loss(model_to_train, paths)
    updates, optimizer_state = optim.update(grads, optimizer_state)
    model_to_train = eqx.apply_updates(model_to_train, updates)
    return loss_item, model_to_train, optimizer_state

Performing optimization#

for step in range(epochs):
    train_dataloader = DataLoader(
        train_dataset.transpose(0, 2, 1),
        batch_size=train_batch_size,
        random_key=train_key,
    )
    for train_datum in train_dataloader:
        loss_val, model, opt_state = make_step(
            model,
            train_datum,
            opt_state,
        )
        if step % 100 == 0:
            print(f"step={step} \t loss={loss_val}")
step=0 	 loss=0.4924733638763428
step=100 	 loss=-4.107013702392578
step=200 	 loss=-5.374304294586182
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[11], line 8
      2 train_dataloader = DataLoader(
      3     train_dataset.transpose(0, 2, 1),
      4     batch_size=train_batch_size,
      5     random_key=train_key,
      6 )
      7 for train_datum in train_dataloader:
----> 8     loss_val, model, opt_state = make_step(
      9         model,
     10         train_datum,
     11         opt_state,
     12     )
     13     if step % 100 == 0:
     14         print(f"step={step} \t loss={loss_val}")

File ~/checkouts/readthedocs.org/user_builds/signax/envs/stable/lib/python3.11/site-packages/equinox/_jit.py:107, in _JitWrapper.__call__(self, *args, **kwargs)
    106 def __call__(self, /, *args, **kwargs):
--> 107     return self._call(False, args, kwargs)

File ~/checkouts/readthedocs.org/user_builds/signax/envs/stable/lib/python3.11/site-packages/equinox/_jit.py:103, in _JitWrapper._call(self, is_lower, args, kwargs)
    101         out = self._cached(dynamic, static)
    102 else:
--> 103     out = self._cached(dynamic, static)
    104 return _postprocess(out)

File ~/checkouts/readthedocs.org/user_builds/signax/envs/stable/lib/python3.11/site-packages/equinox/_module.py:307, in _unflatten_module(cls, aux, dynamic_field_values)
    297     aux = _FlattenedData(
    298         tuple(dynamic_field_names),
    299         tuple(static_field_names),
   (...)
    302         tuple(wrapper_field_values),
    303     )
    304     return tuple(dynamic_field_values), aux
--> 307 def _unflatten_module(cls: type["Module"], aux: _FlattenedData, dynamic_field_values):
    308     module = object.__new__(cls)
    309     for name, value in zip(aux.dynamic_field_names, dynamic_field_values):

KeyboardInterrupt: 

Visualize between data and generated samples

eval_predicted = jax.vmap(lambda datum: model(datum))(
    eval_dataset.transpose(0, 2, 1)
).transpose(0, 2, 1)

plt.plot(eval_predicted[50:100, 0].T, "b", alpha=0.1)
plt.plot(signals[50:100, 1, 1:].T, "#ba0404", alpha=0.2);
../_images/ff96adf26080619e669a613863ce0e747e070c55cd93ce424c8ddd0c5c148db1.png