Source code for gt4py.cartesian.gtc.common

# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2022, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import annotations

import enum
import functools
import typing
from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import scipy.special

from gt4py import eve
from gt4py.cartesian.gtc.utils import dimension_flags_to_names, flatten_list
from gt4py.eve import datamodels


[docs]class GTCPreconditionError(eve.exceptions.EveError, RuntimeError): message_template = "GTC pass precondition error: [{info}]" def __init__(self, *, expected: str, **kwargs: Any) -> None: super().__init__(expected=expected, **kwargs) # type: ignore
[docs]class GTCPostconditionError(eve.exceptions.EveError, RuntimeError): message_template = "GTC pass postcondition error: [{info}]" def __init__(self, *, expected: str, **kwargs: Any) -> None: super().__init__(expected=expected, **kwargs) # type: ignore
[docs]class AssignmentKind(eve.StrEnum): """Kind of assignment: plain or combined with operations.""" PLAIN = "=" ADD = "+=" SUB = "-=" MUL = "*=" DIV = "/="
[docs]@enum.unique class UnaryOperator(eve.StrEnum): """Unary operator indentifier.""" POS = "+" NEG = "-" NOT = "not"
[docs]@enum.unique class ArithmeticOperator(eve.StrEnum): """Arithmetic operators.""" ADD = "+" SUB = "-" MUL = "*" DIV = "/" MATMULT = "@"
[docs]@enum.unique class ComparisonOperator(eve.StrEnum): """Comparison operators.""" GT = ">" LT = "<" GE = ">=" LE = "<=" EQ = "==" NE = "!="
[docs]@enum.unique class LogicalOperator(eve.StrEnum): """Logical operators.""" AND = "and" OR = "or"
[docs]@enum.unique class DataType(eve.IntEnum): """Data type identifier.""" # IDs from gt4py.cartesian INVALID = -1 AUTO = 0 DEFAULT = 1 BOOL = 10 INT8 = 11 INT16 = 12 INT32 = 14 INT64 = 18 FLOAT32 = 104 FLOAT64 = 108
[docs] def isbool(self): return self == self.BOOL
[docs] def isinteger(self): return self in (self.INT8, self.INT32, self.INT64)
[docs] def isfloat(self): return self in (self.FLOAT32, self.FLOAT64)
[docs]@enum.unique class LoopOrder(eve.StrEnum): """Loop order identifier.""" PARALLEL = "parallel" FORWARD = "forward" BACKWARD = "backward"
[docs]@enum.unique class BuiltInLiteral(eve.StrEnum): MAX_VALUE = "max" MIN_VALUE = "min" ZERO = "zero" ONE = "one" TRUE = "true" FALSE = "false"
[docs]@enum.unique class NativeFunction(eve.StrEnum): ABS = "abs" MIN = "min" MAX = "max" MOD = "mod" SIN = "sin" COS = "cos" TAN = "tan" ARCSIN = "arcsin" ARCCOS = "arccos" ARCTAN = "arctan" SINH = "sinh" COSH = "cosh" TANH = "tanh" ARCSINH = "arcsinh" ARCCOSH = "arccosh" ARCTANH = "arctanh" SQRT = "sqrt" POW = "pow" EXP = "exp" LOG = "log" GAMMA = "gamma" CBRT = "cbrt" ISFINITE = "isfinite" ISINF = "isinf" ISNAN = "isnan" FLOOR = "floor" CEIL = "ceil" TRUNC = "trunc" IR_OP_TO_NUM_ARGS: ClassVar[Dict["NativeFunction", int]] @property def arity(self) -> int: return self.IR_OP_TO_NUM_ARGS[self]
NativeFunction.IR_OP_TO_NUM_ARGS = { NativeFunction(f): v # instead of noqa on every line for f, v in { NativeFunction.ABS: 1, NativeFunction.MIN: 2, NativeFunction.MAX: 2, NativeFunction.MOD: 2, NativeFunction.SIN: 1, NativeFunction.COS: 1, NativeFunction.TAN: 1, NativeFunction.ARCSIN: 1, NativeFunction.ARCCOS: 1, NativeFunction.ARCTAN: 1, NativeFunction.SINH: 1, NativeFunction.COSH: 1, NativeFunction.TANH: 1, NativeFunction.ARCSINH: 1, NativeFunction.ARCCOSH: 1, NativeFunction.ARCTANH: 1, NativeFunction.SQRT: 1, NativeFunction.POW: 2, NativeFunction.EXP: 1, NativeFunction.LOG: 1, NativeFunction.GAMMA: 1, NativeFunction.CBRT: 1, NativeFunction.ISFINITE: 1, NativeFunction.ISINF: 1, NativeFunction.ISNAN: 1, NativeFunction.FLOOR: 1, NativeFunction.CEIL: 1, NativeFunction.TRUNC: 1, }.items() }
[docs]@enum.unique class LevelMarker(eve.StrEnum): START = "start" END = "end"
[docs]@enum.unique class ExprKind(eve.IntEnum): SCALAR: "ExprKind" = typing.cast("ExprKind", enum.auto()) FIELD: "ExprKind" = typing.cast("ExprKind", enum.auto())
[docs]class LocNode(eve.Node): loc: Optional[eve.SourceLocation] = None
[docs]@eve.utils.noninstantiable class Expr(LocNode): """ Expression base class. All expressions have - an optional `dtype` - an expression `kind` (scalar or field) """ # Both kind and dtype are set to default here and root validators propagate the correct value after the __init__ is called. kind: ExprKind = ExprKind.FIELD dtype: DataType = DataType.AUTO
[docs]@eve.utils.noninstantiable class Stmt(LocNode): pass
[docs]def verify_condition_is_boolean(parent_node_cls: datamodels.DataModel, cond: Expr) -> None: if cond.dtype and cond.dtype is not DataType.BOOL: raise ValueError("Condition in `{}` must be boolean.".format(type(parent_node_cls)))
[docs]def verify_and_get_common_dtype( node_cls: Type[datamodels.DataModel], exprs: List[Expr], *, strict: bool = True ) -> Optional[DataType]: assert len(exprs) > 0 if all(e.dtype is not DataType.AUTO for e in exprs): dtypes: List[DataType] = [e.dtype for e in exprs] # type: ignore # guaranteed to be not None dtype = dtypes[0] if strict: if all(dt == dtype for dt in dtypes): return dtype else: raise ValueError( f"Type mismatch in `{node_cls.__name__}`. Types are " + ", ".join(dt.name for dt in dtypes) ) else: # upcasting return max(dt for dt in dtypes) else: return None
[docs]def compute_kind(*values: Expr) -> ExprKind: if any(v.kind == ExprKind.FIELD for v in values): return ExprKind.FIELD else: return ExprKind.SCALAR
[docs]class Literal(eve.Node): # TODO(havogt) reconsider if `str` is a good representation for value, # maybe it should be Union[float,int,str] etc? value: Union[BuiltInLiteral, str] dtype: DataType kind: ExprKind = ExprKind.SCALAR
StmtT = TypeVar("StmtT", bound=Stmt) ExprT = TypeVar("ExprT", bound=Expr) TargetT = TypeVar("TargetT", bound=Expr) VariableKOffsetT = TypeVar("VariableKOffsetT")
[docs]class CartesianOffset(eve.Node): i: int j: int k: int
[docs] @classmethod def zero(cls) -> "CartesianOffset": return cls(i=0, j=0, k=0)
[docs] def to_dict(self) -> Dict[str, int]: return {"i": self.i, "j": self.j, "k": self.k}
[docs]class VariableKOffset(eve.GenericNode, Generic[ExprT]): k: ExprT
[docs] def to_dict(self) -> Dict[str, Optional[int]]: return {"i": 0, "j": 0, "k": None}
[docs] @datamodels.validator("k") def offset_expr_is_int(self, attribute: datamodels.Attribute, value: Any) -> None: value = typing.cast(Expr, value) if value.dtype is not DataType.AUTO and not value.dtype.isinteger(): raise ValueError("Variable vertical index must be an integer expression")
[docs]class ScalarAccess(LocNode): name: eve.Coerced[eve.SymbolRef] kind: ExprKind = ExprKind.SCALAR
[docs]class FieldAccess(eve.GenericNode, Generic[ExprT, VariableKOffsetT]): name: eve.Coerced[eve.SymbolRef] offset: Union[CartesianOffset, VariableKOffsetT] data_index: List[ExprT] = eve.field(default_factory=list) kind: ExprKind = ExprKind.FIELD
[docs] @classmethod def centered(cls, *, name: str, loc: eve.SourceLocation = None) -> "FieldAccess": return cls(name=name, loc=loc, offset=CartesianOffset.zero())
[docs] @datamodels.validator("data_index") def data_index_exprs_are_int(self, attribute: datamodels.Attribute, value: Any) -> None: value = typing.cast(List[Expr], value) if value and any( index.dtype is not DataType.AUTO and not index.dtype.isinteger() for index in value ): raise ValueError("Data indices must be integer expressions")
[docs]class BlockStmt(eve.GenericNode, eve.SymbolTableTrait, Generic[StmtT]): body: List[StmtT]
[docs]class IfStmt(eve.GenericNode, Generic[StmtT, ExprT]): """ Generic if statement. Verifies that `cond` is a boolean expr (if `dtype` is set). """ cond: ExprT true_branch: StmtT false_branch: Optional[StmtT] = None
[docs] @datamodels.validator("cond") def condition_is_boolean(self, attribute: datamodels.Attribute, value: Expr) -> None: verify_condition_is_boolean(self, value)
[docs]class While(eve.GenericNode, Generic[StmtT, ExprT]): """ Generic while loop. Verifies that `cond` is a boolean expr (if `dtype` is set). """ cond: ExprT body: List[StmtT]
[docs] @datamodels.validator("cond") def condition_is_boolean(self, attribute: datamodels.Attribute, value: Expr) -> None: verify_condition_is_boolean(self, value)
[docs]class AssignStmt(eve.GenericNode, Generic[TargetT, ExprT]): left: TargetT right: ExprT
def _make_root_validator(impl: datamodels.RootValidator) -> datamodels.RootValidator: return datamodels.root_validator(typing.cast(datamodels.RootValidator, classmethod(impl)))
[docs]def assign_stmt_dtype_validation(*, strict: bool) -> datamodels.RootValidator: def _impl( cls: Type[datamodels.DataModel], instance: datamodels.DataModel, ) -> None: assert isinstance(instance, AssignStmt) verify_and_get_common_dtype(cls, [instance.left, instance.right], strict=strict) return _make_root_validator(_impl)
[docs]class UnaryOp(eve.GenericNode, Generic[ExprT]): """ Generic unary operation with type propagation. The generic `UnaryOp` already contains logic for type propagation. """ op: UnaryOperator expr: ExprT
[docs] @datamodels.root_validator @classmethod def dtype_propagation(cls: Type[UnaryOp], instance: UnaryOp) -> None: instance.dtype = instance.expr.dtype # type: ignore[attr-defined]
[docs] @datamodels.root_validator @classmethod def kind_propagation(cls: Type[UnaryOp], instance: UnaryOp) -> None: instance.kind = instance.expr.kind # type: ignore[attr-defined]
[docs] @datamodels.root_validator @classmethod def op_to_dtype_check(cls: Type[UnaryOp], instance: UnaryOp) -> None: if instance.expr.dtype: if instance.op == UnaryOperator.NOT: if not instance.expr.dtype == DataType.BOOL: raise ValueError("Unary operator `NOT` only allowed with boolean expression.") else: if instance.expr.dtype == DataType.BOOL: raise ValueError( f"Unary operator `{instance.op.name}` not allowed with boolean expression." )
[docs]class BinaryOp(eve.GenericNode, Generic[ExprT]): """Generic binary operation with type propagation. The generic BinaryOp already contains logic for - strict type checking if the `dtype` for `left` and `right` is set. - type propagation (taking `operator` type into account). """ # consider parametrizing on op op: Union[ArithmeticOperator, ComparisonOperator, LogicalOperator] left: ExprT right: ExprT
[docs] @datamodels.root_validator @classmethod def kind_propagation(cls: Type[BinaryOp], instance: BinaryOp) -> None: instance.kind = compute_kind(instance.left, instance.right) # type: ignore[attr-defined]
[docs]def binary_op_dtype_propagation(*, strict: bool) -> datamodels.RootValidator: def _impl(cls: Type[BinaryOp], instance: BinaryOp) -> None: common_dtype = verify_and_get_common_dtype( cls, [instance.left, instance.right], strict=strict ) if common_dtype: if isinstance(instance.op, ArithmeticOperator): if common_dtype is not DataType.BOOL: instance.dtype = common_dtype # type: ignore[attr-defined] else: raise ValueError("Boolean expression is not allowed with arithmetic operation.") elif isinstance(instance.op, LogicalOperator): if common_dtype is DataType.BOOL: instance.dtype = DataType.BOOL # type: ignore[attr-defined] else: raise ValueError("Arithmetic expression is not allowed in boolean operation.") elif isinstance(instance.op, ComparisonOperator): instance.dtype = DataType.BOOL # type: ignore[attr-defined] return _make_root_validator(_impl)
[docs]class TernaryOp(eve.GenericNode, Generic[ExprT]): """ Generic ternary operation with type propagation. The generic TernaryOp already contains logic for - strict type checking if the `dtype` for `true_expr` and `false_expr` is set. - type checking for `cond` - type propagation. """ # consider parametrizing cond type and expr separately cond: ExprT true_expr: ExprT false_expr: ExprT
[docs] @datamodels.validator("cond") def condition_is_boolean(self, attribute: datamodels.Attribute, value: Expr) -> None: return verify_condition_is_boolean(self, value)
[docs] @datamodels.root_validator @classmethod def kind_propagation(cls: Type[TernaryOp], instance: TernaryOp) -> None: instance.kind = compute_kind(instance.true_expr, instance.false_expr) # type: ignore[attr-defined]
[docs]def ternary_op_dtype_propagation(*, strict: bool) -> datamodels.RootValidator: def _impl(cls: Type[TernaryOp], instance: TernaryOp) -> None: common_dtype = verify_and_get_common_dtype( cls, [instance.true_expr, instance.false_expr], strict=strict ) if common_dtype: instance.dtype = common_dtype # type: ignore[attr-defined] return _make_root_validator(_impl)
[docs]class Cast(eve.GenericNode, Generic[ExprT]): dtype: DataType expr: ExprT
[docs] @datamodels.root_validator @classmethod def kind_propagation(cls: Type[Cast], instance: Cast) -> None: instance.kind = compute_kind(instance.expr) # type: ignore[attr-defined]
[docs]class NativeFuncCall(eve.GenericNode, Generic[ExprT]): func: NativeFunction args: List[ExprT]
[docs] @datamodels.root_validator @classmethod def arity_check(cls: Type[NativeFuncCall], instance: NativeFuncCall) -> None: if instance.func.arity != len(instance.args): raise ValueError( f"{instance.func} accepts {instance.func.arity} arguments, {len(instance.args)} where passed." )
[docs] @datamodels.root_validator @classmethod def kind_propagation(cls: Type[NativeFuncCall], instance: NativeFuncCall) -> None: instance.kind = compute_kind(*instance.args) # type: ignore[attr-defined]
[docs]def native_func_call_dtype_propagation(*, strict: bool = True) -> datamodels.RootValidator: def _impl(cls: Type[NativeFuncCall], instance: NativeFuncCall) -> None: if instance.func in (NativeFunction.ISFINITE, NativeFunction.ISINF, NativeFunction.ISNAN): instance.dtype = DataType.BOOL # type: ignore[attr-defined] else: # assumes all NativeFunction args have a common dtype common_dtype = verify_and_get_common_dtype(cls, instance.args, strict=strict) if common_dtype: instance.dtype = common_dtype # type: ignore[attr-defined] return _make_root_validator(_impl)
[docs]def validate_dtype_is_set() -> datamodels.RootValidator: def _impl(cls: Type[ExprT], instance: ExprT) -> None: dtype_nodes: List[ExprT] = [] for v in flatten_list(datamodels.astuple(instance)): if isinstance(v, eve.Node): dtype_nodes.extend(v.walk_values().if_hasattr("dtype")) nodes_without_dtype = [] for node in dtype_nodes: if not node.dtype: nodes_without_dtype.append(node) if len(nodes_without_dtype) > 0: raise ValueError("Nodes without dtype detected {}".format(nodes_without_dtype)) return _make_root_validator(_impl)
class _LvalueDimsValidator(eve.VisitorWithSymbolTableTrait): def __init__(self, vertical_loop_type: Type[eve.Node], decl_type: Type[eve.Node]) -> None: if not vertical_loop_type.__annotations__.get("loop_order") is LoopOrder: raise ValueError( f"Vertical loop type {vertical_loop_type} has no `loop_order` attribute" ) if not decl_type.__annotations__.get("dimensions") == Tuple[bool, bool, bool]: raise ValueError( f"Field decl type {decl_type} must have a `dimensions` " "attribute of type `Tuple[bool, bool, bool]`." ) self.vertical_loop_type = vertical_loop_type self.decl_type = decl_type def visit_Node( self, node: eve.Node, *, loop_order: Optional[LoopOrder] = None, **kwargs: Any ) -> None: if isinstance(node, self.vertical_loop_type): loop_order = node.loop_order # type: ignore[attr-defined] # cannot narrow based on `vertical_loop_type` self.generic_visit(node, loop_order=loop_order, **kwargs) def visit_AssignStmt( self, node: AssignStmt, *, loop_order: LoopOrder, symtable: Dict[str, Any], **kwargs: Any ) -> None: decl = symtable.get(node.left.name, None) if decl is None: raise ValueError("Symbol {} not found.".format(node.left.name)) if not isinstance(decl, self.decl_type): return None allowed_flags = self._allowed_flags(loop_order) flags = decl.dimensions # type: ignore[attr-defined] # Decls are defined on derived IRs, not common, so `dimensions` is unknown. if flags not in allowed_flags: dims = dimension_flags_to_names(flags) raise ValueError( f"Not allowed to assign to {dims}-field `{node.left.name}` in {loop_order.name}." ) return None def _allowed_flags(self, loop_order: LoopOrder) -> List[Tuple[bool, bool, bool]]: allowed_flags = [(True, True, True)] # ijk always allowed if loop_order is not LoopOrder.PARALLEL: allowed_flags.append((True, True, False)) # ij only allowed in FORWARD and BACKWARD return allowed_flags # TODO(ricoh) consider making gtir.Decl & oir.Decl common and / or adding a VerticalLoop baseclass # TODO(ricoh) in common instead of passing type arguments
[docs]def validate_lvalue_dims( vertical_loop_type: Type[eve.Node], decl_type: Type[eve.Node] ) -> datamodels.RootValidator: """ Validate lvalue dimensions using the root node symbol table. The following tree structure is expected:: Root(`SymTableTrait`) |- * |- `vertical_loop_type` |- loop_order: `LoopOrder` |- * |- AssignStmt(`AssignStmt`) |- left: `Node`, validated only if reference to `decl_type` in symtable |- symtable_: Symtable[name, Union[`decl_type`, *]] DeclType |- dimensions: `Tuple[bool, bool, bool]` Parameters ---------- vertical_loop_type: A node type with an `LoopOrder` attribute named `loop_order` decl_type: A declaration type with field dimension information in the format `Tuple[bool, bool, bool]` in an attribute named `dimensions`. """ def _impl(cls: Type[datamodels.DataModel], instance: datamodels.DataModel) -> None: _LvalueDimsValidator(vertical_loop_type, decl_type).visit(instance) return _make_root_validator(_impl)
[docs]class AxisBound(eve.Node): level: LevelMarker offset: int = 0
[docs] @classmethod def from_start(cls, offset: int) -> AxisBound: return cls(level=LevelMarker.START, offset=offset)
[docs] @classmethod def from_end(cls, offset: int) -> AxisBound: return cls(level=LevelMarker.END, offset=offset)
[docs] @classmethod def start(cls, offset: int = 0) -> AxisBound: return cls.from_start(offset)
[docs] @classmethod def end(cls, offset: int = 0) -> AxisBound: return cls.from_end(offset)
def __eq__(self, other: object) -> bool: if not isinstance(other, AxisBound): return False return self.level == other.level and self.offset == other.offset def __lt__(self, other: AxisBound) -> bool: if not isinstance(other, AxisBound): return NotImplemented return (self.level == LevelMarker.START and other.level == LevelMarker.END) or ( self.level == other.level and self.offset < other.offset ) def __le__(self, other: AxisBound) -> bool: if not isinstance(other, AxisBound): return NotImplemented return self < other or self == other def __gt__(self, other: AxisBound) -> bool: if not isinstance(other, AxisBound): return NotImplemented return not self < other and not self == other def __ge__(self, other: AxisBound) -> bool: if not isinstance(other, AxisBound): return NotImplemented return not self < other
[docs]class HorizontalInterval(eve.Node): """Represents an interval of the index space in the horizontal. This is separate from `gtir.Interval` because the endpoints may be outside the compute domain. """ start: Optional[AxisBound] end: Optional[AxisBound]
[docs] @classmethod def compute_domain(cls, start_offset: int = 0, end_offset: int = 0) -> "HorizontalInterval": return cls(start=AxisBound.start(start_offset), end=AxisBound.end(end_offset))
[docs] @classmethod def full(cls) -> HorizontalInterval: return cls(start=None, end=None)
[docs] @classmethod def at_endpt( cls, level: LevelMarker, start_offset: int, end_offset: Optional[int] = None ) -> "HorizontalInterval": if end_offset is None: end_offset = start_offset + 1 return cls( start=AxisBound(level=level, offset=start_offset), end=AxisBound(level=level, offset=end_offset), )
[docs] @datamodels.root_validator @classmethod def check_start_before_end(cls: Type[HorizontalInterval], instance: HorizontalInterval) -> None: if instance.start and instance.end and not (instance.start <= instance.end): raise ValueError( f"End ({instance.end}) is not after or equal to start ({instance.start})" )
[docs] def is_single_index(self) -> bool: if self.start is None or self.end is None or self.start.level != self.end.level: return False return abs(self.end.offset - self.start.offset) == 1
[docs] def overlaps(self, other: HorizontalInterval) -> bool: if self.start is None and other.start is None: return True if self.start is None and other.start is not None: left_interval = self right_interval = other elif other.start is None or (self.start is not None and other.start < self.start): left_interval = other right_interval = self elif self.start is not None and other.start is not None and self.start < other.start: left_interval = self right_interval = other else: assert self.start == other.start return True if left_interval.end is None or ( right_interval.start is not None and right_interval.start < left_interval.end ): return True return False
[docs]class HorizontalMask(LocNode): """Expr to represent a convex portion of the horizontal iteration space.""" i: HorizontalInterval j: HorizontalInterval @property def intervals(self) -> Tuple[HorizontalInterval, HorizontalInterval]: return (self.i, self.j)
[docs]class HorizontalRestriction(eve.GenericNode, Generic[StmtT]): """A specialization of the horizontal space.""" mask: HorizontalMask body: List[StmtT]
[docs]def data_type_to_typestr(dtype: DataType) -> str: table = { DataType.BOOL: "bool", DataType.INT8: "int8", DataType.INT16: "int16", DataType.INT32: "int32", DataType.INT64: "int64", DataType.FLOAT32: "float32", DataType.FLOAT64: "float64", } if not isinstance(dtype, DataType): raise TypeError("Can only convert instances of DataType to typestr.") if dtype not in table: raise ValueError("Can not convert INVALID, AUTO or DEFAULT to typestr.") return np.dtype(table[dtype]).str
[docs]@functools.lru_cache(maxsize=None, typed=True) # typed since uniqueness is only guaranteed per enum def op_to_ufunc( op: Union[ UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction ] ) -> np.ufunc: table: Dict[ Union[ UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction ], np.ufunc, ] # Can't put all in single table since UnaryOperator.POS == BinaryOperator.ADD if isinstance(op, UnaryOperator): table = { UnaryOperator.POS: np.positive, UnaryOperator.NEG: np.negative, UnaryOperator.NOT: np.logical_not, } elif isinstance(op, ArithmeticOperator): table = { ArithmeticOperator.ADD: np.add, ArithmeticOperator.SUB: np.subtract, ArithmeticOperator.MUL: np.multiply, ArithmeticOperator.DIV: np.true_divide, } elif isinstance(op, ComparisonOperator): table = { ComparisonOperator.GT: np.greater, ComparisonOperator.LT: np.less, ComparisonOperator.GE: np.greater_equal, ComparisonOperator.LE: np.less_equal, ComparisonOperator.EQ: np.equal, ComparisonOperator.NE: np.not_equal, } elif isinstance(op, LogicalOperator): table = { LogicalOperator.AND: np.logical_and, LogicalOperator.OR: np.logical_or, } elif isinstance(op, NativeFunction): table = { NativeFunction.ABS: np.abs, NativeFunction.MIN: np.minimum, NativeFunction.MAX: np.maximum, NativeFunction.MOD: np.remainder, NativeFunction.SIN: np.sin, NativeFunction.COS: np.cos, NativeFunction.TAN: np.tan, NativeFunction.ARCSIN: np.arcsin, NativeFunction.ARCCOS: np.arccos, NativeFunction.ARCTAN: np.arctan, NativeFunction.SINH: np.sinh, NativeFunction.COSH: np.cosh, NativeFunction.TANH: np.tanh, NativeFunction.ARCSINH: np.arcsinh, NativeFunction.ARCCOSH: np.arccosh, NativeFunction.ARCTANH: np.arctanh, NativeFunction.SQRT: np.sqrt, NativeFunction.POW: np.power, NativeFunction.EXP: np.exp, NativeFunction.LOG: np.log, NativeFunction.GAMMA: scipy.special.gamma, NativeFunction.CBRT: np.cbrt, NativeFunction.ISFINITE: np.isfinite, NativeFunction.ISINF: np.isinf, NativeFunction.ISNAN: np.isnan, NativeFunction.FLOOR: np.floor, NativeFunction.CEIL: np.ceil, NativeFunction.TRUNC: np.trunc, } else: raise TypeError( "Can only convert instances of GTC operators and supported native functions to typestr." ) return table[op]
[docs]@functools.lru_cache(maxsize=None) def typestr_to_data_type(typestr: str) -> DataType: if not isinstance(typestr, str) or len(typestr) < 3 or not typestr[2:].isnumeric(): return DataType.INVALID # type: ignore table = { ("b", 1): DataType.BOOL, ("i", 1): DataType.INT8, ("i", 2): DataType.INT16, ("i", 4): DataType.INT32, ("i", 8): DataType.INT64, ("f", 4): DataType.FLOAT32, ("f", 8): DataType.FLOAT64, } key = (typestr[1], int(typestr[2:])) return table.get(key, DataType.INVALID) # type: ignore