Converting NumPyro objects to DataTree#
DataTree is the data format ArviZ relies on.
This page covers multiple ways to generate a DataTree from NumPyro MCMC and SVI objects.
See also
Conversion from Python, numpy or pandas objects
DataTree for Exploratory Analysis of Bayesian Models for an overview of
InferenceDataand its role within ArviZ.
We will start by importing the required packages and defining the model. The famous 8 school model.
import arviz_base as az
import numpy as np
from numpy.typing import ArrayLike
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, autoguide, Predictive
from jax import random
import jax.numpy as jnp
J = 8
y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools_model(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", J):
eta = numpyro.sample("eta", dist.Normal(0, 1))
theta = numpyro.deterministic("theta", mu + tau * eta)
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
def eight_schools_custom_guide(J, sigma, y=None):
# Variational parameters for mu
mu_loc = numpyro.param("mu_loc", 0.0)
mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive)
mu = numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
# Variational parameters for tau (positive support)
tau_loc = numpyro.param("tau_loc", 1.0)
tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive)
tau = numpyro.sample("tau", dist.LogNormal(jnp.log(tau_loc), tau_scale))
# Variational parameters for eta
eta_loc = numpyro.param("eta_loc", jnp.zeros(J))
eta_scale = numpyro.param("eta_scale", jnp.ones(J), constraint=dist.constraints.positive)
with numpyro.plate("J", J):
eta = numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))
# Deterministic transform
numpyro.deterministic("theta", mu + tau * eta)
Convert from MCMC#
This first example shows conversion from MCMC
# fit with MCMC
nuts = NUTS(eight_schools_model)
mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)
# sample the posterior predictive
predictive = Predictive(eight_schools_model, mcmc.get_samples())
samples_predictive = predictive(random.PRNGKey(1), J=J, sigma=sigma)
# Convert to MCMC
idata_mcmc = az.from_numpyro(mcmc, posterior_predictive=samples_predictive)
idata_mcmc
Convert from SVI with Autoguide#
eight_schools_guide = autoguide.AutoNormal(eight_schools_model, init_loc_fn=numpyro.infer.init_to_median(num_samples=100))
svi = SVI(
eight_schools_model,
guide=eight_schools_guide,
optim=numpyro.optim.Adam(0.01),
loss = Trace_ELBO()
)
svi_result = svi.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)
# sample the posterior predictive
predictive_svi = Predictive(eight_schools_model, guide=eight_schools_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi = predictive_svi(random.PRNGKey(1), J=J, sigma=sigma)
idata_svi = az.from_numpyro_svi(
svi,
svi_result=svi_result,
model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
num_samples = 4000, # number of samples to draw in the posterior
posterior_predictive=samples_predictive_svi
)
idata_svi
Converting from SVI with a custom guide function#
svi_custom_guide = SVI(
eight_schools_model,
guide=eight_schools_custom_guide,
optim=numpyro.optim.Adam(0.01),
loss = Trace_ELBO()
)
svi_custom_guide_result = svi_custom_guide.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)
# sample the posterior predictive
predictive_svi_custom = Predictive(eight_schools_model, guide=eight_schools_custom_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi_custom = predictive_svi_custom(random.PRNGKey(1), J=J, sigma=sigma)
idata_svi_custom_guide = az.from_numpyro_svi(
svi_custom_guide,
svi_result=svi_custom_guide_result,
model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
num_samples = 4000, # number of samples to draw in the posterior
posterior_predictive=samples_predictive_svi_custom
)
idata_svi_custom_guide
Automatically Labelling Event Dims#
NumPyro batch dims are automatically labelled according to their corresponding plate names. In order to label event dims, we add infer={"event_dims": dim_labels} to the numpyro.sample statement as shown below:
def eight_schools_model_zsn(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
eta = numpyro.sample(
"eta",
dist.ZeroSumNormal(tau, event_shape=(J,)),
# note: this allows arviz to infer the event dimension labels
infer={"event_dims":["J"]}
)
with numpyro.plate("J", J):
theta = numpyro.deterministic("theta", mu + eta)
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
# fit with MCMC
nuts = NUTS(eight_schools_model_zsn)
mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc2.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)
# sample the posterior predictive
predictive2 = Predictive(eight_schools_model, mcmc2.get_samples())
samples_predictive2 = predictive2(random.PRNGKey(1), J=J, sigma=sigma)
# Convert to MCMC
idata_mcmc2 = az.from_numpyro(mcmc2, posterior_predictive=samples_predictive2)
Notice that eta is labelled appropriately with J
idata_mcmc2
Extending NumPyro Conversion to other Inference Objects#
NumPyroInferenceAdapter can be leveraged to extend ArviZ conversion to other NumPyro Inference Objects (such as the NestedSampler)
The example below uses the SVI implementation as an example, where an adapter class is created that inherits the NumPyroInferenceAdapter base class
class SVIAdapter(az.NumPyroInferenceAdapter):
"""Adapter for SVI to standardize attributes and methods with other inference objects."""
def __init__(
self,
svi,
*,
svi_result,
model_args=None,
model_kwargs=None,
num_samples: int = 1000,
):
if svi is None:
raise ValueError("svi parameter is required for SVIAdapter")
if svi_result is None:
raise ValueError("svi_result parameter is required for SVIAdapter")
super().__init__(
svi,
model=getattr(svi.guide, "model", svi.model),
model_args=model_args,
model_kwargs=model_kwargs,
sample_shape=(num_samples,),
)
self.result_obj = svi_result
@property
def sample_dims(self) -> list[str]: # noqa: D102
return ["sample"]
# Implement getting posterior samples
def get_samples( # noqa: D102
self, seed: int | None = None, group_by_chain: bool = False, **kwargs: dict
) -> dict[str, ArrayLike]:
key = self.prng_key_func(seed or 0)
if isinstance(self.posterior.guide, numpyro.infer.autoguide.AutoGuide):
return self.posterior.guide.sample_posterior(
key,
self.result_obj.params,
*self._args,
sample_shape=self.sample_shape,
**self._kwargs,
)
# if a custom guide is provided, sample by hand
predictive = numpyro.infer.Predictive(
self.posterior.guide, params=self.result_obj.params, num_samples=self.sample_shape[0]
)
samples = predictive(key, *self._args, **self._kwargs)
return samples
The instantiated adapter can now be passed directly into az.from_numpyro.
adapter = SVIAdapter(
svi,
svi_result=svi_result,
model_kwargs=dict(J=J, sigma=sigma, y=y_obs),
num_samples = 4000
)
idata_svi2 = az.from_numpyro(adapter, posterior_predictive=samples_predictive_svi)
idata_svi2
%load_ext watermark
%watermark -n -u -v -iv -w