from copy import copy
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Tuple, cast

import aesara
import aesara.tensor as at
import numpy as np
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value
from aesara.graph.rewriting.basic import node_rewriter
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.scan.op import Scan
from aesara.scan.rewriting import scan_eqopt1, scan_eqopt2
from aesara.scan.utils import ScanArgs
from aesara.tensor.random.type import RandomType
from aesara.tensor.subtensor import Subtensor, indices_from_subtensor
from aesara.tensor.var import TensorVariable
from aesara.updates import OrderedUpdates

from aeppl.abstract import (
from aeppl.joint_logprob import conditional_logprob
from aeppl.logprob import _logprob
from aeppl.rewriting import (

    from aesara.graph.basic import Apply, Variable

class MeasurableScan(Scan):
    """A placeholder used to specify a log-likelihood for a scan sub-graph."""

    def __str__(self):
        res = super().__str__()
        return f"measurable_{res}"


def convert_outer_out_to_in(
    input_scan_args: ScanArgs,
    outer_out_vars: Iterable[TensorVariable],
    new_outer_input_vars: Dict[TensorVariable, TensorVariable],
    inner_out_fn: Callable[
        [Dict[TensorVariable, TensorVariable]], Iterable[TensorVariable]
) -> ScanArgs:
    r"""Convert outer-graph outputs into outer-graph inputs.

        The source `Scan` arguments.
        The outer-graph output variables that are to be converted into an
        outer-graph input.
        The variables used for the new outer-graph input computed for
        A function that takes the remapped outer-out variables and produces new
        inner-graph outputs.  This can be used to transform the
        `outer_out_vars`\s' corresponding inner-graph outputs into something
        else entirely, like log-probabilities.

    A `ScanArgs` object for a `Scan` in which `outer_out_vars` has been converted to an
    outer-graph input.

    output_scan_args = copy(input_scan_args)
    inner_outs_to_new_inner_ins = {}

    # Map inner-outputs to outer-outputs
    old_inner_outs_to_outer_outs = {}

    for oo_var in outer_out_vars:
        var_info = output_scan_args.find_among_fields(
            oo_var, field_filter=lambda x: x.startswith("outer_out")

        assert var_info is not None
        assert oo_var in new_outer_input_vars

        io_var = output_scan_args.get_alt_field(var_info, "inner_out")
        old_inner_outs_to_outer_outs[io_var] = oo_var

    # In this loop, we gather information about the new inner-inputs that have
    # been created and what their corresponding inner-outputs were, and we
    # update the outer and inner-inputs to reflect the addition of new
    # inner-inputs.
    for old_inner_out_var, oo_var in old_inner_outs_to_outer_outs.items():
        # Couldn't one do the same with `var_info`?
        inner_out_info = output_scan_args.find_among_fields(
            old_inner_out_var, field_filter=lambda x: x.startswith("inner_out")

        output_scan_args.remove_from_fields(old_inner_out_var, rm_dependents=False)

        # Remove the old outer-output variable.
        # Not sure if this really matters, since we don't use the outer-outputs
        # when building a new `Scan`, but doing it keeps the `ScanArgs` object
        # consistent.
        output_scan_args.remove_from_fields(oo_var, rm_dependents=False)

        # Use the index for the specific inner-graph sub-collection to which this
        # variable belongs (e.g. index `1` among the inner-graph sit-sot terms)
        var_idx = inner_out_info.index

        # The old inner-output variable becomes the a new inner-input
        new_inner_in_var = old_inner_out_var.clone()
   = f"{}_vv"

        inner_outs_to_new_inner_ins[old_inner_out_var] = new_inner_in_var

        # We want to remove elements from both lists and tuples, because the
        # members of `ScanArgs` could switch from being `list`s to `tuple`s
        # soon
        def remove(x, i):
            return x[:i] + x[i + 1 :]

        # If we're replacing a [m|s]it-sot, then we need to add a new nit-sot
        add_nit_sot = False
            inner_in_mit_sot_var = cast(
                Tuple[int, ...], tuple(output_scan_args.inner_in_mit_sot[var_idx])
            new_inner_in_seqs = inner_in_mit_sot_var + (new_inner_in_var,)
            new_inner_in_mit_sot = remove(output_scan_args.inner_in_mit_sot, var_idx)
            new_outer_in_mit_sot = remove(output_scan_args.outer_in_mit_sot, var_idx)
            new_inner_in_sit_sot = tuple(output_scan_args.inner_in_sit_sot)
            new_outer_in_sit_sot = tuple(output_scan_args.outer_in_sit_sot)
            add_nit_sot = True
            new_inner_in_seqs = (output_scan_args.inner_in_sit_sot[var_idx],) + (
            new_inner_in_sit_sot = remove(output_scan_args.inner_in_sit_sot, var_idx)
            new_outer_in_sit_sot = remove(output_scan_args.outer_in_sit_sot, var_idx)
            new_inner_in_mit_sot = tuple(output_scan_args.inner_in_mit_sot)
            new_outer_in_mit_sot = tuple(output_scan_args.outer_in_mit_sot)
            add_nit_sot = True
            new_inner_in_seqs = (new_inner_in_var,)
            new_inner_in_mit_sot = tuple(output_scan_args.inner_in_mit_sot)
            new_outer_in_mit_sot = tuple(output_scan_args.outer_in_mit_sot)
            new_inner_in_sit_sot = tuple(output_scan_args.inner_in_sit_sot)
            new_outer_in_sit_sot = tuple(output_scan_args.outer_in_sit_sot)

        output_scan_args.inner_in_mit_sot = list(new_inner_in_mit_sot)
        output_scan_args.inner_in_sit_sot = list(new_inner_in_sit_sot)
        output_scan_args.outer_in_mit_sot = list(new_outer_in_mit_sot)
        output_scan_args.outer_in_sit_sot = list(new_outer_in_sit_sot)

            mit_sot_var_taps = cast(
                Tuple[int, ...], tuple(output_scan_args.mit_sot_in_slices[var_idx])
            taps = mit_sot_var_taps + (0,)
            new_mit_sot_in_slices = remove(output_scan_args.mit_sot_in_slices, var_idx)
            taps = (-1, 0)
            new_mit_sot_in_slices = tuple(output_scan_args.mit_sot_in_slices)
            taps = (0,)
            new_mit_sot_in_slices = tuple(output_scan_args.mit_sot_in_slices)

        output_scan_args.mit_sot_in_slices = list(new_mit_sot_in_slices)

        taps, new_inner_in_seqs = zip(
            *sorted(zip(taps, new_inner_in_seqs), key=lambda x: x[0])

        new_inner_in_seqs = tuple(output_scan_args.inner_in_seqs) + tuple(

        output_scan_args.inner_in_seqs = list(new_inner_in_seqs)

        slice_seqs = zip(
            -np.asarray(taps), [n if n < 0 else None for n in reversed(taps)]

        # XXX: If the caller passes the variables output by `aesara.scan`, it's
        # likely that this will fail, because those variables can sometimes be
        # slices of the actual outer-inputs (e.g. `out[1:]` instead of `out`
        # when `taps=[-1]`).
        var_slices = [new_outer_input_vars[oo_var][b:e] for b, e in slice_seqs]
        n_steps = at.min([at.shape(n)[0] for n in var_slices])

        output_scan_args.n_steps = n_steps

        new_outer_in_seqs = tuple(output_scan_args.outer_in_seqs) + tuple(
            v[:n_steps] for v in var_slices

        output_scan_args.outer_in_seqs = list(new_outer_in_seqs)

        if add_nit_sot:
            new_outer_in_nit_sot = tuple(output_scan_args.outer_in_nit_sot) + (n_steps,)
            new_outer_in_nit_sot = tuple(output_scan_args.outer_in_nit_sot)

        output_scan_args.outer_in_nit_sot = list(new_outer_in_nit_sot)

    # Now, we can add new inner-outputs for the custom calculations.
    # We don't need to create corresponding outer-outputs, because `Scan` will
    # do that when we call `Scan.make_node`.  All we need is a consistent
    # outer-inputs and inner-graph spec., which we should have in
    # `output_scan_args`.
    remapped_io_to_ii = inner_outs_to_new_inner_ins
    new_inner_out_nit_sot = tuple(output_scan_args.inner_out_nit_sot) + tuple(

    output_scan_args.inner_out_nit_sot = list(new_inner_out_nit_sot)

    return output_scan_args

def get_random_outer_outputs(
    scan_args: ScanArgs,
) -> List[Tuple[int, TensorVariable, TensorVariable]]:
    """Get the `MeasurableVariable` outputs of a `Scan` (well, its `ScanArgs`).

    A tuple of tuples containing the index of each outer-output variable, the
    outer-output variable itself, and the inner-output variable that
    is an instance of `MeasurableVariable`.
    rv_vars = []
    for n, oo_var in enumerate(
        [o for o in scan_args.outer_outputs if not isinstance(o.type, RandomType)]
        oo_info = scan_args.find_among_fields(oo_var)
        io_type =[("_", 6) + 1) :]
        inner_out_type = "inner_out_{}".format(io_type)
        io_var = getattr(scan_args, inner_out_type)[oo_info.index]
        if io_var.owner and isinstance(io_var.owner.op, MeasurableVariable):
            rv_vars.append((n, oo_var, io_var))
    return rv_vars

def construct_scan(
    scan_args: ScanArgs, **kwargs
) -> Tuple[List[TensorVariable], OrderedUpdates]:
    scan_op = Scan(
        scan_args.inner_inputs, scan_args.inner_outputs,, **kwargs
    node = scan_op.make_node(*scan_args.outer_inputs)
    updates = OrderedUpdates(zip(scan_args.outer_in_shared, scan_args.outer_out_shared))
    return node.outputs, updates

[docs]@_logprob.register(MeasurableScan) def logprob_ScanRV(op, values, *inputs, name=None, **kwargs): new_node = op.make_node(*inputs) scan_args = ScanArgs.from_node(new_node) rv_outer_outs = get_random_outer_outputs(scan_args) var_indices, rv_vars, io_vars = zip(*rv_outer_outs) value_map = {_rv: _val for _rv, _val in zip(rv_vars, values)} def create_inner_out_logp( value_map: Dict[TensorVariable, TensorVariable] ) -> TensorVariable: """Create a log-likelihood inner-output for a `Scan`.""" logp_parts, _ = conditional_logprob(realized=value_map) return logp_parts.values() logp_scan_args = convert_outer_out_to_in( scan_args, rv_vars, value_map, inner_out_fn=create_inner_out_logp, ) # Remove the shared variables corresponding to replaced terms. # TODO FIXME: This is a really dirty approach, because it effectively # assumes that all sampling is being removed, and, thus, all shared updates # relating to `RandomType`s. Instead, we should be more precise and only # remove the `RandomType`s associated with `values`. logp_scan_args.outer_in_shared = [ i for i in logp_scan_args.outer_in_shared if not isinstance(i.type, RandomType) ] logp_scan_args.inner_in_shared = [ i for i in logp_scan_args.inner_in_shared if not isinstance(i.type, RandomType) ] logp_scan_args.inner_out_shared = [ i for i in logp_scan_args.inner_out_shared if not isinstance(i.type, RandomType) ] # XXX TODO: Remove this properly # logp_scan_args.outer_out_shared = [] logp_scan_out, updates = construct_scan(logp_scan_args, mode=op.mode) # Automatically pick up updates so that we don't have to pass them around for key, value in updates.items(): key.default_update = value return logp_scan_out
@node_rewriter([Scan]) def find_measurable_scans(fgraph, node): r"""Finds `Scan`\s for which a `logprob` can be computed. This will convert said `Scan`\s into `MeasurableScan`\s. It also updates random variable and value variable mappings that have been specified for parts of a `Scan`\s outputs (e.g. everything except the initial values). """ if isinstance(node.op, MeasurableScan): return None if not hasattr(fgraph, "shape_feature"): return None # pragma: no cover curr_scanargs = ScanArgs.from_node(node) # Find the un-output `MeasurableVariable`s created in the inner-graph clients: Dict["Variable", List["Variable"]] = {} local_fgraph_topo = aesara.graph.basic.io_toposort( curr_scanargs.inner_inputs, [o for o in curr_scanargs.inner_outputs if not isinstance(o.type, RandomType)], clients=clients, ) for n in local_fgraph_topo: if isinstance(n.op, MeasurableVariable): non_output_node_clients = [ c for c in clients[n] if c not in curr_scanargs.inner_outputs ] if non_output_node_clients: # This node is a `MeasurableVariable`, but it depends on # variable that's not being output? # TODO: Why can't we make this a `MeasurableScan`? return None op = MeasurableScan( curr_scanargs.inner_inputs, curr_scanargs.inner_outputs,, mode=node.op.mode, ) new_node = op.make_node(*curr_scanargs.outer_inputs) return update_scan_value_vars(fgraph, curr_scanargs, node, new_node) def update_scan_value_vars( fgraph, curr_scanargs: ScanArgs, node: "Apply", new_node: "Apply" ) -> Dict["Variable", "Variable"]: r"""Remap user inputs that have been specified in terms of `Subtensor`\s of this `Scan`'s node's outputs. For example, the output that the user got was something like ``out[1:]`` for ``outputs_info = [{"initial": x0, "taps": [-1]}]``, so they likely passed ``{out[1:]: x_1T_vv}`` to ``joint_logprob``. Since ``out[1:]`` isn't really the output of a `Scan`, but a `Subtensor` of the output `out` of a `Scan`, we need to account for that. We do so by replacing the bound/valued `Subtensor` with the un-sliced `Scan` outputs and a sufficiently padded version of the bound value. TODO: Find a simpler--and perhaps more general--way of handling these `Subtensor`\s. """ # if not any(isinstance(out.owner.op, ValuedVariable) for out in node.outputs): # return new_node.outputs # Get any `Subtensor` outputs that have been applied to outputs of this # `Scan` (and get the corresponding indices of the outputs from this # `Scan`) output_clients: List[Tuple["Variable", int]] = sum( [ [ # This is expected to work for `Subtensor` `Op`s, # because they only ever have one output (cl.default_output(), i) for cl, _ in fgraph.get_clients(out) if isinstance(cl.op, Subtensor) ] for i, out in enumerate(node.outputs) ], [], ) indirect_rv_vars = [ (out.default_output(), out_idx) for out, out_idx in sum( [ [ (oo, o_i) for oo, _ in fgraph.get_clients(o_n) if isinstance(oo.op, ValuedVariable) ] for o_n, o_i in output_clients ], [], ) ] replacements = dict(zip(node.outputs, new_node.outputs)) if not indirect_rv_vars: return replacements # We need this for the `clone` in the loop that follows if aesara.config.compute_test_value != "off": compute_test_value(node) # We're going to replace the user's random variable/value variable mappings # with ones that map directly to outputs of this `Scan`. for rv_val_var, out_idx in indirect_rv_vars: # `rv_var` is the `*Subtensor*` rv_var, val_var = rv_val_var.owner.inputs # The full/un-`Subtensor`ed `Scan` output that we need to use full_out = node.outputs[out_idx] assert rv_var.owner.inputs[0] == full_out # A new value variable that spans the full output. # We don't want the old graph to appear in the new log-probability # graph, so we use the shape feature to (hopefully) get the shape # without the entire `Scan` itself. full_out_shape = tuple( fgraph.shape_feature.get_shape(fgraph, full_out, i) for i in range(full_out.ndim) ) new_val_var = at.empty(full_out_shape, dtype=full_out.dtype) # Set the parts of this new value variable that applied to the # user-specified value variable to the user's value variable subtensor_indices = indices_from_subtensor( rv_var.owner.inputs[1:], rv_var.owner.op.idx_list ) # E.g. for a single `-1` TAPS, `s_0T[1:] = s_1T` where `s_0T` is # `new_val_var` and `s_1T` is the user-specified value variable # that only spans times `t=1` to `t=T`. new_val_var = at.set_subtensor(new_val_var[subtensor_indices], val_var) # This is the outer-input that sets `s_0T[i] = taps[i]` where `i` # is a TAP index (e.g. a TAP of `-1` maps to index `0` in a vector # of the entire series). var_info = curr_scanargs.find_among_fields(full_out) alt_type =[("_", 6) + 1) :] outer_input_var = getattr(curr_scanargs, f"outer_in_{alt_type}")[var_info.index] # These outer-inputs are using by `aesara.scan.utils.expand_empty`, and # are expected to consist of only a single `set_subtensor` call. # That's why we can simply replace the first argument of the node. assert isinstance(outer_input_var.owner.op, inc_subtensor_ops) # We're going to set those values on our `new_val_var` so that it can # serve as a complete replacement for the old input `outer_input_var`. # from aesara.graph import clone_replace # new_val_var = outer_input_var.owner.clone_with_new_inputs( [new_val_var] + outer_input_var.owner.inputs[1:] ).default_output() new_full_out = new_node.outputs[out_idx] new_full_out = valued_variable.make_node( new_full_out, new_val_var ).default_output() del replacements[full_out] # We have to replace the output of the `ValuedVariable` in this case replacements[rv_val_var] = new_full_out return replacements @node_rewriter([Scan]) def add_opts_to_inner_graphs(fgraph, node): """Update the `Mode`(s) used to compile the inner-graph of a `Scan` `Op`. This is how we add the measurable IR rewrites to the "body" (i.e. inner-graph) of a `Scan` loop. """ if not isinstance(node.op, Scan): return None # Avoid unnecessarily re-applying this rewrite if getattr(node.op.mode, "had_logprob_rewrites", False): return None inner_fgraph = FunctionGraph( node.op.inner_inputs, node.op.inner_outputs, clone=True, copy_inputs=False, copy_orphans=False, ) logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])).rewrite( inner_fgraph ) new_outputs = list(inner_fgraph.outputs) # TODO FIXME: This is pretty hackish. new_mode = copy(node.op.mode) new_mode.had_logprob_rewrites = True op = Scan(node.op.inner_inputs, new_outputs,, mode=new_mode) new_node = op.make_node(*node.inputs) return dict(zip(node.outputs, new_node.outputs)) @_get_measurable_outputs.register(MeasurableScan) def _get_measurable_outputs_MeasurableScan(op, node): # TODO: This should probably use `get_random_outer_outputs` # scan_args = ScanArgs.from_node(node) # rv_outer_outs = get_random_outer_outputs(scan_args) return [o for o in node.outputs if not isinstance(o.type, RandomType)] measurable_ir_rewrites_db.register( "add_opts_to_inner_graphs", add_opts_to_inner_graphs, # out2in( # add_opts_to_inner_graphs, name="add_opts_to_inner_graphs", ignore_newtrees=True # ), "basic", "scan", ) measurable_ir_rewrites_db.register( "find_measurable_scans", find_measurable_scans, "basic", "scan", ) # Add scan canonicalizations that aren't in the canonicalization DB logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, "basic", "scan") logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, "basic", "scan")