import warnings
from copy import copy
from typing import TYPE_CHECKING, Optional, Sequence
import aesara
import aesara.tensor as at
import numpy as np
from aesara.compile.builders import OpFromGraph
from aesara.graph.basic import Apply, Constant
from aesara.graph.op import Op
from aesara.tensor.basic import make_vector
from aesara.tensor.random.utils import broadcast_params, normalize_size_param
from aesara.tensor.var import TensorVariable
from aeppl.abstract import MeasurableVariable, _get_measurable_outputs
if TYPE_CHECKING:
from aesara.tensor.random.utils import RandomStream
class DiracDelta(Op):
"""An `Op` that represents a Dirac-delta distribution."""
__props__ = ("rtol", "atol")
def __init__(self, rtol=1e-5, atol=1e-8):
self.rtol = rtol
self.atol = atol
def make_node(self, x):
x = at.as_tensor(x)
return Apply(self, [x], [x.type()])
def do_constant_folding(self, fgraph, node):
# Without this, the `Op` would be removed from the graph during
# canonicalization
return False
def perform(self, node, inp, out):
(x,) = inp
(z,) = out
warnings.warn(
"DiracDelta is a dummy Op that shouldn't be used in a compiled graph"
)
z[0] = x
def infer_shape(self, fgraph, node, input_shapes):
return input_shapes
dirac_delta = DiracDelta()
MeasurableVariable.register(DiracDelta)
def non_constant(x):
x = at.as_tensor_variable(x)
if isinstance(x, Constant):
# XXX: This isn't good for `size` parameters, because it could result
# in `at.get_vector_length` exceptions.
res = x.type()
res.tag = copy(res.tag)
if aesara.config.compute_test_value != "off":
res.tag.test_value = x.data
res.name = x.name
return res
else:
return x
def switching_process(
comp_rvs: Sequence[TensorVariable],
states: TensorVariable,
):
"""Construct a switching process over arbitrary univariate mixtures and a state sequence.
This simply constructs a graph of the following form:
at.stack(comp_rvs)[states, *idx]
where ``idx`` makes sure that `states` selects mixture components along all
the other axes.
Parameters
----------
comp_rvs
A list containing `MeasurableVariable` objects for each mixture component.
states
The hidden state sequence. It should have a number of states
equal to the size of `comp_dists`.
"""
states = at.as_tensor(states, dtype=np.int64)
comp_rvs_bcast = at.broadcast_arrays(*[at.as_tensor(rv) for rv in comp_rvs])
M_rv = at.stack(comp_rvs_bcast)
indices = (states,) + tuple(at.arange(d) for d in tuple(M_rv.shape)[1:])
rv_var = M_rv[indices]
return rv_var
class DiscreteMarkovChainFactory(OpFromGraph):
"""An `Op` constructed from an Aesara graph that represents a discrete Markov chain.
This "composite" `Op` allows us to mark a sub-graph as measurable and
assign a `_logprob` dispatch implementation.
As far as broadcasting is concerned, this `Op` has the following
`RandomVariable`-like properties:
ndim_supp = 1
ndims_params = (3, 1)
TODO: It would be nice to express this as a `Blockwise` `Op`.
"""
default_output = 0
MeasurableVariable.register(DiscreteMarkovChainFactory)
@_get_measurable_outputs.register(DiscreteMarkovChainFactory)
def _get_measurable_outputs_DiscreteMarkovChainFactory(op, node):
return [node.outputs[0]]
def create_discrete_mc_op(srng, size, Gammas, gamma_0):
"""Construct a `DiscreteMarkovChainFactory` `Op`.
This returns a `Scan` that performs the follow:
states[0] = categorical(gamma_0)
for t in range(1, N):
states[t] = categorical(Gammas[t, state[t-1]])
The Aesara graph representing the above is wrapped in an `OpFromGraph` so
that we can easily assign it a specific log-probability.
TODO: Eventually, AePPL should be capable of parsing more sophisticated
`Scan`s and producing nearly the same log-likelihoods, and the use of
`OpFromGraph` will no longer be necessary.
"""
# Again, we need to preserve the length of this symbolic vector, so we do
# this.
size_param = make_vector(
*[non_constant(size[i]) for i in range(at.get_vector_length(size))]
)
size_param.name = "size"
# We make shallow copies so that unwanted ancestors don't appear in the
# graph.
Gammas_param = non_constant(Gammas).clone()
Gammas_param.name = "Gammas_param"
gamma_0_param = non_constant(gamma_0).clone()
gamma_0_param.name = "gamma_0_param"
bcast_Gammas_param, bcast_gamma_0_param = broadcast_params(
(Gammas_param, gamma_0_param), (3, 1)
)
# Sample state 0 in each state sequence
state_0 = srng.categorical(
bcast_gamma_0_param,
size=tuple(size_param) + tuple(bcast_gamma_0_param.shape[:-1]),
# size=at.join(0, size_param, bcast_gamma_0_param.shape[:-1]),
)
N = bcast_Gammas_param.shape[-3]
states_shape = tuple(state_0.shape) + (N,)
bcast_Gammas_param = at.broadcast_to(
bcast_Gammas_param, states_shape + tuple(bcast_Gammas_param.shape[-2:])
)
def loop_fn(n, state_nm1, Gammas_inner):
gamma_t = Gammas_inner[..., n, :, :]
idx = tuple(at.ogrid[[slice(None, d) for d in tuple(state_0.shape)]]) + (
state_nm1.T,
)
gamma_t = gamma_t[idx]
state_n = srng.categorical(gamma_t)
return state_n.T
res, updates = aesara.scan(
loop_fn,
outputs_info=[{"initial": state_0.T, "taps": [-1]}],
sequences=[at.arange(N)],
non_sequences=[bcast_Gammas_param],
# strict=True,
)
# TODO FIXME: This is unreliable and needs to be replaced with explicit
# update support in `OpFromGraph`.
update_outputs = [state_0.owner.inputs[0].default_update]
update_outputs.extend(updates.values())
return (
DiscreteMarkovChainFactory(
[size_param, Gammas_param, gamma_0_param],
[res.T] + update_outputs,
inline=True,
on_unused_input="ignore",
),
updates,
)
[docs]def discrete_markov_chain(
Gammas: TensorVariable,
gamma_0: TensorVariable,
size=None,
srng: Optional["RandomStream"] = None,
**kwargs
):
"""Construct a first-order discrete Markov chain distribution.
This defines a random vector that consists of state indicator values (i.e.
``0`` to ``M - 1`` where ``M`` is the dimensionality of the state space)
that are driven by a discrete Markov chain.
Given an array of transition probability matrices ``Gamma`` and initial
state probabilities ``gamma0``, `discrete_markov_chain` represents the
probability distribution of ``states`` defined by:
.. code::
states[0] = categorical(gamma_0)
for t in range(1, N):
states[t] = categorical(Gammas[t, state[t-1]])
Example
-------
.. code::
import aesara.tensor as at
import numpy as np
num_steps = 10
Gammas_base = np.array([[.5, .5], [.5, .5]])
Gammas = np.broadcast(Gammas_base, (num_steps, 2, 2))
gamma0 = np.r_[0.5, 0.5]
dmc, updates = discrete_markov_chain(test_Gamma, test_gamma_0)
Parameters
----------
Gammas
An array of transition probability matrices. `Gammas` takes the
shape ``... x N x M x M`` for a state sequence of length ``N`` having
``M``-many distinct states. Each row, ``r``, in a transition probability
matrix gives the probability of transitioning from state ``r`` to each
other state.
gamma_0
The initial state probabilities. The last dimension should be length ``M``,
i.e. the number of distinct states.
"""
gamma_0 = at.as_tensor_variable(gamma_0)
assert Gammas.ndim >= 3
Gammas = at.as_tensor_variable(Gammas)
size = normalize_size_param(size)
if srng is None:
srng = at.random.RandomStream()
dmc_op, updates = create_discrete_mc_op(srng, size, Gammas, gamma_0)
rv_var = dmc_op(size, Gammas, gamma_0)
# TODO FIXME: This is unreliable and needs to be replaced with explicit
# update support in `OpFromGraph`.
updates = {
rv_var.owner.inputs[-2]: rv_var.owner.outputs[-2],
rv_var.owner.inputs[-1]: rv_var.owner.outputs[-1],
}
testval = kwargs.pop("testval", None)
if testval is not None:
rv_var.tag.test_value = testval
return rv_var, updates