Source code for larch.model.mixtures

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import MutableSequence

from .._optional import jax, jnp, js
from .param_core import ParameterBucket


[docs] class Mixture(ABC): """ A mixing distribution for a model parameter. This is an abstract base class. Subclasses must implement the `param_names`, `prep`, `roll`, and `to_dict` methods. """
[docs] def __init__(self): self._parent = None
@abstractmethod def param_names(self): """ Named parameters referenced by this mixture, and their default values. Returns ------- dict """ raise NotImplementedError() @abstractmethod def prep(self, bucket: ParameterBucket): raise NotImplementedError() @abstractmethod def roll(self, draws: jax.Array, parameters: jax.Array) -> jax.Array: """ Apply this mixing distribution to some random draws. Parameters ---------- draws : jax.Array, shape [...] A set of pseudo-random draws, nominally uniformly distributed in the range 0 to 1. parameters : jax.Array, shape [..., n_params] An array of parameters, previously broadcasted to the same shape as the draws, plus the parameter dimension itself. Returns ------- parameters : jax.Array, shape [..., n_params] The computed distribution of the target parameter has been overlaid. """ raise NotImplementedError() @abstractmethod def to_dict(self): raise NotImplementedError()
class MixtureList(MutableSequence): def __init__(self, init=None): self._parent = None self._mixtures = list() if init is not None: for i in init: if isinstance(i, Mixture): self._mixtures.append(i) else: raise TypeError( f"members of {self.__class__.__name__} must be Mixture" ) def set_parent(self, instance): self._parent = instance for i in self._mixtures: i._parent = instance def __fresh(self, instance): newself = MixtureList() newself._instance = instance setattr(instance, self.private_name, newself) return newself def __mangle(self): try: self._parent.mangle() except AttributeError: pass def __get__(self, instance, owner): if instance is None: return self newself = getattr(instance, self.private_name, None) if newself is None: newself = self.__fresh(instance) return newself def __set__(self, instance, values): newself = getattr(instance, self.private_name, None) if newself is None: newself = self.__fresh(instance) else: newself._mixtures.clear() newself.__init__(values) newself.set_parent(instance) newself.__mangle() def __delete__(self, instance): newself = getattr(instance, self.private_name, None) if newself is not None and len(newself): newself.__mangle() if newself is None: newself = self.__fresh(instance) else: newself._mixtures.clear() newself.__init__() def __set_name__(self, owner, name): self.name = name self.private_name = "_private_" + name def __getitem__(self, item): return self._mixtures[item] def __setitem__(self, key: int, value): if not isinstance(value, Mixture): raise TypeError("items must be of type Mixture") self._mixtures[key] = value self.__mangle() self._mixtures[key]._parent = self._parent def __delitem__(self, key): del self._mixtures[key] self.__mangle() def __len__(self): return len(self._mixtures) def insert(self, index, value): if not isinstance(value, Mixture): raise TypeError("items must be of type Mixture") self._mixtures.insert(index, value) self.__mangle() self._mixtures[index]._parent = self._parent def __repr__(self): return repr(self._mixtures) def _is_duplicate(self, value): for i in self._mixtures: if i == value: return True return False def to_list(self): return [i.to_dict() for i in self._mixtures] def from_list(self, j): self._mixtures.clear() for i in j: kind = i.pop("type") if kind is None: raise ValueError("missing mixture type") cls = globals()[kind] self._mixtures.append(cls(**i))
[docs] class Normal(Mixture): """A normal distribution applied to a model parameter."""
[docs] def __init__(self, mean: str, std: str): super().__init__() self.mean_ = mean self.std_ = std self.imean = -1 self.istd = -1 self.default_mean = 0.0 self.default_std = 0.001
def __repr__(self): return f"{self.__class__.__name__}({self.mean_!r}, {self.std_!r})" def __eq__(self, other): return ( isinstance(other, Normal) and self.mean_ == other.mean_ and self.std_ == other.std_ ) def param_names(self): return { self.mean_: self.default_mean, self.std_: self.default_std, } def prep(self, bucket: ParameterBucket): self.imean = bucket.get_param_loc(self.mean_) self.istd = bucket.get_param_loc(self.std_) def roll(self, draw_vec, parameters): assert self.imean >= 0 assert self.istd >= 0 v = js.stats.norm.ppf( draw_vec, parameters[..., self.imean], parameters[..., self.istd] ) parameters = parameters.at[..., self.imean].set(v) return parameters def to_dict(self): return dict( type=self.__class__.__name__, mean=self.mean_, std=self.std_, )
[docs] class LogNormal(Normal): """ A log-normal distribution applied to a model parameter. The "mean" and "std" parameters are the mean and standard deviation of the underlying normal distribution, not the mean and standard deviation of the log-normal distribution. """ def roll(self, draw_vec, parameters): assert self.imean >= 0 assert self.istd >= 0 v = js.stats.norm.ppf( draw_vec, parameters[..., self.imean], parameters[..., self.istd] ) parameters = parameters.at[..., self.imean].set(jnp.exp(v)) return parameters
[docs] class NegLogNormal(Normal): """ The negative of a log-normal distribution applied to a model parameter. This is convenient when it is desired to ensure the parameter must be negative, as expected for coefficients associated with travel time or travel costs. The "mean" and "std" parameters are the mean and standard deviation of the underlying normal distribution, not the mean and standard deviation of the log-normal distribution. """ def roll(self, draw_vec, parameters): assert self.imean >= 0 assert self.istd >= 0 v = js.stats.norm.ppf( draw_vec, parameters[..., self.imean], parameters[..., self.istd] ) parameters = parameters.at[..., self.imean].set(-jnp.exp(v)) return parameters
class Uniform(Mixture): """A uniform distribution applied to a model parameter.""" def __init__(self, low: str, high: str): super().__init__() self.low_ = low self.high_ = high self.ilow = -1 self.ihigh = -1 self.default_low = 0.0 self.default_high = 1.0 def __repr__(self): return f"{self.__class__.__name__}({self.low_!r}, {self.high_!r})" def __eq__(self, other): return ( isinstance(other, Uniform) and self.low_ == other.low_ and self.high_ == other.high_ ) def param_names(self): return { self.low_: self.default_low, self.high_: self.default_high, } def prep(self, bucket: ParameterBucket): self.ilow = bucket.get_param_loc(self.low_) self.ihigh = bucket.get_param_loc(self.high_) def roll(self, draw_vec, parameters): assert self.ilow >= 0 assert self.ihigh >= 0 spread = draw_vec * (parameters[..., self.ihigh] - parameters[..., self.ilow]) parameters = parameters.at[..., self.ilow].set( parameters[..., self.ilow] + spread ) return parameters def to_dict(self): return dict( type=self.__class__.__name__, low=self.low_, high=self.high_, )