# from jax import grad, jit, vmap
# import jax
from jax import random
import jax.numpy as jnp
import datetime as dt
# import pyro
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import arviz as az
/home/brent/anaconda3/envs/pytorch_pyro/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
# Read data
df_raw = pd.read_csv('stock_data.csv')
# New df for analysis
df1 = df_raw
df1['date_stamp']= pd.to_datetime(df1['date_stamp'], format="%d/%m/%Y")
df1 = df1[df1['date_stamp'] == '2021-06-30'].copy()
# Scale function
scale = lambda x: (x - x.mean()) / x.std()
df1['log_mkt_cap'] = np.log(df1['mkt_cap'])
df1['log_assets'] = np.log(df1['total_assets'])
df1['log_equity_cln'] = np.log(-df1['total_equity_cln'])
df1['roe'] = -df1['roe']
df1['roe_s'] = df1.roe.pipe(scale)
df1['leverage_s'] = df1.leverage.pipe(scale)
le = LabelEncoder()
df1['sector_tf'] = le.fit_transform(df1['sector'].values)
cols = ['date_stamp', 'symbol', 'log_mkt_cap', 'log_assets', 'log_equity_cln', 'roe', 'roe_s', 'leverage', 'leverage_s', 'sector_tf', 'log_pb']
df1 = df1[cols].copy()
df1.head()
date_stamp | symbol | log_mkt_cap | log_assets | log_equity_cln | roe | roe_s | leverage | leverage_s | sector_tf | log_pb | |
---|---|---|---|---|---|---|---|---|---|---|---|
53 | 2021-06-30 | A | 10.716239 | 9.177197 | 8.477204 | 0.166845 | 0.383000 | 0.503411 | -0.688139 | 6 | 2.234926 |
86 | 2021-06-30 | AA | 8.833674 | 9.606428 | 8.520388 | -0.031416 | -0.353283 | 0.662450 | 0.167220 | 10 | 0.313342 |
146 | 2021-06-30 | AAL | 9.515441 | 11.035019 | 8.732434 | -1.000000 | -3.950308 | 1.110744 | 2.578291 | 0 | 0.783072 |
203 | 2021-06-30 | AAN | 7.672204 | 7.906806 | 7.327254 | -0.165563 | -0.851463 | 0.439851 | -1.029989 | 0 | 0.345311 |
239 | 2021-06-30 | AAP | 9.518787 | 9.379208 | 8.177379 | 0.135932 | 0.268199 | 0.699356 | 0.365717 | 3 | 1.343235 |
Multi level regression WITH correlation b/w intercept and slope
Model | Attribute description | Attribute description |
---|---|---|
symbol | the ticker symbol identifying the company | the ticker symbol identifying the company |
date_stamp | date_stamp | date_stamp |
def m1(grp1, x1, y = None):
a = numpyro.sample("a", dist.Normal(1.25, 1)) # prior for population level intercept
b = numpyro.sample("b", dist.Normal(1, 1.5)) # prior for population level slope
#sigma_grp1 = numpyro.sample("sigma_grp1", dist.Exponential(1).expand([2])) # prior for standard deviation of group level effects ('sd' per brms)
sigma_grp1 = numpyro.sample('sigma_grp1', dist.Exponential(1), sample_shape=(2,)) # prior for standard deviation of group level effects ('sd' per brms)
Rho = numpyro.sample("Rho", dist.LKJ(2, 2)) # prior for correlation b/w group level slopes & intercepts
cov = jnp.matmul(jnp.matmul(jnp.diag(sigma_grp1), Rho), jnp.diag(sigma_grp1)) # construct variance / covariance mtrx b/w slopes & intercepts
#cov = jnp.outer(sigma_grp1, sigma_grp1) * Rho # construct variance / covariance mtrx b/w slopes & intercepts (alternate)
a_grp1_b_grp1 = numpyro.sample("a_grp1_b_grp1", dist.MultivariateNormal(
loc=jnp.stack([a, b]), covariance_matrix=cov) \
.expand(batch_shape=[11])) # prior for mv norm dist b/w slopes & intercepts
a_grp1 = a_grp1_b_grp1[:, 0]
b_grp1 = a_grp1_b_grp1[:, 1]
mu = a_grp1[grp1] + b_grp1[grp1] * x1 # linear model
sigma = numpyro.sample("sigma", dist.Exponential(1)) # prior for residual SD of response distribution
numpyro.sample("y", dist.Normal(mu, sigma), obs=y) # likelihood
Run MCMC
m1_mcmc = MCMC(NUTS(m1), num_warmup=1000, num_samples=1000, num_chains=1)
m1_mcmc.run(random.PRNGKey(0), grp1=df1.sector_tf.values, x1=df1.roe.values, y=df1.log_pb.values)
0%| | 0/2000 [00:00<?, ?it/s]/home/brent/anaconda3/envs/pytorch_pyro/lib/python3.9/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement. warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() ' sample: 100%|██████████| 2000/2000 [00:19<00:00, 100.96it/s, 15 steps of size 4.04e-01. acc. prob=0.88]
Render model
grp1 = jnp.arange(1, 12, 1, dtype=int)
x1 = jnp.ones(11)
y = jnp.ones(11)
numpyro.render_model(m1, model_args=(grp1, x1, y), render_distributions=True)
Summary coefficients
m1_mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat Rho[0,0] 1.00 0.00 1.00 1.00 1.00 nan nan Rho[0,1] -0.45 0.24 -0.47 -0.82 -0.08 819.76 1.00 Rho[1,0] -0.45 0.24 -0.47 -0.82 -0.08 819.76 1.00 Rho[1,1] 1.00 0.00 1.00 1.00 1.00 941.51 1.00 a 0.90 0.19 0.90 0.60 1.20 733.85 1.00 a_grp1_b_grp1[0,0] 1.35 0.07 1.35 1.24 1.47 1464.27 1.00 a_grp1_b_grp1[0,1] 0.81 0.22 0.80 0.44 1.17 1233.19 1.00 a_grp1_b_grp1[1,0] 1.50 0.10 1.50 1.34 1.68 838.82 1.01 a_grp1_b_grp1[1,1] 1.56 0.32 1.55 1.08 2.10 909.22 1.00 a_grp1_b_grp1[2,0] 0.98 0.15 0.99 0.76 1.23 726.27 1.00 a_grp1_b_grp1[2,1] 2.49 0.46 2.50 1.66 3.12 675.92 1.00 a_grp1_b_grp1[3,0] 1.40 0.08 1.40 1.26 1.51 891.77 1.00 a_grp1_b_grp1[3,1] 0.43 0.21 0.42 0.08 0.77 1392.12 1.00 a_grp1_b_grp1[4,0] -0.05 0.09 -0.05 -0.21 0.10 1148.65 1.00 a_grp1_b_grp1[4,1] 3.81 0.45 3.81 3.03 4.53 834.73 1.00 a_grp1_b_grp1[5,0] 0.27 0.14 0.27 0.04 0.49 675.94 1.00 a_grp1_b_grp1[5,1] 2.74 1.24 2.70 0.83 4.87 664.07 1.00 a_grp1_b_grp1[6,0] 1.70 0.10 1.70 1.54 1.87 736.03 1.00 a_grp1_b_grp1[6,1] -0.36 0.27 -0.35 -0.80 0.10 842.02 1.00 a_grp1_b_grp1[7,0] 0.42 0.13 0.42 0.21 0.63 854.53 1.00 a_grp1_b_grp1[7,1] -0.06 0.39 -0.06 -0.74 0.53 900.08 1.00 a_grp1_b_grp1[8,0] 0.70 0.24 0.69 0.32 1.08 809.02 1.00 a_grp1_b_grp1[8,1] 3.34 1.21 3.30 1.39 5.34 685.91 1.00 a_grp1_b_grp1[9,0] 0.52 0.07 0.52 0.41 0.64 1309.90 1.00 a_grp1_b_grp1[9,1] 2.54 0.55 2.55 1.66 3.45 1007.62 1.00 a_grp1_b_grp1[10,0] 0.95 0.13 0.94 0.73 1.14 1082.62 1.00 a_grp1_b_grp1[10,1] 1.71 0.61 1.72 0.73 2.70 1164.27 1.00 b 1.65 0.47 1.65 0.92 2.46 862.43 1.00 sigma 0.74 0.02 0.74 0.71 0.77 1285.21 1.00 sigma_grp1[0] 0.61 0.15 0.59 0.40 0.83 723.52 1.00 sigma_grp1[1] 1.54 0.40 1.48 0.96 2.13 642.96 1.00 Number of divergences: 0
# Sample from posterior
m1_post = m1_mcmc.get_samples()
# To arviz
m1_arviz = az.from_numpyro(posterior=m1_mcmc)
az.style.use("arviz-grayscale")
#centered_data = az.load_arviz_data(m1_arviz)
#non_centered_data = az.load_arviz_data("non_centered_eight")
az.plot_forest(
data=m1_arviz,
kind="forestplot",
combined=True,
#ridgeplot_overlap=3,
#colors='white',
figsize=(9, 7)
)
plt.show()