# 201: Exampville Mode Choice

Welcome to Exampville, the best simulated town in this here part of the internet!

Exampville is a demonstration provided with Larch that walks through some of the 
data and tools that a transportation planner might use when building a travel model. 

In [None]:
# HIDDEN
from pytest import approx

import larch as lx

In [None]:
import numpy as np

import larch as lx
from larch import P, X

In this example notebook, we will walk through the creation of a tour mode choice model.

To begin, we'll re-load the tours and skims data from the 
data setup example.

In [None]:
hh, pp, tour, skims = lx.example(200, ["hh", "pp", "tour", "skims"])

## Preprocessing

The Exampville data output contains a set of files similar to what we might
find for a real travel survey: network skims, and tables of households, persons,
and tours.  We'll need to connect these tables together to create a composite dataset
for mode choice model estimation, using the DataTree structure.

In [None]:
from addicty import Dict

Mode = Dict(
    DA=1,
    SR=2,
    Walk=3,
    Bike=4,
    Transit=5,
).freeze()

In [None]:
tour_dataset = lx.Dataset.construct.from_idco(tour.set_index("TOURID"), alts=Mode)

In [None]:
od_skims = lx.Dataset.construct.from_omx(skims)

In [None]:
dt = lx.DataTree(
    tour=tour_dataset,
    hh=hh.set_index("HHID"),
    person=pp.set_index("PERSONID"),
    od=od_skims,
    do=od_skims,
    relationships=(
        "tours.HHID @ hh.HHID",
        "tours.PERSONID @ person.PERSONID",
        "hh.HOMETAZ @ od.otaz",
        "tours.DTAZ @ od.dtaz",
        "hh.HOMETAZ @ do.dtaz",
        "tours.DTAZ @ do.otaz",
    ),
)

In Exampville, there are only two kinds of trips: 

- work (purpose=1) and 
- non-work (purpose=2). 

We want to estimate a mode choice model for work trips, 
so weâ€™ll begin by excluding all the other trips:

In [None]:
dt_work = dt.query_cases("TOURPURP == 1")
dt_work.n_cases

If we wanted to, we could also filter the data with a more complex filter, accessing variables on tables other than the tours data.  For example, to only include low income households:

In [None]:
dt_work_low_income = dt.query_cases("TOURPURP == 1 and INCOME < 30000")
dt_work_low_income.n_cases

## Model Definition

And then we are ready to create our model.

In [None]:
m = lx.Model(datatree=dt_work)
m.title = "Exampville Work Tour Mode Choice v1"

We will explicitly define the set of utility functions 
we want to use.  Because the DataFrames we are using to 
serve data to this model contains exclusively `idco` format
data, we'll use only the `utility_co` mapping to define
a unique utility function for each alternative.

In [None]:
m.utility_co[Mode.DA] = (
    +P.InVehTime * X.AUTO_TIME + P.Cost * X.AUTO_COST  # dollars per mile
)

m.utility_co[Mode.SR] = (
    +P.ASC_SR
    + P.InVehTime * X.AUTO_TIME
    + P.Cost * (X.AUTO_COST * 0.5)  # dollars per mile, half share
    + P("LogIncome:SR") * X("log(INCOME)")
)

m.utility_co[Mode.Walk] = (
    +P.ASC_Walk + P.NonMotorTime * X.WALK_TIME + P("LogIncome:Walk") * X("log(INCOME)")
)

m.utility_co[Mode.Bike] = (
    +P.ASC_Bike + P.NonMotorTime * X.BIKE_TIME + P("LogIncome:Bike") * X("log(INCOME)")
)

m.utility_co[Mode.Transit] = (
    +P.ASC_Transit
    + P.InVehTime * X.TRANSIT_IVTT
    + P.OutVehTime * X.TRANSIT_OVTT
    + P.Cost * X.TRANSIT_FARE
    + P("LogIncome:Transit") * X("log(INCOME)")
)

To write a nested logit mode, we'll attach some nesting nodes to the 
model's `graph`.  Each `new_node` allows us to define the set of 
codes for the child nodes (elemental alternatives, or lower level nests)
as well as giving the new nest a name and assigning a logsum parameter.
The return value of this method is the node code for the newly created 
nest, which then can potenially be used as a child code when creating
a higher level nest.  We do this below, adding the 'Car' nest into the 
'Motor' nest.

In [None]:
Car = m.graph.new_node(parameter="Mu:Car", children=[Mode.DA, Mode.SR], name="Car")
NonMotor = m.graph.new_node(
    parameter="Mu:NonMotor", children=[Mode.Walk, Mode.Bike], name="NonMotor"
)
Motor = m.graph.new_node(
    parameter="Mu:Motor", children=[Car, Mode.Transit], name="Motor"
)

Let's visually check on the nesting structure.

In [None]:
m.graph

The tour mode choice model's choice variable is indicated by 
the code value in 'TOURMODE', and this can be 
defined for the model using `choice_co_code`.

In [None]:
m.choice_co_code = "TOURMODE"

We can also give a dictionary of availability conditions based 
on values in the `idco` data, using the `availability_co_vars`
attribute.  Alternatives that are always available can be indicated
by setting the criterion to 1.

In [None]:
m.availability_co_vars = {
    Mode.DA: "AGE >= 16",
    Mode.SR: 1,
    Mode.Walk: "WALK_TIME < 60",
    Mode.Bike: "BIKE_TIME < 60",
    Mode.Transit: "TRANSIT_FARE>0",
}

Then let's prepare this data for estimation.  Even though the
data is already in memory, the `load_data` method is used to 
pre-process the data, extracting the required values, pre-computing 
the values of fixed expressions, and assembling the results into
contiguous arrays suitable for computing the log likelihood values
efficiently.

## Model Estimation

We can check on some important statistics of this loaded data even
before we estimate the model.

In [None]:
m.choice_avail_summary()

In [None]:
# TEST
summary = m.choice_avail_summary()
assert (
    (summary.to_markdown())
    == """
|                            | name     |   chosen | available   | availability condition   |
|:---------------------------|:---------|---------:|:------------|:-------------------------|
| 1                          | DA       |     6052 | 7564        | AGE >= 16                |
| 2                          | SR       |      810 | 7564        | 1                        |
| 3                          | Walk     |      196 | 4179        | WALK_TIME < 60           |
| 4                          | Bike     |       72 | 7564        | BIKE_TIME < 60           |
| 5                          | Transit  |      434 | 4199        | TRANSIT_FARE>0           |
| 6                          | Car      |     6862 | 7564        |                          |
| 7                          | NonMotor |      268 | 7564        |                          |
| 8                          | Motor    |     7296 | 7564        |                          |
| < Total All Alternatives > |          |     7564 | <NA>        |                          |
"""[1:-1]
)

If we are satisfied with the statistics we see above, we
can go ahead and estimate the model.

In [None]:
m.compute_engine = "numba"

In [None]:
# TEST
# testing the JAX engine
mj = m.copy()
mj.compute_engine = "jax"

In [None]:
result = m.maximize_loglike(method="bhhh")

In [None]:
# TEST
assert result.loglike == approx(-3493.0397298749467)

After we find the best fitting parameters, we can compute
some variance-covariance statistics, incuding standard errors of
the estimates and t statistics, using `calculate_parameter_covariance`.

In [None]:
m.calculate_parameter_covariance();

Then we can review the results in a variety of report tables.

In [None]:
m.parameter_summary()

In [None]:
# TEST
assert m.loglike(
    {
        "ASC_Bike": -0.258,
        "ASC_SR": 1.42,
        "ASC_Transit": 6.75,
        "ASC_Walk": 8.62,
        "Cost": -0.175,
        "InVehTime": -0.123,
        "LogIncome:Bike": -0.196,
        "LogIncome:SR": -0.193,
        "LogIncome:Transit": -0.557,
        "LogIncome:Walk": -0.522,
        "Mu:Car": 0.259,
        "Mu:Motor": 0.801,
        "Mu:NonMotor": 0.853,
        "NonMotorTime": -0.265,
        "OutVehTime": -0.254,
    }
) == approx(-3493.1651598166823)
assert m.d_loglike() == approx(
    np.array(
        [
            -5.923349e-01,
            -4.578097e01,
            6.587518e-01,
            -2.008887e00,
            3.204704e01,
            4.917038e00,
            -6.412804e00,
            -4.797529e02,
            7.073657e00,
            -2.129807e01,
            -8.594817e01,
            -1.037919e00,
            -7.935426e-01,
            -4.520675e01,
            5.006008e-02,
        ]
    )
)
assert m.bhhh() == approx(
    np.asarray(
        [
            [
                8.075759e01,
                -6.429084e00,
                -9.146018e-01,
                -1.393596e01,
                -4.421948e01,
                -3.935382e02,
                8.602698e02,
                -6.723960e01,
                -9.535575e00,
                -1.492397e02,
                -2.146034e01,
                -1.493242e00,
                6.045004e01,
                3.948673e02,
                -4.710155e00,
            ],
            [
                -6.429084e00,
                1.018920e04,
                -2.546029e01,
                -6.871663e00,
                -6.805603e03,
                1.943158e02,
                -6.723960e01,
                1.055976e05,
                -2.541149e02,
                -8.057482e01,
                1.825392e04,
                7.838297e00,
                -4.304127e00,
                -1.932471e02,
                -1.419322e02,
            ],
            [
                -9.146018e-01,
                -2.546029e01,
                1.672353e02,
                -3.990500e00,
                2.372922e02,
                -9.974702e02,
                -9.535575e00,
                -2.541149e02,
                1.779416e03,
                -4.201672e01,
                -6.399661e01,
                -6.428867e01,
                -1.050907e00,
                -8.546463e01,
                6.388929e02,
            ],
            [
                -1.393596e01,
                -6.871663e00,
                -3.990500e00,
                1.433439e02,
                -4.968905e01,
                -3.955136e02,
                -1.492397e02,
                -8.057482e01,
                -4.201672e01,
                1.519186e03,
                -3.567450e01,
                -4.654976e00,
                1.229872e01,
                2.638594e03,
                -4.559200e00,
            ],
            [
                -4.421948e01,
                -6.805603e03,
                2.372922e02,
                -4.968905e01,
                6.561081e03,
                -8.477408e02,
                -4.698072e02,
                -7.058237e04,
                2.486329e03,
                -5.228979e02,
                -1.119508e04,
                -6.949258e01,
                -3.032948e01,
                -1.506044e03,
                9.141991e02,
            ],
            [
                -3.935382e02,
                1.943158e02,
                -9.974702e02,
                -3.955136e02,
                -8.477408e02,
                1.319753e04,
                -4.168034e03,
                1.947043e03,
                -1.063983e04,
                -4.143149e03,
                5.665783e02,
                6.851572e02,
                -2.827231e02,
                -1.244020e04,
                -4.186515e03,
            ],
            [
                8.602698e02,
                -6.723960e01,
                -9.535575e00,
                -1.492397e02,
                -4.698072e02,
                -4.168034e03,
                9.230747e03,
                -7.055590e02,
                -9.940447e01,
                -1.607689e03,
                -2.232192e02,
                -1.584709e01,
                6.506544e02,
                4.176751e03,
                -5.105354e01,
            ],
            [
                -6.723960e01,
                1.055976e05,
                -2.541149e02,
                -8.057482e01,
                -7.058237e04,
                1.947043e03,
                -7.055590e02,
                1.099888e06,
                -2.533960e03,
                -9.330855e02,
                1.932402e05,
                6.734681e01,
                -4.840060e01,
                -2.146091e03,
                -1.431983e03,
            ],
            [
                -9.535575e00,
                -2.541149e02,
                1.779416e03,
                -4.201672e01,
                2.486329e03,
                -1.063983e04,
                -9.940447e01,
                -2.533960e03,
                1.904483e04,
                -4.467831e02,
                -6.474706e02,
                -6.487259e02,
                -1.189795e01,
                -9.018875e02,
                6.709564e03,
            ],
            [
                -1.492397e02,
                -8.057482e01,
                -4.201672e01,
                1.519186e03,
                -5.228979e02,
                -4.143149e03,
                -1.607689e03,
                -9.330855e02,
                -4.467831e02,
                1.620863e04,
                -3.878460e02,
                -5.077262e01,
                1.308380e02,
                2.785630e04,
                -4.844429e01,
            ],
            [
                -2.146034e01,
                1.825392e04,
                -6.399661e01,
                -3.567450e01,
                -1.119508e04,
                5.665783e02,
                -2.232192e02,
                1.932402e05,
                -6.474706e02,
                -3.878460e02,
                3.641852e04,
                1.680026e01,
                -1.929609e01,
                -8.572419e02,
                -3.366271e02,
            ],
            [
                -1.493242e00,
                7.838297e00,
                -6.428867e01,
                -4.654976e00,
                -6.949258e01,
                6.851572e02,
                -1.584709e01,
                6.734681e01,
                -6.487259e02,
                -5.077262e01,
                1.680026e01,
                4.883765e02,
                -1.511243e00,
                -1.040714e02,
                9.782830e02,
            ],
            [
                6.045004e01,
                -4.304127e00,
                -1.050907e00,
                1.229872e01,
                -3.032948e01,
                -2.827231e02,
                6.506544e02,
                -4.840060e01,
                -1.189795e01,
                1.308380e02,
                -1.929609e01,
                -1.511243e00,
                1.274382e02,
                7.420690e02,
                -1.648080e00,
            ],
            [
                3.948673e02,
                -1.932471e02,
                -8.546463e01,
                2.638594e03,
                -1.506044e03,
                -1.244020e04,
                4.176751e03,
                -2.146091e03,
                -9.018875e02,
                2.785630e04,
                -8.572419e02,
                -1.040714e02,
                7.420690e02,
                5.998454e04,
                -1.462547e02,
            ],
            [
                -4.710155e00,
                -1.419322e02,
                6.388929e02,
                -4.559200e00,
                9.141991e02,
                -4.186515e03,
                -5.105354e01,
                -1.431983e03,
                6.709564e03,
                -4.844429e01,
                -3.366271e02,
                9.782830e02,
                -1.648080e00,
                -1.462547e02,
                6.705056e03,
            ],
        ]
    )
)

In [None]:
# TEST
assert m.loglike(
    {
        "ASC_Bike": 0,
        "ASC_SR": 0,
        "ASC_Transit": 0,
        "ASC_Walk": 0,
        "Cost": -0.175,
        "InVehTime": 0,
        "LogIncome:Bike": 0,
        "LogIncome:SR": 0,
        "LogIncome:Transit": 0,
        "LogIncome:Walk": 0,
        "Mu:Car": 0.259,
        "Mu:Motor": 0.801,
        "Mu:NonMotor": 0.853,
        "NonMotorTime": 0,
        "OutVehTime": 0,
    }
) == approx(-13998.675244346756)
assert m.d_loglike() == approx(
    [
        -2635.23326,
        -10134.810415,
        -705.264876,
        -1043.736457,
        10341.511798,
        32697.264187,
        -28621.538072,
        -111952.485356,
        -7921.881902,
        -11292.174797,
        8475.077026,
        547.118835,
        -1553.37647,
        -88469.408946,
        -50705.443209,
    ]
)

In [None]:
# TEST
assert m.loglike(
    {
        "ASC_Bike": 0,
        "ASC_SR": 0,
        "ASC_Transit": 0,
        "ASC_Walk": 0,
        "Cost": -0.175,
        "InVehTime": 0,
        "LogIncome:Bike": 0,
        "LogIncome:SR": 0,
        "LogIncome:Transit": 0,
        "LogIncome:Walk": 0,
        "Mu:Car": 0.259,
        "Mu:Motor": 1.0,
        "Mu:NonMotor": 0.853,
        "NonMotorTime": 0,
        "OutVehTime": 0,
    }
) == approx(-13874.599159933234)
assert m.d_loglike() == approx(
    [
        -2519.134354,
        -10332.08641,
        -535.431417,
        -1024.805223,
        10396.304454,
        30467.260085,
        -27354.84768,
        -114131.694692,
        -6024.215889,
        -11088.974147,
        8278.911004,
        680.534926,
        -1527.132206,
        -85008.121443,
        -39466.996726,
    ]
)

In [None]:
# TEST
assert m.loglike(
    {
        "ASC_Bike": 0,
        "ASC_SR": 0,
        "ASC_Transit": 0,
        "ASC_Walk": 0,
        "Cost": -0.175,
        "InVehTime": 0,
        "LogIncome:Bike": 0,
        "LogIncome:SR": 0,
        "LogIncome:Transit": 0,
        "LogIncome:Walk": 0,
        "Mu:Car": 1.0,
        "Mu:Motor": 1.0,
        "Mu:NonMotor": 0.853,
        "NonMotorTime": 0,
        "OutVehTime": 0,
    }
) == approx(-11070.479718500861)
assert m.d_loglike() == approx(
    [
        -2010.425803,
        -1143.121097,
        -354.09085,
        -784.933422,
        3768.268759,
        23780.554672,
        -21833.896314,
        -12810.449485,
        -4047.774585,
        -8505.711496,
        2476.484342,
        171.162676,
        -1194.599282,
        -67783.778892,
        -31832.259099,
    ]
)

In [None]:
# TEST
assert m.loglike(
    {
        "ASC_Bike": 0,
        "ASC_SR": 0,
        "ASC_Transit": 0,
        "ASC_Walk": 0,
        "Cost": -0.175,
        "InVehTime": 0,
        "LogIncome:Bike": 0,
        "LogIncome:SR": 0,
        "LogIncome:Transit": 0,
        "LogIncome:Walk": 0,
        "Mu:Car": 1.0,
        "Mu:Motor": 1.0,
        "Mu:NonMotor": 1.0,
        "NonMotorTime": 0,
        "OutVehTime": 0,
    }
) == approx(-11251.479763710428)
assert m.d_loglike() == approx(
    [
        -2052.097618,
        -1092.631079,
        -346.361534,
        -849.180852,
        3834.917306,
        24339.366944,
        -22286.249505,
        -12266.75887,
        -3964.809683,
        -9193.707237,
        2544.498541,
        188.063682,
        -1268.016907,
        -70294.694308,
        -31671.372503,
    ]
)

In [None]:
# TEST
m.pvals = result.x.values
assert result.x.values == approx(
    np.array(
        [
            -0.258487,
            1.422852,
            6.754264,
            8.621445,
            -0.175691,
            -0.123711,
            -0.196929,
            -0.193804,
            -0.557133,
            -0.522779,
            0.259289,
            0.801594,
            0.853706,
            -0.265583,
            -0.254791,
        ]
    ),
    rel=1e-2,
)

In [None]:
# TEST
assert m.pstderr == approx(
    np.array(
        [
            1.339537,
            1.001639,
            2.064447,
            1.138892,
            0.119573,
            0.029206,
            0.123539,
            0.135447,
            0.169267,
            0.100378,
            0.18073,
            0.20086,
            0.112139,
            0.016306,
            0.064567,
        ]
    ),
    rel=1e-2,
)

In [None]:
# TEST
assert m.parameter_summary().data["Signif"].to_dict() == {
    "ASC_Bike": "",
    "ASC_SR": "",
    "ASC_Transit": "**",
    "ASC_Walk": "***",
    "Cost": "",
    "InVehTime": "***",
    "LogIncome:Bike": "",
    "LogIncome:SR": "",
    "LogIncome:Transit": "***",
    "LogIncome:Walk": "***",
    "Mu:Car": "***",
    "Mu:Motor": "",
    "Mu:NonMotor": "",
    "NonMotorTime": "***",
    "OutVehTime": "***",
}

In [None]:
m.estimation_statistics()

In [None]:
# TEST
# testing the JAX engine
mj.set_cap(20)

In [None]:
# TEST
# testing the JAX engine
resultj = mj.maximize_loglike(stderr=False)

In [None]:
# TEST
assert resultj.loglike == approx(-3493.0397298749467)

## Save and Report Model

In [None]:
report = lx.Reporter(title=m.title)

In [None]:
report.append("# Parameter Summary")
report.append(m.parameter_summary())
report

In [None]:
report << "# Estimation Statistics" << m.estimation_statistics()

In [None]:
report << "# Utility Functions" << m.utility_functions()

In [None]:
report.save(
    "exampville_mode_choice.html",
    overwrite=True,
    metadata=m.dumps(),
)