Converting NumPyro objects to DataTree#

DataTree is the data format ArviZ relies on.

This page covers multiple ways to generate a DataTree from NumPyro MCMC and SVI objects.

See also

We will start by importing the required packages and defining the model. The famous 8 school model.

import arviz_base as az
import numpy as np
from numpy.typing import ArrayLike

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, autoguide, Predictive
from jax import random
import jax.numpy as jnp
J = 8
y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools_model(J, sigma, y=None):
    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    with numpyro.plate("J", J):
        eta = numpyro.sample("eta", dist.Normal(0, 1))
        theta = numpyro.deterministic("theta", mu + tau * eta)
        return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
    

def eight_schools_custom_guide(J, sigma, y=None):

    # Variational parameters for mu
    mu_loc = numpyro.param("mu_loc", 0.0)
    mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive)
    mu = numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))

    # Variational parameters for tau (positive support)
    tau_loc = numpyro.param("tau_loc", 1.0)
    tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive)
    tau = numpyro.sample("tau", dist.LogNormal(jnp.log(tau_loc), tau_scale))

    # Variational parameters for eta
    eta_loc = numpyro.param("eta_loc", jnp.zeros(J))
    eta_scale = numpyro.param("eta_scale", jnp.ones(J), constraint=dist.constraints.positive)
    with numpyro.plate("J", J):
        eta = numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))

        # Deterministic transform
        numpyro.deterministic("theta", mu + tau * eta)

Convert from MCMC#

This first example shows conversion from MCMC

# fit with MCMC
nuts = NUTS(eight_schools_model)
mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)

# sample the posterior predictive
predictive = Predictive(eight_schools_model, mcmc.get_samples())
samples_predictive = predictive(random.PRNGKey(1), J=J, sigma=sigma)

# Convert to MCMC
idata_mcmc = az.from_numpyro(mcmc, posterior_predictive=samples_predictive)
idata_mcmc

Convert from SVI with Autoguide#

eight_schools_guide = autoguide.AutoNormal(eight_schools_model, init_loc_fn=numpyro.infer.init_to_median(num_samples=100))
svi = SVI(
    eight_schools_model, 
    guide=eight_schools_guide,
    optim=numpyro.optim.Adam(0.01),
    loss = Trace_ELBO()
)
svi_result = svi.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)

# sample the posterior predictive
predictive_svi = Predictive(eight_schools_model, guide=eight_schools_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi = predictive_svi(random.PRNGKey(1), J=J, sigma=sigma)


idata_svi = az.from_numpyro_svi(
    svi,
    svi_result=svi_result,
    model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
    num_samples = 4000, # number of samples to draw in the posterior
    posterior_predictive=samples_predictive_svi
)
idata_svi

Converting from SVI with a custom guide function#

svi_custom_guide = SVI(
    eight_schools_model, 
    guide=eight_schools_custom_guide,
    optim=numpyro.optim.Adam(0.01),
    loss = Trace_ELBO()
)
svi_custom_guide_result = svi_custom_guide.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)

# sample the posterior predictive
predictive_svi_custom = Predictive(eight_schools_model, guide=eight_schools_custom_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi_custom = predictive_svi_custom(random.PRNGKey(1), J=J, sigma=sigma)

idata_svi_custom_guide = az.from_numpyro_svi(
    svi_custom_guide,
    svi_result=svi_custom_guide_result,
    model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
    num_samples = 4000, # number of samples to draw in the posterior
    posterior_predictive=samples_predictive_svi_custom
)
idata_svi_custom_guide

Automatically Labelling Event Dims#

NumPyro batch dims are automatically labelled according to their corresponding plate names. In order to label event dims, we add infer={"event_dims": dim_labels} to the numpyro.sample statement as shown below:

def eight_schools_model_zsn(J, sigma, y=None):
    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    eta = numpyro.sample(
        "eta", 
        dist.ZeroSumNormal(tau, event_shape=(J,)),
        # note: this allows arviz to infer the event dimension labels
        infer={"event_dims":["J"]}
    )
    with numpyro.plate("J", J):
        theta = numpyro.deterministic("theta", mu + eta)
        return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)


# fit with MCMC
nuts = NUTS(eight_schools_model_zsn)
mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc2.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)


# sample the posterior predictive
predictive2 = Predictive(eight_schools_model, mcmc2.get_samples())
samples_predictive2 = predictive2(random.PRNGKey(1), J=J, sigma=sigma)

# Convert to MCMC
idata_mcmc2 = az.from_numpyro(mcmc2, posterior_predictive=samples_predictive2)

Notice that eta is labelled appropriately with J

idata_mcmc2

Extending NumPyro Conversion to other Inference Objects#

NumPyroInferenceAdapter can be leveraged to extend ArviZ conversion to other NumPyro Inference Objects (such as the NestedSampler)

The example below uses the SVI implementation as an example, where an adapter class is created that inherits the NumPyroInferenceAdapter base class

class SVIAdapter(az.NumPyroInferenceAdapter):
    """Adapter for SVI to standardize attributes and methods with other inference objects."""

    def __init__(
        self,
        svi,
        *,
        svi_result,
        model_args=None,
        model_kwargs=None,
        num_samples: int = 1000,
    ):
        if svi is None:
            raise ValueError("svi parameter is required for SVIAdapter")
        if svi_result is None:
            raise ValueError("svi_result parameter is required for SVIAdapter")

        super().__init__(
            svi,
            model=getattr(svi.guide, "model", svi.model),
            model_args=model_args,
            model_kwargs=model_kwargs,
            sample_shape=(num_samples,),
        )
        self.result_obj = svi_result

    @property
    def sample_dims(self) -> list[str]:  # noqa: D102
        return ["sample"]

    # Implement getting posterior samples
    def get_samples(  # noqa: D102
        self, seed: int | None = None, group_by_chain: bool = False, **kwargs: dict
    ) -> dict[str, ArrayLike]:
        key = self.prng_key_func(seed or 0)
        if isinstance(self.posterior.guide, numpyro.infer.autoguide.AutoGuide):
            return self.posterior.guide.sample_posterior(
                key,
                self.result_obj.params,
                *self._args,
                sample_shape=self.sample_shape,
                **self._kwargs,
            )
        # if a custom guide is provided, sample by hand
        predictive = numpyro.infer.Predictive(
            self.posterior.guide, params=self.result_obj.params, num_samples=self.sample_shape[0]
        )
        samples = predictive(key, *self._args, **self._kwargs)
        return samples

The instantiated adapter can now be passed directly into az.from_numpyro.

adapter = SVIAdapter(
    svi, 
    svi_result=svi_result, 
    model_kwargs=dict(J=J, sigma=sigma, y=y_obs),
    num_samples = 4000
)

idata_svi2 = az.from_numpyro(adapter, posterior_predictive=samples_predictive_svi)
idata_svi2
%load_ext watermark
%watermark -n -u -v -iv -w