Source code for gt4py.cartesian.frontend.nodes

# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2022, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""
Implementation of the intermediate representations used in GT4Py.

-----------
Definitions
-----------

Empty
    Empty node value (`None` is a valid Python value)

InvalidBranch
    Sentinel value for wrongly build conditional expressions

Builtin enumeration (:class:`Builtin`)
    Named Python constants
    [`NONE`, `FALSE`, `TRUE`]

DataType enumeration (:class:`DataType`)
    Native numeric data types
    [`INVALID`, `AUTO`, `DEFAULT`, `BOOL`,
    `INT8`, `INT16`, `INT32`, `INT64`, `FLOAT32`, `FLOAT64`]

UnaryOperator enumeration (:class:`UnaryOperator`)
    Unary operators
    [`POS`, `NEG`, `NOT`]

BinaryOperator enumeration (:class:`BinaryOperator`)
    Binary operators
    [`ADD`, `SUB`, `MUL`, `DIV`, `POW`, `AND`, `OR`,
    `EQ`, `NE`, `LT`, `LE`, `GT`, `GE`]

NativeFunction enumeration (:class:`NativeFunction`)
    Native function identifier
    [`ABS`, `MAX`, `MIN, `MOD`, `SIN`, `COS`, `TAN`, `ARCSIN`, `ARCCOS`, `ARCTAN`,
    `SQRT`, `EXP`, `LOG`, `ISFINITE`, `ISINF`, `ISNAN`, `FLOOR`, `CEIL`, `TRUNC`]

LevelMarker enumeration (:class:`LevelMarker`)
    Special axis levels
    [`START`, `END`]

IterationOrder enumeration (:class:`IterationOrder`)
    Execution order
    [`BACKWARD`, `PARALLEL`, `FORWARD`]

Index (:class:`gt4py.definitions.Index`)
    Multidimensional integer offset
    [int+]

Extent (:class:`gt4py.definitions.Extent`)
    Multidimensional integer extent
    [(lower: `int`, upper: `int`)+]



-------------
Definition IR
-------------

All nodes have an optional attribute `loc` [`Location(line: int, column: int, scope: str)`]
storing a reference to the piece of source code which originated the node.

 ::

    Axis(name: str)

    Domain(parallel_axes: List[Axis], [sequential_axis: Axis])
        # LatLonGrids -> parallel_axes: ["I", "J"], sequential_axis: "K"

    Literal     = ScalarLiteral(value: Any (should match DataType), data_type: DataType)
                | BuiltinLiteral(value: Builtin)

    Ref         = VarRef(name: str, [index: int])
                | FieldRef(name: str, offset: Dict[str, int | Expr])
                # Horizontal indices must be ints

    NativeFuncCall(func: NativeFunction, args: List[Expr], data_type: DataType)

    Cast(expr: Expr, data_type: DataType)

    AxisPosition(axis: str, data_type: DataType)

    AxisIndex(axis: str, endpt: LevelMarker, offset: int, data_type: DataType)

    Expr        = Literal | Ref | NativeFuncCall | Cast | CompositeExpr | InvalidBranch | AxisPosition | AxisIndex

    CompositeExpr   = UnaryOpExpr(op: UnaryOperator, arg: Expr)
                    | BinOpExpr(op: BinaryOperator, lhs: Expr, rhs: Expr)
                    | TernaryOpExpr(condition: Expr, then_expr: Expr, else_expr: Expr)

    Decl        = FieldDecl(name: str, data_type: DataType, axes: List[str],
                            is_api: bool, layout_id: str)
                | VarDecl(name: str, data_type: DataType, length: int,
                          is_api: bool, [init: Literal])

    BlockStmt(stmts: List[Statement])

    Statement   = Decl
                | Assign(target: Ref, value: Expr)
                | If(condition: expr, main_body: BlockStmt, else_body: BlockStmt)
                | HorizontalIf(intervals: Dict[str, Interval], body: BlockStmt)
                | While(condition: expr, body: BlockStmt)
                | BlockStmt

    AxisBound(level: LevelMarker | VarRef, offset: int)
        # bound = level + offset
        # level: LevelMarker = special START or END level
        # level: VarRef = access to `int` or `[int]` variable holding the run-time value of the level
        # offset: int

    AxisInterval(start: AxisBound, end: AxisBound)
        # start is included
        # end is excluded

    ComputationBlock(interval: AxisInterval, iteration_order: IterationOrder, body: BlockStmt)

    ArgumentInfo(name: str, is_keyword: bool, [default: Any])

    StencilDefinition(name: str,
                      domain: Domain,
                      api_signature: List[ArgumentInfo],
                      api_fields: List[FieldDecl],
                      parameters: List[VarDecl],
                      computations: List[ComputationBlock],
                      [externals: Dict[str, Any], sources: Dict[str, str]])

"""

import enum
import operator
import sys
from typing import Generator, List, Optional, Sequence, Type

import numpy as np

from gt4py.cartesian.definitions import AccessKind, CartesianSpace
from gt4py.cartesian.gtc.definitions import Extent, Index
from gt4py.cartesian.utils.attrib import Any as Any
from gt4py.cartesian.utils.attrib import Dict as DictOf
from gt4py.cartesian.utils.attrib import List as ListOf
from gt4py.cartesian.utils.attrib import Optional as OptionalOf
from gt4py.cartesian.utils.attrib import Tuple as TupleOf
from gt4py.cartesian.utils.attrib import Union as UnionOf
from gt4py.cartesian.utils.attrib import attribkwclass as attribclass
from gt4py.cartesian.utils.attrib import attribute, attributes_of


# ---- Foundations ----
[docs]class Empty: pass
[docs]class Node: pass
[docs]@attribclass class Location(Node): line = attribute(of=int) column = attribute(of=int) scope = attribute(of=str, default="<source>")
[docs] @classmethod def from_ast_node(cls, ast_node, scope="<source>"): lineno = getattr(ast_node, "lineno", 0) col_offset = getattr(ast_node, "col_offset", 0) return cls(line=lineno, column=col_offset + 1, scope=scope)
# ---- IR: domain ----
[docs]@attribclass class Axis(Node): name = attribute(of=str)
[docs]@enum.unique class LevelMarker(enum.Enum): START = 0 END = -1 def __str__(self): return self.name
[docs]@attribclass class Domain(Node): parallel_axes = attribute(of=ListOf[Axis]) sequential_axis = attribute(of=Axis, optional=True)
[docs] @classmethod def LatLonGrid(cls): return cls( parallel_axes=[ Axis(name=CartesianSpace.Axis.I.name), Axis(name=CartesianSpace.Axis.J.name), ], sequential_axis=Axis(name=CartesianSpace.Axis.K.name), )
@property def axes(self): result = list(self.parallel_axes) if self.sequential_axis: result.append(self.sequential_axis) return result @property def axes_names(self): return [ax.name for ax in self.axes] @property def domain_ndims(self): return len(self.parallel_axes) + (1 if self.sequential_axis else 0) ndims = domain_ndims
[docs] def index(self, axis): if isinstance(axis, Axis): axis = axis.name assert isinstance(axis, str) return self.axes_names.index(axis)
# ---- IR: types ----
[docs]@enum.unique class Builtin(enum.Enum): NONE = -1 FALSE = 0 TRUE = 1
[docs] @classmethod def from_value(cls, value): if value is None: result = cls.NONE elif value is True: result = cls.TRUE elif value is False: result = cls.FALSE return result
def __str__(self): return self.name
[docs]@enum.unique class DataType(enum.Enum): INVALID = -1 AUTO = 0 DEFAULT = 1 BOOL = 10 INT8 = 11 INT16 = 12 INT32 = 14 INT64 = 18 FLOAT32 = 104 FLOAT64 = 108 def __str__(self): return self.name @property def dtype(self): return np.dtype(self.NATIVE_TYPE_TO_NUMPY[self])
[docs] @classmethod def from_dtype(cls, py_dtype): if isinstance(py_dtype, type): py_dtype = np.dtype(py_dtype) assert isinstance(py_dtype, np.dtype) return cls.NUMPY_TO_NATIVE_TYPE.get(py_dtype.name, cls.INVALID)
[docs] @classmethod def merge(cls, *args): result = cls(max(arg.value for arg in args)) return result
DataType.NATIVE_TYPE_TO_NUMPY = { DataType.DEFAULT: "float_", DataType.BOOL: "bool", DataType.INT8: "int8", DataType.INT16: "int16", DataType.INT32: "int32", DataType.INT64: "int64", DataType.FLOAT32: "float32", DataType.FLOAT64: "float64", } DataType.NUMPY_TO_NATIVE_TYPE = {value: key for key, value in DataType.NATIVE_TYPE_TO_NUMPY.items()} # ---- IR: expressions ----
[docs]class Expr(Node): pass
[docs]class Literal(Expr): pass
[docs]class InvalidBranch(Expr): pass
[docs]@attribclass class ScalarLiteral(Literal): value = attribute(of=Any) # Potentially an array of numeric structs data_type = attribute(of=DataType) loc = attribute(of=Location, optional=True)
# @attribclass # class TupleLiteral(Node): # items = attribute(of=TupleOf[Expr]) # # @property # def length(self): # return len(self.items)
[docs]@attribclass class BuiltinLiteral(Literal): value = attribute(of=Builtin) loc = attribute(of=Location, optional=True)
[docs]class Ref(Expr): pass
[docs]@attribclass class VarRef(Ref): name = attribute(of=str) index = attribute(of=int, optional=True) loc = attribute(of=Location, optional=True)
[docs]@attribclass class FieldRef(Ref): name = attribute(of=str) offset = attribute(of=DictOf[str, UnionOf[int, Expr]]) data_index = attribute(of=ListOf[Expr], factory=list) loc = attribute(of=Location, optional=True)
[docs] @classmethod def at_center( cls, name: str, axes: Sequence[str], data_index: Optional[List[int]] = None, loc=None ): return cls( name=name, offset={axis: 0 for axis in axes}, data_index=data_index or [], loc=loc )
[docs]@attribclass class Cast(Expr): data_type = attribute(of=DataType) expr = attribute(of=Expr) loc = attribute(of=Location, optional=True)
[docs]@attribclass class AxisPosition(Expr): axis = attribute(of=str) data_type = attribute(of=DataType, default=DataType.INT32)
[docs]@attribclass class AxisIndex(Expr): axis = attribute(of=str) endpt = attribute(of=LevelMarker) offset = attribute(of=int) data_type = attribute(of=DataType, default=DataType.INT32)
[docs]@enum.unique class NativeFunction(enum.Enum): ABS = enum.auto() MIN = enum.auto() MAX = enum.auto() MOD = enum.auto() SIN = enum.auto() COS = enum.auto() TAN = enum.auto() ARCSIN = enum.auto() ARCCOS = enum.auto() ARCTAN = enum.auto() SINH = enum.auto() COSH = enum.auto() TANH = enum.auto() ARCSINH = enum.auto() ARCCOSH = enum.auto() ARCTANH = enum.auto() SQRT = enum.auto() EXP = enum.auto() LOG = enum.auto() GAMMA = enum.auto() CBRT = enum.auto() ISFINITE = enum.auto() ISINF = enum.auto() ISNAN = enum.auto() FLOOR = enum.auto() CEIL = enum.auto() TRUNC = enum.auto() @property def arity(self): return type(self).IR_OP_TO_NUM_ARGS[self]
NativeFunction.IR_OP_TO_NUM_ARGS = { NativeFunction.ABS: 1, NativeFunction.MIN: 2, NativeFunction.MAX: 2, NativeFunction.MOD: 2, NativeFunction.SIN: 1, NativeFunction.COS: 1, NativeFunction.TAN: 1, NativeFunction.ARCSIN: 1, NativeFunction.ARCCOS: 1, NativeFunction.ARCTAN: 1, NativeFunction.SINH: 1, NativeFunction.COSH: 1, NativeFunction.TANH: 1, NativeFunction.ARCSINH: 1, NativeFunction.ARCCOSH: 1, NativeFunction.ARCTANH: 1, NativeFunction.SQRT: 1, NativeFunction.EXP: 1, NativeFunction.LOG: 1, NativeFunction.GAMMA: 1, NativeFunction.CBRT: 1, NativeFunction.ISFINITE: 1, NativeFunction.ISINF: 1, NativeFunction.ISNAN: 1, NativeFunction.FLOOR: 1, NativeFunction.CEIL: 1, NativeFunction.TRUNC: 1, }
[docs]@attribclass class NativeFuncCall(Expr): func = attribute(of=NativeFunction) args = attribute(of=ListOf[Expr]) data_type = attribute(of=DataType) loc = attribute(of=Location, optional=True)
[docs]class CompositeExpr(Expr): pass
[docs]@enum.unique class UnaryOperator(enum.Enum): POS = 1 NEG = 2 TRANSPOSED = 5 NOT = 11 @property def python_op(self): return type(self).IR_OP_TO_PYTHON_OP[self] @property def python_symbol(self): return type(self).IR_OP_TO_PYTHON_SYMBOL[self]
UnaryOperator.IR_OP_TO_PYTHON_OP = { UnaryOperator.POS: operator.pos, UnaryOperator.NEG: operator.neg, UnaryOperator.NOT: operator.not_, } UnaryOperator.IR_OP_TO_PYTHON_SYMBOL = { UnaryOperator.POS: "+", UnaryOperator.NEG: "-", UnaryOperator.NOT: "not", }
[docs]@attribclass class UnaryOpExpr(CompositeExpr): op = attribute(of=UnaryOperator) arg = attribute(of=Expr) loc = attribute(of=Location, optional=True)
[docs]@enum.unique class BinaryOperator(enum.Enum): ADD = 1 SUB = 2 MUL = 3 DIV = 4 POW = 5 MOD = 6 MATMULT = 8 AND = 11 OR = 12 EQ = 21 NE = 22 LT = 23 LE = 24 GT = 25 GE = 26 @property def python_op(self): return type(self).IR_OP_TO_PYTHON_OP[self] @property def python_symbol(self): return type(self).IR_OP_TO_PYTHON_SYMBOL[self]
BinaryOperator.IR_OP_TO_PYTHON_OP = { BinaryOperator.ADD: operator.add, BinaryOperator.SUB: operator.sub, BinaryOperator.MUL: operator.mul, BinaryOperator.DIV: operator.truediv, BinaryOperator.POW: operator.pow, BinaryOperator.MOD: operator.mod, # BinaryOperator.AND: lambda a, b: a and b, # non short-circuit emulation # BinaryOperator.OR: lambda a, b: a or b, # non short-circuit emulation BinaryOperator.LT: operator.lt, BinaryOperator.LE: operator.le, BinaryOperator.EQ: operator.eq, BinaryOperator.GE: operator.ge, BinaryOperator.GT: operator.gt, BinaryOperator.NE: operator.ne, } BinaryOperator.IR_OP_TO_PYTHON_SYMBOL = { BinaryOperator.ADD: "+", BinaryOperator.SUB: "-", BinaryOperator.MUL: "*", BinaryOperator.DIV: "/", BinaryOperator.POW: "**", BinaryOperator.MOD: "%", BinaryOperator.AND: "and", BinaryOperator.OR: "or", BinaryOperator.LT: "<", BinaryOperator.LE: "<=", BinaryOperator.EQ: "==", BinaryOperator.GE: ">=", BinaryOperator.GT: ">", BinaryOperator.NE: "!=", }
[docs]@attribclass class BinOpExpr(CompositeExpr): op = attribute(of=BinaryOperator) lhs = attribute(of=Expr) rhs = attribute(of=Expr) loc = attribute(of=Location, optional=True)
[docs]@attribclass class TernaryOpExpr(CompositeExpr): condition = attribute(of=Expr) then_expr = attribute(of=Expr) else_expr = attribute(of=Expr) loc = attribute(of=Location, optional=True)
# ---- IR: statements ----
[docs]class Statement(Node): pass
# @attribclass # class ExprStmt(Statement): # expr = attribute(of=Expr) # loc = attribute(of=Location, optional=True)
[docs]class Decl(Statement): pass
[docs]@attribclass class FieldDecl(Decl): name = attribute(of=str) data_type = attribute(of=DataType) axes = attribute(of=ListOf[str]) is_api = attribute(of=bool) data_dims = attribute(of=ListOf[int], factory=list) layout_id = attribute(of=str, default="_default_") loc = attribute(of=Location, optional=True)
[docs]@attribclass class VarDecl(Decl): name = attribute(of=str) data_type = attribute(of=DataType) length = attribute(of=int) is_api = attribute(of=bool) init = attribute(of=Literal, optional=True) loc = attribute(of=Location, optional=True) @property def is_parameter(self): return self.is_api @property def is_scalar(self): return self.length == 0
[docs]@attribclass class BlockStmt(Statement): stmts = attribute(of=ListOf[Statement]) loc = attribute(of=Location, optional=True)
[docs]@attribclass class Assign(Statement): target = attribute(of=Ref) value = attribute(of=Expr) loc = attribute(of=Location, optional=True)
[docs]@attribclass class If(Statement): condition = attribute(of=Expr) main_body = attribute(of=BlockStmt) else_body = attribute(of=BlockStmt, optional=True) loc = attribute(of=Location, optional=True)
[docs]@attribclass class While(Statement): condition = attribute(of=Expr) body = attribute(of=BlockStmt) loc = attribute(of=Location, optional=True)
# ---- IR: computations ----
[docs]@enum.unique class IterationOrder(enum.Enum): BACKWARD = -1 PARALLEL = 0 FORWARD = 1 @property def symbol(self): if self == self.BACKWARD: return "<-" elif self == self.PARALLEL: return "||" elif self == self.FORWARD: return "->" def __str__(self): return self.name def __lshift__(self, steps: int): return self.cycle(steps=-steps) def __rshift__(self, steps: int): return self.cycle(steps=steps)
[docs]@attribclass class AxisBound(Node): level = attribute(of=LevelMarker) offset = attribute(of=int, default=0) loc = attribute(of=Location, optional=True)
[docs]@attribclass class AxisInterval(Node): start = attribute(of=AxisBound) end = attribute(of=AxisBound) loc = attribute(of=Location, optional=True)
[docs] @classmethod def full_interval(cls, order=IterationOrder.PARALLEL): if order != IterationOrder.BACKWARD: interval = cls( start=AxisBound(level=LevelMarker.START, offset=0), end=AxisBound(level=LevelMarker.END, offset=0), ) else: interval = cls( start=AxisBound(level=LevelMarker.END, offset=-1), end=AxisBound(level=LevelMarker.START, offset=-1), ) return interval
@property def is_single_index(self) -> bool: if not isinstance(self.start, AxisBound) or not isinstance(self.end, AxisBound): return False return self.start.level == self.end.level and self.start.offset == self.end.offset - 1
[docs] def disjoint_from(self, other: "AxisInterval") -> bool: # This made-up constant must be larger than any LevelMarker.offset used DOMAIN_SIZE: int = 1000 def get_offset(bound: AxisBound) -> int: return ( 0 + bound.offset if bound.level == LevelMarker.START else sys.maxsize + bound.offset ) self_start = get_offset(self.start) self_end = get_offset(self.end) other_start = get_offset(other.start) other_end = get_offset(other.end) return not (self_start <= other_start < self_end) and not ( other_start <= self_start < other_end )
# TODO Find a better place for this in the file.
[docs]@attribclass class HorizontalIf(Statement): intervals = attribute(of=DictOf[str, AxisInterval]) body = attribute(of=BlockStmt)
[docs]@attribclass class ComputationBlock(Node): interval = attribute(of=AxisInterval) iteration_order = attribute(of=IterationOrder) body = attribute(of=BlockStmt) loc = attribute(of=Location, optional=True)
[docs]@attribclass class ArgumentInfo(Node): name = attribute(of=str) is_keyword = attribute(of=bool, default=False) default = attribute(of=Any, default=Empty)
[docs]@attribclass class StencilDefinition(Node): name = attribute(of=str) domain = attribute(of=Domain) api_signature = attribute(of=ListOf[ArgumentInfo]) api_fields = attribute(of=ListOf[FieldDecl]) parameters = attribute(of=ListOf[VarDecl]) computations = attribute(of=ListOf[ComputationBlock]) externals = attribute(of=DictOf[str, Any], optional=True) sources = attribute(of=DictOf[str, str], optional=True) docstring = attribute(of=str, default="") loc = attribute(of=Location, optional=True)