# Source code for aeppl.joint_logprob

import warnings
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import aesara.tensor as at
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import GraphRewriter, NodeRewriter
from aesara.tensor.var import TensorVariable

from aeppl.abstract import ValuedVariable, get_measurable_outputs
from aeppl.logprob import _logprob
from aeppl.rewriting import construct_ir_fgraph, ir_cleanup_db

if TYPE_CHECKING:
from aesara.graph.basic import Apply, Variable

class DensityNotFound(Exception):
"""An exception raised when a density cannot be found."""

[docs]def conditional_logprob(
*random_variables: TensorVariable,
realized: Dict[TensorVariable, TensorVariable] = {},
ir_rewriter: Optional[GraphRewriter] = None,
extra_rewrites: Optional[Union[GraphRewriter, NodeRewriter]] = None,
**kwargs,
) -> Tuple[Dict[TensorVariable, TensorVariable], Tuple[TensorVariable, ...]]:
r"""Create a map between random variables and their conditional log-probabilities.

Consider the following Aesara model:

.. code-block:: python

import aesara.tensor as at

srng = at.random.RandomStream()

sigma2_rv = srng.invgamma(0.5, 0.5)
Y_rv = srng.normal(0, at.sqrt(sigma2_rv))

Which represents the following mathematical model:

.. math::

\sigma^2 \sim& \operatorname{InvGamma}(0.5, 0.5) \\
Y \sim& \operatorname{N}\left(0, \sigma^2\right)

We can generate the graph that computes the conditional log-density associated
with each random variable with:

.. code-block:: python

import aeppl

logprobs, (sigma2_vv, Y_vv) = aeppl.conditional_logprob(sigma2_rv, Y_rv)

The list of random variables passed to conditional_logprob implicitly
defines a joint density that factorized according to the graphical model
represented by the Aesara model. Here, logprobs[sigma2_rv] corresponds
to the conditional log-density :math:\operatorname{P}\left(sigma^2=s \mid Y\right) and
logprobs[Y_rv] to :math:\operatorname{P}\left(Y=y \mid \sigma^2\right).

To build the log-density graphs, conditional_logprob must generate the
value variable associated with each random variable. They are returned along
with the graph in the same order as the random variables were passed to
conditional_logprob. Here, the value variables sigma2_vv and Y_vv
correspond to :math:s and :math:y in the previous expressions
respectively.

It is also possible to call conditional_logprob omitting some of the
random variables in the graph:

.. code-block:: python

logprobs, (Y_vv,) = aeppl.conditional_logprob(Y_rv)

In this case, logprobs[Y_rv] corresponds to the conditional log-density
:math:\operatorname{P}\left(Y=y \mid \sigma^2\right) where
:math:\sigma^2 is a stochastic variable.

Another important case is when one the variables is already realized. For
instance, if Y_rv is observed we can include its realized value directly
in the log-density graphs:

.. code-block:: python

y_obs = Y_rv.copy()
logprobs, (sigma2_vv,) = aeppl.conditional_logprob(sigma2_rv, realized={Y_rv: y_obs})

In this case, conditional_logprob uses the value variable passed in the
conditional log-density graphs it produces.

Parameters
==========
random_variables
A list of  random variables for which we need to return a
conditional log-probability graph.
realized
A dict that maps realized random variables to their realized
values. These values used in the generated conditional
log-density graphs.
ir_rewriter
Rewriter that produces the intermediate representation of measurable
variables.
extra_rewrites
Extra rewrites to be applied (e.g. reparameterizations, transforms,
etc.)

Returns
=======
conditional_logprobs
A dict that maps each random variable to the graph that computes their
conditional log-density implicitly defined by the random variables passed
as arguments. None if a log-density cannot be computed.
value_variables
A list of the created valued variables in the same order as the
order in which their corresponding random variables were passed as
arguments. Empty if random_variables is empty.

"""

deprecated_option = kwargs.pop("warn_missing_rvs", None)

if deprecated_option:
warnings.warn(
"The warn_missing_rvs option is deprecated and has been removed.",
DeprecationWarning,
)

# Create value variables by cloning the input measurable variables
original_rv_values = {}
for rv in random_variables:
vv = rv.clone()
if rv.name:
vv.name = f"{rv.name}_vv"
original_rv_values[rv] = vv

# Value variables are not cloned when constructing the conditional log-probability
# graphs. We can thus use them to recover the original random variables to index the
# maps to the logprob graphs and value variables before returning them.
rv_values = {**original_rv_values, **realized}

fgraph, new_rv_values = construct_ir_fgraph(
rv_values, ir_rewriter=ir_rewriter, extra_rewrites=extra_rewrites
)

# We assign log-densities on a per-node basis, and not per-output/variable.
realized_vars = set()
new_to_old_rvs = {}
nodes_to_vals: Dict["Apply", List[Tuple["Variable", "Variable"]]] = {}

for bnd_var, (old_mvar, val) in zip(fgraph.outputs, new_rv_values.items()):
mnode = bnd_var.owner
assert mnode and isinstance(mnode.op, ValuedVariable)

rv_var, val_var = mnode.inputs
rv_node = rv_var.owner

if rv_node is None:
raise DensityNotFound(f"Couldn't derive a log-probability for {rv_var}")

if old_mvar in realized:

nodes_to_vals.setdefault(rv_node, []).append((val_var, val))

new_to_old_rvs[rv_var] = old_mvar

value_vars: Tuple["Variable", ...] = ()
logprob_vars = {}

for rv_node, rv_val_pairs in nodes_to_vals.items():
outputs = get_measurable_outputs(rv_node.op, rv_node)

if not outputs:
raise DensityNotFound(f"Couldn't derive a log-probability for {rv_node}")

if len(outputs) < len(rv_val_pairs):
raise ValueError(
f"Too many values ({rv_val_pairs}) bound to node {rv_node}."
)

assert len(outputs) == len(rv_val_pairs)

rv_vals, rv_base_vals = zip(*rv_val_pairs)

rv_logprobs = _logprob(
rv_node.op,
rv_vals,
*rv_node.inputs,
**kwargs,
)

if not isinstance(rv_logprobs, (tuple, list)):
rv_logprobs = (rv_logprobs,)

for lp_out, rv_out, rv_base_val in zip(rv_logprobs, outputs, rv_base_vals):
old_mvar = new_to_old_rvs[rv_out]

if old_mvar.name:
lp_out.name = f"{rv_out.name}_logprob"

logprob_vars[old_mvar] = lp_out

if rv_out not in realized_vars:
value_vars += (rv_base_val,)

# # Recompute test values for the changes introduced by the
# # replacements above.
# if config.compute_test_value != "off":
#     for node in io_toposort(graph_inputs([rv_logprobs]), outputs):
#         compute_test_value(node)

# Remove unneeded IR elements from the graph
rv_logprobs_fg = FunctionGraph(outputs=tuple(logprob_vars.values()), clone=False)
ir_cleanup_db.query("+basic").rewrite(rv_logprobs_fg)

return logprob_vars, value_vars

[docs]def joint_logprob(
*random_variables: List[TensorVariable],
realized: Dict[TensorVariable, TensorVariable] = {},
**kwargs,
) -> Optional[Tuple[TensorVariable, Tuple[TensorVariable, ...]]]:
r"""Build the graph of the joint log-density of an Aesara graph.

Consider the following Aesara model:

.. code-block:: python

import aesara.tensor as at

srng = at.random.RandomStream()
sigma2_rv = srng.invgamma(0.5, 0.5)
Y_rv = srng.normal(0, at.sqrt(sigma2_rv))

Which represents the following mathematical model:

.. math::

\sigma^2 \sim& \operatorname{InvGamma}(0.5, 0.5) \\
Y \sim& \operatorname{N}\left(0, \sigma^2\right)

We can generate the graph that computes the joint log-density associated
with this model:

.. code-block:: python

import aeppl

logprob, (sigma2_vv, Y_vv) = aeppl.joint_logprob(sigma2_rv, Y_rv)

To build the joint log-density graph, joint_logprob must generate the
value variable associated with each random variable. They are returned along
with the graph in the same order as the random variables were passed to
joint_logprob. Here, the value variables sigma2_vv and Y_vv
correspond to the values :math:s and :math:y taken by :math:\sigma^2
and :math:Y, respectively.

It is also possible to call joint_logprob omitting some of the random
variables in the graph:

.. code-block:: python

logprob, (Y_vv,) = aeppl.joint_logprob(Y_rv)

In this case, logprob corresponds to the joint log-density
:math:\operatorname{P}\left(Y, \sigma^2\right) where :math:\sigma^2 is
a stochastic variable.

Another important case is when one of the variables is already realized. For
instance, if Y_rv is observed we can include its realized value directly
in the log-density graphs:

.. code-block:: python

y_obs = Y_rv.copy()
logprob, (sigma2_vv,) = aeppl.joint_logprob(sigma2_rv, realized={Y_rv: y_obs})

In this case, joint_logprob uses the value y_obs mapped to Y_rv in the
conditional log-density graphs it produces, so that logprob corresponds
to the density :math:\operatorname{P}\left(\sigma^2 \mid Y=y\right) when
:math:y and :math:Y correspond to y_obs and Y_rv, respectively.

Parameters
==========
random_variables
A list of  random variables for which we need to return a
conditional log-probability graph.
realized
A dict that maps  random variables to their realized value.

Returns
=======
logprob
A TensorVariable that represents the joint log-probability of the graph
implicitly defined by the random variables passed as arguments. None if
a log-density cannot be computed.
value_variables
A list of the created valued variables in the same order as the
order in which their corresponding random variables were passed as
arguments. Empty if random_variables is empty.

"""
logprob, value_variables = conditional_logprob(
*random_variables, realized=realized, **kwargs
)
if not logprob:
return None
elif len(logprob) == 1:
cond_logprob = tuple(logprob.values())
return at.sum(cond_logprob), value_variables
else:
joint_logprob: TensorVariable = at.sum(
[at.sum(factor) for factor in logprob.values()]
)
return joint_logprob, value_variables