from __future__ import annotations
import re
import warnings
import numpy as np
import pandas as pd
import sharrow as sh
import xarray as xr
from pandas.errors import UndefinedVariableError
from xarray.core import dtypes
from . import construct as construct
from . import flow as flow
from . import patch as patch
from .dim_names import ALTID as _ALTID
from .dim_names import ALTIDX as _ALTIDX
from .dim_names import CASEALT as _CASEALT
from .dim_names import CASEID as _CASEID
from .dim_names import CASEPTR as _CASEPTR
from .patch import register_dataarray_classmethod
# from .dim_names import GROUPID as _GROUPID
# from .dim_names import INGROUP as _INGROUP
try:
from sharrow import DataArray as _sharrow_DataArray
from sharrow import Dataset as _sharrow_Dataset
from sharrow import DataTree as _sharrow_DataTree
from sharrow.accessors import register_dataarray_method
except ImportError:
warnings.warn("larch.dataset requires the sharrow library", stacklevel=2)
class _noclass:
pass
_sharrow_Dataset = xr.Dataset
_sharrow_DataArray = xr.DataArray
_sharrow_DataTree = _noclass
register_dataarray_method = lambda x: x
DataArray = _sharrow_DataArray
@register_dataarray_classmethod
def zeros(cls, *coords, dtype=np.float64, name=None, attrs=None):
"""
Construct a dataset filled with zeros.
Parameters
----------
coords : Tuple[array-like]
A sequence of coordinate vectors. Ideally each should have a
`name` attribute that names a dimension, otherwise placeholder
names are used.
dtype : dtype, default np.float64
dtype of the new array. If omitted, it defaults to np.float64.
name : str or None, optional
Name of this array.
attrs : dict_like or None, optional
Attributes to assign to the new instance. By default, an empty
attribute dictionary is initialized.
Returns
-------
DataArray
"""
dims = []
shape = []
coo = {}
for n, c in enumerate(coords):
i = getattr(c, "name", f"dim_{n}")
dims.append(i)
shape.append(len(c))
coo[i] = c
return cls(
data=np.zeros(shape, dtype=dtype),
coords=coo,
dims=dims,
name=name,
attrs=attrs,
)
@register_dataarray_classmethod
def ones(cls, *coords, dtype=np.float64, name=None, attrs=None):
"""
Construct a dataset filled with ones.
Parameters
----------
coords : Tuple[array-like]
A sequence of coordinate vectors. Ideally each should have a
`name` attribute that names a dimension, otherwise placeholder
names are used.
dtype : dtype, default np.float64
dtype of the new array. If omitted, it defaults to np.float64.
name : str or None, optional
Name of this array.
attrs : dict_like or None, optional
Attributes to assign to the new instance. By default, an empty
attribute dictionary is initialized.
Returns
-------
DataArray
"""
dims = []
shape = []
coo = {}
for n, c in enumerate(coords):
i = getattr(c, "name", f"dim_{n}")
dims.append(i)
shape.append(len(c))
coo[i] = c
return cls(
data=np.ones(shape, dtype=dtype),
coords=coo,
dims=dims,
name=name,
attrs=attrs,
)
@register_dataarray_classmethod
def from_zarr(cls, *args, name=None, **kwargs):
dataset = xr.open_zarr(*args, **kwargs)
if name is None:
names = set(dataset.variables) - set(dataset.coords)
if len(names) == 1:
name = names.pop()
else:
raise ValueError("cannot infer name to load")
return dataset[name]
@register_dataarray_method
def value_counts(self, index_name="index"):
"""
Count the number of times each unique value appears in the array.
Parameters
----------
index_name : str, default 'index'
Name of index dimension in result.
Returns
-------
DataArray
"""
values, freqs = np.unique(self, return_counts=True)
return self.__class__(freqs, dims=index_name, coords={index_name: values})
Dataset = _sharrow_Dataset
[docs]
class DataTree(_sharrow_DataTree):
DatasetType = Dataset
[docs]
def __init__(
self,
graph=None,
root_node_name=None,
extra_funcs=(),
extra_vars=None,
cache_dir=None,
relationships=(),
force_digitization=False,
**kwargs,
):
super().__init__(
graph=graph,
root_node_name=root_node_name,
extra_funcs=extra_funcs,
extra_vars=extra_vars,
cache_dir=cache_dir,
relationships=relationships,
force_digitization=force_digitization,
**kwargs,
)
dim_order = []
c = self.root_dataset.dc.CASEID
if c is None and len(self.root_dataset.sizes) == 1:
self.root_dataset.dc.CASEID = list(self.root_dataset.sizes.keys())[0]
c = self.root_dataset.dc.CASEID
if c is not None:
dim_order.append(c)
a = self.root_dataset.dc.ALTID
if a is not None:
dim_order.append(a)
self.dim_order = tuple(dim_order)
def idco_subtree(self):
if "idcoVars" in self.subspaces:
return self.subspaces["idcoVars"].dc.as_tree()
return self.drop_dims(self.ALTID, ignore_missing_dims=True)
@property
def dc(self):
return self
@property
def CASEID(self):
"""Str : The _caseid_ dimension of the root Dataset."""
result = self.root_dataset.dc.CASEID
if result is None:
warnings.warn("no defined CASEID", stacklevel=2)
return _CASEID
return result
@property
def ALTID(self):
"""Str : The _altid_ dimension of the root Dataset."""
result = self.root_dataset.dc.ALTID
if result is None:
warnings.warn("no defined ALTID", stacklevel=2)
return _ALTID
return result
@property
def CASEALT(self):
"""Str : The _casealt_ dimension of the root Dataset, if defined."""
result = self.root_dataset.attrs.get(_CASEALT, None)
return result
@property
def ALTIDX(self):
"""Str : The _alt_idx_ dimension of the root Dataset, if defined."""
result = self.root_dataset.attrs.get(_ALTIDX, None)
return result
@property
def CASEPTR(self):
"""Str : The _caseptr_ dimension of the root Dataset, if defined."""
result = self.root_dataset.attrs.get(_CASEPTR, None)
return result
@property
def n_cases(self):
"""Int : The size of the _caseid_ dimension of the root Dataset."""
return self.root_dataset.sizes[self.CASEID]
@property
def n_alts(self):
"""Int : The size of the _altid_ dimension of the root Dataset."""
return self.root_dataset.sizes[self.ALTID]
[docs]
def query_cases(self, query, parser="pandas", engine=None):
"""
Return a new DataTree, with a query filter applied to the root Dataset.
Parameters
----------
query : str
Python expressions to be evaluated against the data variables
in the root dataset. The expressions will be evaluated using the pandas
eval() function, and can contain any valid Python expressions but cannot
contain any Python statements.
parser : {"pandas", "python"}, default: "pandas"
The parser to use to construct the syntax tree from the expression.
The default of 'pandas' parses code slightly different than standard
Python. Alternatively, you can parse an expression using the 'python'
parser to retain strict Python semantics.
engine : {"python", "numexpr", None}, default: None
The engine used to evaluate the expression. Supported engines are:
- None: tries to use numexpr, falls back to python
- "numexpr": evaluates expressions using numexpr
- "python": performs operations as if you had eval’d in top level python
Returns
-------
DataTree
A new DataTree with the same contents as this DataTree, except each
array of the root Dataset is indexed by the results of the query on
the CASEID dimension.
See Also
--------
Dataset.query_cases
"""
obj = self.copy()
try:
obj.root_dataset = obj.root_dataset.dc.query_cases(
query, parser=parser, engine=engine
)
except UndefinedVariableError:
filter = self.idco_subtree().get_expr(
query, allow_native=False, engine="sharrow", dtype="bool_"
)
obj.root_dataset = obj.root_dataset.dc.isel({self.CASEID: filter})
return obj
def slice_cases(self, *case_slice):
if len(case_slice) != 1 or not isinstance(case_slice[0], slice):
case_slice = slice(*case_slice)
return self.replace_datasets(
{self.root_node_name: self.root_dataset.isel({self.CASEID: case_slice})}
)
def caseids(self) -> pd.Index:
"""
Access the caseids coordinates as an index.
Returns
-------
pandas.Index
"""
try:
return self.root_dataset.indexes[self.CASEID]
except KeyError:
for _k, v in self.subspaces.items():
if self.CASEID in v.indexes:
return v.indexes[self.CASEID]
raise
def altids(self) -> pd.Index:
"""
Access the altids coordinates as an index.
Returns
-------
pd.Index
"""
try:
return self.root_dataset.indexes[self.ALTID]
except KeyError:
for _k, v in self.subspaces.items():
if self.ALTID in v.indexes:
return v.indexes[self.ALTID]
raise
def set_altnames(self, alt_names):
"""
Set the alternative names for this DataTree.
Parameters
----------
altnames : Mapping or array-like
A mapping of (integer) codes to names, or an array or names
of the same length and order as the alternatives already
defined in this Dataset.
"""
self.root_dataset = self.root_dataset.dc.set_altnames(alt_names)
def alts_mapping(self):
return self.root_dataset.dc.alts_mapping
def alts_name_to_id(self):
return dict((j, i) for (i, j) in self.alts_mapping().items())
[docs]
def setup_flow(self, *args, **kwargs) -> sh.Flow:
"""
Set up a new Flow for analysis using the structure of this DataTree.
Parameters
----------
definition_spec : dict[str,str]
Gives the names and expressions that define the variables to
create in this new `Flow`.
cache_dir : Path-like, optional
A location to write out generated python and numba code. If not
provided, a unique temporary directory is created.
name : str, optional
The name of this Flow used for writing out cached files. If not
provided, a unique name is generated. If `cache_dir` is given,
be sure to avoid name conflicts with other flow's in the same
directory.
dtype : str, default "float32"
The name of the numpy dtype that will be used for the output.
boundscheck : bool, default False
If True, boundscheck enables bounds checking for array indices, and
out of bounds accesses will raise IndexError. The default is to not
do bounds checking, which is faster but can produce garbage results
or segfaults if there are problems, so try turning this on for
debugging if you are getting unexplained errors or crashes.
error_model : {'numpy', 'python'}, default 'numpy'
The error_model option controls the divide-by-zero behavior. Setting
it to ‘python’ causes divide-by-zero to raise exception like
CPython. Setting it to ‘numpy’ causes divide-by-zero to set the
result to +/-inf or nan.
nopython : bool, default True
Compile using numba's `nopython` mode. Provided for debugging only,
as there's little point in turning this off for production code, as
all the speed benefits of sharrow will be lost.
fastmath : bool, default True
If true, fastmath enables the use of "fast" floating point transforms,
which can improve performance but can result in tiny distortions in
results. See numba docs for details.
parallel : bool, default True
Enable or disable parallel computation for certain functions.
readme : str, optional
A string to inject as a comment at the top of the flow Python file.
flow_library : Mapping[str,Flow], optional
An in-memory cache of precompiled Flow objects. Using this can result
in performance improvements when repeatedly using the same definitions.
extra_hash_data : Tuple[Hashable], optional
Additional data used for generating the flow hash. Useful to prevent
conflicts when using a flow_library with multiple similar flows.
write_hash_audit : bool, default True
Writes a hash audit log into a comment in the flow Python file, for
debugging purposes.
hashing_level : int, default 1
Level of detail to write into flow hashes. Increase detail to avoid
hash conflicts for similar flows. Level 2 adds information about
names used in expressions and digital encodings to the flow hash,
which prevents conflicts but requires more pre-computation to generate
the hash.
dim_exclude : Collection[str], optional
Exclude these root dataset dimensions from this flow.
Returns
-------
Flow
"""
if "dim_exclude" not in kwargs:
if "_exclude_dims_" in self.root_dataset.attrs:
kwargs["dim_exclude"] = self.root_dataset.attrs["_exclude_dims_"]
try:
return super().setup_flow(*args, **kwargs)
except ValueError as err:
regex = re.match("^unable to rewrite (.*) to itself$", str(err))
if regex:
raise ValueError(
f"Setup failed for variable {regex.group(1)}. Check the expression "
f"and the names of the variables in the dataset."
) from err
else:
raise err
def merge(
objects,
compat="no_conflicts",
join="outer",
fill_value=dtypes.NA,
combine_attrs="override",
*,
caseid=None,
alts=None,
):
"""
Merge any number of xarray objects into a single larch.Dataset as variables.
Parameters
----------
objects : iterable of Dataset or iterable of DataArray or iterable of dict-like
Merge together all variables from these objects. If any of them are
DataArray objects, they must have a name.
compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional
String indicating how to compare variables of the same name for
potential conflicts:
- "broadcast_equals": all values must be equal when variables are
broadcast against each other to ensure common dimensions.
- "equals": all values and dimensions must be the same.
- "identical": all values, dimensions and attributes must be the
same.
- "no_conflicts": only values which are not null in both datasets
must be equal. The returned dataset then contains the combination
of all non-null values.
- "override": skip comparing and pick variable from first dataset
join : {"outer", "inner", "left", "right", "exact"}, optional
String indicating how to combine differing indexes in objects.
- "outer": use the union of object indexes
- "inner": use the intersection of object indexes
- "left": use indexes from the first object with each dimension
- "right": use indexes from the last object with each dimension
- "exact": instead of aligning, raise `ValueError` when indexes to be
aligned are not equal
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
fill_value : scalar or dict-like, optional
Value to use for newly missing values. If a dict-like, maps
variable names to fill values. Use a data array's name to
refer to its values.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"} or callable, default: "override"
A callable or a string indicating how to combine attrs of the objects being
merged:
- "drop": empty attrs on returned Dataset.
- "identical": all attrs must be the same on every object.
- "no_conflicts": attrs from all objects are combined, any that have
the same name must also have the same value.
- "drop_conflicts": attrs from all objects are combined, any that have
the same name but different values are dropped.
- "override": skip comparing and copy attrs from the first dataset to
the result.
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
as its only parameters.
caseid : str, optional, keyword only
This named dimension will be marked as the '_caseid_' dimension.
alts : str or Mapping or array-like, keyword only
If given as a str, this named dimension will be marked as the
'_altid_' dimension. Otherwise, give a Mapping that defines
alternative names and (integer) codes or an array of codes.
Returns
-------
Dataset
Dataset with combined variables from each object.
"""
return Dataset.construct(
xr.merge(
objects,
compat="no_conflicts",
join="outer",
fill_value=dtypes.NA,
combine_attrs="override",
),
caseid=caseid,
alts=alts,
)
# @nb.njit
# def ce_dissolve_zero_variance(ce_data, ce_caseptr):
# """
#
# Parameters
# ----------
# ce_data : array-like, shape [n_casealts] one-dim only
# ce_altidx
# ce_caseptr
# n_alts
#
# Returns
# -------
# out : ndarray
# flag : int
# 1 if variance was detected, 0 if no variance was found and
# the `out` array is valid.
# """
# failed = 0
# if ce_caseptr.ndim == 2:
# ce_caseptr1 = ce_caseptr[:,-1]
# else:
# ce_caseptr1 = ce_caseptr[1:]
# shape = (ce_caseptr1.shape[0], )
# out = np.zeros(shape, dtype=ce_data.dtype)
# c = 0
# out[0] = ce_data[0]
# for row in range(ce_data.shape[0]):
# if row == ce_caseptr1[c]:
# c += 1
# out[c] = ce_data[row]
# else:
# if out[c] != ce_data[row]:
# failed = 1
# break
# return out, failed
# @nb.njit
# def case_ptr_to_indexes(n_casealts, case_ptrs):
# case_index = np.zeros(n_casealts, dtype=np.int64)
# for c in range(case_ptrs.shape[0]-1):
# case_index[case_ptrs[c]:case_ptrs[c + 1]] = c
# return case_index