Time series generative model

This experiment follows the Section 4.1 of Deep Signature Transforms paper.

import equinox as eqx
import jax
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
from nets import create_generative_net
from utils.dataloader import DataLoader
from utils.signature_normalization import normalize_signature
from jax import numpy as jnp
from jax import random as jrandom
from signax.module import SignatureTransform
from utils.brownian_motion import get_bm_noise
from utils.ornstein_uhlenbeck import get_ou_signal

Parameter setup

random_seed = 1234
random_key = jrandom.PRNGKey(random_seed)

train_batch_size = 2**10
val_batch_size = 2**10
epochs = 300
n_points = 100

train_key, eval_key, signal_key, model_key = jrandom.split(random_key, 4)

Create data

train_dataset = get_bm_noise(
    n_points=n_points, num_samples=train_batch_size, random_key=train_key
)
eval_dataset = get_bm_noise(
    n_points=n_points, num_samples=val_batch_size, random_key=eval_key
)
signals = get_ou_signal(signal_key, train_batch_size, n_points)

Plot Ornstein-Uhlenbeck data

for sig_path in signals[:100]:
    plt.plot(*sig_path, "orange", alpha=0.1)
../_images/edfd9306f136367b800277ee86896eee4c8cd41fcff80d1cac2d23afd597ce80.png

Create model

signature_depth = 4
model = create_generative_net(2, key=model_key)

optim = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=1e-1),
)
opt_state = optim.init(eqx.filter(model, eqx.is_array))
@eqx.filter_jit
def batch_normalize(batch_generator, batch_data):
    """Normalize signature from generated data"""
    return jax.vmap(
        lambda _batch: normalize_signature(
            batch_generator(_batch),
            signature_depth,
        )
    )(batch_data)
signature_transform = SignatureTransform(signature_depth)
normalized_signal_sigs = batch_normalize(
    signature_transform, signals.transpose(0, 2, 1)
)
@eqx.filter_jit
def predict(model_to_predict, path):
    path = model_to_predict(path).squeeze()
    timeline = jnp.linspace(0, 1, path.shape[0] + 1)
    path = jnp.stack([timeline, jnp.concatenate([jnp.array([0]), path])])
    path = path.transpose((1, 0))
    sig = signature_transform(path)

    return sig
def kernel(sigs1, sigs2):
    """Kernel function between two signatures.
    This will be used in computing maximum mean discrepancy (MMD)
    """
    return jnp.mean(jnp.matmul(sigs1, sigs2.transpose()))


t1 = kernel(normalized_signal_sigs, normalized_signal_sigs)
@eqx.filter_value_and_grad
def loss(model_to_train, paths):
    generated_sigs = batch_normalize(
        lambda path: predict(model_to_train, path),
        paths,
    )

    t2 = kernel(normalized_signal_sigs, generated_sigs)
    t3 = kernel(generated_sigs, generated_sigs)

    return jnp.log(t1 - 2 * t2 + t3)


@eqx.filter_jit
def make_step(model_to_train, paths, optimizer_state):
    loss_item, grads = loss(model_to_train, paths)
    updates, optimizer_state = optim.update(grads, optimizer_state)
    model_to_train = eqx.apply_updates(model_to_train, updates)
    return loss_item, model_to_train, optimizer_state

Performing optimization

for step in range(epochs):
    train_dataloader = DataLoader(
        train_dataset.transpose(0, 2, 1),
        batch_size=train_batch_size,
        random_key=train_key,
    )
    for train_datum in train_dataloader:
        loss_val, model, opt_state = make_step(
            model,
            train_datum,
            opt_state,
        )
        if step % 100 == 0:
            print(f"step={step} \t loss={loss_val}")
step=0 	 loss=0.4741534888744354
step=100 	 loss=-3.988189935684204
step=200 	 loss=-5.587058067321777

Visualize between data and generated samples

eval_predicted = jax.vmap(lambda datum: model(datum))(
    eval_dataset.transpose(0, 2, 1)
).transpose(0, 2, 1)

plt.plot(eval_predicted[50:100, 0].T, "b", alpha=0.1)
plt.plot(signals[50:100, 1, 1:].T, "#ba0404", alpha=0.2);
../_images/2ca0bdcbae282997f0ad38399a486dc2b119d8a9c62734c990620678a14161d5.png