Learn Hurst exponent for fractional Brownian Motion#

This example performs to task of predicting Hurst exponent given the synthetic data generated from fractional Brownian motion.

import equinox as eqx
import fbm
import jax
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import nets
import optax  # https://github.com/deepmind/optax


jax.config.update("jax_platform_name", "cpu")
n_paths_train = 200
n_paths_test = 50
dt = 1e-2 / 3.0  # 300 time steps
hurst_exponent = jnp.around(jnp.linspace(0.2, 0.8, 7), decimals=7).tolist()
def dataloader(arrays, batch_size, *, key):
    # this taken from equinox documentation
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size
def generate_data(key):
    train_key, test_key = jrandom.split(key)
    X_train, Y_train = [], []
    X_test, Y_test = [], []
    # generate train
    for hurst in hurst_exponent:
        train_key, _ = jrandom.split(train_key)
        X = fbm.generate_fbm(
            hurst=hurst,
            n_paths=n_paths_train,
            dt=dt,
            key=train_key,
        )
        Y = jnp.array([hurst] * n_paths_train)
        X_train.append(X)
        Y_train.append(Y)

    # generate test
    for hurst in hurst_exponent:
        test_key, _ = jrandom.split(test_key)
        X = fbm.generate_fbm(
            hurst=hurst,
            n_paths=n_paths_test,
            dt=dt,
            key=test_key,
        )
        Y = jnp.array([hurst] * n_paths_test)
        X_test.append(X)
        Y_test.append(Y)

    X_train = jnp.concatenate(X_train)
    Y_train = jnp.concatenate(Y_train)
    X_test = jnp.concatenate(X_test)
    Y_test = jnp.concatenate(Y_test)

    return (
        X_train[..., None],
        Y_train[..., None],
        X_test[..., None],
        Y_test[..., None],
    )
seed = 1234
key = jrandom.PRNGKey(seed)
data_key, loader_key, model_key = jrandom.split(key, 3)
X_train, Y_train, X_test, Y_test = generate_data(key=data_key)
model = nets.create_simple_net(
    dim=1,
    signature_depth=3,
    augment_layer_size=(3,),
    augmented_kernel_size=3,
    mlp_width=32,
    mlp_depth=5,
    output_size=1,
    final_activation=jax.nn.sigmoid,
    key=model_key,
)
iter_data = dataloader((X_train, Y_train), batch_size=128, key=loader_key)
optim = optax.adam(learning_rate=1e-3)
opt_state = optim.init(eqx.filter(model, eqx.is_array))
@eqx.filter_value_and_grad
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    assert pred_y.shape[0] == y.shape[0]
    return jnp.mean(jnp.square(pred_y - y))
@eqx.filter_jit
def make_step(model, x, y, opt_state):
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state
test_mse = []
for step, (x, y) in zip(range(500), iter_data):
    loss, model, opt_state = make_step(model, x, y, opt_state)
    loss = loss.item()
    test_mse += [jnp.mean(jnp.square(jax.vmap(model)(X_test) - Y_test)).item()]
    if step % 10 == 0:
        print(f"step={step} \t loss={loss}")
step=0 	 loss=0.04126128554344177
step=10 	 loss=0.03442610800266266
step=20 	 loss=0.027970541268587112
step=30 	 loss=0.02049068734049797
step=40 	 loss=0.014842262491583824
step=50 	 loss=0.01463927049189806
step=60 	 loss=0.008735964074730873
step=70 	 loss=0.003853133413940668
step=80 	 loss=0.003183095483109355
step=90 	 loss=0.0017226741183549166
step=100 	 loss=0.0018800244433805346
step=110 	 loss=0.001246932428330183
step=120 	 loss=0.0010552150197327137
step=130 	 loss=0.0011177051346749067
step=140 	 loss=0.000960731296800077
step=150 	 loss=0.0007560451049357653
step=160 	 loss=0.0006463612080551684
step=170 	 loss=0.0007553484756499529
step=180 	 loss=0.0006547566154040396
step=190 	 loss=0.0005832638125866652
step=200 	 loss=0.0005575535469688475
step=210 	 loss=0.00035019911592826247
step=220 	 loss=0.00040039169834926724
step=230 	 loss=0.00045131167280487716
step=240 	 loss=0.0003143866779282689
step=250 	 loss=0.00036405035643838346
step=260 	 loss=0.0002983830636367202
step=270 	 loss=0.0002879722451325506
step=280 	 loss=0.000225563402636908
step=290 	 loss=0.00022640179668087512
step=300 	 loss=0.00024252216098830104
step=310 	 loss=0.00021433460642583668
step=320 	 loss=0.00020663876784965396
step=330 	 loss=0.0002086602326016873
step=340 	 loss=0.0001339660375379026
step=350 	 loss=0.0001248260377906263
step=360 	 loss=0.00014416594058275223
step=370 	 loss=0.0001910615392262116
step=380 	 loss=0.000121168268378824
step=390 	 loss=0.00010210465552518144
step=400 	 loss=0.00016413693083450198
step=410 	 loss=0.00010884729272220284
step=420 	 loss=0.00010439904872328043
step=430 	 loss=0.00014752228162251413
step=440 	 loss=0.00011902432743227109
step=450 	 loss=9.99133990262635e-05
step=460 	 loss=9.923087782226503e-05
step=470 	 loss=8.425781561527401e-05
step=480 	 loss=9.92306595435366e-05
step=490 	 loss=8.330534910783172e-05
plt.plot(test_mse)
plt.yscale("log")
../_images/9b80b5b25054276a3d0e8dfd9670b45deaf0f32cc1077753ef4615ca29cb6925.png