Compute Engines#

import larch as lx

Larch provides a Python interface for setting up, estimating, and applying discrete choice models. Under the hood, Larch runs on a “compute engine”, which provides all the back-end code needed to execute the mathematical computations that power the models.

Available Engines#

Larch currently offers two compute engines: numba and jax.

The numba compute engine is relatively fast, and is the best choice for estimating basic models, such as simple multinomial logit (MNL) and nested logit (NL) models, especially with small to moderate sized datasets. It employs “jit” compiled functions that run faster than regular Python, which are optimized for the model type but not specifically for each dataset. The numba engine is not available for mixed logit models.

The jax compute engine runs quite fast and efficiently, and is the best choice for estimating complicated models, especially mixed logit models. When using this engine, each model step is compiled and optimized specifically for the data and structures used in that model. Compiled code is not cached to disk, so this optimization adds a significant amount of “fixed” overhead time for model estimation and application. However, for large and complex models this overhead can be well worth the investment.

Setting Engines#

The compute engine can be chosen by providing a compute_engine argument at model initialization:

m = lx.Model(compute_engine="jax")

Alternatively, the engine can be selected by changing the compute_engine attribute later.

m.compute_engine = "numba"

If you try to use an engine type that is not available

m.compute_engine = "steam"
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 1
----> 1 m.compute_engine = "steam"

File /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/larch/model/basemodel.py:252, in BaseModel.__setattr__(self, name, value)
    250 def __setattr__(self, name, value):
    251     if name.startswith("_") or hasattr(self, "_" + name) or hasattr(self, name):
--> 252         object.__setattr__(self, name, value)
    253     else:
    254         raise TypeError(
    255             f"Cannot set {name!r} on object of type {self.__class__.__name__}"
    256         )

File /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/larch/model/jaxmodel.py:120, in Model.compute_engine(self, engine)
    117 @compute_engine.setter
    118 def compute_engine(self, engine):
    119     if engine not in {"numba", "jax", None}:
--> 120         raise ValueError("invalid compute engine")
    121     self._compute_engine = engine
    122     if self._compute_engine == "jax" and not jax:

ValueError: invalid compute engine