Source code for larch.model.jaxmodel

from __future__ import annotations

import logging
import warnings
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
from xarray import Dataset

from .._optional import jax, jnp
from ..compiled import compiledmethod, jitmethod, reset_compiled_methods
from ..folding import fold_dataset
from ..optimize import OptimizeMixin
from .numbamodel import NumbaModel

if TYPE_CHECKING:
    from larch.util import dictx
    from larch.util.excel import ExcelWriter

logger = logging.getLogger(__name__)


def _get_jnp_array(dataset, name):
    if name not in dataset:
        return None
    return jnp.asarray(dataset[name])


def _as_jnp_array(obj):
    if obj is None:
        return None
    return jnp.asarray(obj)


class PanelMixin:
    def __init__(self, *args, **kwargs):
        self._groupid = kwargs.pop("groupid", None)

    @property
    def groupid(self):
        return self._groupid

    @groupid.setter
    def groupid(self, g):
        if g is None or isinstance(g, str):
            if self._groupid != g:
                self.mangle()
        else:
            self.mangle()
        self._groupid = g


class MangleOnChange:
    def __init__(self, *req_type):
        self.req_type = req_type

    def __set_name__(self, owner, name):
        self.name = name
        self.private_name = "_" + name

    def __get__(self, instance, owner):
        if instance is None:
            return self
        else:
            try:
                result = getattr(instance, self.private_name)
            except AttributeError:
                if self.req_type == (bool,):
                    default = False
                elif len(self.req_type):
                    default = self.req_type[0]()
                else:
                    default = None
                setattr(instance, self.private_name, default)
                result = getattr(instance, self.private_name)
            return result

    def __set__(self, instance, value):
        if self.req_type == (bool,):
            value = bool(value)
        elif self.req_type:
            if not isinstance(value, self.req_type):
                raise TypeError(
                    f"attribute `{self.name}` must be of type {self.req_type} not {type(value)}"
                )
        old_value = getattr(instance, self.private_name, "<--missing-->")
        if old_value == value:
            return
        else:
            setattr(instance, self.private_name, value)
            try:
                instance.mangle()
            except AttributeError:
                pass


[docs] class Model(NumbaModel, OptimizeMixin, PanelMixin):
[docs] def __init__(self, *args, **kwargs): PanelMixin.__init__(self, *args, **kwargs) super().__init__(*args, **kwargs) self._n_draws = 100 self._draws = None self.prerolled_draws = True self.common_draws = False
@property def compute_engine(self): engine = self._compute_engine if engine is None: if not jax: engine = "numba" else: engine = "jax" return engine @compute_engine.setter def compute_engine(self, engine): if engine not in {"numba", "jax", None}: raise ValueError("invalid compute engine") self._compute_engine = engine if self._compute_engine == "jax" and not jax: warnings.warn( "jax is not installed, falling back to numba", stacklevel=2, ) self._compute_engine = "numba" if self._compute_engine == "jax" and self.use_streaming: warnings.warn( "setting use_streaming to False, jax is not yet compatible", stacklevel=2, ) self.use_streaming = False prerolled_draws = MangleOnChange(bool) common_draws = MangleOnChange(bool) n_draws = MangleOnChange(int) seed = MangleOnChange() # @property # def n_draws(self): # return self._n_draws # # @n_draws.setter # def n_draws(self, n): # if n == self._n_draws: # return # else: # self._n_draws = n # self.mangle() def mangle(self, data=True, structure=True): super().mangle(data, structure) self._draws = None reset_compiled_methods(self) def unmangle(self, force=False, structure_only=False): if not self._mangled and not force: return marker = f"_currently_unmangling_{__file__}" if getattr(self, marker, False): return try: setattr(self, marker, True) super().unmangle(force=force, structure_only=structure_only) for mix in self.mixtures: mix.prep(self._parameter_bucket) if not structure_only: if self.groupid is not None and self.dataset is not None: self.dataset = fold_dataset(self.dataset, self.groupid) finally: delattr(self, marker) def reflow_data_arrays(self): """Reload the internal data_arrays so they are consistent with the datatree.""" if self.compute_engine != "jax": return super().reflow_data_arrays() if self.graph is None: self._data_arrays = None return if self._should_preload_data: datatree = self.datatree else: datatree = self.datatree.replace_datasets( { self.datatree.root_node_name: self.datatree.root_dataset.isel( {self.datatree.CASEID: slice(0, 1)} ) } ) if datatree is not None: request = self.required_data() if isinstance(self.groupid, str): request["group_co"] = self.groupid from .data_arrays import prepare_data dataset, self.dataflows = prepare_data( datasource=datatree, request=request, float_dtype=self.float_dtype, cache_dir=datatree.cache_dir, flows=self.dataflows, make_unused_flows=self.use_streaming, ) if isinstance(self.groupid, str): dataset = fold_dataset(dataset, "group") elif self.groupid is not None: dataset = fold_dataset(dataset, self.groupid) self.dataset = dataset try: self._data_arrays = self.dataset.dc.to_arrays( self.graph, float_dtype=self.float_dtype, ) except KeyError: # no defined caseid dimension, JAX only self._data_arrays = None self.work_arrays = None else: if self.work_arrays is not None: self._rebuild_work_arrays() @property def data_as_loaded(self): return self._dataset @property def dataset(self): """larch.Dataset : Data arrays as loaded for model computation.""" super().unmangle() if self._dataset is None: self.reflow_data_arrays() try: return self._dataset except AttributeError: return None @dataset.setter def dataset(self, dataset): if dataset is self._dataset: return # from xarray import Dataset as _Dataset if isinstance(dataset, Dataset): if self.groupid is not None: dataset = fold_dataset(dataset, self.groupid) self._dataset = dataset self._data_arrays = None self._rebuild_fixed_arrays() # elif isinstance(dataset, _Dataset): # if self.groupid is not None: # dataset = fold_dataset(dataset, self.groupid) # self._dataset = Dataset(dataset) # self._data_arrays = None # self._rebuild_fixed_arrays() else: raise TypeError(f"dataset must be Dataset not {type(dataset)}") def make_random_draws(self, engine="numpy"): self.unmangle() for i in self.mixtures: i.prep(self._parameter_bucket) n_panels = self.dataset.dc.n_panels n_mixtures = len(self.mixtures) n_draws = self.n_draws draws = None if self._draws is not None: if self.common_draws and self._draws.shape == (n_draws, n_mixtures): draws = self._draws if not self.common_draws and self._draws.shape == ( n_panels, n_draws, n_mixtures, ): draws = self._draws if draws is None: if engine == "numpy": seed = self.seed or 0 if self.common_draws: if n_draws > 0 and n_mixtures > 0: draws, seed = self._make_random_draws_numpy( n_draws, n_mixtures, seed ) else: if n_draws > 0 and n_mixtures > 0 and n_panels > 0: draws, seed = self._make_random_draws_numpy_2( n_draws, n_mixtures, n_panels, seed ) else: draws = None elif engine == "jax": seed = self.seed or 0 rk = jax.random.PRNGKey(seed) if self.common_draws: if n_draws > 0 and n_mixtures > 0: draws = self._make_random_draws_out(n_draws, n_mixtures, rk)[0] else: if n_draws > 0 and n_mixtures > 0 and n_panels > 0: draws = self._make_random_draws_out_2( n_draws, n_mixtures, n_panels, rk )[0] else: draws = None else: raise ValueError(f"unknown random engine {engine!r}") if self.prerolled_draws: self._draws = draws return draws def _make_random_draws_numpy(self, n_draws, n_mixtures, seed): if isinstance(seed, np.random.Generator): rgen = seed else: rgen = np.random.default_rng(seed) draws = ( rgen.random(size=[n_draws, n_mixtures], dtype=np.float32) + np.arange(n_draws, dtype=np.float32)[:, np.newaxis] ) for i in range(n_mixtures): rgen.shuffle(draws[:, i]) return np.clip(draws / n_draws, 0, np.float32(1 - 1e-7)), rgen def _make_random_draws_numpy_2(self, n_draws, n_mixtures, n_panels, seed): if isinstance(seed, np.random.Generator): rgen = seed else: rgen = np.random.default_rng(seed) draws = ( rgen.random(size=[n_panels, n_draws, n_mixtures], dtype=np.float32) + np.arange(n_draws, dtype=np.float32)[np.newaxis, :, np.newaxis] ) for i in range(n_mixtures): for p in range(n_panels): rgen.shuffle(draws[p, :, i]) return np.clip(draws / n_draws, 0, np.float32(1 - 1e-7)), rgen # @jitmethod(static_argnums=(0,1), static_argnames=('n_draws', 'n_mixtures')) def _make_random_draws_out(self, n_draws, n_mixtures, random_key): draws = jax.random.uniform( random_key, [n_draws, n_mixtures], ) def body(i, carry): draws, rnd_key = carry rnd_key, subkey = jax.random.split(rnd_key) draws = draws.at[:, i].add(jax.random.permutation(subkey, draws.shape[0])) return draws, rnd_key draws, random_key = jax.lax.fori_loop( 0, draws.shape[1], body, (draws, random_key) ) return draws * (1 / n_draws), random_key # @jitmethod(static_argnums=(0,1,2), static_argnames=('n_draws', 'n_mixtures', 'n_cases')) def _make_random_draws_out_2(self, n_draws, n_mixtures, n_cases, random_key): def body(carry, x): rkey = carry draws, rkey = self._make_random_draws_out(n_draws, n_mixtures, rkey) return rkey, draws random_key, draws = jax.lax.scan(body, random_key, None, n_cases) return draws, random_key def check_random_draws(self, engine="numpy"): if self.mixtures: if self.prerolled_draws: if self._draws is None: self.make_random_draws(engine=engine) def apply_random_draws(self, parameters, draws=None): # if draws is None: # draws = self._draws parameters = jnp.broadcast_to( parameters, [*draws.shape[:-1], parameters.shape[0]] ) for mix_n, mix in enumerate(self.mixtures): u = draws[..., mix_n] parameters = mix.roll(u, parameters) return parameters @jitmethod def _jax_quantity(self, params, databundle): ca = databundle.get("ca", None) n_alts = self.dataset.dc.n_alts n_nodes = len(self.graph) u = jnp.zeros([n_nodes]) if ca is not None and self._fixed_arrays.qca_param_slot.size: u = u.at[:n_alts].add( jnp.dot( ca[..., self._fixed_arrays.qca_data_slot], jnp.exp(params[self._fixed_arrays.qca_param_slot]) * self._fixed_arrays.qca_scale, ) ) return u @jitmethod def jax_quantity(self, params): ca = _get_jnp_array(self.dataset, "ca") av = _as_jnp_array(self._data_arrays.av) if av is not None: depth = av.ndim - 1 elif ca is not None: depth = ca.ndim - 2 else: raise ValueError("missing data") f = self._jax_quantity for _level in range(depth): f = jax.vmap(f, in_axes=(None, 0)) return f(params, {"ca": ca, "av": av}) @jitmethod def _jax_utility(self, params, databundle): ca = databundle.get("ca", None) co = databundle.get("co", None) av = databundle.get("av", None) if co is None: n_vars_co = 0 else: n_vars_co = co.shape[-1] n_alts = self.dataset.dc.n_alts n_nodes = len(self.graph) x = jnp.zeros([self.dataset.dc.n_alts, n_vars_co + 1]) x = x.at[self._fixed_arrays.uco_alt_slot, self._fixed_arrays.uco_data_slot].add( params[self._fixed_arrays.uco_param_slot] * self._fixed_arrays.uco_scale ) if self._fixed_arrays.qca_param_slot.size: theta = params[self._fixed_arrays.qscale_param_slot] q = self._jax_quantity(params, databundle) log_q = jnp.log(jnp.clip(q, 1e-15, 1e15)) u = jnp.clip(log_q, -1e15) * theta else: u = jnp.zeros([n_nodes]) if ca is not None and self._fixed_arrays.uca_param_slot.size: u = u.at[:n_alts].add( jnp.dot( ca[..., self._fixed_arrays.uca_data_slot], params[self._fixed_arrays.uca_param_slot] * self._fixed_arrays.uca_scale, ) ) if co is not None: temp = jnp.dot(co, x[:, :-1].T) u = u.at[:n_alts].add(temp) u = u.at[:n_alts].add(x[:, -1].T) if av is not None and not self.availability_any: u = u.at[:n_alts].set(jnp.where(av[:n_alts], u[:n_alts], -jnp.inf)) return u @jitmethod def jax_utility(self, params): ca = _get_jnp_array(self.dataset, "ca") co = _get_jnp_array(self.dataset, "co") av = _as_jnp_array(self._data_arrays.av) if av is not None: depth = av.ndim - 1 elif co is not None: depth = co.ndim - 1 elif ca is not None: depth = ca.ndim - 2 else: raise ValueError("missing data") f = self._jax_utility for _level in range(depth): f = jax.vmap(f, in_axes=(None, 0)) return f(params, {"ca": ca, "co": co, "av": av}) def __utility_for_nest(self, slot): nest_code = self.graph.standard_sort[slot] child_slots = self.graph.successor_slots(nest_code) mu_name = self.graph.nodes[nest_code].get("parameter") if mu_name is None: mu_slot = -1 else: mu_slot = self.get_param_loc(mu_name) # @jit def u_nest(params, utility_array, array_av): mu = params[mu_slot] if mu_slot >= 0 else 1.0 shifter = utility_array[jnp.asarray(child_slots)].max() carry = jnp.zeros(utility_array.shape[:-1]) # num_av = jnp.sum(utility_array[..., child_slots] > -1e30, axis=-1) num_av = array_av[..., slot] def body(carry, child_slot): carry = jnp.add( carry, jnp.where( num_av > 1, jnp.exp( jnp.clip(utility_array[..., child_slot] - shifter, -1e37) / mu ), jnp.exp( jnp.clip(utility_array[..., child_slot] - shifter, -1e37) ), ), ) return carry, None carry, _ = jax.lax.scan(body, carry, child_slots) utility_array = utility_array.at[..., slot].add( jnp.where( num_av > 1, jnp.clip(jnp.log(carry), -1e38) * mu, jnp.clip(jnp.log(carry), -1e39), ) + shifter ) return utility_array return u_nest def _mu_slots(self): n_params = self.n_params mu_names = [ self.graph.nodes[i].get("parameter") for i in self.graph.standard_sort ] mu_slots = jnp.asarray( [ (self.get_param_loc(mu_name) if mu_name is not None else n_params) for mu_name in mu_names ] ) return mu_slots @compiledmethod def utility_for_nests(self): n_alts = self.graph.n_elementals() n_nodes = len(self.graph) def u_nesting_none(out, beta, array_av, mu_slots): # TODO Maybe filter on av return out.at[..., -1].set(jnp.log(jnp.exp(out[..., :-1]).sum(-1))) if n_nodes - n_alts <= 1: return u_nesting_none if n_nodes - n_alts < 3: # @jit def u_nesting_few(out, beta, array_av, mu_slots): for slot in range(n_alts, n_nodes): out = self.__utility_for_nest(slot)(beta, out, array_av) return out return u_nesting_few # many nests, use more efficient loop def add_u_to_parent(u, params, child_slot, parent_slot, avail, mu_slots): mu = params[mu_slots[parent_slot]] u = u.at[..., parent_slot].add( jnp.where( avail[..., parent_slot] > 1, jnp.exp(jnp.clip(u[..., child_slot], -1e37) / mu), jnp.exp(jnp.clip(u[..., child_slot], -85)), ) ) return u def max_u_to_parent(u, bs, child_slot, parent_slot): bs = bs.at[..., parent_slot].set( jnp.maximum( u[..., child_slot], bs[..., parent_slot], ) ) return bs def log_self(u, bs, params, self_slot, avail, mu_slots): mu = params[mu_slots[self_slot]] u = u.at[..., self_slot].set( jnp.maximum( jnp.where( avail[..., self_slot] > 1, jnp.log(u[..., self_slot]) * mu, jnp.log(u[..., self_slot]), ), bs[..., self_slot], ) ) return u slotarray = np.stack(self.graph.edge_slot_arrays()).T # @jax.jit def u_rollup(utility_array, parameter_vector, avail_ca, mu_slots): n_params = parameter_vector.size params = jnp.ones(n_params + 1, dtype=parameter_vector.dtype) params = params.at[:n_params].set(parameter_vector) # the backstop prevents underflow when mu is too small backstop = jnp.full_like(utility_array, -jnp.inf) def body(carry, xs): u, bs, params = carry up_slot, dn_slot, firstvisit, allocslot = xs # if firstvisit >= 0 and dn_slot>=n_alts: # u = log_self(u, params, dn_slot) u = jax.lax.cond( (firstvisit >= 0) & (dn_slot >= n_alts), lambda u: log_self(u, bs, params, dn_slot, avail_ca, mu_slots), lambda u: u, operand=u, ) u = add_u_to_parent(u, params, dn_slot, up_slot, avail_ca, mu_slots) bs = max_u_to_parent(u, bs, dn_slot, up_slot) return (u, bs, params), None (utility_array, backstop, _ignore_1), _ignore_2 = jax.lax.scan( body, (utility_array, backstop, params), slotarray ) # log utility at root utility_array = utility_array.at[..., -1].set( jnp.clip(jnp.log(utility_array[..., -1]), -1e38) ) return utility_array return u_rollup def __probability_for_nest(self, slot): nest_code = self.graph.standard_sort[slot] child_slots = self.graph.successor_slots(nest_code) mu_name = self.graph.nodes[nest_code].get("parameter") if mu_name is None: mu_slot = -1 else: mu_slot = self.get_param_loc(mu_name) # @jit def probability_nest(params, utility_array, probability_array): mu = params[mu_slot] if mu_slot >= 0 else 1.0 u_nest = utility_array[..., slot] def body(carry, child_slot): diff = jnp.clip(utility_array[..., child_slot], -1e33) - jnp.clip( u_nest, -1e33 ) add_me = diff / mu carry = carry.at[..., child_slot].set(add_me + carry[..., slot]) return carry, None probability_array, _ = jax.lax.scan(body, probability_array, child_slots) return probability_array return probability_nest @jitmethod def _jax_log_probability_bundle(self, params, databundle): ca, co, ch, av, wt = databundle return self._jax_log_probability(params, ca, co, av) @jitmethod def _jax_utility_include_nests(self, params, databundle): av = databundle.get("av", None) n_alts = self.dataset.dc.n_alts utility_array = self._jax_utility(params, databundle) # downshift to prevent over/underflow shifter = utility_array[:n_alts].max(axis=-1) if av is not None and not self.availability_any: utility_array = utility_array.at[:n_alts].add( jnp.where(av[:n_alts], -shifter, 0) ) else: utility_array = utility_array.at[:n_alts].add(-shifter) mu_slots = self._mu_slots() utility_array = self.utility_for_nests(utility_array, params, av, mu_slots) return utility_array + shifter @jitmethod def jax_utility_include_nests(self, params): ca = _get_jnp_array(self.dataset, "ca") co = _get_jnp_array(self.dataset, "co") av = _as_jnp_array(self._data_arrays.av) if av is not None: depth = av.ndim - 1 elif co is not None: depth = co.ndim - 1 elif ca is not None: depth = ca.ndim - 2 else: raise ValueError("missing data") f = self._jax_utility_include_nests for _level in range(depth): f = jax.vmap(f, in_axes=(None, 0)) return f(params, {"ca": ca, "co": co, "av": av})
[docs] def quantity( self, x=None, *, start_case=None, stop_case=None, step_case=None, return_type=None, ): if self.compute_engine != "jax": return super().quantity( x=x, start_case=start_case, stop_case=stop_case, step_case=step_case, return_type=return_type, ) if x is not None: self.pvals = x return self.jax_quantity(self.pvals)
[docs] def utility( self, x=None, *, start_case=None, stop_case=None, step_case=None, return_format=None, ): if self.compute_engine != "jax": return super().utility( x=x, start_case=start_case, stop_case=stop_case, step_case=step_case, return_format=return_format, ) if x is not None: self.pvals = x return self.jax_utility(self.pvals)
@jitmethod def _jax_log_probability(self, params, databundle): av = databundle.get("av", None) n_alts = self.dataset.dc.n_alts n_nodes = len(self.graph) utility_array = self._jax_utility(params, databundle) # downshift to prevent over/underflow shifter = utility_array[:n_alts].max(axis=-1) if av is not None and not self.availability_any: utility_array = utility_array.at[:n_alts].add( jnp.where(av[:n_alts], -shifter, 0) ) else: utility_array = utility_array.at[:n_alts].add(-shifter) mu_slots = self._mu_slots() utility_array = self.utility_for_nests(utility_array, params, av, mu_slots) logprobability = jnp.zeros_like(utility_array) for slot in range(n_nodes, n_alts, -1): logprobability = self.__probability_for_nest(slot - 1)( params, utility_array, logprobability ) # clipping log probability at 0, prevents underflow/overflow # when one of the nests only includes very bad alternatives. logprobability = jnp.clip(logprobability, None, 0) return logprobability @jitmethod def jax_log_probability(self, params): ca = _get_jnp_array(self.dataset, "ca") co = _get_jnp_array(self.dataset, "co") av = _as_jnp_array(self._data_arrays.av) if av is not None: depth = av.ndim - 1 elif co is not None: depth = co.ndim - 1 elif ca is not None: depth = ca.ndim - 2 else: raise ValueError("missing data") f = self._jax_log_probability for _level in range(depth): f = jax.vmap(f, in_axes=(None, 0)) return f(params, {"ca": ca, "co": co, "av": av}) @jitmethod def _jax_probability(self, params, databundle): n_alts = self.dataset.dc.n_alts return jnp.exp(self._jax_log_probability(params, databundle)[:n_alts]) @jitmethod def jax_probability(self, params): ca = _get_jnp_array(self.dataset, "ca") co = _get_jnp_array(self.dataset, "co") av = _as_jnp_array(self._data_arrays.av) if av is not None: depth = av.ndim - 1 elif co is not None: depth = co.ndim - 1 elif ca is not None: depth = ca.ndim - 2 else: raise ValueError("missing data") f = self._jax_probability for _level in range(depth): f = jax.vmap(f, in_axes=(None, 0)) return f(params, {"ca": ca, "co": co, "av": av}) @jitmethod def _jax_likelihood(self, params, databundle): n_alts = self.dataset.dc.n_alts ch = databundle.get("ch", None)[:n_alts] pr = self._jax_probability(params, databundle)[:n_alts] likely = jnp.where(ch, pr, 1.0) # TODO make power if needed return likely @jitmethod(static_argnums=(3,), static_argnames="n_draws") def _jax_random_params(self, params, databundle, groupbundle=None, n_draws=100): if self.prerolled_draws: draws = groupbundle.get("draws", None) else: rk = groupbundle.get("rk", None) draws, _ = self._make_random_draws_out(n_draws, len(self.mixtures), rk) rand_params = self.apply_random_draws(params, draws) return rand_params @jitmethod def jax_random_params(self, params): ca = _get_jnp_array(self.dataset, "ca") co = _get_jnp_array(self.dataset, "co") av = _as_jnp_array(self._data_arrays.av) ch = _as_jnp_array(self._data_arrays.ch) n_draws = self.n_draws seed = self.seed or 42 if av is not None: depth = av.ndim - 1 shape = av.shape[:-1] elif co is not None: depth = co.ndim - 1 shape = co.shape[:-1] elif ca is not None: depth = ca.ndim - 2 shape = ca.shape[:-2] elif ch is not None: depth = ch.ndim - 1 shape = ch.shape[:-1] else: raise ValueError("missing data") random_key = jax.random.PRNGKey(seed) if self.groupid is not None: depth = depth - 1 shape = shape[:-1] f = self._jax_random_params # params, databundle, groupbundle=None, n_draws=100 from .random import keysplit commons = None if self.common_draws else 0 for _i in range(depth): f = jax.vmap(f, in_axes=(None, 0, commons, None)) if not self.prerolled_draws: random_key, shape = keysplit(random_key, shape) if self.prerolled_draws: return f( params, {"ca": ca, "co": co, "av": av, "ch": ch}, {"draws": self._draws}, n_draws, ) else: return f( params, {"ca": ca, "co": co, "av": av, "ch": ch}, {"rk": random_key}, n_draws, ) @jitmethod(static_argnums=(3,), static_argnames="n_draws") def _jax_loglike_casewise(self, params, databundle, groupbundle=None, n_draws=100): if len(self.mixtures) == 0: logpr = self._jax_log_probability(params, databundle) ch = databundle.get("ch", None) n_alts = self.dataset.dc.n_alts # return (logpr[:n_alts] * ch[:n_alts]).sum() return jnp.where(ch[:n_alts], logpr[:n_alts] * ch[:n_alts], 0).sum() else: if self.prerolled_draws: draws = groupbundle.get("draws", None) else: rk = groupbundle.get("rk", None) draws, _ = self._make_random_draws_out(n_draws, len(self.mixtures), rk) ch = databundle.get("ch", None) rand_params = self.apply_random_draws(params, draws) if ch.ndim == 2: # PANEL DATA # vmap over ingroup likelihood_f = jax.vmap( self._jax_likelihood, in_axes=(None, 0), ) # vmap over draws likelihood = jax.vmap( likelihood_f, in_axes=(0, None), out_axes=-1, )(rand_params, databundle) # collapse likelihood over all alternatives likelihood = likelihood.prod([0, 1]) # average over all draws likelihood = likelihood.mean(0) return jnp.log(likelihood) else: # vmap over draws likelihood = jax.vmap( self._jax_likelihood, in_axes=(0, None), out_axes=-1, )(rand_params, databundle) # collapse likelihood over all alternatives likelihood = likelihood.prod(0) # average over all draws likelihood = likelihood.mean(0) return jnp.log(likelihood) @jitmethod def jax_loglike_casewise(self, params): ca = _get_jnp_array(self.dataset, "ca") co = _get_jnp_array(self.dataset, "co") av = _as_jnp_array(self._data_arrays.av) ch = _as_jnp_array(self._data_arrays.ch) n_draws = self.n_draws seed = self.seed or 42 if av is not None: depth = av.ndim - 1 shape = av.shape[:-1] elif co is not None: depth = co.ndim - 1 shape = co.shape[:-1] elif ca is not None: depth = ca.ndim - 2 shape = ca.shape[:-2] elif ch is not None: depth = ch.ndim - 1 shape = ch.shape[:-1] else: raise ValueError("missing data") random_key = jax.random.PRNGKey(seed) if self.groupid is not None: depth = depth - 1 shape = shape[:-1] f = ( self._jax_loglike_casewise ) # params, databundle, groupbundle=None, n_draws=100 from .random import keysplit commons = None if self.common_draws else 0 for _i in range(depth): f = jax.vmap(f, in_axes=(None, 0, commons, None)) if not self.prerolled_draws: random_key, shape = keysplit(random_key, shape) if self.prerolled_draws: return f( params, {"ca": ca, "co": co, "av": av, "ch": ch}, {"draws": self._draws}, n_draws, ) else: return f( params, {"ca": ca, "co": co, "av": av, "ch": ch}, {"rk": random_key}, n_draws, ) @jitmethod def jax_loglike(self, params): wt = _as_jnp_array(self._data_arrays.wt) return (self.jax_loglike_casewise(params) * wt).sum()
[docs] def loglike( self, x=None, *, start_case: int | None = None, stop_case: int | None = None, step_case: int | None = None, check_if_best: bool = True, error_if_bad: bool = True, **kwargs, ): if self.compute_engine != "jax": return super().loglike( x=x, start_case=start_case, stop_case=stop_case, step_case=step_case, check_if_best=check_if_best, error_if_bad=error_if_bad, **kwargs, ) if start_case is not None: raise NotImplementedError("start_case with engine=jax") if stop_case is not None: raise NotImplementedError("stop_case with engine=jax") if step_case is not None: raise NotImplementedError("step_case with engine=jax") # if kwargs: # raise NotImplementedError(f"{kwargs.popitem()[0]} with engine=jax") if x is not None: self.pvals = x self.check_random_draws() result = float(self.jax_loglike(self.pvals)) if ( check_if_best and start_case is None and stop_case is None and step_case is None ): self._check_if_best(result) if error_if_bad: for f, tag in [(np.isnan, "NaN"), (np.isinf, "Inf")]: if f(result): caseids = self.dataset.dc.caseids() nan_wt = f(self._data_arrays.wt) bad_wt_indexes = np.where(nan_wt)[0] if len(bad_wt_indexes) > 0: msg = f"weight is {tag}" msg += f" in {len(bad_wt_indexes)} cases, including CASEIDs:" msg += f" {caseids[bad_wt_indexes[0]]}" for i in bad_wt_indexes[1:5]: msg += f", {caseids[i]}" if len(bad_wt_indexes) > 5: msg += ", ..." raise ValueError(msg) ll_casewise = self.jax_loglike_casewise(self.pvals) msg = f"log likelihood is {tag}" bad_case_indexes = np.where(f(ll_casewise))[0] if len(bad_case_indexes) > 0: msg += f" in {len(bad_case_indexes)} cases, including CASEIDs:" msg += f" {caseids[bad_case_indexes[0]]}" for i in bad_case_indexes[1:5]: msg += f", {caseids[i]}" if len(bad_case_indexes) > 5: msg += ", ..." else: msg += " but not in any individual cases" raise ValueError(msg) return result
[docs] def d_loglike( self, x=None, *, start_case=None, stop_case=None, step_case=None, return_series=False, **kwargs, ): if self.compute_engine != "jax": return super().d_loglike( x=x, start_case=start_case, stop_case=stop_case, step_case=step_case, return_series=return_series, ) if start_case is not None: raise NotImplementedError("start_case with engine=jax") if stop_case is not None: raise NotImplementedError("stop_case with engine=jax") if step_case is not None: raise NotImplementedError("step_case with engine=jax") if kwargs: raise NotImplementedError(f"{kwargs.popitem()[0]} with engine=jax") if x is not None: self.pvals = x self.check_random_draws() result = self.jax_d_loglike(self.pvals) * (self.pholdfast == 0) if return_series: result = pd.Series(result, index=self.pnames) return result
[docs] def loglike_casewise( self, x=None, *, start_case=None, stop_case=None, step_case=None, **kwargs, ): if self.compute_engine != "jax": return super().loglike_casewise( x=x, start_case=start_case, stop_case=stop_case, step_case=step_case, **kwargs, ) if start_case is not None: raise NotImplementedError("start_case with engine=jax") if stop_case is not None: raise NotImplementedError("stop_case with engine=jax") if step_case is not None: raise NotImplementedError("step_case with engine=jax") # if kwargs: # raise NotImplementedError(f"{kwargs.popitem()[0]} with engine=jax") if x is not None: self.pvals = x self.check_random_draws() result = np.asarray(self.jax_loglike_casewise(self.pvals)) return result
[docs] def maximize_loglike( self, *args, **kwargs, ) -> dictx: """ Maximize the log likelihood. Parameters ---------- method : str, optional The optimization method to use. See scipy.optimize for most possibilities, or use 'BHHH'. Defaults to SLSQP if there are any constraints or finite parameter bounds, otherwise defaults to BHHH. quiet : bool, default False Whether to suppress the dashboard. options : dict, optional These options are passed through to the `scipy.optimize.minimize` function. maxiter : int, optional Maximum number of iterations. This argument is just added to `options` for most methods. Returns ------- larch.util.dictx A dictionary of results, including final log likelihood, elapsed time, and other statistics. The exact items included in output will vary by estimation method. """ if self.compute_engine == "jax": self.check_random_draws() return self.jax_maximize_loglike(*args, **kwargs) else: return super().maximize_loglike(*args, **kwargs)
def estimate(self, *args, **kwargs): """ Maximize loglike, and then calculate parameter covariance. This convenience method runs the following methods in order: - maximize_loglike - calculate_parameter_covariance All arguments are passed through to maximize_loglike. Returns ------- dictx """ result = self.maximize_loglike(*args, **kwargs) self.calculate_parameter_covariance() return result
[docs] def loglike_null(self, use_cache=True): """ Compute the log likelihood at null values. Set all parameter values to the value indicated in the "nullvalue" column of the parameter frame, and compute the log likelihood with the currently loaded data. Note that the null value for each parameter may not be zero (for example, the default null value for logsum parameters in a nested logit model is 1). Parameters ---------- use_cache : bool, default True Use the cached value if available. Set to -1 to raise an exception if there is no cached value. Returns ------- float """ if self.compute_engine != "jax": return super().loglike_null(use_cache) if self._cached_loglike_null is not None and use_cache: return self._cached_loglike_null elif use_cache == -1: raise ValueError("no cached value") else: self.check_random_draws() self._cached_loglike_null = float(self.jax_loglike(self.pnullvals)) return self._cached_loglike_null
[docs] def mixture_summary(self) -> pd.DataFrame: """ Create a summary of the mixture parameters as a pandas DataFrame. For parameters with random distributions, this summarizes the distribution of the parameters. The summary includes the mean, standard deviation, and quartiles of the parameters, as well as the share of positive, negative, and zero values. The summary is computed statistically from the random draws, not analytically from the parameter values, and so may be slightly off from the "true" values if the number of draws is too small (but the estimation results will also be impacted in this case). """ random_params = self.jax_random_params(self.pvals) x = np.arange(random_params.ndim - 1) means = random_params.mean(x) stds = random_params.std(x) mins = random_params.min(x) maxs = random_params.max(x) positive = (random_params > 0).mean(x) negative = (random_params < 0).mean(x) zero = (random_params == 0).mean(x) q25 = jnp.quantile(random_params, 0.25, tuple(x)) q50 = jnp.quantile(random_params, 0.5, tuple(x)) q75 = jnp.quantile(random_params, 0.75, tuple(x)) nonzero_variance = np.asarray((maxs - mins) > 0) result = pd.DataFrame( { "mean": means, "std": stds, "share +": positive, "share -": negative, "share ø": zero, "q25": q25, "median": q50, "q75": q75, }, index=self.pnames, ) return result.loc[nonzero_variance].copy()
def mixture_density( self, param_name: str, limits: tuple[float] | None = (0.01, 0.99), **sns_theme ): """ Create a density plot of a mixture parameter. This method requires the seaborn package to be installed. Parameters ---------- param_name : str The name of the parameter to plot. limits : tuple[float], optional The quantiles to use as the limits of the density plot. If None, the limits are not set. Returns ------- matplotlib.Axes """ import seaborn as sns sns.set_theme(**sns_theme) i = self.get_param_loc(param_name) random_params = self.jax_random_params(self.pvals)[..., i].flatten() if limits is not None: q01 = jnp.quantile(random_params, limits[0]) q99 = jnp.quantile(random_params, limits[1]) else: q01 = q99 = None result = sns.kdeplot(random_params, cut=0, clip=(q01, q99)) result.set_xlabel(param_name) return result
[docs] def to_xlsx( self, filename, save_now=True, data_statistics: bool = True, nesting: bool = True, embed_model: bool = True, ) -> ExcelWriter: """ Write the estimation results to an Excel file. Parameters ---------- filename : str The name of the file to write. save_now : bool, default True Whether to save the file immediately. If False, the ExcelWriter object is returned. data_statistics : bool, default True Whether to include data statistics in the Excel file. nesting : bool, default True Whether to include nesting statistics in the Excel file. embed_model : bool, default True Whether to embed the model in the Excel file. Returns ------- larch.util.excel.ExcelWriter or None """ from larch.util.excel import _make_excel_writer result = _make_excel_writer(self, filename, save_now=False) result._post_init( filename, model=self, data_statistics=data_statistics, nesting=nesting, embed=embed_model, ) if save_now: result.close() else: return result