get_default_dims

xarray_einstats.linalg.get_default_dims(da1_dims, d2_dims=None)[source]

Get the dimensions corresponding to the matrices.

Parameters:
da1_dimslist of str
da2_dimslist of str, optional

Used only in case of multiple inputs, otherwise it will keep its default value of None

Returns:
list of str

The dimensions indicating the matrix dimensions. Must be an iterable containing two strings.

Warning

dims is required for functions in the linalg module. This function acts as a placeholder and only raises an error indicating that dims is a required argument unless this function is monkeypatched.

It is documented here to show how to write and configure a substitute function.

Examples

The xarray_einstats default behaviour is requiring the dims argument for functions in the linalg module. Not providing it raises a TypeError

from xarray_einstats import linalg, tutorial
da = tutorial.generate_matrices_dataarray(5)
linalg.inv(da)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 3
      1 from xarray_einstats import linalg, tutorial
      2 da = tutorial.generate_matrices_dataarray(5)
----> 3 linalg.inv(da)

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/95/lib/python3.12/site-packages/xarray_einstats/linalg.py:1056, in inv(da, dims, **kwargs)
   1041 """Wrap :func:`numpy.linalg.inv`.
   1042 
   1043 Usage examples of all arguments is available at the :ref:`linalg_tutorial` page.
   (...)   1053 DataArray
   1054 """
   1055 if dims is None:
-> 1056     dims = _attempt_default_dims("inv", da.dims)
   1057 return xr.apply_ufunc(
   1058     np.linalg.inv, da, input_core_dims=[dims], output_core_dims=[dims], **kwargs
   1059 )

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/95/lib/python3.12/site-packages/xarray_einstats/linalg.py:130, in _attempt_default_dims(func, da1_dims, da2_dims)
    128     aux = get_default_dims(da1_dims, da2_dims)
    129 except MissingMonkeypatchError:
--> 130     raise TypeError(
    131         f"{func} missing required argument dims. Use "
    132         "xarray_einstats.linalg.default_dims context manager or pass dims explicitly"
    133     ) from None
    134 return aux

TypeError: inv missing required argument dims. Use xarray_einstats.linalg.default_dims context manager or pass dims explicitly

You need to pass the dimensions corresponding the matrix axes explicitly

linalg.inv(da, dims=["dim", "dim2"])
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> Size: 4kB
0.5087 -0.6454 -0.4175 0.1449 -0.1026 ... -0.1792 -0.6749 -0.2189 0.2564 1.403
Dimensions without coordinates: batch, experiment, dim, dim2

However, in many cases it will be possible to identify those dimensions from the list of all dimension names in the input.

Here we show how to monkeypatch get_default_dims to get a different default behaviour. If you follow a convention to label the dimensions corresponding to the matrix axes, you can integrate this logic into xarray_einstats, which will avoid unnecessary repetition, especially if performing several chained linear algebra operations:

def get_default_dims(dims1, dims2):
    if dims2 is not None:
        raise TypeError("Default dims only valid for single input functions")
    matrix_dims = [dim for dim in dims1 if f"{dim}2" in dims1]
    if len(matrix_dims) != 1:
        raise TypeError("Unable to guess default matrix dims")
    dim = matrix_dims[0]
    return [dim, f"{dim}2"]

linalg.get_default_dims = get_default_dims
linalg.inv(da)
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> Size: 4kB
0.5087 -0.6454 -0.4175 0.1449 -0.1026 ... -0.1792 -0.6749 -0.2189 0.2564 1.403
Dimensions without coordinates: batch, experiment, dim, dim2

You can still use dims explicitly to override those defaults.

Note

Monkeypatching get_default_dims directly works but is error-prone. Consider using the default_dims context manager instead.