# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2022, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
import functools
import itertools
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Set, Union, cast
from devtools import debug # noqa: F401
from typing_extensions import Protocol
from gt4py import eve
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.common import CartesianOffset, ExprKind
from gt4py.cartesian.gtc.gtcpp import gtcpp
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import (
collect_symbol_names,
symbol_name_creator,
)
# - Each HorizontalExecution is a Functor (and a Stage)
# - Each VerticalLoop is MultiStage
def _extract_accessors(node: eve.Node, temp_names: Set[str]) -> List[gtcpp.GTAccessor]:
extents = (
node.walk_values()
.if_isinstance(gtcpp.AccessorRef)
.reduceby(
(lambda extent, accessor_ref: extent + accessor_ref.offset),
"name",
init=gtcpp.GTExtent.zero(),
as_dict=True,
)
)
inout_fields: Set[str] = (
node.walk_values()
.if_isinstance(gtcpp.AssignStmt)
.getattr("left")
.if_isinstance(gtcpp.AccessorRef)
.getattr("name")
.to_set()
)
ndims = dict(
node.walk_values()
.if_isinstance(gtcpp.AccessorRef)
.map(
lambda accessor: (
accessor.name,
3 + (len(accessor.data_index) if accessor.name not in temp_names else 0),
)
)
)
return [
gtcpp.GTAccessor(
name=name,
id=i,
intent=gtcpp.Intent.INOUT if name in inout_fields else gtcpp.Intent.IN,
extent=extent,
ndim=ndims[name],
)
for i, (name, extent) in enumerate(extents.items())
]
def _make_axis_offset_expr(
bound: common.AxisBound,
axis_index: int,
axis_length_accessor: Callable[[int], gtcpp.AccessorRef],
) -> gtcpp.Expr:
if bound.level == common.LevelMarker.END:
base = axis_length_accessor(axis_index)
return gtcpp.BinaryOp(
op=common.ArithmeticOperator.ADD,
left=base,
right=gtcpp.Literal(value=str(bound.offset), dtype=common.DataType.INT32),
)
else:
return gtcpp.Literal(value=str(bound.offset), dtype=common.DataType.INT32)
[docs]class SymbolNameCreator(Protocol):
[docs] def __call__(self, name: str) -> str:
...
[docs]class OIRToGTCpp(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait):
[docs] @dataclass
class ProgramContext:
functors: List[gtcpp.GTFunctor] = field(default_factory=list)
[docs] def add_functor(self, functor: gtcpp.GTFunctor) -> "OIRToGTCpp.ProgramContext":
self.functors.append(functor)
return self
[docs] @dataclass
class GTComputationContext:
create_symbol_name: SymbolNameCreator
temporaries: List[gtcpp.Temporary] = field(default_factory=list)
positionals: Dict[int, gtcpp.Positional] = field(default_factory=dict)
axis_lengths: Dict[int, gtcpp.AxisLength] = field(default_factory=dict)
_arguments: Set[str] = field(default_factory=set)
[docs] def add_temporaries(
self, temporaries: List[gtcpp.Temporary]
) -> "OIRToGTCpp.GTComputationContext":
self.temporaries.extend(temporaries)
return self
@property
def arguments(self) -> List[gtcpp.Arg]:
return [gtcpp.Arg(name=name) for name in self._arguments]
[docs] def add_arguments(self, arguments: Set[str]) -> "OIRToGTCpp.GTComputationContext":
self._arguments.update(arguments)
return self
@staticmethod
def _make_scalar_accessor(name: str) -> gtcpp.AccessorRef:
return gtcpp.AccessorRef(
name=name,
offset=CartesianOffset.zero(),
kind=ExprKind.SCALAR,
dtype=common.DataType.INT32,
)
[docs] def make_positional(self, axis: int) -> gtcpp.AccessorRef:
axis_name = ["I", "J", "K"][axis].lower()
name = self.create_symbol_name(f"ax{axis}_ind")
positional = self.positionals.setdefault(
axis, gtcpp.Positional(name=name, axis_name=axis_name)
)
return self._make_scalar_accessor(positional.name)
[docs] def make_length(self, axis: int) -> gtcpp.AccessorRef:
name = self.create_symbol_name(f"ax{axis}_len")
length = self.axis_lengths.setdefault(axis, gtcpp.AxisLength(name=name, axis=axis))
return self._make_scalar_accessor(length.name)
@property
def extra_decls(self) -> List[gtcpp.ComputationDecl]:
return list(self.positionals.values()) + list(self.axis_lengths.values())
[docs] def visit_Literal(self, node: oir.Literal, **kwargs: Any) -> gtcpp.Literal:
return gtcpp.Literal(value=node.value, dtype=node.dtype)
[docs] def visit_UnaryOp(self, node: oir.UnaryOp, **kwargs: Any) -> gtcpp.UnaryOp:
return gtcpp.UnaryOp(op=node.op, expr=self.visit(node.expr, **kwargs))
[docs] def visit_BinaryOp(self, node: oir.BinaryOp, **kwargs: Any) -> gtcpp.BinaryOp:
return gtcpp.BinaryOp(
op=node.op,
left=self.visit(node.left, **kwargs),
right=self.visit(node.right, **kwargs),
)
[docs] def visit_TernaryOp(self, node: oir.TernaryOp, **kwargs: Any) -> gtcpp.TernaryOp:
return gtcpp.TernaryOp(
cond=self.visit(node.cond, **kwargs),
true_expr=self.visit(node.true_expr, **kwargs),
false_expr=self.visit(node.false_expr, **kwargs),
)
[docs] def visit_NativeFuncCall(self, node: oir.NativeFuncCall, **kwargs: Any) -> gtcpp.NativeFuncCall:
return gtcpp.NativeFuncCall(func=node.func, args=self.visit(node.args, **kwargs))
[docs] def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> gtcpp.Cast:
return gtcpp.Cast(dtype=node.dtype, expr=self.visit(node.expr, **kwargs))
[docs] def visit_Temporary(self, node: oir.Temporary, **kwargs: Any) -> gtcpp.Temporary:
return gtcpp.Temporary(name=node.name, dtype=node.dtype, data_dims=node.data_dims)
[docs] def visit_VariableKOffset(
self, node: oir.VariableKOffset, **kwargs: Any
) -> gtcpp.VariableKOffset:
return gtcpp.VariableKOffset(k=self.visit(node.k, **kwargs))
[docs] def visit_FieldAccess(self, node: oir.FieldAccess, **kwargs: Any) -> gtcpp.AccessorRef:
return gtcpp.AccessorRef(
name=node.name,
offset=self.visit(node.offset, **kwargs),
data_index=self.visit(node.data_index, **kwargs),
dtype=node.dtype,
)
[docs] def visit_ScalarAccess(
self, node: oir.ScalarAccess, **kwargs: Any
) -> Union[gtcpp.AccessorRef, gtcpp.LocalAccess]:
assert "symtable" in kwargs
if node.name in kwargs["symtable"]:
symbol = kwargs["symtable"][node.name]
if isinstance(symbol, oir.ScalarDecl):
return gtcpp.AccessorRef(
name=symbol.name, offset=CartesianOffset.zero(), dtype=symbol.dtype
)
assert isinstance(symbol, oir.LocalScalar)
return gtcpp.LocalAccess(name=node.name, dtype=node.dtype)
[docs] def visit_AxisBound(
self, node: oir.AxisBound, *, is_start: bool, **kwargs: Any
) -> gtcpp.GTLevel:
if node.level == common.LevelMarker.START:
splitter = 0
offset = node.offset + 1 if (node.offset >= 0 and is_start) else node.offset
elif node.level == common.LevelMarker.END:
splitter = 1
offset = node.offset - 1 if (node.offset <= 0 and not is_start) else node.offset
else:
raise ValueError("Cannot handle dynamic levels")
return gtcpp.GTLevel(splitter=splitter, offset=offset)
[docs] def visit_Interval(self, node: oir.Interval, **kwargs: Any) -> gtcpp.GTInterval:
return gtcpp.GTInterval(
from_level=self.visit(node.start, is_start=True),
to_level=self.visit(node.end, is_start=False),
)
def _mask_to_expr(
self, mask: common.HorizontalMask, comp_ctx: "GTComputationContext"
) -> gtcpp.Expr:
mask_expr: List[gtcpp.Expr] = []
for axis_index, interval in enumerate(mask.intervals):
if interval.is_single_index():
assert interval.start is not None
mask_expr.append(
gtcpp.BinaryOp(
op=common.ComparisonOperator.EQ,
left=comp_ctx.make_positional(axis_index),
right=_make_axis_offset_expr(
interval.start, axis_index, comp_ctx.make_length
),
)
)
else:
for op, endpt in zip(
(common.ComparisonOperator.GE, common.ComparisonOperator.LT),
(interval.start, interval.end),
):
if endpt is None:
continue
mask_expr.append(
gtcpp.BinaryOp(
op=op,
left=comp_ctx.make_positional(axis_index),
right=_make_axis_offset_expr(endpt, axis_index, comp_ctx.make_length),
)
)
return (
functools.reduce(
lambda a, b: gtcpp.BinaryOp(op=common.LogicalOperator.AND, left=a, right=b),
mask_expr,
)
if mask_expr
else gtcpp.Literal(value=common.BuiltInLiteral.TRUE, dtype=common.DataType.BOOL)
)
[docs] def visit_HorizontalRestriction(
self, node: oir.HorizontalRestriction, **kwargs: Any
) -> gtcpp.IfStmt:
mask = self._mask_to_expr(node.mask, kwargs["comp_ctx"])
return gtcpp.IfStmt(
cond=mask, true_branch=gtcpp.BlockStmt(body=self.visit(node.body, **kwargs))
)
[docs] def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs: Any) -> gtcpp.AssignStmt:
assert "symtable" in kwargs
return gtcpp.AssignStmt(
left=self.visit(node.left, **kwargs), right=self.visit(node.right, **kwargs)
)
[docs] def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> gtcpp.IfStmt:
return gtcpp.IfStmt(
cond=self.visit(node.mask, **kwargs),
true_branch=gtcpp.BlockStmt(body=self.visit(node.body, **kwargs)),
)
[docs] def visit_While(self, node: oir.While, **kwargs: Any) -> gtcpp.While:
return gtcpp.While(
cond=self.visit(node.cond, **kwargs), body=self.visit(node.body, **kwargs)
)
[docs] def visit_HorizontalExecution(
self,
node: oir.HorizontalExecution,
*,
prog_ctx: "ProgramContext",
comp_ctx: "GTComputationContext",
interval: gtcpp.GTInterval,
**kwargs: Any,
) -> gtcpp.GTStage:
assert "symtable" in kwargs
apply_method = gtcpp.GTApplyMethod(
interval=self.visit(interval, **kwargs),
body=self.visit(node.body, comp_ctx=comp_ctx, **kwargs),
local_variables=self.visit(node.declarations, **kwargs),
)
accessors = _extract_accessors(apply_method, {decl.name for decl in comp_ctx.temporaries})
stage_args = [gtcpp.Arg(name=acc.name) for acc in accessors]
tmp_names = {tmp.name for tmp in comp_ctx.temporaries}
param_names_not_tmps = {
str(param_arg.name) for param_arg in stage_args if param_arg.name not in tmp_names
}
comp_ctx.add_arguments(param_names_not_tmps)
functor_name = type(node).__name__ + str(id(node))
prog_ctx.add_functor(
gtcpp.GTFunctor(
name=functor_name,
applies=[apply_method],
param_list=gtcpp.GTParamList(accessors=accessors),
)
)
return gtcpp.GTStage(functor=functor_name, args=stage_args)
[docs] def visit_VerticalLoop(
self,
node: oir.VerticalLoop,
*,
comp_ctx: GTComputationContext,
**kwargs: Any,
) -> gtcpp.GTMultiStage:
# the following visit assumes that temporaries are already available in comp_ctx
stages = list(
itertools.chain(
*(
self.visit(
section.horizontal_executions,
interval=section.interval,
default=([], []),
comp_ctx=comp_ctx,
**kwargs,
)
for section in node.sections
)
)
)
caches = self.visit(node.caches)
return gtcpp.GTMultiStage(loop_order=node.loop_order, stages=stages, caches=caches)
[docs] def visit_IJCache(self, node: oir.IJCache, **kwargs: Any) -> gtcpp.IJCache:
return gtcpp.IJCache(name=node.name, loc=node.loc)
[docs] def visit_KCache(self, node: oir.KCache, **kwargs: Any) -> gtcpp.KCache:
return gtcpp.KCache(name=node.name, fill=node.fill, flush=node.flush, loc=node.loc)
[docs] def visit_FieldDecl(self, node: oir.FieldDecl, **kwargs: Any) -> gtcpp.FieldDecl:
return gtcpp.FieldDecl(
name=node.name, dtype=node.dtype, dimensions=node.dimensions, data_dims=node.data_dims
)
[docs] def visit_ScalarDecl(self, node: oir.ScalarDecl, **kwargs: Any) -> gtcpp.GlobalParamDecl:
return gtcpp.GlobalParamDecl(name=node.name, dtype=node.dtype)
[docs] def visit_LocalScalar(self, node: oir.LocalScalar, **kwargs: Any) -> gtcpp.LocalVarDecl:
return gtcpp.LocalVarDecl(name=node.name, dtype=node.dtype, loc=node.loc)
[docs] def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> gtcpp.Program:
prog_ctx = self.ProgramContext()
comp_ctx = self.GTComputationContext(
create_symbol_name=cast(
SymbolNameCreator, symbol_name_creator(collect_symbol_names(node))
)
)
assert all([isinstance(decl, oir.Temporary) for decl in node.declarations])
comp_ctx.add_temporaries(self.visit(node.declarations))
multi_stages = self.visit(
node.vertical_loops, prog_ctx=prog_ctx, comp_ctx=comp_ctx, **kwargs
)
gt_computation = gtcpp.GTComputationCall(
arguments=comp_ctx.arguments,
extra_decls=comp_ctx.extra_decls,
temporaries=comp_ctx.temporaries,
multi_stages=multi_stages,
)
parameters = self.visit(node.params)
return gtcpp.Program(
name=node.name,
parameters=parameters,
functors=prog_ctx.functors,
gt_computation=gt_computation,
)