arviz_base.from_numpyro

Contents

arviz_base.from_numpyro#

arviz_base.from_numpyro(posterior=None, *, prior=None, posterior_predictive=None, predictions=None, constant_data=None, predictions_constant_data=None, log_likelihood=False, index_origin=None, coords=None, dims=None, pred_dims=None, extra_event_dims=None, sample_dims=None, num_chains=None)[source]#

Convert NumPyro mcmc inference data into a DataTree object.

For a usage example read Converting NumPyro objects to DataTree

If no dims are provided, this will infer batch dim names from NumPyro model plates. For event dim names, such as with the ZeroSumNormal, infer={“event_dims”:dim_names} can be provided in numpyro.sample, i.e.:

# equivalent to dims entry, {"gamma": ["groups"]}
gamma = numpyro.sample(
    "gamma",
    dist.ZeroSumNormal(1, event_shape=(n_groups,)),
    infer={"event_dims":["groups"]}
)

There is also an additional extra_event_dims input to cover any edge cases, for instance deterministic sites with event dims (which dont have an infer argument to provide metadata).

Parameters:
posteriornumpyro.infer.MCMC | NumPyroInferenceAdapter

A fitted MCMC object from NumPyro, or an instance of a child class of NumPyroInferenceAdapter.

priordict, optional

Prior samples from a NumPyro model

posterior_predictivedict, optional

Posterior predictive samples for the posterior

predictionsdict, optional

Out of sample predictions

constant_datadict, optional

Dictionary containing constant data variables mapped to their values.

predictions_constant_datadict, optional

Constant data used for out-of-sample predictions.

log_likelihoodbool, default False

Whether to compute and include log likelihood in the output.

index_originint, optional
coordsdict, optional

Map of dimensions to coordinates

dimsdict of {strlist of str}, optional

Map variable names to their coordinates. Will be inferred if they are not provided.

pred_dimsdict, optional

Dims for predictions data. Map variable names to their coordinates. Default behavior is to infer dims if this is not provided

extra_event_dimsdict, optional

Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to their coordinates.

sample_dimslist of str, optional

Names of the sample dimensions (e.g., [“chain”, “draw”] for MCMC, [“sample”] for SVI). Must be provided if posterior is None. If posterior is provided, this argument is ignored and overwritten with posterior.sample_dims.

num_chainsint, optional

Number of chains used for sampling. Defaults to 1 for MCMC if not provided. Ignored if posterior is present.

Returns:
xarray.DataTree