The examples below demonstrate the difference between sample shape (that specified with the sample_shape
parameter of the sample
method), batch shape (that specified with the batch_shape
parameter of the expand
method) and event shape (that specified with the distribution sampled from).
Note that batch shape signifies the number instances of a distribution (they must be the same distribution) that is sampled.
Reference for this is here.
from jax import random
import jax.numpy as jnp
import numpyro.distributions as dist
/home/brent/anaconda3/envs/pytorch_pyro/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
d = dist.Normal(0,1) \
.sample(random.PRNGKey(42), sample_shape=(3,))
print('Sample: \n', d)
print("\n")
print('Event shape:', dist.Normal(0,1).event_shape, ' Batch shape:', dist.Normal(0,1).batch_shape, ' Sample shape:', d.shape)
Sample: [ 0.18693547 -1.2806505 -1.5593132 ] Event shape: () Batch shape: () Sample shape: (3,)
d = dist.Normal(0,1) \
.expand(batch_shape=(1,)) \
.sample(random.PRNGKey(42), sample_shape=(3,))
print('Sample: \n', d)
print("\n")
print('Event shape:', dist.Normal(0,1).event_shape, ' Batch shape:', dist.Normal(0,1).batch_shape, ' Sample shape:', d.shape)
Sample: [[ 0.18693547] [-1.2806505 ] [-1.5593132 ]] Event shape: () Batch shape: () Sample shape: (3, 1)
d = dist.Normal(0,1)
s = d.expand(batch_shape=(2,)) \
.sample(random.PRNGKey(42), sample_shape=(3,))
print('Sample: \n', s)
print("\n")
print('Event shape:', d.event_shape, ' Batch shape:', d.batch_shape, ' Sample shape:', s.shape)
Sample: [[ 0.6122652 1.1225883 ] [ 1.1373317 -0.8127325 ] [-0.890405 0.12623145]] Event shape: () Batch shape: () Sample shape: (3, 2)
d = dist.Normal(0,1)
s = d.expand(batch_shape=[2]) \
.sample(random.PRNGKey(42), sample_shape=(3,))
print('Sample: \n', s)
print("\n")
print('Event shape:', d.event_shape, ' Batch shape:', d.batch_shape, ' Sample shape:', s.shape)
Sample: [[ 0.6122652 1.1225883 ] [ 1.1373317 -0.8127325 ] [-0.890405 0.12623145]] Event shape: () Batch shape: () Sample shape: (3, 2)
d = dist.LKJ(2, 2)
s = d.sample(random.PRNGKey(42), sample_shape=(1,))
print('Sample: \n', s)
print("\n")
print('Event shape:', d.event_shape, ' Batch shape:', d.batch_shape, ' Sample shape:', s.shape)
Sample: [[[1. 0.44555074] [0.44555074 1. ]]] Event shape: (2, 2) Batch shape: () Sample shape: (1, 2, 2)
d = dist.LKJ(2, 2)
s = d.expand(batch_shape=[2]) \
.sample(random.PRNGKey(42), sample_shape=(1,))
print('Sample: \n', s)
print("\n")
print('Event shape:', d.event_shape, ' Batch shape:', d.batch_shape, ' Sample shape:', s.shape)
Sample: [[[[ 1. 0.44442013] [ 0.44442013 1. ]] [[ 1. -0.11944741] [-0.11944741 1. ]]]] Event shape: (2, 2) Batch shape: () Sample shape: (1, 2, 2, 2)
lkj = dist.LKJ(2, 2)
lkj_sample = lkj.expand(batch_shape=[2]) \
.sample(random.PRNGKey(42), sample_shape=(1,))
d = dist.MultivariateNormal(
loc=jnp.stack([1, 2]),
covariance_matrix=jnp.matmul(jnp.matmul(jnp.diag(jnp.array([3, 4])), lkj_sample), jnp.diag(jnp.array([3, 4])))
)
s = d.sample(random.PRNGKey(42), sample_shape=(3,))
print('Sample: \n', s)
print("\n")
print('Event shape:', d.event_shape, ' Batch shape:', d.batch_shape, ' Sample shape:', s.shape)
Sample: [[[[ 2.1126237 -1.5519805 ] [-0.5430193 9.980261 ]]] [[[-2.8950949 -2.4949522 ] [ 2.6727245 2.1419883 ]]] [[[-1.4942696 -0.814054 ] [-0.64660287 -0.11682558]]]] Event shape: (2,) Batch shape: (1, 2) Sample shape: (3, 1, 2, 2)
lkj_sample = lkj.sample(random.PRNGKey(42), sample_shape=(1,))
d = dist.MultivariateNormal(
loc=jnp.stack([1, 2]),
covariance_matrix=jnp.matmul(jnp.matmul(jnp.diag(jnp.array([3, 4])), lkj_sample), jnp.diag(jnp.array([3, 4])))
) \
.expand(batch_shape=(6,))
s = d.sample(random.PRNGKey(42), sample_shape=(1,))
print('Sample: \n', s)
print("\n")
print('Event shape:', d.event_shape, ' Batch shape:', d.batch_shape, ' Sample shape:', s.shape)
Sample: [[[ 2.1126237 -1.5476623 ] [-0.5430193 8.057651 ] [-2.8950949 -2.4994526 ] [ 2.6727245 3.3619635 ] [-1.4942696 -0.81697655] [-0.64660287 -1.1234295 ]]] Event shape: (2,) Batch shape: (6,) Sample shape: (1, 6, 2)
lkj_sample = lkj.sample(random.PRNGKey(42), sample_shape=(1,))
d = dist.MultivariateNormal(
loc=jnp.stack([1, 2]),
covariance_matrix=jnp.matmul(jnp.matmul(jnp.diag(jnp.array([3, 4])), lkj_sample), jnp.diag(jnp.array([3, 4])))
) \
.expand(batch_shape=(1,))
s = d.sample(random.PRNGKey(42), sample_shape=(6,))
print('Sample: \n', s)
print("\n")
print('Event shape:', d.event_shape, ' Batch shape:', d.batch_shape, ' Sample shape:', s.shape)
Sample: [[[ 2.1126237 -1.5476623 ]] [[-0.5430193 8.057651 ]] [[-2.8950949 -2.4994526 ]] [[ 2.6727245 3.3619635 ]] [[-1.4942696 -0.81697655]] [[-0.64660287 -1.1234295 ]]] Event shape: (2,) Batch shape: (1,) Sample shape: (6, 1, 2)
cov_mat = jnp.matmul(jnp.matmul(jnp.diag(jnp.array([3, 4])), lkj_sample), jnp.diag(jnp.array([3, 4])))
print('Sample: \n', cov_mat)
print("\n")
print('Shape:', cov_mat.shape)
Sample: [[[ 9. 5.346609] [ 5.346609 16. ]]] Shape: (1, 2, 2)