import abc
from copy import copy
from functools import partial, singledispatch
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import aesara.tensor as at
from aesara.gradient import DisconnectedType, jacobian
from aesara.graph.basic import Apply, Variable, walk
from aesara.graph.features import AlreadyThere, Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from aesara.tensor.math import add, exp, log, mul, reciprocal, sub, true_divide
from aesara.tensor.rewriting.basic import register_useless
from aesara.tensor.var import TensorVariable
from typing_extensions import Protocol
from aeppl.abstract import (
MeasurableElemwise,
MeasurableVariable,
ValuedVariable,
_get_measurable_outputs,
assign_custom_measurable_outputs,
valued_variable,
)
from aeppl.logprob import _logprob, logprob
from aeppl.rewriting import ir_cleanup_db, measurable_ir_rewrites_db
if TYPE_CHECKING:
from aesara.graph.rewriting.basic import NodeRewriter
class TransformFnType(Protocol):
def __call__(
self, measurable_input: MeasurableVariable, *other_inputs: Variable
) -> Tuple["RVTransform", Tuple[TensorVariable, ...]]:
pass
def register_measurable_ir(
node_rewriter: "NodeRewriter",
*tags: str,
**kwargs,
):
name = kwargs.pop("name", None) or node_rewriter.__name__
measurable_ir_rewrites_db.register(
name, node_rewriter, "basic", "transform", *tags, **kwargs
)
return node_rewriter
@singledispatch
def _default_transformed_rv(
op: Op,
node: Apply,
) -> Optional[Apply]:
"""Create a node for a transformed log-probability of a `MeasurableVariable`.
This function dispatches on the type of `op`. If you want to implement
new transforms for a `MeasurableVariable`, register a function on this
dispatcher.
"""
return None
class TransformedVariable(Op):
r"""A no-op that identifies a transformed value variable and its un-transformed input.
This transformed/untransformed-value pairing is primarily used as a means to
obtain the untransformed value for use with the `RVTransform.log_jac_det` when
log-probabilities are generated.
It is expected that all occurrences of these `Op`\s will be removed during
compilation by `remove_TransformedVariables`.
"""
__props__ = ()
view_map = {0: [0]}
def make_node(self, tran_value: TensorVariable, value: TensorVariable):
return Apply(self, [tran_value, value], [tran_value.type()])
def perform(self, node, inputs, outputs):
raise NotImplementedError(
"These `Op`s should be removed from graphs used for computation."
)
def connection_pattern(self, node):
return [[True], [False]]
def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
def grad(self, args, g_outs):
return g_outs[0], DisconnectedType()()
transformed_variable = TransformedVariable()
@register_useless
@node_rewriter([TransformedVariable])
def remove_TransformedVariables(fgraph, node):
return [node.inputs[0]]
ir_cleanup_db.register(
"remove-TransformedVariables", in2out(remove_TransformedVariables), "basic"
)
class RVTransform(abc.ABC):
@abc.abstractmethod
def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
"""Apply the transformation."""
@abc.abstractmethod
def backward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
"""Invert the transformation."""
def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
"""Construct the log of the absolute value of the Jacobian determinant."""
# jac = at.reshape(
# gradient(at.sum(self.backward(value, *inputs)), [value]), value.shape
# )
# return at.log(at.abs(jac))
phi_inv = self.backward(value, *inputs)
return at.log(
at.abs(at.nlinalg.det(at.atleast_2d(jacobian(phi_inv, [value])[0])))
)
class DefaultTransformSentinel:
pass
DEFAULT_TRANSFORM = DefaultTransformSentinel()
@node_rewriter([ValuedVariable])
def transform_values(fgraph: FunctionGraph, node: Apply):
r"""Apply transforms to the values of value-bound measurable variables.
It is assumed that the input value variables correspond to forward
transformations, usually chosen in such a way that the values are
unconstrained on the real line.
For example, if ``Y = halfnormal(...)``, we assume the respective value
variable is specified on the log scale and back-transform it to obtain
``Y`` on the natural scale.
The main steps of this rewrite are as follows:
1. Obtain a `RVTransform` for a `ValuedVariable`.
2. Replace all occurrences of the original value variable with the
"backward"-transform of a new value variable in the "forward"-transformed
space.
3. Signify that the original value variable is to be replaced by the new
one.
4. Replace the old `ValuedVariable` with a new one containing a
`TransformedVariable` value.
Step 3. is currently accomplished by updating the `rvs_to_values`
dictionary associated with the `FunctionGraph`. Our main entry-point,
`conditional_logprob`, checks this dictionary for value variable changes.
The new value variable mentioned in Step 2. may be of a different `Type`
(e.g. extra/fewer dimensions) than the original value variable; this is why
we must replace the corresponding original value variables before we
construct log-probability graphs. This is due to the fact that we
generally cannot replace variables with new ones that have different
`Type`\s.
"""
values_to_transforms = getattr(fgraph, "values_to_transforms", None)
if values_to_transforms is None:
return None # pragma: no cover
rv_var, value_var = node.inputs
rv_node = rv_var.owner
try:
rv_var = rv_node.default_output()
rv_var_out_idx = rv_node.outputs.index(rv_var)
except ValueError:
return None
transform = values_to_transforms.get(value_var, None)
if transform is None:
return None
elif transform is DEFAULT_TRANSFORM:
trans_node = _default_transformed_rv(rv_node.op, rv_node)
if trans_node is None:
return None
transform = trans_node.op.transform
else:
# This automatically constructs a `_logprob` dispatch for the
# transformed node, so look there for the log-probability
# implementation
new_op = _create_transformed_rv_op(rv_node.op, transform)
# Create a new `Apply` node and outputs
trans_node = rv_node.clone()
trans_node.op = new_op
trans_node.outputs[rv_var_out_idx].name = rv_node.outputs[rv_var_out_idx].name
# We now assume that the old value variable represents the *transformed space*.
# We normally initialize value variables as copies of the random variables,
# thus, in the untransformed space, we need to apply the forward
# transformation to get value variables with the correct `Type`s in the
# transformed space.
trans_value_var: TensorVariable = transform.forward(
value_var, *trans_node.inputs
).type()
if value_var.name:
trans_value_var.name = f"{value_var.name}-trans"
fgraph.add_input(trans_value_var)
# We need to replace all instances of the old value variables
# with "inversely/un-" transformed versions of themselves.
untrans_value_var = transformed_variable(
transform.backward(trans_value_var, *trans_node.inputs), trans_value_var
)
# This effectively lets the caller know that a value variable has been
# replaced (i.e. they should filter all their old value variables through
# the replacements map).
fgraph.value_clone_to_value[value_var] = trans_value_var
trans_var = trans_node.outputs[rv_var_out_idx]
new_var = valued_variable(trans_var, untrans_value_var)
return {value_var: untrans_value_var, node.outputs[0]: new_var}
class TransformValuesMapping(Feature):
r"""A `Feature` that maintains a map between value variables and their transforms as
well as between value variables and their transformed counterparts.
This is exclusively/primarily used by `TransformValuesRewrite`.
"""
def __init__(self, values_to_transforms, value_clone_to_value):
"""
Parameters
==========
values_to_transforms
Mapping between value variables and their transformations. Each
value variable can be assigned one of `RVTransform`,
`DEFAULT_TRANSFORM`, or ``None``. Random variables with no
transform specified remain unchanged.
value_clone_to_value
Mapping between random variable value clones and their original
value variables.
"""
self.values_to_transforms = values_to_transforms
self.value_clone_to_value = value_clone_to_value
def on_attach(self, fgraph):
if hasattr(fgraph, "values_to_transforms"):
raise AlreadyThere()
fgraph.values_to_transforms = self.values_to_transforms
fgraph.value_clone_to_value = self.value_clone_to_value
class MeasurableElemwiseTransform(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a transformed `Elemwise`."""
# Cannot use `transform` as name because it would clash with the property added by
# the `TransformValuesRewrite`
transform_elemwise: RVTransform
measurable_input_idx: int
def __init__(
self, *args, transform: RVTransform, measurable_input_idx: int, **kwargs
):
self.transform_elemwise = transform
self.measurable_input_idx = measurable_input_idx
super().__init__(*args, **kwargs)
@_get_measurable_outputs.register(MeasurableElemwiseTransform)
def _get_measurable_outputs_ElemwiseTransform(op, node):
return [node.default_output()]
@_logprob.register(MeasurableElemwiseTransform)
def measurable_elemwise_logprob(
op: MeasurableElemwiseTransform, values, *inputs, **kwargs
):
"""Compute the log-probability graph for a `MeasurableElemwiseTransform`."""
# TODO: Could other rewrites affect the order of inputs?
(value,) = values
other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)
# The value variable must still be back-transformed to be on the natural support of
# the respective measurable input.
backward_value = op.transform_elemwise.backward(value, *other_inputs)
input_logprob = logprob(measurable_input, backward_value, **kwargs)
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)
return input_logprob + jacobian
@register_measurable_ir
@node_rewriter([true_divide])
def measurable_true_divide(fgraph, node):
r"""Rewrite a `true_divide` node to a `MeasurableVariable`.
TODO FIXME: We need to update/clarify the canonicalization situation so that
these can be reliably rewritten as products of reciprocals.
"""
numerator, denominator = node.inputs
reciprocal_denominator = at.reciprocal(denominator)
# `denominator` is measurable
res = measurable_reciprocal.transform(fgraph, reciprocal_denominator.owner)
if res:
reciprocal_denominator = res[0]
# `numerator` is measurable
return measurable_mul.transform(
fgraph, at.mul(numerator, reciprocal_denominator).owner
)
@register_measurable_ir
@node_rewriter([sub])
def measurable_sub(fgraph, node):
r"""Rewrite a `sub` node to a `MeasurableVariable`.
TODO FIXME: We need to update/clarify the canonicalization situation so
that these can be reliably rewritten.
"""
minuend, subtrahend = node.inputs
mul_subtrahend = at.mul(-1, subtrahend)
# `subtrahend` is measurable
res = measurable_mul.transform(fgraph, mul_subtrahend.owner)
if res:
mul_subtrahend = res[0]
# TODO FIXME: `local_add_canonizer` will unreliably rewrite expressions like
# `x - y` to `-y + x` (e.g. apparently when `y` is a constant?) and, as a result,
# this will not be reached. We're leaving this in just in case, but we
# ultimately need to fix Aesara's canonicalizations.
# `minuend` is measurable
return measurable_add.transform(fgraph, at.add(minuend, mul_subtrahend).owner)
@register_measurable_ir
@node_rewriter([exp])
def measurable_exp(fgraph, node):
"""Rewrite an `exp` node to a `MeasurableVariable`."""
def transform(measurable_input, *args):
return ExpTransform(), (measurable_input,)
return construct_elemwise_transform(fgraph, node, transform)
@register_measurable_ir
@node_rewriter([log])
def measurable_log(fgraph, node):
"""Rewrite a `log` node to a `MeasurableVariable`."""
def transform(measurable_input, *args):
return LogTransform(), (measurable_input,)
return construct_elemwise_transform(fgraph, node, transform)
@register_measurable_ir
@node_rewriter([add])
def measurable_add(fgraph, node):
"""Rewrite an `add` node to a `MeasurableVariable`."""
def transform(measurable_input, *other_inputs):
transform_inputs = (
measurable_input,
at.add(*other_inputs) if len(other_inputs) > 1 else other_inputs[0],
)
transform = LocTransform(
transform_args_fn=lambda *inputs: inputs[-1],
)
return transform, transform_inputs
return construct_elemwise_transform(fgraph, node, transform)
@register_measurable_ir
@node_rewriter([mul])
def measurable_mul(fgraph, node):
"""Rewrite a `mul` node to a `MeasurableVariable`."""
def transform(measurable_input, *other_inputs):
transform_inputs = (
measurable_input,
at.mul(*other_inputs) if len(other_inputs) > 1 else other_inputs[0],
)
return (
ScaleTransform(
transform_args_fn=lambda *inputs: inputs[-1],
),
transform_inputs,
)
return construct_elemwise_transform(fgraph, node, transform)
@register_measurable_ir
@node_rewriter([reciprocal])
def measurable_reciprocal(fgraph, node):
"""Rewrite a `reciprocal` node to a `MeasurableVariable`."""
def transform(measurable_input, *other_inputs):
return ReciprocalTransform(), (measurable_input,)
return construct_elemwise_transform(fgraph, node, transform)
def construct_elemwise_transform(
fgraph: FunctionGraph,
node: Apply,
transform_fn: "TransformFnType",
) -> Optional[List[Variable]]:
"""Construct a measurable transformation for an `Elemwise` node.
Parameters
----------
fgraph
The `FunctionGraph` in which `node` resides.
node
The `Apply` node to be converted.
transform_fn
A function that takes a single measurable input and all the remaining
inputs and returns a transform object and transformed inputs.
Returns
-------
A new variable with an `Apply` node with a `MeasurableElemwiseTransform`
that replaces `node`.
"""
scalar_op = node.op.scalar_op
# Check that we have a single source of measurement
measurable_inputs = [
inp
for idx, inp in enumerate(node.inputs)
if inp.owner
and isinstance(inp.owner.op, MeasurableVariable)
and not isinstance(inp.owner.op, ValuedVariable)
]
if len(measurable_inputs) != 1:
return None
measurable_input: TensorVariable = measurable_inputs[0]
# Do not apply rewrite to discrete variables
# TODO: Formalize this restriction better.
if measurable_input.type.dtype.startswith("int"):
return None
# Check that other inputs are not potentially measurable, in which case this rewrite
# would be invalid
# TODO FIXME: This is rather costly and redundant; find a way to avoid it
# or make it cheaper.
other_inputs = tuple(inp for inp in node.inputs if inp is not measurable_input)
def expand(var: TensorVariable) -> List[TensorVariable]:
new_vars: List[TensorVariable] = []
if (
var.owner
and not isinstance(var.owner.op, MeasurableVariable)
and not isinstance(var.owner.op, ValuedVariable)
):
new_vars.extend(reversed(var.owner.inputs))
return new_vars
if any(
var
for var in walk(other_inputs, expand, False)
if (
var.owner
and isinstance(var.owner.op, MeasurableVariable)
and not isinstance(var.owner.op, ValuedVariable)
)
):
return None
# Make base_measure outputs unmeasurable
# This seems to be the only thing preventing nested rewrites from being erased
# TODO: Is this still needed?
if not isinstance(measurable_input.owner.op, ValuedVariable):
measurable_input = assign_custom_measurable_outputs(measurable_input.owner)
measurable_input_idx = 0
transform, transform_inputs = transform_fn(measurable_input, *other_inputs)
transform_op = MeasurableElemwiseTransform(
scalar_op=scalar_op,
transform=transform,
measurable_input_idx=measurable_input_idx,
)
transform_out = transform_op.make_node(*transform_inputs).default_output()
transform_out.name = node.outputs[0].name
return [transform_out]
def _create_transformed_rv_op(
rv_op: Op,
transform: RVTransform,
*,
default: bool = False,
cls_dict_extra: Optional[Dict] = None,
) -> Op:
"""Create a new transformed variable instance given a base `RandomVariable` `Op`.
This will essentially copy the `type` of the given `Op` instance, create a
copy of said `Op` instance and change it's `type` to the new one.
In the end, we have an `Op` instance that will map to a `RVTransform` while
also behaving exactly as it did before.
Parameters
==========
rv_op
The `RandomVariable` for which we want to construct a `TransformedRV`.
transform
The `RVTransform` for `rv_op`.
default
If ``False`` do not make `transform` the default transform for `rv_op`.
cls_dict_extra
Additional class members to add to the constructed `TransformedRV`.
"""
trans_name = getattr(transform, "name", "transformed")
rv_op_type = type(rv_op)
rv_type_name = rv_op_type.__name__
cls_dict = rv_op_type.__dict__.copy()
rv_name = cls_dict.get("name", "")
if rv_name:
cls_dict["name"] = f"{rv_name}_{trans_name}"
cls_dict["transform"] = transform
if cls_dict_extra is not None:
cls_dict.update(cls_dict_extra)
new_op_type = type(f"Transformed{rv_type_name}", (rv_op_type,), cls_dict)
MeasurableVariable.register(new_op_type)
@_logprob.register(new_op_type)
def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
"""Compute the log-likelihood graph for a `TransformedRV`.
We assume that the value variable was back-transformed to be on the natural
support of the respective random variable.
"""
(value,) = values
assert value.owner.op == transformed_variable
backward_val, trans_val = value.owner.inputs
logprob = _logprob(rv_op, (backward_val,), *inputs, **kwargs)
if use_jacobian:
jacobian = op.transform.log_jac_det(trans_val, *inputs)
logprob += jacobian
return logprob
transform_op = rv_op_type if default else new_op_type
@_default_transformed_rv.register(transform_op)
def class_transformed_rv(op, node):
new_op = new_op_type()
res = new_op.make_node(*node.inputs)
res.outputs[1].name = node.outputs[1].name
return res
new_op = copy(rv_op)
new_op.__class__ = new_op_type
return new_op
create_default_transformed_rv_op = partial(_create_transformed_rv_op, default=True)
TransformedUniformRV = create_default_transformed_rv_op(
at.random.uniform,
# inputs[3] = lower; inputs[4] = upper
IntervalTransform(lambda *inputs: (inputs[3], inputs[4])),
)
TransformedParetoRV = create_default_transformed_rv_op(
at.random.pareto,
# inputs[3] = alpha
IntervalTransform(lambda *inputs: (inputs[3], None)),
)
TransformedTriangularRV = create_default_transformed_rv_op(
at.random.triangular,
# inputs[3] = lower; inputs[5] = upper
IntervalTransform(lambda *inputs: (inputs[3], inputs[5])),
)
TransformedHalfNormalRV = create_default_transformed_rv_op(
at.random.halfnormal,
# inputs[3] = loc
IntervalTransform(lambda *inputs: (inputs[3], None)),
)
TransformedWaldRV = create_default_transformed_rv_op(
at.random.wald,
LogTransform(),
)
TransformedExponentialRV = create_default_transformed_rv_op(
at.random.exponential,
LogTransform(),
)
TransformedLognormalRV = create_default_transformed_rv_op(
at.random.lognormal,
LogTransform(),
)
TransformedHalfCauchyRV = create_default_transformed_rv_op(
at.random.halfcauchy,
LogTransform(),
)
TransformedGammaRV = create_default_transformed_rv_op(
at.random.gamma,
LogTransform(),
)
TransformedInvGammaRV = create_default_transformed_rv_op(
at.random.invgamma,
LogTransform(),
)
TransformedChiSquareRV = create_default_transformed_rv_op(
at.random.chisquare,
LogTransform(),
)
TransformedWeibullRV = create_default_transformed_rv_op(
at.random.weibull,
LogTransform(),
)
TransformedBetaRV = create_default_transformed_rv_op(
at.random.beta,
LogOddsTransform(),
)
TransformedVonMisesRV = create_default_transformed_rv_op(
at.random.vonmises,
CircularTransform(),
)
TransformedDirichletRV = create_default_transformed_rv_op(
at.random.dirichlet,
SimplexTransform(),
)