Source code for arviz_base.io_numpyro

"""NumPyro-specific conversion code."""

import warnings
from abc import ABC, abstractmethod
from collections import defaultdict

import numpy as np
from numpy.typing import ArrayLike
from xarray import DataTree

from arviz_base.base import dict_to_dataset, requires
from arviz_base.rcparams import rc_context, rcParams
from arviz_base.utils import expand_dims

try:
    import jax
    import numpyro
except ImportError as e:
    raise ImportError(
        "The NumPyro I/O backend requires optional dependencies, jax and numpyro.\n\n"
    ) from e


class NumPyroInferenceAdapter(ABC):
    """Standardize methods across NumPyro inference objects for use with NumPyroConverter."""

    def __init__(self, inference_obj, model, model_args, model_kwargs, sample_shape):
        """Initialize the adapter with common attributes for NumPyro inference objects.

        This base class constructor sets up the shared infrastructure needed by all
        NumPyro inference adapters (MCMC, SVI, etc.) to provide a unified interface
        for the NumPyroConverter.

        Parameters
        ----------
        inference_obj : Any
            The NumPyro inference object to adapt (e.g., MCMC, SVI, or other inference types).
        model : callable
            The NumPyro model function that was used for inference.
        model_args : tuple, optional
            Positional arguments passed to the model during inference.
            If None, defaults to an empty tuple.
        model_kwargs : dict, optional
            Keyword arguments passed to the model during inference.
            If None, defaults to an empty dict.
        sample_shape : tuple of int
            Shape of the samples to be returned by get_samples().
            For MCMC: (num_chains, num_draws)
            For SVI: (num_samples,)
        """
        self.posterior = inference_obj
        self.model = model
        self._args = model_args or tuple()
        self._kwargs = model_kwargs or dict()
        self.sample_shape = sample_shape

        self.prng_key_func = jax.random.PRNGKey

    @property
    @abstractmethod
    def sample_dims(self):
        """Return the sample dimension names.

        Returns
        -------
        list of str
            Sample dimension names (e.g., ["chain", "draw"] for MCMC, ["sample"] for SVI).
        """
        raise NotImplementedError

    @abstractmethod
    def get_samples(self, seed=None, group_by_chain=False, **kwargs):
        """Get posterior samples from the inference object.

        Parameters
        ----------
        seed : int, optional
            Random seed for sampling. Not all inference types use this parameter.
            For MCMC, this parameter is ignored as samples are already drawn.
            For SVI, this controls random number generation.
        group_by_chain : bool, default False
            Whether to group samples by chain dimension. For MCMC, this separates
            samples into chains. For SVI this parameter is ignored.
        **kwargs : dict
            Additional keyword arguments passed to the underlying inference object's
            sampling method.

        Returns
        -------
        dict of {str: array-like}
            Dictionary mapping parameter names to their sampled values.
            For MCMC with group_by_chain=True: arrays have shape (num_chains, num_draws, ...).
            For MCMC with group_by_chain=False: arrays have shape (num_chains * num_draws, ...).
            For SVI: arrays have shape (num_samples, ...).
        """
        raise NotImplementedError

    def get_extra_fields(self, **kwargs):
        """Get extra fields from the inference object (e.g., divergences for MCMC).

        Returns
        -------
        dict of {str: array-like}
            Dictionary of extra diagnostic fields. Empty dict by default.
        """
        return dict()


class SVIAdapter(NumPyroInferenceAdapter):
    """Adapter for SVI to standardize attributes and methods with other inference objects."""

    def __init__(
        self,
        svi,
        *,
        svi_result,
        model_args=None,
        model_kwargs=None,
        num_samples: int = 1000,
    ):
        """Initialize SVI adapter for variational inference results.

        Parameters
        ----------
        svi : numpyro.infer.SVI
            Fitted SVI object.
        svi_result : numpyro.infer.svi.SVIRunResult
            SVI optimization results containing learned parameters.
        model_args : tuple, optional
            Positional arguments for the model.
        model_kwargs : dict, optional
            Keyword arguments for the model.
        num_samples : int, default 1000
            Number of posterior samples to generate from the guide.
        """
        if svi is None:
            raise ValueError("svi parameter is required for SVIAdapter")
        if svi_result is None:
            raise ValueError("svi_result parameter is required for SVIAdapter")

        super().__init__(
            svi,
            model=getattr(svi.guide, "model", svi.model),
            model_args=model_args,
            model_kwargs=model_kwargs,
            sample_shape=(num_samples,),
        )
        self.result_obj = svi_result

    @property
    def sample_dims(self) -> list[str]:  # noqa: D102
        return ["sample"]

    def get_samples(  # noqa: D102
        self, seed: int | None = None, group_by_chain: bool = False, **kwargs: dict
    ) -> dict[str, ArrayLike]:
        key = self.prng_key_func(seed or 0)
        if isinstance(self.posterior.guide, numpyro.infer.autoguide.AutoGuide):
            return self.posterior.guide.sample_posterior(
                key,
                self.result_obj.params,
                *self._args,
                sample_shape=self.sample_shape,
                **self._kwargs,
            )
        # if a custom guide is provided, sample by hand
        predictive = numpyro.infer.Predictive(
            self.posterior.guide, params=self.result_obj.params, num_samples=self.sample_shape[0]
        )
        samples = predictive(key, *self._args, **self._kwargs)
        return samples


class MCMCAdapter(NumPyroInferenceAdapter):
    """Adapter for MCMC to standardize attributes and methods with other inference objects."""

    def __init__(self, mcmc):
        """Initialize MCMC adapter from fitted MCMC object.

        Parameters
        ----------
        mcmc : numpyro.infer.MCMC
            Fitted MCMC object with completed sampling.
        """
        self.nchains = mcmc.num_chains
        self.ndraws = mcmc.num_samples // mcmc.thinning
        super().__init__(
            mcmc,
            model=mcmc.sampler.model,
            model_args=mcmc._args,
            model_kwargs=mcmc._kwargs,
            sample_shape=(self.nchains, self.ndraws),
        )

    @property
    def sample_dims(self) -> list[str]:  # noqa: D102
        return ["chain", "draw"]

    def get_samples(  # noqa: D102
        self, seed: int | None = None, group_by_chain: bool = True, **kwargs: dict
    ) -> dict[str, ArrayLike]:
        return self.posterior.get_samples(group_by_chain=group_by_chain, **kwargs)

    def get_extra_fields(self, **kwargs) -> dict[str, ArrayLike]:  # noqa: D102
        return self.posterior.get_extra_fields(group_by_chain=True, **kwargs)


def _add_dims(dims_a, dims_b):
    """Merge two dimension mappings by concatenating dimension labels.

    Used to combine batch dims with event dims by appending the dims of dims_b to dims_a.

    Parameters
    ----------
    dims_a : dict of {str: list of str(s)}
        Mapping from site name to a list of dimension labels, typically
        representing batch dimensions.
    dims_b : dict of {str: list of str(s)}
        Mapping from site name to a list of dimension labels, typically
        representing event dimensions.

    Returns
    -------
    dict of {str: list of str(s)}
        Combined mapping where each site name is associated with the
        concatenated dimension labels from both inputs.
    """
    merged = defaultdict(list, dims_a)
    for k, v in dims_b.items():
        merged[k].extend(v)

    # Convert back to a regular dict
    return dict(merged)


def infer_dims(
    model,
    model_args=None,
    model_kwargs=None,
):
    """Infers batch dim names from numpyro model plates.

    Parameters
    ----------
    model : callable
        A numpyro model function.
    model_args : tuple of (Any, ...), optional
        Input args for the numpyro model.
    model_kwargs : dict of {str: Any}, optional
        Input kwargs for the numpyro model.

    Returns
    -------
    dict of {str: list of str(s)}
        Mapping from model site name to list of dimension labels.
    """
    dist = numpyro.distributions
    handlers = numpyro.handlers
    init_to_sample = numpyro.infer.initialization.init_to_sample
    PytreeTrace = numpyro.ops.pytree.PytreeTrace

    model_args = tuple() if model_args is None else model_args
    model_kwargs = dict() if model_kwargs is None else model_kwargs

    def _get_dist_name(fn):
        if isinstance(fn, dist.Independent | dist.ExpandedDistribution | dist.MaskedDistribution):
            return _get_dist_name(fn.base_dist)
        return type(fn).__name__

    def get_trace():
        # We use `init_to_sample` to get around ImproperUniform distribution,
        # which does not have `sample` method.
        subs_model = handlers.substitute(
            handlers.seed(model, 0),
            substitute_fn=init_to_sample,
        )
        trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs)
        # Work around an issue where jax.eval_shape does not work
        # for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
        # Here we will remove `fn` and store its name in the trace.
        for _, site in trace.items():
            if site["type"] == "sample":
                site["fn_name"] = _get_dist_name(site.pop("fn"))
            elif site["type"] == "deterministic":
                site["fn_name"] = "Deterministic"
        return PytreeTrace(trace)

    # We use eval_shape to avoid any array computation.
    trace = jax.eval_shape(get_trace).trace

    named_dims = {}

    # loop through the trace and pull the batch dim and event dim names
    for name, site in trace.items():
        batch_dims = [frame.name for frame in sorted(site["cond_indep_stack"], key=lambda x: x.dim)]
        event_dims = list(site.get("infer", {}).get("event_dims", []))

        # save the dim names leading with batch dims
        if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims):
            named_dims[name] = batch_dims + event_dims

    return named_dims


class NumPyroConverter:
    """Encapsulate NumPyro specific logic."""

    # pylint: disable=too-many-instance-attributes

    model = None

    def __init__(
        self,
        *,
        posterior=None,
        prior=None,
        posterior_predictive=None,
        predictions=None,
        constant_data=None,
        predictions_constant_data=None,
        log_likelihood=False,
        index_origin=None,
        coords=None,
        dims=None,
        pred_dims=None,
        extra_event_dims=None,
        num_chains=None,
    ):
        """Convert NumPyro data into an InferenceData object.

        Parameters
        ----------
        posterior : NumPyroInferenceAdapter
            A NumPyroInferenceAdapter child class
        prior : dict, optional
            Prior samples from a NumPyro model
        posterior_predictive : dict, optional
            Posterior predictive samples for the posterior
        predictions : dict, optional
            Out of sample predictions
        constant_data : dict, optional
            Dictionary containing constant data variables mapped to their values.
        predictions_constant_data : dict, optional
            Constant data used for out-of-sample predictions.
        log_likelihood : bool, default False
            Whether to compute and include log likelihood in the output.
        index_origin : int, optional
        coords : dict, optional
            Map of dimensions to coordinates
        dims : dict of {str : list of str}, optional
            Map variable names to their coordinates. Will be inferred if they are not provided.
        pred_dims : dict, optional
            Dims for predictions data. Map variable names to their coordinates.
        extra_event_dims : dict, optional
            Maps event dims that couldnt be inferred (ie deterministic sites) to their coordinates.
        num_chains : int, optional
            Number of chains used for sampling MCMC. Ignored if posterior is present, or if
            inference method is not MCMC.
        """
        self.posterior = posterior
        self.prior = jax.device_get(prior)
        self.posterior_predictive = jax.device_get(posterior_predictive)
        self.predictions = predictions
        self.constant_data = constant_data
        self.predictions_constant_data = predictions_constant_data
        self.log_likelihood = log_likelihood
        self.index_origin = rcParams["data.index_origin"] if index_origin is None else index_origin
        self.coords = coords
        self.dims = dims
        self.pred_dims = pred_dims
        self.extra_event_dims = extra_event_dims

        # use nchains to help infer shape when posterior isnt present for MCMC
        self.nchains = (num_chains or 1) if rcParams["data.sample_dims"][0] == "chain" else None

        if posterior is not None:
            samples = jax.device_get(self.posterior.get_samples())
            if hasattr(samples, "_asdict"):
                # In case it is easy to convert to a dictionary, as in the case of namedtuples
                samples = {k: expand_dims(v) for k, v in samples._asdict().items()}
            if not isinstance(samples, dict):
                # handle the case we run MCMC with a general potential_fn
                # (instead of a NumPyro model) whose args is not a dictionary
                # (e.g. f(x) = x ** 2)
                tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0]
                samples = {
                    f"Param:{i}": jax.device_get(v) for i, v in enumerate(tree_flatten_samples)
                }
            self._samples = samples
            self.model = self.posterior.model
            self.sample_shape = self.posterior.sample_shape

            # model arguments and keyword arguments
            self._args = self.posterior._args  # pylint: disable=protected-access
            self._kwargs = self.posterior._kwargs  # pylint: disable=protected-access
            self.dims = self.dims if self.dims is not None else self.infer_dims()
            self.pred_dims = (
                self.pred_dims if self.pred_dims is not None else self.infer_pred_dims()
            )
        else:
            self.sample_shape = self._infer_sample_shape()

        observations = {}
        if self.model is not None:
            trace = self._get_model_trace(
                self.model,
                model_args=self._args,
                model_kwargs=self._kwargs,
                key=jax.random.PRNGKey(0),
            )
            observations = {
                name: site["value"]
                for name, site in trace.items()
                if site["type"] == "sample" and site["is_observed"]
            }
        self.observations = observations if observations else None

    def _get_model_trace(self, model, model_args, model_kwargs, key):
        """Extract the numpyro model trace."""
        model_args = model_args or tuple()
        model_kwargs = model_kwargs or dict()

        # we need to use an init strategy to generate random samples for ImproperUniform sites
        seeded_model = numpyro.handlers.substitute(
            numpyro.handlers.seed(model, key),
            substitute_fn=numpyro.infer.init_to_sample,
        )
        trace = numpyro.handlers.trace(seeded_model).get_trace(*model_args, **model_kwargs)
        return trace

    def _infer_sample_shape(self):
        # try to use these sources to infer sample shape
        sources = [
            self.predictions,
            self.posterior_predictive,
            self.prior,
        ]
        # pick first available source
        get_from = next((src for src in sources if src is not None), None)
        no_constant_data = self.constant_data is None and self.predictions_constant_data is None
        if get_from is not None:
            aelem = next(iter(get_from.values()))  # pick an arbitrary element

            # For MCMC from numpyro, we need to reshape the sample shape
            # based on the number of chains provided
            if self.nchains is not None:
                ndraws, remainder = divmod(aelem.shape[0], self.nchains)
                if remainder != 0:
                    raise ValueError(
                        f"Sample Shape in shape provided {aelem.shape} is "
                        "not divisible by the number of chains {self.nchains}."
                    )
                return (self.nchains, ndraws)
            else:
                return aelem.shape[: len(rcParams["data.sample_dims"])]
        elif no_constant_data:
            raise ValueError(
                "When constructing InferenceData, must have at least one of "
                "posterior, prior, posterior_predictive, or predictions."
            )
        else:
            # fallback shape when theres no inference, but there is constant data
            return (1,) * len(rcParams["data.sample_dims"])

    @requires("posterior")
    def posterior_to_xarray(self):
        """Convert the posterior to an xarray dataset."""
        data = self._samples
        return dict_to_dataset(
            data,
            inference_library=numpyro,
            coords=self.coords,
            dims=self.dims,
            index_origin=self.index_origin,
        )

    @requires("posterior")
    def sample_stats_to_xarray(self):
        """Extract sample_stats from NumPyro posterior."""
        rename_key = {
            "potential_energy": "lp",
            "adapt_state.step_size": "step_size",
            "num_steps": "n_steps",
            "accept_prob": "acceptance_rate",
        }
        data = {}
        for stat, value in self.posterior.get_extra_fields().items():
            if isinstance(value, dict | tuple):
                continue
            name = rename_key.get(stat, stat)
            value_cp = value.copy()
            data[name] = value_cp
            if stat == "num_steps":
                data["tree_depth"] = np.log2(value_cp).astype(int) + 1

        return dict_to_dataset(
            data,
            inference_library=numpyro,
            dims=None,
            coords=self.coords,
            index_origin=self.index_origin,
        )

    @requires("posterior")
    @requires("model")
    def log_likelihood_to_xarray(self):
        """Extract log likelihood from NumPyro posterior."""
        if not self.log_likelihood:
            return None
        data = {}
        if self.observations is not None:
            samples = self.posterior.get_samples(group_by_chain=False)
            if hasattr(samples, "_asdict"):
                samples = samples._asdict()
            log_likelihood_dict = numpyro.infer.log_likelihood(
                self.model, samples, *self._args, **self._kwargs
            )
            for obs_name, log_like in log_likelihood_dict.items():
                shape = self.sample_shape + log_like.shape[1:]
                data[obs_name] = np.reshape(np.asarray(log_like), shape)
        return dict_to_dataset(
            data,
            inference_library=numpyro,
            dims=self.dims,
            coords=self.coords,
            index_origin=self.index_origin,
            skip_event_dims=True,
        )

    def translate_posterior_predictive_dict_to_xarray(self, dct, dims):
        """Convert posterior_predictive or prediction samples to xarray."""
        data = {}
        for k, ary in dct.items():
            shape = ary.shape
            if (shape[: len(self.sample_shape)] == self.sample_shape) or shape[0] == np.prod(
                self.sample_shape
            ):
                data[k] = ary.reshape(self.sample_shape + shape[1:])
            else:
                data[k] = expand_dims(ary)
                warnings.warn(
                    "posterior predictive shape not compatible with sample shape. "
                    "This can mean that some sample dims are not represented."
                )
        return dict_to_dataset(
            data,
            inference_library=numpyro,
            coords=self.coords,
            dims=dims,
            index_origin=self.index_origin,
        )

    @requires("posterior_predictive")
    def posterior_predictive_to_xarray(self):
        """Convert posterior_predictive samples to xarray."""
        return self.translate_posterior_predictive_dict_to_xarray(
            self.posterior_predictive, self.dims
        )

    @requires("predictions")
    def predictions_to_xarray(self):
        """Convert predictions to xarray."""
        return self.translate_posterior_predictive_dict_to_xarray(self.predictions, self.pred_dims)

    def priors_to_xarray(self):
        """Convert prior samples (and if possible prior predictive too) to xarray."""
        if self.prior is None:
            return {"prior": None, "prior_predictive": None}
        if self.posterior is not None:
            prior_vars = list(self._samples.keys())
            prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars]
        else:
            prior_vars = self.prior.keys()
            prior_predictive_vars = None

        # dont expand dims for SVI
        expand_dims_func = expand_dims if len(rcParams["data.sample_dims"]) > 1 else lambda x: x
        priors_dict = {
            group: (
                None
                if var_names is None
                else dict_to_dataset(
                    {k: expand_dims_func(self.prior[k]) for k in var_names},
                    inference_library=numpyro,
                    coords=self.coords,
                    dims=self.dims,
                    index_origin=self.index_origin,
                )
            )
            for group, var_names in zip(
                ("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
            )
        }
        return priors_dict

    @requires("observations")
    @requires("model")
    def observed_data_to_xarray(self):
        """Convert observed data to xarray."""
        return dict_to_dataset(
            self.observations,
            inference_library=numpyro,
            dims=self.dims,
            coords=self.coords,
            sample_dims=[],
            index_origin=self.index_origin,
        )

    @requires("constant_data")
    def constant_data_to_xarray(self):
        """Convert constant_data to xarray."""
        return dict_to_dataset(
            self.constant_data,
            inference_library=numpyro,
            dims=self.dims,
            coords=self.coords,
            sample_dims=[],
            index_origin=self.index_origin,
        )

    @requires("predictions_constant_data")
    def predictions_constant_data_to_xarray(self):
        """Convert predictions_constant_data to xarray."""
        return dict_to_dataset(
            self.predictions_constant_data,
            inference_library=numpyro,
            dims=self.pred_dims,
            coords=self.coords,
            sample_dims=[],
            index_origin=self.index_origin,
        )

    def to_datatree(self):
        """Convert all available data to an InferenceData object.

        Note that if groups can not be created (i.e., there is no `trace`, so
        the `posterior` and `sample_stats` can not be extracted), then the InferenceData
        will not have those groups.
        """
        dicto = {
            "posterior": self.posterior_to_xarray(),
            "sample_stats": self.sample_stats_to_xarray(),
            "log_likelihood": self.log_likelihood_to_xarray(),
            "posterior_predictive": self.posterior_predictive_to_xarray(),
            "predictions": self.predictions_to_xarray(),
            **self.priors_to_xarray(),
            "observed_data": self.observed_data_to_xarray(),
            "constant_data": self.constant_data_to_xarray(),
            "predictions_constant_data": self.predictions_constant_data_to_xarray(),
        }

        return DataTree.from_dict({group: ds for group, ds in dicto.items() if ds is not None})

    @requires("posterior")
    @requires("model")
    def infer_dims(self) -> dict[str, list[str]]:
        """Infers dims for input data."""
        dims = infer_dims(self.model, self._args, self._kwargs)
        if self.extra_event_dims:
            dims = _add_dims(dims, self.extra_event_dims)
        return dims

    @requires("posterior")
    @requires("model")
    @requires("predictions")
    def infer_pred_dims(self) -> dict[str, list[str]]:
        """Infers dims for predictions data."""
        dims = infer_dims(self.model, self._args, self._kwargs)
        if self.extra_event_dims:
            dims = _add_dims(dims, self.extra_event_dims)
        return dims


[docs] def from_numpyro( posterior=None, *, prior=None, posterior_predictive=None, predictions=None, constant_data=None, predictions_constant_data=None, log_likelihood=False, index_origin=None, coords=None, dims=None, pred_dims=None, extra_event_dims=None, sample_dims=None, num_chains=None, ): """Convert NumPyro mcmc inference data into a DataTree object. For a usage example read :ref:`numpyro_conversion` If no dims are provided, this will infer batch dim names from NumPyro model plates. For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}` can be provided in numpyro.sample, i.e.:: # equivalent to dims entry, {"gamma": ["groups"]} gamma = numpyro.sample( "gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,)), infer={"event_dims":["groups"]} ) There is also an additional `extra_event_dims` input to cover any edge cases, for instance deterministic sites with event dims (which dont have an `infer` argument to provide metadata). Parameters ---------- posterior : numpyro.infer.MCMC | NumPyroInferenceAdapter A fitted MCMC object from NumPyro, or an instance of a child class of NumPyroInferenceAdapter. prior : dict, optional Prior samples from a NumPyro model posterior_predictive : dict, optional Posterior predictive samples for the posterior predictions : dict, optional Out of sample predictions constant_data : dict, optional Dictionary containing constant data variables mapped to their values. predictions_constant_data : dict, optional Constant data used for out-of-sample predictions. log_likelihood : bool, default False Whether to compute and include log likelihood in the output. index_origin : int, optional coords : dict, optional Map of dimensions to coordinates dims : dict of {str : list of str}, optional Map variable names to their coordinates. Will be inferred if they are not provided. pred_dims : dict, optional Dims for predictions data. Map variable names to their coordinates. Default behavior is to infer dims if this is not provided extra_event_dims : dict, optional Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to their coordinates. sample_dims : list of str, optional Names of the sample dimensions (e.g., ["chain", "draw"] for MCMC, ["sample"] for SVI). Must be provided if `posterior` is None. If `posterior` is provided, this argument is ignored and overwritten with `posterior.sample_dims`. num_chains : int, optional Number of chains used for sampling. Defaults to 1 for MCMC if not provided. Ignored if posterior is present. Returns ------- DataTree """ if posterior is None: if sample_dims is None: raise ValueError( "sample_dims must be provided if posterior is None. " "For MCMC use ['chain', 'draw'], for SVI use ['sample']." ) elif isinstance(posterior, numpyro.infer.MCMC): sample_dims = ["chain", "draw"] posterior = MCMCAdapter(posterior) else: sample_dims = posterior.sample_dims with rc_context(rc={"data.sample_dims": sample_dims}): return NumPyroConverter( posterior=posterior, prior=prior, posterior_predictive=posterior_predictive, predictions=predictions, constant_data=constant_data, predictions_constant_data=predictions_constant_data, log_likelihood=log_likelihood, index_origin=index_origin, coords=coords, dims=dims, pred_dims=pred_dims, extra_event_dims=extra_event_dims, num_chains=num_chains, ).to_datatree()
def from_numpyro_svi( svi=None, *, svi_result=None, model_args=None, model_kwargs=None, prior=None, posterior_predictive=None, predictions=None, constant_data=None, predictions_constant_data=None, log_likelihood=False, index_origin=None, coords=None, dims=None, pred_dims=None, extra_event_dims=None, num_samples=1000, ): """Convert NumPyro SVI results into a DataTree object. For a usage example read :ref:`numpyro_conversion` If no dims are provided, this will infer batch dim names from NumPyro model plates. For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}` can be provided in numpyro.sample, i.e.:: # equivalent to dims entry, {"gamma": ["groups"]} gamma = numpyro.sample( "gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,)), infer={"event_dims":["groups"]} ) There is also an additional `extra_event_dims` input to cover any edge cases, for instance deterministic sites with event dims (which dont have an `infer` argument to provide metadata). Parameters ---------- svi : numpyro.infer.SVI, optional Numpyro SVI instance used for fitting the model. If not provided, no posterior will be included in the output, and at least one of prior, posterior_predictive, or predictions must be provided. svi_result : numpyro.infer.svi.SVIRunResult, optional SVI results from a fitted model. Required if SVI is provided. model_args : tuple, optional Model arguments, should match those used for fitting the model. model_kwargs : dict, optional Model keyword arguments, should match those used for fitting the model. prior : dict, optional Prior samples from a NumPyro model posterior_predictive : dict, optional Posterior predictive samples for the posterior predictions : dict, optional Out of sample predictions constant_data : dict, optional Dictionary containing constant data variables mapped to their values. predictions_constant_data : dict, optional Constant data used for out-of-sample predictions. log_likelihood : bool, default False Whether to compute and include log likelihood in the output. index_origin : int, optional coords : dict, optional Map of dimensions to coordinates dims : dict of {str : list of str}, optional Map variable names to their coordinates. Will be inferred if they are not provided. pred_dims : dict, optional Dims for predictions data. Map variable names to their coordinates. Default behavior is to infer dims if this is not provided extra_event_dims : dict, optional Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to their coordinates. num_samples : int, default 1000 Number of posterior samples to generate Returns ------- DataTree """ with rc_context(rc={"data.sample_dims": ["sample"]}): posterior = ( SVIAdapter( svi, svi_result=svi_result, model_args=model_args, model_kwargs=model_kwargs, num_samples=num_samples, ) if svi is not None else None ) return NumPyroConverter( posterior=posterior, prior=prior, posterior_predictive=posterior_predictive, predictions=predictions, constant_data=constant_data, predictions_constant_data=predictions_constant_data, log_likelihood=log_likelihood, index_origin=index_origin, coords=coords, dims=dims, pred_dims=pred_dims, extra_event_dims=extra_event_dims, ).to_datatree()