Source code for aeppl.censoring

from typing import TYPE_CHECKING, List, Optional

import aesara.tensor as at
import numpy as np
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import node_rewriter
from aesara.scalar.basic import ceil as scalar_ceil
from aesara.scalar.basic import clip as scalar_clip
from aesara.scalar.basic import floor as scalar_floor
from aesara.scalar.basic import round_half_to_even as scalar_round_half_to_even
from aesara.tensor.math import ceil, clip, floor, round_half_to_even
from aesara.tensor.var import TensorConstant

from aeppl.abstract import (
    MeasurableElemwise,
    MeasurableVariable,
    ValuedVariable,
    assign_custom_measurable_outputs,
)
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, logdiffexp
from aeppl.rewriting import measurable_ir_rewrites_db

if TYPE_CHECKING:
    from aesara.graph.basic import Variable
    from aesara.graph.op import Op


class MeasurableClip(MeasurableElemwise):
    """A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""


measurable_clip = MeasurableClip(scalar_clip)


@node_rewriter([clip])
def find_measurable_clips(
    fgraph: FunctionGraph, node: Node
) -> Optional[List["Variable"]]:
    # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

    if isinstance(node.op, MeasurableClip):
        return None  # pragma: no cover

    clipped_var = node.outputs[0]
    base_var, lower_bound, upper_bound = node.inputs

    if not (
        base_var.owner
        and isinstance(base_var.owner.op, MeasurableVariable)
        and not isinstance(base_var.owner.op, ValuedVariable)
    ):
        return None

    # Replace bounds by `+-inf` if `y = clip(x, x, ?)` or `y=clip(x, ?, x)`
    # This is used in `clip_logprob` to generate a more succint logprob graph
    # for one-sided clipped random variables
    lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf)
    upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf)

    # Make base_var unmeasurable
    unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
    clipped_rv_node = measurable_clip.make_node(
        unmeasurable_base_var, lower_bound, upper_bound
    )
    clipped_rv = clipped_rv_node.outputs[0]

    clipped_rv.name = clipped_var.name

    return [clipped_rv]


measurable_ir_rewrites_db.register(
    "find_measurable_clips",
    find_measurable_clips,
    "basic",
    "censoring",
)


[docs]@_logprob.register(MeasurableClip) def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs): r"""Logprob of a clipped censored distribution The probability is given by .. math:: \begin{cases} 0 & \text{for } x < lower, \\ \text{CDF}(lower, dist) & \text{for } x = lower, \\ \text{P}(x, dist) & \text{for } lower < x < upper, \\ 1-\text{CDF}(upper, dist) & \text {for} x = upper, \\ 0 & \text{for } x > upper, \end{cases} """ (value,) = values base_rv_op = base_rv.owner.op base_rv_inputs = base_rv.owner.inputs logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs) logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs) if base_rv_op.name: logprob.name = f"{base_rv_op}_logprob" logcdf.name = f"{base_rv_op}_logcdf" is_lower_bounded, is_upper_bounded = False, False if not ( isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value)) ): is_upper_bounded = True logccdf = at.log1mexp(logcdf) # For right clipped discrete RVs, we need to add an extra term # corresponding to the pmf at the upper bound if base_rv.dtype.startswith("int"): logccdf = at.logaddexp(logccdf, logprob) logprob = at.switch( at.eq(value, upper_bound), logccdf, at.switch(at.gt(value, upper_bound), -np.inf, logprob), ) if not ( isinstance(lower_bound, TensorConstant) and np.all(np.isneginf(lower_bound.value)) ): is_lower_bounded = True logprob = at.switch( at.eq(value, lower_bound), logcdf, at.switch(at.lt(value, lower_bound), -np.inf, logprob), ) if is_lower_bounded and is_upper_bounded: logprob = CheckParameterValue("lower_bound <= upper_bound")( logprob, at.all(at.le(lower_bound, upper_bound)) ) return logprob
class MeasurableRound(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a clipped RV sub-graph.""" measurable_ceil = MeasurableRound(scalar_ceil) measurable_floor = MeasurableRound(scalar_floor) measurable_round_half_to_even = MeasurableRound(scalar_round_half_to_even) @node_rewriter([ceil]) def find_measurable_ceil(fgraph: FunctionGraph, node: Node): return construct_measurable_rounding(fgraph, node, measurable_ceil) @node_rewriter([floor]) def find_measurable_floor(fgraph: FunctionGraph, node: Node): return construct_measurable_rounding(fgraph, node, measurable_floor) @node_rewriter([round_half_to_even]) def find_measurable_round_half_to_even(fgraph: FunctionGraph, node: Node): return construct_measurable_rounding(fgraph, node, measurable_round_half_to_even) measurable_ir_rewrites_db.register( "find_measurable_ceil", find_measurable_ceil, "basic", "censoring", ) measurable_ir_rewrites_db.register( "find_measurable_floor", find_measurable_floor, "basic", "censoring", ) measurable_ir_rewrites_db.register( "find_measurable_round_half_to_even", find_measurable_round_half_to_even, "basic", "censoring", ) def construct_measurable_rounding( fgraph: FunctionGraph, node: Node, rounded_op: "Op" ) -> Optional[List["Variable"]]: if isinstance(node.op, MeasurableRound): return None # pragma: no cover (rounded_var,) = node.outputs (base_var,) = node.inputs if not ( base_var.owner and isinstance(base_var.owner.op, MeasurableVariable) and not isinstance(base_var.owner.op, ValuedVariable) # Rounding only makes sense for continuous variables and base_var.dtype.startswith("float") ): return None # Make base_var unmeasurable unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner) rounded_rv = rounded_op.make_node(unmeasurable_base_var).default_output() rounded_rv.name = rounded_var.name return [rounded_rv]
[docs]@_logprob.register(MeasurableRound) def round_logprob(op, values, base_rv, **kwargs): r"""Logprob of a rounded censored distribution The probability of a distribution rounded to the nearest integer is given by .. math:: \begin{cases} \text{CDF}(x+\frac{1}{2}, dist) - \text{CDF}(x-\frac{1}{2}, dist) & \text{for } x \in \mathbb{Z}, \\ 0 & \text{otherwise}, \end{cases} The probability of a distribution rounded up is given by .. math:: \begin{cases} \text{CDF}(x, dist) - \text{CDF}(x-1, dist) & \text{for } x \in \mathbb{Z}, \\ 0 & \text{otherwise}, \end{cases} The probability of a distribution rounded down is given by .. math:: \begin{cases} \text{CDF}(x+1, dist) - \text{CDF}(x, dist) & \text{for } x \in \mathbb{Z}, \\ 0 & \text{otherwise}, \end{cases} """ (value,) = values if op == measurable_round_half_to_even: value = at.round(value) value_upper = value + 0.5 value_lower = value - 0.5 elif op == measurable_floor: value = at.floor(value) value_upper = value + 1.0 value_lower = value elif op == measurable_ceil: value = at.ceil(value) value_upper = value value_lower = value - 1.0 else: raise TypeError(f"Unsupported scalar_op {op.scalar_op}") # pragma: no cover base_rv_op = base_rv.owner.op base_rv_inputs = base_rv.owner.inputs logcdf_upper = _logcdf(base_rv_op, value_upper, *base_rv_inputs, **kwargs) logcdf_lower = _logcdf(base_rv_op, value_lower, *base_rv_inputs, **kwargs) if base_rv_op.name: logcdf_upper.name = f"{base_rv_op}_logcdf_upper" logcdf_lower.name = f"{base_rv_op}_logcdf_lower" return logdiffexp(logcdf_upper, logcdf_lower)