Source code for tfields.bases.bases

#!/usr/bin/env
# encoding: utf-8
"""
Author:     Daniel Boeckenhoff
Mail:       daniel.boeckenhoff@ipp.mpg.de

part of tfields library
Tools for sympy coordinate transformation
"""
import tfields
import numpy as np
import sympy
import sympy.diffgeom
from six import string_types
import warnings


[docs]def get_coord_system(base): """ Args: base (str or sympy.diffgeom.get_coord_system) Return: sympy.diffgeom.get_coord_system """ if isinstance(base, string_types) or ( isinstance(base, np.ndarray) and base.dtype.kind in {"U", "S"} ): base = getattr(tfields.bases, str(base)) if not isinstance(base, sympy.diffgeom.CoordSystem): bse_tpe = type(base) expctd_tpe = type(sympy.diffgeom.CoordSystem) raise TypeError( "Wrong type of coord_system base {bse_tpe}. " "Expected {expctd_tpe}".format(**locals()) ) return base
[docs]def get_coord_system_name(base): """ Args: base (str or sympy.diffgeom.get_coord_system) Returns: str: name of base """ if isinstance(base, sympy.diffgeom.CoordSystem): base = getattr(base, "name") # if not (isinstance(base, string_types) or base is None): # baseType = type(base) # raise ValueError("Coordinate system must be string_type." # " Retrieved value '{base}' of type {baseType}." # .format(**locals())) return str(base)
[docs]def lambdified_trafo(base_old, base_new): """ Args: base_old (sympy.CoordSystem) base_new (sympy.CoordSystem) Examples: >>> import numpy as np >>> import tfields Transform cartestian to cylinder or spherical >>> a = np.array([[3,4,0]]) >>> trafo = tfields.bases.lambdified_trafo(tfields.bases.cartesian, ... tfields.bases.cylinder) >>> new = np.concatenate([trafo(*coords).T for coords in a]) >>> assert new[0, 0] == 5 >>> trafo = tfields.bases.lambdified_trafo(tfields.bases.cartesian, ... tfields.bases.spherical) >>> new = np.concatenate([trafo(*coords).T for coords in a]) >>> assert new[0, 0] == 5 """ coords = tuple(base_old.coord_function(i) for i in range(base_old.dim)) f = sympy.lambdify( coords, base_old.coord_tuple_transform_to(base_new, list(coords)), modules="numpy", ) return f
[docs]def transform(array, base_old, base_new, **kwargs): """ Transform the input array in place Args: array (np.ndarray) base_old (str or sympy.CoordSystem): base_new (str or sympy.CoordSystem): Examples: Cylindrical coordinates >>> import numpy as np >>> import tfields >>> cart = np.array([[0, 0, 0], ... [1, 0, 0], ... [1, 1, 0], ... [0, 1, 0], ... [-1, 1, 0], ... [-1, 0, 0], ... [-1, -1, 0], ... [0, -1, 0], ... [1, -1, 0], ... [0, 0, 1]]) >>> cyl = tfields.bases.transform(cart, 'cartesian', 'cylinder') >>> cyl Transform cylinder to spherical. No connection is defined so routing via cartesian >>> import numpy as np >>> import tfields >>> b = np.array([[5, np.arctan(4. / 3), 0]]) >>> newB = b.copy() >>> tfields.bases.transform(b, 'cylinder', 'spherical') >>> assert newB[0, 0] == 5 >>> assert round(newB[0, 1], 10) == round(b[0, 1], 10) """ base_old = get_coord_system(base_old) base_new = get_coord_system(base_new) # very fast trafos in numpy only short_trafo = None try: short_trafo = getattr(base_old, f"to_{base_new.name}".format(**locals())) except AttributeError: pass if short_trafo: short_trafo(array, **kwargs) return if base_new not in base_old.transforms: for baseTmp in base_new.transforms: if baseTmp in base_old.transforms: transform(array, base_old, baseTmp, **kwargs) transform(array, baseTmp, base_new, **kwargs) return raise ValueError(f"Transformation {base_old} -> {base_new} not found.") # trafo via lambdified sympy expressions trafo = tfields.bases.lambdified_trafo(base_old, base_new) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="invalid value encountered in double_scalars" ) array[:] = np.concatenate([trafo(*coords).T for coords in array])
if __name__ == "__main__": # pragma: no cover import doctest doctest.testmod()