Module hebi.rewrite.rewrite_zero_ary
Expand source code
from ast import *
from ..util import CompilingNodeTransformer
from ..typed_ast import (
TypedFunctionDef,
FunctionType,
NoneInstanceType,
TypedConstant,
TypedCall,
UnitInstanceType,
)
"""
Rewrites all functions that don't take arguments
into functions that take a singleton None argument.
Also rewrites all function calls without arguments
to calls that pass Unit into the function.
We need to take case of the dataclass call there, which should not be adjusted.
"""
class RewriteZeroAry(CompilingNodeTransformer):
step = "Rewriting zero-ary functions"
def visit_FunctionDef(self, node: TypedFunctionDef) -> TypedFunctionDef:
if len(node.args.args) == 0:
node.args.args.append(arg("_", Constant(None)))
assert isinstance(node.typ.typ, FunctionType)
node.typ.typ.argtyps.append(NoneInstanceType)
self.generic_visit(node)
return node
def visit_Call(self, node: TypedCall) -> TypedCall:
if isinstance(node.func, Name) and node.func.id == "dataclass":
# special case for the dataclass function
return node
if node.func.typ.typ.argtyps == [UnitInstanceType] and node.args == []:
# this would not pass the type check normally, only possible due to the zero-arg rewrite
# 0-ary functions expect another parameter
node.args.append(TypedConstant(None, typ=UnitInstanceType))
self.generic_visit(node)
return node
Classes
class RewriteZeroAry
-
A :class:
NodeVisitor
subclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformer
will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone
, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo
) todata['foo']
::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Index(value=Str(s=node.id)), ctx=node.ctx )
Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visit
method for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class RewriteZeroAry(CompilingNodeTransformer): step = "Rewriting zero-ary functions" def visit_FunctionDef(self, node: TypedFunctionDef) -> TypedFunctionDef: if len(node.args.args) == 0: node.args.args.append(arg("_", Constant(None))) assert isinstance(node.typ.typ, FunctionType) node.typ.typ.argtyps.append(NoneInstanceType) self.generic_visit(node) return node def visit_Call(self, node: TypedCall) -> TypedCall: if isinstance(node.func, Name) and node.func.id == "dataclass": # special case for the dataclass function return node if node.func.typ.typ.argtyps == [UnitInstanceType] and node.args == []: # this would not pass the type check normally, only possible due to the zero-arg rewrite # 0-ary functions expect another parameter node.args.append(TypedConstant(None, typ=UnitInstanceType)) self.generic_visit(node) return node
Ancestors
- CompilingNodeTransformer
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Class variables
var step
Methods
def visit(self, node)
-
Inherited from:
CompilingNodeTransformer
.visit
Visit a node.
def visit_Call(self, node: TypedCall) ‑> TypedCall
-
Expand source code
def visit_Call(self, node: TypedCall) -> TypedCall: if isinstance(node.func, Name) and node.func.id == "dataclass": # special case for the dataclass function return node if node.func.typ.typ.argtyps == [UnitInstanceType] and node.args == []: # this would not pass the type check normally, only possible due to the zero-arg rewrite # 0-ary functions expect another parameter node.args.append(TypedConstant(None, typ=UnitInstanceType)) self.generic_visit(node) return node
def visit_FunctionDef(self, node: TypedFunctionDef) ‑> TypedFunctionDef
-
Expand source code
def visit_FunctionDef(self, node: TypedFunctionDef) -> TypedFunctionDef: if len(node.args.args) == 0: node.args.args.append(arg("_", Constant(None))) assert isinstance(node.typ.typ, FunctionType) node.typ.typ.argtyps.append(NoneInstanceType) self.generic_visit(node) return node