Signax - Signature computation in JAX#

Introduction#

Signax is a JAX library for signature computation.

Installation#

Install via pip

python3 -m pip install signax

Install via source

git clone https://github.com/anh-tong/signax.git
cd signax
python3 -m pip install -v -e .

Get Started#

import jax
import jax.random as jrandom
import signax


key = jrandom.PRNGKey(0)
depth = 3

# compute signature for a single path
length = 100
dim = 20
path = jrandom.normal(shape=(length, dim), key=key)
output = signax.signature(path, depth)
# output is a list of array representing tensor algebra

# compute signature for batches (multiple) of paths
# this is done via `jax.vmap`
batch_size = 20
path = jrandom.normal(shape=(batch_size, length, dim), key=key)
output = signax.signature(path, depth)