from typing import List, Optional, Tuple, Union, cast
import aesara
import aesara.tensor as at
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, compute_test_value
from aesara.graph.rewriting.basic import (
EquilibriumGraphRewriter,
node_rewriter,
pre_greedy_node_rewriter,
)
from aesara.ifelse import ifelse
from aesara.scalar.basic import Switch
from aesara.tensor.basic import Join, MakeVector
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.rewriting import (
local_dimshuffle_rv_lift,
local_subtensor_rv_lift,
)
from aesara.tensor.shape import shape_tuple
from aesara.tensor.subtensor import (
as_index_literal,
as_nontensor_scalar,
get_canonical_form_slice,
is_basic_idx,
)
from aesara.tensor.type import TensorType
from aesara.tensor.type_other import NoneConst, NoneTypeT, SliceType
from aesara.tensor.var import TensorVariable
from aeppl.abstract import (
MeasurableVariable,
ValuedVariable,
assign_custom_measurable_outputs,
)
from aeppl.logprob import _logprob, logprob
from aeppl.rewriting import local_lift_DiracDelta, logprob_rewrites_db, subtensor_ops
from aeppl.utils import get_constant_value
def is_newaxis(x):
return isinstance(x, type(None)) or isinstance(getattr(x, "type", None), NoneTypeT)
def expand_indices(
indices: Tuple[Optional[Union[Variable, slice]], ...], shape: Tuple[TensorVariable]
) -> Tuple[TensorVariable]:
"""Convert basic and/or advanced indices into a single, broadcasted advanced indexing operation.
Parameters
----------
indices
The indices to convert.
shape
The shape of the array being indexed.
"""
n_non_newaxis = sum(1 for idx in indices if not is_newaxis(idx))
n_missing_dims = len(shape) - n_non_newaxis
full_indices = list(indices) + [slice(None)] * n_missing_dims
# We need to know if a "subspace" was generated by advanced indices
# bookending basic indices. If so, we move the advanced indexing subspace
# to the "front" of the shape (i.e. left-most indices/last-most
# dimensions).
index_types = [is_basic_idx(idx) for idx in full_indices]
first_adv_idx = len(shape)
try:
first_adv_idx = index_types.index(False)
first_bsc_after_adv_idx = index_types.index(True, first_adv_idx)
index_types.index(False, first_bsc_after_adv_idx)
moved_subspace = True
except ValueError:
moved_subspace = False
n_basic_indices = sum(index_types)
# The number of dimensions in the subspace created by the advanced indices
n_subspace_dims = max(
(
getattr(idx, "ndim", 0)
for idx, is_basic in zip(full_indices, index_types)
if not is_basic
),
default=0,
)
# The number of dimensions for each expanded index
n_output_dims = n_subspace_dims + n_basic_indices
adv_indices = []
shape_copy = list(shape)
n_preceding_basics = 0
for d, idx in enumerate(full_indices):
if not is_basic_idx(idx):
s = shape_copy.pop(0)
idx = at.as_tensor(idx)
if moved_subspace:
# The subspace generated by advanced indices appear as the
# upper dimensions in the "expanded" index space, so we need to
# add broadcast dimensions for the non-basic indices to the end
# of these advanced indices
expanded_idx = idx[(Ellipsis,) + (None,) * n_basic_indices]
else:
# In this case, we need to add broadcast dimensions for the
# basic indices that proceed and follow the group of advanced
# indices; otherwise, a contiguous group of advanced indices
# forms a broadcasted set of indices that are iterated over
# within the same subspace, which means that all their
# corresponding "expanded" indices have exactly the same shape.
expanded_idx = idx[(None,) * n_preceding_basics][
(Ellipsis,) + (None,) * (n_basic_indices - n_preceding_basics)
]
else:
if is_newaxis(idx):
n_preceding_basics += 1
continue
s = shape_copy.pop(0)
if isinstance(idx, slice) or isinstance(
getattr(idx, "type", None), SliceType
):
idx = as_index_literal(idx)
idx_slice, _ = get_canonical_form_slice(idx, s)
idx = at.arange(idx_slice.start, idx_slice.stop, idx_slice.step)
if moved_subspace:
# Basic indices appear in the lower dimensions
# (i.e. right-most) in the output, and are preceded by
# the subspace generated by the advanced indices.
expanded_idx = idx[(None,) * (n_subspace_dims + n_preceding_basics)][
(Ellipsis,) + (None,) * (n_basic_indices - n_preceding_basics - 1)
]
else:
# In this case, we need to know when the basic indices have
# moved past the contiguous group of advanced indices (in the
# "expanded" index space), so that we can properly pad those
# dimensions in this basic index's shape.
# Don't forget that a single advanced index can introduce an
# arbitrary number of dimensions to the expanded index space.
# If we're currently at a basic index that's past the first
# advanced index, then we're necessarily past the group of
# advanced indices.
n_preceding_dims = (
n_subspace_dims if d > first_adv_idx else 0
) + n_preceding_basics
expanded_idx = idx[(None,) * n_preceding_dims][
(Ellipsis,) + (None,) * (n_output_dims - n_preceding_dims - 1)
]
n_preceding_basics += 1
assert expanded_idx.ndim <= n_output_dims
adv_indices.append(expanded_idx)
return cast(Tuple[TensorVariable], tuple(at.broadcast_arrays(*adv_indices)))
def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
"""Pull a ``RandomVariable`` ``Op`` down through a graph, when possible."""
fgraph = FunctionGraph(outputs=dont_touch_vars or [], clone=False)
return pre_greedy_node_rewriter(
fgraph,
[
local_dimshuffle_rv_lift,
local_subtensor_rv_lift,
local_lift_DiracDelta,
],
x,
)
class MixtureRV(Op):
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""
__props__ = ("indices_end_idx", "out_dtype", "out_shape")
def __init__(self, indices_end_idx, out_dtype, out_shape):
super().__init__()
self.indices_end_idx = indices_end_idx
self.out_dtype = out_dtype
self.out_shape = out_shape
def make_node(self, *inputs):
return Apply(
self, list(inputs), [TensorType(self.out_dtype, shape=self.out_shape)()]
)
def perform(self, node, inputs, outputs):
raise NotImplementedError("This is a stand-in Op.") # pragma: no cover
MeasurableVariable.register(MixtureRV)
def get_stack_mixture_vars(
node: Apply,
) -> Tuple[Optional[List[TensorVariable]], Optional[int]]:
r"""Extract the mixture terms from a `*Subtensor*` applied to stacked `MeasurableVariable`\s."""
if not isinstance(node.op, subtensor_ops):
return None, None # pragma: no cover
join_axis = NoneConst
joined_rvs = node.inputs[0]
# First, make sure that it's some sort of concatenation
if not (joined_rvs.owner and isinstance(joined_rvs.owner.op, (MakeVector, Join))):
# Node is not a compatible join `Op`
return None, join_axis # pragma: no cover
if isinstance(joined_rvs.owner.op, MakeVector):
mixture_rvs = joined_rvs.owner.inputs
elif isinstance(joined_rvs.owner.op, Join):
mixture_rvs = joined_rvs.owner.inputs[1:]
join_axis = joined_rvs.owner.inputs[0]
try:
join_axis = int(get_constant_value(join_axis))
except ValueError:
# TODO: Support symbolic join axes
raise NotImplementedError(
"Symbolic `Join` axes are not supported in mixtures"
)
join_axis = at.as_tensor(join_axis)
if not all(
rv.owner and isinstance(rv.owner.op, MeasurableVariable) for rv in mixture_rvs
):
# Currently, all mixture components must be `MeasurableVariable` outputs
# TODO: Allow constants and make them Dirac-deltas
# raise NotImplementedError(
# "All mixture components must be `MeasurableVariable` outputs"
# )
return None, join_axis
return mixture_rvs, join_axis
@node_rewriter(subtensor_ops)
def mixture_replace(fgraph, node):
r"""Identify mixture sub-graphs and replace them with a place-holder `Op`.
The basic idea is to find ``stack(mixture_comps)[I_rv]``, where
``mixture_comps`` is a ``list`` of `MeasurableVariable`\s and ``I_rv`` is a
`MeasurableVariable` with a discrete and finite support.
From these terms, new terms ``Z_rv[i] = mixture_comps[i][i == I_rv]`` are
created for each ``i`` in ``enumerate(mixture_comps)``.
"""
old_mixture_rv = node.default_output()
mixture_res, join_axis = get_stack_mixture_vars(node)
if mixture_res is None or any(
rv.owner and isinstance(rv.owner.op, ValuedVariable) for rv in mixture_res
):
return None # pragma: no cover
mixing_indices = node.inputs[1:]
# We loop through mixture components and collect all the array elements
# that belong to each one (by way of their indices).
mixture_rvs = []
for i, component_rv in enumerate(mixture_res):
# We create custom types for the mixture components and assign them
# null `get_measurable_outputs` dispatches so that they aren't
# erroneously encountered in places like `conditional_logprob`.
new_node = assign_custom_measurable_outputs(component_rv.owner)
out_idx = component_rv.owner.outputs.index(component_rv)
new_comp_rv = new_node.outputs[out_idx]
mixture_rvs.append(new_comp_rv)
# Replace this sub-graph with a `MixtureRV`
mix_op = MixtureRV(
1 + len(mixing_indices),
old_mixture_rv.type.dtype,
old_mixture_rv.type.shape,
)
new_node = mix_op.make_node(*([join_axis] + mixing_indices + mixture_rvs))
new_mixture_rv = new_node.default_output()
if aesara.config.compute_test_value != "off":
# We can't use `MixtureRV` to compute a test value; instead, we'll use
# the original node's test value.
if not hasattr(old_mixture_rv.tag, "test_value"):
compute_test_value(node)
new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value
if old_mixture_rv.name:
new_mixture_rv.name = f"{old_mixture_rv.name}-mixture"
return [new_mixture_rv]
@node_rewriter((Elemwise,))
def switch_mixture_replace(fgraph, node):
if not isinstance(node.op.scalar_op, Switch):
return None # pragma: no cover
old_mixture_rv = node.default_output()
# idx, component_1, component_2 = node.inputs
mixture_rvs = []
for component_rv in node.inputs[1:]:
if not (
component_rv.owner
and isinstance(component_rv.owner.op, MeasurableVariable)
and not isinstance(component_rv.owner.op, ValuedVariable)
):
return None
new_node = assign_custom_measurable_outputs(component_rv.owner)
out_idx = component_rv.owner.outputs.index(component_rv)
new_comp_rv = new_node.outputs[out_idx]
mixture_rvs.append(new_comp_rv)
mix_op = MixtureRV(
2,
old_mixture_rv.type.dtype,
old_mixture_rv.type.shape,
)
new_node = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
)
new_mixture_rv = new_node.default_output()
if aesara.config.compute_test_value != "off":
if not hasattr(old_mixture_rv.tag, "test_value"):
compute_test_value(node)
new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value
if old_mixture_rv.name:
new_mixture_rv.name = f"{old_mixture_rv.name}-mixture"
return [new_mixture_rv]
[docs]@_logprob.register(MixtureRV)
def logprob_MixtureRV(
op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs
):
(value,) = values
join_axis = cast(Variable, inputs[0])
indices = cast(TensorVariable, inputs[1 : op.indices_end_idx])
comp_rvs = cast(TensorVariable, inputs[op.indices_end_idx :])
assert len(indices) > 0
if len(indices) > 1 or indices[0].ndim > 0:
if isinstance(join_axis.type, NoneTypeT):
# `join_axis` will be `NoneConst` if the "join" was a `MakeVector`
# (i.e. scalar measurable variables were combined to make a
# vector).
# Since some form of advanced indexing is necessarily occurring, we
# need to reformat the MakeVector arguments so that they fit the
# `Join` format expected by the logic below.
join_axis_val = 0
comp_rvs = [comp[None] for comp in comp_rvs]
original_shape = (len(comp_rvs),)
else:
join_axis_val = get_constant_value(join_axis).item()
original_shape = shape_tuple(comp_rvs[0])
bcast_indices = expand_indices(indices, original_shape)
logp_val = at.empty(bcast_indices[0].shape)
for m, rv in enumerate(comp_rvs):
idx_m_on_axis = at.nonzero(at.eq(bcast_indices[join_axis_val], m))
m_indices = tuple(
v[idx_m_on_axis]
for i, v in enumerate(bcast_indices)
if i != join_axis_val
)
# Drop superfluous join dimension
rv = rv[0]
# TODO: Do we really need to do this now?
# Could we construct this form earlier and
# do the lifting for everything at once, instead of
# this intentional one-off?
rv_m = rv_pull_down(rv[m_indices] if m_indices else rv)
val_m = value[idx_m_on_axis]
logp_m = logprob(rv_m, val_m)
logp_val = at.set_subtensor(logp_val[idx_m_on_axis], logp_m)
else:
logp_val = 0.0
for i, comp_rv in enumerate(comp_rvs):
comp_logp = logprob(comp_rv, value)
logp_val += ifelse(
at.eq(indices[0], i),
comp_logp,
at.zeros_like(value, dtype=comp_logp.type.dtype),
)
return logp_val
logprob_rewrites_db.register(
"mixture_replace",
EquilibriumGraphRewriter(
[mixture_replace, switch_mixture_replace],
max_use_ratio=aesara.config.optdb__max_use_ratio,
),
"basic",
"mixture",
)