Source code for aeppl.joint_logprob

import warnings
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import aesara.tensor as at
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import GraphRewriter, NodeRewriter
from aesara.tensor.var import TensorVariable

from aeppl.abstract import ValuedVariable, get_measurable_outputs
from aeppl.logprob import _logprob
from aeppl.rewriting import construct_ir_fgraph, ir_cleanup_db

if TYPE_CHECKING:
    from aesara.graph.basic import Apply, Variable


class DensityNotFound(Exception):
    """An exception raised when a density cannot be found."""


[docs]def conditional_logprob( *random_variables: TensorVariable, realized: Dict[TensorVariable, TensorVariable] = {}, ir_rewriter: Optional[GraphRewriter] = None, extra_rewrites: Optional[Union[GraphRewriter, NodeRewriter]] = None, **kwargs, ) -> Tuple[Dict[TensorVariable, TensorVariable], Tuple[TensorVariable, ...]]: r"""Create a map between random variables and their conditional log-probabilities. Consider the following Aesara model: .. code-block:: python import aesara.tensor as at srng = at.random.RandomStream() sigma2_rv = srng.invgamma(0.5, 0.5) Y_rv = srng.normal(0, at.sqrt(sigma2_rv)) Which represents the following mathematical model: .. math:: \sigma^2 \sim& \operatorname{InvGamma}(0.5, 0.5) \\ Y \sim& \operatorname{N}\left(0, \sigma^2\right) We can generate the graph that computes the conditional log-density associated with each random variable with: .. code-block:: python import aeppl logprobs, (sigma2_vv, Y_vv) = aeppl.conditional_logprob(sigma2_rv, Y_rv) The list of random variables passed to `conditional_logprob` implicitly defines a joint density that factorized according to the graphical model represented by the Aesara model. Here, ``logprobs[sigma2_rv]`` corresponds to the conditional log-density :math:`\operatorname{P}\left(sigma^2=s \mid Y\right)` and ``logprobs[Y_rv]`` to :math:`\operatorname{P}\left(Y=y \mid \sigma^2\right)`. To build the log-density graphs, `conditional_logprob` must generate the value variable associated with each random variable. They are returned along with the graph in the same order as the random variables were passed to `conditional_logprob`. Here, the value variables ``sigma2_vv`` and ``Y_vv`` correspond to :math:`s` and :math:`y` in the previous expressions respectively. It is also possible to call `conditional_logprob` omitting some of the random variables in the graph: .. code-block:: python logprobs, (Y_vv,) = aeppl.conditional_logprob(Y_rv) In this case, ``logprobs[Y_rv]`` corresponds to the conditional log-density :math:`\operatorname{P}\left(Y=y \mid \sigma^2\right)` where :math:`\sigma^2` is a stochastic variable. Another important case is when one the variables is already realized. For instance, if `Y_rv` is observed we can include its realized value directly in the log-density graphs: .. code-block:: python y_obs = Y_rv.copy() logprobs, (sigma2_vv,) = aeppl.conditional_logprob(sigma2_rv, realized={Y_rv: y_obs}) In this case, `conditional_logprob` uses the value variable passed in the conditional log-density graphs it produces. Parameters ========== random_variables A ``list`` of random variables for which we need to return a conditional log-probability graph. realized A ``dict`` that maps realized random variables to their realized values. These values used in the generated conditional log-density graphs. ir_rewriter Rewriter that produces the intermediate representation of measurable variables. extra_rewrites Extra rewrites to be applied (e.g. reparameterizations, transforms, etc.) Returns ======= conditional_logprobs A ``dict`` that maps each random variable to the graph that computes their conditional log-density implicitly defined by the random variables passed as arguments. ``None`` if a log-density cannot be computed. value_variables A ``list`` of the created valued variables in the same order as the order in which their corresponding random variables were passed as arguments. Empty if `random_variables` is empty. """ deprecated_option = kwargs.pop("warn_missing_rvs", None) if deprecated_option: warnings.warn( "The `warn_missing_rvs` option is deprecated and has been removed.", DeprecationWarning, ) # Create value variables by cloning the input measurable variables original_rv_values = {} for rv in random_variables: vv = rv.clone() if rv.name: vv.name = f"{rv.name}_vv" original_rv_values[rv] = vv # Value variables are not cloned when constructing the conditional log-probability # graphs. We can thus use them to recover the original random variables to index the # maps to the logprob graphs and value variables before returning them. rv_values = {**original_rv_values, **realized} fgraph, new_rv_values = construct_ir_fgraph( rv_values, ir_rewriter=ir_rewriter, extra_rewrites=extra_rewrites ) # We assign log-densities on a per-node basis, and not per-output/variable. realized_vars = set() new_to_old_rvs = {} nodes_to_vals: Dict["Apply", List[Tuple["Variable", "Variable"]]] = {} for bnd_var, (old_mvar, val) in zip(fgraph.outputs, new_rv_values.items()): mnode = bnd_var.owner assert mnode and isinstance(mnode.op, ValuedVariable) rv_var, val_var = mnode.inputs rv_node = rv_var.owner if rv_node is None: raise DensityNotFound(f"Couldn't derive a log-probability for {rv_var}") if old_mvar in realized: realized_vars.add(rv_var) nodes_to_vals.setdefault(rv_node, []).append((val_var, val)) new_to_old_rvs[rv_var] = old_mvar value_vars: Tuple["Variable", ...] = () logprob_vars = {} for rv_node, rv_val_pairs in nodes_to_vals.items(): outputs = get_measurable_outputs(rv_node.op, rv_node) if not outputs: raise DensityNotFound(f"Couldn't derive a log-probability for {rv_node}") if len(outputs) < len(rv_val_pairs): raise ValueError( f"Too many values ({rv_val_pairs}) bound to node {rv_node}." ) assert len(outputs) == len(rv_val_pairs) rv_vals, rv_base_vals = zip(*rv_val_pairs) rv_logprobs = _logprob( rv_node.op, rv_vals, *rv_node.inputs, **kwargs, ) if not isinstance(rv_logprobs, (tuple, list)): rv_logprobs = (rv_logprobs,) for lp_out, rv_out, rv_base_val in zip(rv_logprobs, outputs, rv_base_vals): old_mvar = new_to_old_rvs[rv_out] if old_mvar.name: lp_out.name = f"{rv_out.name}_logprob" logprob_vars[old_mvar] = lp_out if rv_out not in realized_vars: value_vars += (rv_base_val,) # # Recompute test values for the changes introduced by the # # replacements above. # if config.compute_test_value != "off": # for node in io_toposort(graph_inputs([rv_logprobs]), outputs): # compute_test_value(node) # Remove unneeded IR elements from the graph rv_logprobs_fg = FunctionGraph(outputs=tuple(logprob_vars.values()), clone=False) ir_cleanup_db.query("+basic").rewrite(rv_logprobs_fg) return logprob_vars, value_vars
[docs]def joint_logprob( *random_variables: List[TensorVariable], realized: Dict[TensorVariable, TensorVariable] = {}, **kwargs, ) -> Optional[Tuple[TensorVariable, Tuple[TensorVariable, ...]]]: r"""Build the graph of the joint log-density of an Aesara graph. Consider the following Aesara model: .. code-block:: python import aesara.tensor as at srng = at.random.RandomStream() sigma2_rv = srng.invgamma(0.5, 0.5) Y_rv = srng.normal(0, at.sqrt(sigma2_rv)) Which represents the following mathematical model: .. math:: \sigma^2 \sim& \operatorname{InvGamma}(0.5, 0.5) \\ Y \sim& \operatorname{N}\left(0, \sigma^2\right) We can generate the graph that computes the joint log-density associated with this model: .. code-block:: python import aeppl logprob, (sigma2_vv, Y_vv) = aeppl.joint_logprob(sigma2_rv, Y_rv) To build the joint log-density graph, `joint_logprob` must generate the value variable associated with each random variable. They are returned along with the graph in the same order as the random variables were passed to `joint_logprob`. Here, the value variables ``sigma2_vv`` and ``Y_vv`` correspond to the values :math:`s` and :math:`y` taken by :math:`\sigma^2` and :math:`Y`, respectively. It is also possible to call `joint_logprob` omitting some of the random variables in the graph: .. code-block:: python logprob, (Y_vv,) = aeppl.joint_logprob(Y_rv) In this case, ``logprob`` corresponds to the joint log-density :math:`\operatorname{P}\left(Y, \sigma^2\right)` where :math:`\sigma^2` is a stochastic variable. Another important case is when one of the variables is already realized. For instance, if `Y_rv` is observed we can include its realized value directly in the log-density graphs: .. code-block:: python y_obs = Y_rv.copy() logprob, (sigma2_vv,) = aeppl.joint_logprob(sigma2_rv, realized={Y_rv: y_obs}) In this case, `joint_logprob` uses the value ``y_obs`` mapped to ``Y_rv`` in the conditional log-density graphs it produces, so that ``logprob`` corresponds to the density :math:`\operatorname{P}\left(\sigma^2 \mid Y=y\right)` when :math:`y` and :math:`Y` correspond to ``y_obs`` and ``Y_rv``, respectively. Parameters ========== random_variables A ``list`` of random variables for which we need to return a conditional log-probability graph. realized A ``dict`` that maps random variables to their realized value. Returns ======= logprob A ``TensorVariable`` that represents the joint log-probability of the graph implicitly defined by the random variables passed as arguments. ``None`` if a log-density cannot be computed. value_variables A ``list`` of the created valued variables in the same order as the order in which their corresponding random variables were passed as arguments. Empty if ``random_variables`` is empty. """ logprob, value_variables = conditional_logprob( *random_variables, realized=realized, **kwargs ) if not logprob: return None elif len(logprob) == 1: cond_logprob = tuple(logprob.values())[0] return at.sum(cond_logprob), value_variables else: joint_logprob: TensorVariable = at.sum( [at.sum(factor) for factor in logprob.values()] ) return joint_logprob, value_variables