Compute Engines#
import larch as lx
Larch provides a Python interface for setting up, estimating, and applying discrete choice models. Under the hood, Larch runs on a “compute engine”, which provides all the back-end code needed to execute the mathematical computations that power the models.
Available Engines#
Larch currently offers two compute engines: numba
and jax
.
The numba
compute engine is relatively fast, and is the best choice for
estimating basic models, such as simple multinomial logit (MNL) and nested
logit (NL) models, especially with small to moderate sized datasets. It
employs “jit” compiled functions that run faster than regular Python, which
are optimized for the model type but not specifically for each dataset.
The numba engine is not available for mixed logit models.
The jax
compute engine runs quite fast and efficiently, and is the best
choice for estimating complicated models, especially mixed logit models.
When using this engine, each model step is compiled and optimized
specifically for the data and structures used in that model. Compiled code
is not cached to disk, so this optimization adds a significant amount of
“fixed” overhead time for model estimation and application. However, for
large and complex models this overhead can be well worth the investment.
Setting Engines#
The compute engine can be chosen by providing a compute_engine
argument
at model initialization:
m = lx.Model(compute_engine="jax")
Alternatively, the engine can be selected by changing the compute_engine
attribute later.
m.compute_engine = "numba"
If you try to use an engine type that is not available
m.compute_engine = "steam"
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[4], line 1
----> 1 m.compute_engine = "steam"
File /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/larch/model/basemodel.py:252, in BaseModel.__setattr__(self, name, value)
250 def __setattr__(self, name, value):
251 if name.startswith("_") or hasattr(self, "_" + name) or hasattr(self, name):
--> 252 object.__setattr__(self, name, value)
253 else:
254 raise TypeError(
255 f"Cannot set {name!r} on object of type {self.__class__.__name__}"
256 )
File /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/larch/model/jaxmodel.py:120, in Model.compute_engine(self, engine)
117 @compute_engine.setter
118 def compute_engine(self, engine):
119 if engine not in {"numba", "jax", None}:
--> 120 raise ValueError("invalid compute engine")
121 self._compute_engine = engine
122 if self._compute_engine == "jax" and not jax:
ValueError: invalid compute engine