Module hebi.util

Expand source code
import ast
import pycardano
from enum import Enum, auto

from .typed_ast import *

import pluthon as plt
import uplc.ast as uplc


PowImpl = plt.Lambda(
    ["f", "x", "y"],
    plt.Ite(
        plt.LessThanEqualsInteger(plt.Var("y"), plt.Integer(0)),
        plt.Integer(1),
        plt.MultiplyInteger(
            plt.Var("x"),
            plt.Apply(
                plt.Var("f"),
                plt.Var("f"),
                plt.Var("x"),
                plt.SubtractInteger(plt.Var("y"), plt.Integer(1)),
            ),
        ),
    ),
)


class PythonBuiltIn(Enum):
    all = plt.Lambda(
        ["_", "xs"],
        plt.FoldList(
            plt.Var("xs"),
            plt.Lambda(["x", "a"], plt.And(plt.Var("x"), plt.Var("a"))),
            plt.Bool(True),
        ),
    )
    any = plt.Lambda(
        ["_", "xs"],
        plt.FoldList(
            plt.Var("xs"),
            plt.Lambda(["x", "a"], plt.Or(plt.Var("x"), plt.Var("a"))),
            plt.Bool(False),
        ),
    )
    abs = plt.Lambda(
        ["_", "x"],
        plt.Ite(
            plt.LessThanInteger(plt.Var("x"), plt.Integer(0)),
            plt.Negate(plt.Var("x")),
            plt.Var("x"),
        ),
    )
    # maps an integer to a unicode code point and decodes it
    # reference: https://en.wikipedia.org/wiki/UTF-8#Encoding
    chr = plt.Lambda(
        ["_", "x"],
        plt.DecodeUtf8(
            plt.Ite(
                plt.LessThanInteger(plt.Var("x"), plt.Integer(0x0)),
                plt.TraceError("ValueError: chr() arg not in range(0x110000)"),
                plt.Ite(
                    plt.LessThanInteger(plt.Var("x"), plt.Integer(0x80)),
                    # encoding of 0x0 - 0x80
                    plt.ConsByteString(plt.Var("x"), plt.ByteString(b"")),
                    plt.Ite(
                        plt.LessThanInteger(plt.Var("x"), plt.Integer(0x800)),
                        # encoding of 0x80 - 0x800
                        plt.ConsByteString(
                            # we do bit manipulation using integer arithmetic here - nice
                            plt.AddInteger(
                                plt.Integer(0b110 << 5),
                                plt.DivideInteger(plt.Var("x"), plt.Integer(1 << 6)),
                            ),
                            plt.ConsByteString(
                                plt.AddInteger(
                                    plt.Integer(0b10 << 6),
                                    plt.ModInteger(plt.Var("x"), plt.Integer(1 << 6)),
                                ),
                                plt.ByteString(b""),
                            ),
                        ),
                        plt.Ite(
                            plt.LessThanInteger(plt.Var("x"), plt.Integer(0x10000)),
                            # encoding of 0x800 - 0x10000
                            plt.ConsByteString(
                                plt.AddInteger(
                                    plt.Integer(0b1110 << 4),
                                    plt.DivideInteger(
                                        plt.Var("x"), plt.Integer(1 << 12)
                                    ),
                                ),
                                plt.ConsByteString(
                                    plt.AddInteger(
                                        plt.Integer(0b10 << 6),
                                        plt.DivideInteger(
                                            plt.ModInteger(
                                                plt.Var("x"), plt.Integer(1 << 12)
                                            ),
                                            plt.Integer(1 << 6),
                                        ),
                                    ),
                                    plt.ConsByteString(
                                        plt.AddInteger(
                                            plt.Integer(0b10 << 6),
                                            plt.ModInteger(
                                                plt.Var("x"), plt.Integer(1 << 6)
                                            ),
                                        ),
                                        plt.ByteString(b""),
                                    ),
                                ),
                            ),
                            plt.Ite(
                                plt.LessThanInteger(
                                    plt.Var("x"), plt.Integer(0x110000)
                                ),
                                # encoding of 0x10000 - 0x10FFF
                                plt.ConsByteString(
                                    plt.AddInteger(
                                        plt.Integer(0b11110 << 3),
                                        plt.DivideInteger(
                                            plt.Var("x"), plt.Integer(1 << 18)
                                        ),
                                    ),
                                    plt.ConsByteString(
                                        plt.AddInteger(
                                            plt.Integer(0b10 << 6),
                                            plt.DivideInteger(
                                                plt.ModInteger(
                                                    plt.Var("x"), plt.Integer(1 << 18)
                                                ),
                                                plt.Integer(1 << 12),
                                            ),
                                        ),
                                        plt.ConsByteString(
                                            plt.AddInteger(
                                                plt.Integer(0b10 << 6),
                                                plt.DivideInteger(
                                                    plt.ModInteger(
                                                        plt.Var("x"),
                                                        plt.Integer(1 << 12),
                                                    ),
                                                    plt.Integer(1 << 6),
                                                ),
                                            ),
                                            plt.ConsByteString(
                                                plt.AddInteger(
                                                    plt.Integer(0b10 << 6),
                                                    plt.ModInteger(
                                                        plt.Var("x"),
                                                        plt.Integer(1 << 6),
                                                    ),
                                                ),
                                                plt.ByteString(b""),
                                            ),
                                        ),
                                    ),
                                ),
                                plt.TraceError(
                                    "ValueError: chr() arg not in range(0x110000)"
                                ),
                            ),
                        ),
                    ),
                ),
            )
        ),
    )
    breakpoint = plt.NoneData()
    hex = plt.Lambda(
        ["_", "x"],
        plt.DecodeUtf8(
            plt.Let(
                [
                    (
                        "hexlist",
                        plt.RecFun(
                            plt.Lambda(
                                ["f", "i"],
                                plt.Ite(
                                    plt.LessThanEqualsInteger(
                                        plt.Var("i"), plt.Integer(0)
                                    ),
                                    plt.EmptyIntegerList(),
                                    plt.MkCons(
                                        plt.Let(
                                            [
                                                (
                                                    "mod",
                                                    plt.ModInteger(
                                                        plt.Var("i"), plt.Integer(16)
                                                    ),
                                                ),
                                            ],
                                            plt.AddInteger(
                                                plt.Var("mod"),
                                                plt.IfThenElse(
                                                    plt.LessThanInteger(
                                                        plt.Var("mod"), plt.Integer(10)
                                                    ),
                                                    plt.Integer(ord("0")),
                                                    plt.Integer(ord("a") - 10),
                                                ),
                                            ),
                                        ),
                                        plt.Apply(
                                            plt.Var("f"),
                                            plt.Var("f"),
                                            plt.DivideInteger(
                                                plt.Var("i"), plt.Integer(16)
                                            ),
                                        ),
                                    ),
                                ),
                            ),
                        ),
                    ),
                    (
                        "mkstr",
                        plt.Lambda(
                            ["i"],
                            plt.FoldList(
                                plt.Apply(plt.Var("hexlist"), plt.Var("i")),
                                plt.Lambda(
                                    ["b", "i"],
                                    plt.ConsByteString(plt.Var("i"), plt.Var("b")),
                                ),
                                plt.ByteString(b""),
                            ),
                        ),
                    ),
                ],
                plt.Ite(
                    plt.EqualsInteger(plt.Var("x"), plt.Integer(0)),
                    plt.ByteString(b"0x0"),
                    plt.Ite(
                        plt.LessThanInteger(plt.Var("x"), plt.Integer(0)),
                        plt.ConsByteString(
                            plt.Integer(ord("-")),
                            plt.AppendByteString(
                                plt.ByteString(b"0x"),
                                plt.Apply(plt.Var("mkstr"), plt.Negate(plt.Var("x"))),
                            ),
                        ),
                        plt.AppendByteString(
                            plt.ByteString(b"0x"),
                            plt.Apply(plt.Var("mkstr"), plt.Var("x")),
                        ),
                    ),
                ),
            )
        ),
    )
    len = auto()
    max = plt.Lambda(
        ["_", "xs"],
        plt.FoldList(
            plt.TailList(plt.Var("xs")),
            plt.Lambda(
                ["x", "a"],
                plt.IfThenElse(
                    plt.LessThanInteger(plt.Var("a"), plt.Var("x")),
                    plt.Var("x"),
                    plt.Var("a"),
                ),
            ),
            plt.HeadList(plt.Var("xs")),
        ),
    )
    min = plt.Lambda(
        ["_", "xs"],
        plt.FoldList(
            plt.TailList(plt.Var("xs")),
            plt.Lambda(
                ["x", "a"],
                plt.IfThenElse(
                    plt.LessThanInteger(plt.Var("a"), plt.Var("x")),
                    plt.Var("a"),
                    plt.Var("x"),
                ),
            ),
            plt.HeadList(plt.Var("xs")),
        ),
    )
    print = plt.Lambda(
        ["_", "x"],
        plt.Trace(plt.Var("x"), plt.NoneData()),
    )
    # NOTE: only correctly defined for positive y
    pow = PowImpl
    oct = plt.Lambda(
        ["_", "x"],
        plt.DecodeUtf8(
            plt.Let(
                [
                    (
                        "octlist",
                        plt.RecFun(
                            plt.Lambda(
                                ["f", "i"],
                                plt.Ite(
                                    plt.LessThanEqualsInteger(
                                        plt.Var("i"), plt.Integer(0)
                                    ),
                                    plt.EmptyIntegerList(),
                                    plt.MkCons(
                                        plt.AddInteger(
                                            plt.ModInteger(
                                                plt.Var("i"), plt.Integer(8)
                                            ),
                                            plt.Integer(ord("0")),
                                        ),
                                        plt.Apply(
                                            plt.Var("f"),
                                            plt.Var("f"),
                                            plt.DivideInteger(
                                                plt.Var("i"), plt.Integer(8)
                                            ),
                                        ),
                                    ),
                                ),
                            ),
                        ),
                    ),
                    (
                        "mkoct",
                        plt.Lambda(
                            ["i"],
                            plt.FoldList(
                                plt.Apply(plt.Var("octlist"), plt.Var("i")),
                                plt.Lambda(
                                    ["b", "i"],
                                    plt.ConsByteString(plt.Var("i"), plt.Var("b")),
                                ),
                                plt.ByteString(b""),
                            ),
                        ),
                    ),
                ],
                plt.Ite(
                    plt.EqualsInteger(plt.Var("x"), plt.Integer(0)),
                    plt.ByteString(b"0o0"),
                    plt.Ite(
                        plt.LessThanInteger(plt.Var("x"), plt.Integer(0)),
                        plt.ConsByteString(
                            plt.Integer(ord("-")),
                            plt.AppendByteString(
                                plt.ByteString(b"0o"),
                                plt.Apply(plt.Var("mkoct"), plt.Negate(plt.Var("x"))),
                            ),
                        ),
                        plt.AppendByteString(
                            plt.ByteString(b"0o"),
                            plt.Apply(plt.Var("mkoct"), plt.Var("x")),
                        ),
                    ),
                ),
            )
        ),
    )
    range = plt.Lambda(
        ["_", "limit"],
        plt.Range(plt.Var("limit")),
    )
    reversed = auto()
    sum = plt.Lambda(
        ["_", "xs"],
        plt.FoldList(
            plt.Var("xs"), plt.BuiltIn(uplc.BuiltInFun.AddInteger), plt.Integer(0)
        ),
    )


class LenImpl(PolymorphicFunction):
    def type_from_args(self, args: typing.List[Type]) -> FunctionType:
        assert (
            len(args) == 1
        ), f"'len' takes only one argument, but {len(args)} were given"
        assert isinstance(
            args[0], InstanceType
        ), "Can only determine length of instances"
        return FunctionType(args, IntegerInstanceType)

    def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
        arg = args[0]
        assert isinstance(arg, InstanceType), "Can only determine length of instances"
        if arg == ByteStringInstanceType:
            return plt.Lambda(["_", "x"], plt.LengthOfByteString(plt.Var("x")))
        elif isinstance(arg.typ, ListType):
            # simple list length function
            return plt.Lambda(
                ["_", "x"],
                plt.FoldList(
                    plt.Var("x"),
                    plt.Lambda(
                        ["a", "_"], plt.AddInteger(plt.Var("a"), plt.Integer(1))
                    ),
                    plt.Integer(0),
                ),
            )
        raise NotImplementedError(f"'len' is not implemented for type {arg}")


class ReversedImpl(PolymorphicFunction):
    def type_from_args(self, args: typing.List[Type]) -> FunctionType:
        assert (
            len(args) == 1
        ), f"'reversed' takes only one argument, but {len(args)} were given"
        typ = args[0]
        assert isinstance(typ, InstanceType), "Can only reverse instances"
        assert isinstance(typ.typ, ListType), "Can only reverse instances of lists"
        # returns list of same type
        return FunctionType(args, typ)

    def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
        arg = args[0]
        assert isinstance(arg, InstanceType), "Can only reverse instances"
        if isinstance(arg.typ, ListType):
            empty_l = empty_list(arg.typ.typ)
            return plt.Lambda(
                ["_", "xs"],
                plt.FoldList(
                    plt.Var("xs"),
                    plt.Lambda(["a", "x"], plt.MkCons(plt.Var("x"), plt.Var("a"))),
                    empty_l,
                ),
            )
        raise NotImplementedError(f"'reversed' is not implemented for type {arg}")


PythonBuiltInTypes = {
    PythonBuiltIn.all: InstanceType(
        FunctionType(
            [InstanceType(ListType(BoolInstanceType))],
            BoolInstanceType,
        )
    ),
    PythonBuiltIn.any: InstanceType(
        FunctionType(
            [InstanceType(ListType(BoolInstanceType))],
            BoolInstanceType,
        )
    ),
    PythonBuiltIn.abs: InstanceType(
        FunctionType(
            [IntegerInstanceType],
            IntegerInstanceType,
        )
    ),
    PythonBuiltIn.chr: InstanceType(
        FunctionType(
            [IntegerInstanceType],
            StringInstanceType,
        )
    ),
    PythonBuiltIn.breakpoint: InstanceType(FunctionType([], NoneInstanceType)),
    PythonBuiltIn.len: InstanceType(PolymorphicFunctionType(LenImpl())),
    PythonBuiltIn.hex: InstanceType(
        FunctionType(
            [IntegerInstanceType],
            StringInstanceType,
        )
    ),
    PythonBuiltIn.max: InstanceType(
        FunctionType(
            [InstanceType(ListType(IntegerInstanceType))],
            IntegerInstanceType,
        )
    ),
    PythonBuiltIn.min: InstanceType(
        FunctionType(
            [InstanceType(ListType(IntegerInstanceType))],
            IntegerInstanceType,
        )
    ),
    PythonBuiltIn.print: InstanceType(
        FunctionType([StringInstanceType], NoneInstanceType)
    ),
    PythonBuiltIn.pow: InstanceType(
        FunctionType(
            [IntegerInstanceType, IntegerInstanceType],
            IntegerInstanceType,
        )
    ),
    PythonBuiltIn.oct: InstanceType(
        FunctionType(
            [IntegerInstanceType],
            StringInstanceType,
        )
    ),
    PythonBuiltIn.range: InstanceType(
        FunctionType(
            [IntegerInstanceType],
            InstanceType(ListType(IntegerInstanceType)),
        )
    ),
    PythonBuiltIn.reversed: InstanceType(PolymorphicFunctionType(ReversedImpl())),
    PythonBuiltIn.sum: InstanceType(
        FunctionType(
            [InstanceType(ListType(IntegerInstanceType))],
            IntegerInstanceType,
        )
    ),
}


class CompilerError(Exception):
    def __init__(self, orig_err: Exception, node: ast.AST, compilation_step: str):
        self.orig_err = orig_err
        self.node = node
        self.compilation_step = compilation_step


class CompilingNodeTransformer(TypedNodeTransformer):
    step = "Node transformation"

    def visit(self, node):
        try:
            return super().visit(node)
        except Exception as e:
            if isinstance(e, CompilerError):
                raise e
            raise CompilerError(e, node, self.step)


class CompilingNodeVisitor(TypedNodeVisitor):
    step = "Node visiting"

    def visit(self, node):
        try:
            return super().visit(node)
        except Exception as e:
            if isinstance(e, CompilerError):
                raise e
            raise CompilerError(e, node, self.step)


def data_from_json(j: typing.Dict[str, typing.Any]) -> uplc.PlutusData:
    if "bytes" in j:
        return uplc.PlutusByteString(bytes.fromhex(j["bytes"]))
    if "int" in j:
        return uplc.PlutusInteger(int(j["int"]))
    if "list" in j:
        return uplc.PlutusList(list(map(data_from_json, j["list"])))
    if "map" in j:
        return uplc.PlutusMap({d["k"]: d["v"] for d in j["map"]})
    if "constructor" in j and "fields" in j:
        return uplc.PlutusConstr(
            j["constructor"], list(map(data_from_json, j["fields"]))
        )
    raise NotImplementedError(f"Unknown datum representation {j}")


def datum_to_cbor(d: pycardano.Datum) -> bytes:
    return pycardano.PlutusData.to_cbor(d, encoding="bytes")


def datum_to_json(d: pycardano.Datum) -> str:
    return pycardano.PlutusData.to_json(d)


class ReturnExtractor(TypedNodeVisitor):
    """Utility to find all Return statements in an AST subtree"""

    def __init__(self):
        self.returns = []

    def visit_Return(self, node: Return) -> None:
        self.returns.append(node)

Functions

def data_from_json(j: Dict[str, Any]) ‑> uplc.ast.PlutusData
Expand source code
def data_from_json(j: typing.Dict[str, typing.Any]) -> uplc.PlutusData:
    if "bytes" in j:
        return uplc.PlutusByteString(bytes.fromhex(j["bytes"]))
    if "int" in j:
        return uplc.PlutusInteger(int(j["int"]))
    if "list" in j:
        return uplc.PlutusList(list(map(data_from_json, j["list"])))
    if "map" in j:
        return uplc.PlutusMap({d["k"]: d["v"] for d in j["map"]})
    if "constructor" in j and "fields" in j:
        return uplc.PlutusConstr(
            j["constructor"], list(map(data_from_json, j["fields"]))
        )
    raise NotImplementedError(f"Unknown datum representation {j}")
def datum_to_cbor(d: Union[pycardano.plutus.PlutusData, dict, pycardano.serialization.IndefiniteList, int, bytes, pycardano.serialization.RawCBOR, pycardano.plutus.RawPlutusData]) ‑> bytes
Expand source code
def datum_to_cbor(d: pycardano.Datum) -> bytes:
    return pycardano.PlutusData.to_cbor(d, encoding="bytes")
def datum_to_json(d: Union[pycardano.plutus.PlutusData, dict, pycardano.serialization.IndefiniteList, int, bytes, pycardano.serialization.RawCBOR, pycardano.plutus.RawPlutusData]) ‑> str
Expand source code
def datum_to_json(d: pycardano.Datum) -> str:
    return pycardano.PlutusData.to_json(d)

Classes

class CompilerError (orig_err: Exception, node: _ast.AST, compilation_step: str)

Common base class for all non-exit exceptions.

Expand source code
class CompilerError(Exception):
    def __init__(self, orig_err: Exception, node: ast.AST, compilation_step: str):
        self.orig_err = orig_err
        self.node = node
        self.compilation_step = compilation_step

Ancestors

  • builtins.Exception
  • builtins.BaseException
class CompilingNodeTransformer

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Index(value=Str(s=node.id)),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class CompilingNodeTransformer(TypedNodeTransformer):
    step = "Node transformation"

    def visit(self, node):
        try:
            return super().visit(node)
        except Exception as e:
            if isinstance(e, CompilerError):
                raise e
            raise CompilerError(e, node, self.step)

Ancestors

Subclasses

Class variables

var step

Methods

def visit(self, node)

Inherited from: TypedNodeTransformer.visit

Visit a node.

Expand source code
def visit(self, node):
    try:
        return super().visit(node)
    except Exception as e:
        if isinstance(e, CompilerError):
            raise e
        raise CompilerError(e, node, self.step)
class CompilingNodeVisitor

A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the visit method.

This class is meant to be subclassed, with the subclass adding visitor methods.

Per default the visitor functions for the nodes are 'visit_' + class name of the node. So a TryFinally node visit function would be visit_TryFinally. This behavior can be changed by overriding the visit method. If no visitor function exists for a node (return value None) the generic_visit visitor is used instead.

Don't use the NodeVisitor if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer) that allows modifications.

Expand source code
class CompilingNodeVisitor(TypedNodeVisitor):
    step = "Node visiting"

    def visit(self, node):
        try:
            return super().visit(node)
        except Exception as e:
            if isinstance(e, CompilerError):
                raise e
            raise CompilerError(e, node, self.step)

Ancestors

Subclasses

Class variables

var step

Methods

def visit(self, node)

Inherited from: TypedNodeVisitor.visit

Visit a node.

Expand source code
def visit(self, node):
    try:
        return super().visit(node)
    except Exception as e:
        if isinstance(e, CompilerError):
            raise e
        raise CompilerError(e, node, self.step)
class LenImpl
Expand source code
class LenImpl(PolymorphicFunction):
    def type_from_args(self, args: typing.List[Type]) -> FunctionType:
        assert (
            len(args) == 1
        ), f"'len' takes only one argument, but {len(args)} were given"
        assert isinstance(
            args[0], InstanceType
        ), "Can only determine length of instances"
        return FunctionType(args, IntegerInstanceType)

    def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
        arg = args[0]
        assert isinstance(arg, InstanceType), "Can only determine length of instances"
        if arg == ByteStringInstanceType:
            return plt.Lambda(["_", "x"], plt.LengthOfByteString(plt.Var("x")))
        elif isinstance(arg.typ, ListType):
            # simple list length function
            return plt.Lambda(
                ["_", "x"],
                plt.FoldList(
                    plt.Var("x"),
                    plt.Lambda(
                        ["a", "_"], plt.AddInteger(plt.Var("a"), plt.Integer(1))
                    ),
                    plt.Integer(0),
                ),
            )
        raise NotImplementedError(f"'len' is not implemented for type {arg}")

Ancestors

Methods

def impl_from_args(self, args: List[Type]) ‑> pluthon.pluthon_ast.AST
Expand source code
def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
    arg = args[0]
    assert isinstance(arg, InstanceType), "Can only determine length of instances"
    if arg == ByteStringInstanceType:
        return plt.Lambda(["_", "x"], plt.LengthOfByteString(plt.Var("x")))
    elif isinstance(arg.typ, ListType):
        # simple list length function
        return plt.Lambda(
            ["_", "x"],
            plt.FoldList(
                plt.Var("x"),
                plt.Lambda(
                    ["a", "_"], plt.AddInteger(plt.Var("a"), plt.Integer(1))
                ),
                plt.Integer(0),
            ),
        )
    raise NotImplementedError(f"'len' is not implemented for type {arg}")
def type_from_args(self, args: List[Type]) ‑> FunctionType
Expand source code
def type_from_args(self, args: typing.List[Type]) -> FunctionType:
    assert (
        len(args) == 1
    ), f"'len' takes only one argument, but {len(args)} were given"
    assert isinstance(
        args[0], InstanceType
    ), "Can only determine length of instances"
    return FunctionType(args, IntegerInstanceType)
class PythonBuiltIn (value, names=None, *, module=None, qualname=None, type=None, start=1)

An enumeration.

Expand source code
class PythonBuiltIn(Enum):
    all = plt.Lambda(
        ["_", "xs"],
        plt.FoldList(
            plt.Var("xs"),
            plt.Lambda(["x", "a"], plt.And(plt.Var("x"), plt.Var("a"))),
            plt.Bool(True),
        ),
    )
    any = plt.Lambda(
        ["_", "xs"],
        plt.FoldList(
            plt.Var("xs"),
            plt.Lambda(["x", "a"], plt.Or(plt.Var("x"), plt.Var("a"))),
            plt.Bool(False),
        ),
    )
    abs = plt.Lambda(
        ["_", "x"],
        plt.Ite(
            plt.LessThanInteger(plt.Var("x"), plt.Integer(0)),
            plt.Negate(plt.Var("x")),
            plt.Var("x"),
        ),
    )
    # maps an integer to a unicode code point and decodes it
    # reference: https://en.wikipedia.org/wiki/UTF-8#Encoding
    chr = plt.Lambda(
        ["_", "x"],
        plt.DecodeUtf8(
            plt.Ite(
                plt.LessThanInteger(plt.Var("x"), plt.Integer(0x0)),
                plt.TraceError("ValueError: chr() arg not in range(0x110000)"),
                plt.Ite(
                    plt.LessThanInteger(plt.Var("x"), plt.Integer(0x80)),
                    # encoding of 0x0 - 0x80
                    plt.ConsByteString(plt.Var("x"), plt.ByteString(b"")),
                    plt.Ite(
                        plt.LessThanInteger(plt.Var("x"), plt.Integer(0x800)),
                        # encoding of 0x80 - 0x800
                        plt.ConsByteString(
                            # we do bit manipulation using integer arithmetic here - nice
                            plt.AddInteger(
                                plt.Integer(0b110 << 5),
                                plt.DivideInteger(plt.Var("x"), plt.Integer(1 << 6)),
                            ),
                            plt.ConsByteString(
                                plt.AddInteger(
                                    plt.Integer(0b10 << 6),
                                    plt.ModInteger(plt.Var("x"), plt.Integer(1 << 6)),
                                ),
                                plt.ByteString(b""),
                            ),
                        ),
                        plt.Ite(
                            plt.LessThanInteger(plt.Var("x"), plt.Integer(0x10000)),
                            # encoding of 0x800 - 0x10000
                            plt.ConsByteString(
                                plt.AddInteger(
                                    plt.Integer(0b1110 << 4),
                                    plt.DivideInteger(
                                        plt.Var("x"), plt.Integer(1 << 12)
                                    ),
                                ),
                                plt.ConsByteString(
                                    plt.AddInteger(
                                        plt.Integer(0b10 << 6),
                                        plt.DivideInteger(
                                            plt.ModInteger(
                                                plt.Var("x"), plt.Integer(1 << 12)
                                            ),
                                            plt.Integer(1 << 6),
                                        ),
                                    ),
                                    plt.ConsByteString(
                                        plt.AddInteger(
                                            plt.Integer(0b10 << 6),
                                            plt.ModInteger(
                                                plt.Var("x"), plt.Integer(1 << 6)
                                            ),
                                        ),
                                        plt.ByteString(b""),
                                    ),
                                ),
                            ),
                            plt.Ite(
                                plt.LessThanInteger(
                                    plt.Var("x"), plt.Integer(0x110000)
                                ),
                                # encoding of 0x10000 - 0x10FFF
                                plt.ConsByteString(
                                    plt.AddInteger(
                                        plt.Integer(0b11110 << 3),
                                        plt.DivideInteger(
                                            plt.Var("x"), plt.Integer(1 << 18)
                                        ),
                                    ),
                                    plt.ConsByteString(
                                        plt.AddInteger(
                                            plt.Integer(0b10 << 6),
                                            plt.DivideInteger(
                                                plt.ModInteger(
                                                    plt.Var("x"), plt.Integer(1 << 18)
                                                ),
                                                plt.Integer(1 << 12),
                                            ),
                                        ),
                                        plt.ConsByteString(
                                            plt.AddInteger(
                                                plt.Integer(0b10 << 6),
                                                plt.DivideInteger(
                                                    plt.ModInteger(
                                                        plt.Var("x"),
                                                        plt.Integer(1 << 12),
                                                    ),
                                                    plt.Integer(1 << 6),
                                                ),
                                            ),
                                            plt.ConsByteString(
                                                plt.AddInteger(
                                                    plt.Integer(0b10 << 6),
                                                    plt.ModInteger(
                                                        plt.Var("x"),
                                                        plt.Integer(1 << 6),
                                                    ),
                                                ),
                                                plt.ByteString(b""),
                                            ),
                                        ),
                                    ),
                                ),
                                plt.TraceError(
                                    "ValueError: chr() arg not in range(0x110000)"
                                ),
                            ),
                        ),
                    ),
                ),
            )
        ),
    )
    breakpoint = plt.NoneData()
    hex = plt.Lambda(
        ["_", "x"],
        plt.DecodeUtf8(
            plt.Let(
                [
                    (
                        "hexlist",
                        plt.RecFun(
                            plt.Lambda(
                                ["f", "i"],
                                plt.Ite(
                                    plt.LessThanEqualsInteger(
                                        plt.Var("i"), plt.Integer(0)
                                    ),
                                    plt.EmptyIntegerList(),
                                    plt.MkCons(
                                        plt.Let(
                                            [
                                                (
                                                    "mod",
                                                    plt.ModInteger(
                                                        plt.Var("i"), plt.Integer(16)
                                                    ),
                                                ),
                                            ],
                                            plt.AddInteger(
                                                plt.Var("mod"),
                                                plt.IfThenElse(
                                                    plt.LessThanInteger(
                                                        plt.Var("mod"), plt.Integer(10)
                                                    ),
                                                    plt.Integer(ord("0")),
                                                    plt.Integer(ord("a") - 10),
                                                ),
                                            ),
                                        ),
                                        plt.Apply(
                                            plt.Var("f"),
                                            plt.Var("f"),
                                            plt.DivideInteger(
                                                plt.Var("i"), plt.Integer(16)
                                            ),
                                        ),
                                    ),
                                ),
                            ),
                        ),
                    ),
                    (
                        "mkstr",
                        plt.Lambda(
                            ["i"],
                            plt.FoldList(
                                plt.Apply(plt.Var("hexlist"), plt.Var("i")),
                                plt.Lambda(
                                    ["b", "i"],
                                    plt.ConsByteString(plt.Var("i"), plt.Var("b")),
                                ),
                                plt.ByteString(b""),
                            ),
                        ),
                    ),
                ],
                plt.Ite(
                    plt.EqualsInteger(plt.Var("x"), plt.Integer(0)),
                    plt.ByteString(b"0x0"),
                    plt.Ite(
                        plt.LessThanInteger(plt.Var("x"), plt.Integer(0)),
                        plt.ConsByteString(
                            plt.Integer(ord("-")),
                            plt.AppendByteString(
                                plt.ByteString(b"0x"),
                                plt.Apply(plt.Var("mkstr"), plt.Negate(plt.Var("x"))),
                            ),
                        ),
                        plt.AppendByteString(
                            plt.ByteString(b"0x"),
                            plt.Apply(plt.Var("mkstr"), plt.Var("x")),
                        ),
                    ),
                ),
            )
        ),
    )
    len = auto()
    max = plt.Lambda(
        ["_", "xs"],
        plt.FoldList(
            plt.TailList(plt.Var("xs")),
            plt.Lambda(
                ["x", "a"],
                plt.IfThenElse(
                    plt.LessThanInteger(plt.Var("a"), plt.Var("x")),
                    plt.Var("x"),
                    plt.Var("a"),
                ),
            ),
            plt.HeadList(plt.Var("xs")),
        ),
    )
    min = plt.Lambda(
        ["_", "xs"],
        plt.FoldList(
            plt.TailList(plt.Var("xs")),
            plt.Lambda(
                ["x", "a"],
                plt.IfThenElse(
                    plt.LessThanInteger(plt.Var("a"), plt.Var("x")),
                    plt.Var("a"),
                    plt.Var("x"),
                ),
            ),
            plt.HeadList(plt.Var("xs")),
        ),
    )
    print = plt.Lambda(
        ["_", "x"],
        plt.Trace(plt.Var("x"), plt.NoneData()),
    )
    # NOTE: only correctly defined for positive y
    pow = PowImpl
    oct = plt.Lambda(
        ["_", "x"],
        plt.DecodeUtf8(
            plt.Let(
                [
                    (
                        "octlist",
                        plt.RecFun(
                            plt.Lambda(
                                ["f", "i"],
                                plt.Ite(
                                    plt.LessThanEqualsInteger(
                                        plt.Var("i"), plt.Integer(0)
                                    ),
                                    plt.EmptyIntegerList(),
                                    plt.MkCons(
                                        plt.AddInteger(
                                            plt.ModInteger(
                                                plt.Var("i"), plt.Integer(8)
                                            ),
                                            plt.Integer(ord("0")),
                                        ),
                                        plt.Apply(
                                            plt.Var("f"),
                                            plt.Var("f"),
                                            plt.DivideInteger(
                                                plt.Var("i"), plt.Integer(8)
                                            ),
                                        ),
                                    ),
                                ),
                            ),
                        ),
                    ),
                    (
                        "mkoct",
                        plt.Lambda(
                            ["i"],
                            plt.FoldList(
                                plt.Apply(plt.Var("octlist"), plt.Var("i")),
                                plt.Lambda(
                                    ["b", "i"],
                                    plt.ConsByteString(plt.Var("i"), plt.Var("b")),
                                ),
                                plt.ByteString(b""),
                            ),
                        ),
                    ),
                ],
                plt.Ite(
                    plt.EqualsInteger(plt.Var("x"), plt.Integer(0)),
                    plt.ByteString(b"0o0"),
                    plt.Ite(
                        plt.LessThanInteger(plt.Var("x"), plt.Integer(0)),
                        plt.ConsByteString(
                            plt.Integer(ord("-")),
                            plt.AppendByteString(
                                plt.ByteString(b"0o"),
                                plt.Apply(plt.Var("mkoct"), plt.Negate(plt.Var("x"))),
                            ),
                        ),
                        plt.AppendByteString(
                            plt.ByteString(b"0o"),
                            plt.Apply(plt.Var("mkoct"), plt.Var("x")),
                        ),
                    ),
                ),
            )
        ),
    )
    range = plt.Lambda(
        ["_", "limit"],
        plt.Range(plt.Var("limit")),
    )
    reversed = auto()
    sum = plt.Lambda(
        ["_", "xs"],
        plt.FoldList(
            plt.Var("xs"), plt.BuiltIn(uplc.BuiltInFun.AddInteger), plt.Integer(0)
        ),
    )

Ancestors

  • enum.Enum

Class variables

var abs
var all
var any
var breakpoint
var chr
var hex
var len
var max
var min
var oct
var pow
var print
var range
var reversed
var sum
class ReturnExtractor

Utility to find all Return statements in an AST subtree

Expand source code
class ReturnExtractor(TypedNodeVisitor):
    """Utility to find all Return statements in an AST subtree"""

    def __init__(self):
        self.returns = []

    def visit_Return(self, node: Return) -> None:
        self.returns.append(node)

Ancestors

Methods

def visit(self, node)

Inherited from: TypedNodeVisitor.visit

Visit a node.

def visit_Return(self, node: _ast.Return) ‑> None
Expand source code
def visit_Return(self, node: Return) -> None:
    self.returns.append(node)
class ReversedImpl
Expand source code
class ReversedImpl(PolymorphicFunction):
    def type_from_args(self, args: typing.List[Type]) -> FunctionType:
        assert (
            len(args) == 1
        ), f"'reversed' takes only one argument, but {len(args)} were given"
        typ = args[0]
        assert isinstance(typ, InstanceType), "Can only reverse instances"
        assert isinstance(typ.typ, ListType), "Can only reverse instances of lists"
        # returns list of same type
        return FunctionType(args, typ)

    def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
        arg = args[0]
        assert isinstance(arg, InstanceType), "Can only reverse instances"
        if isinstance(arg.typ, ListType):
            empty_l = empty_list(arg.typ.typ)
            return plt.Lambda(
                ["_", "xs"],
                plt.FoldList(
                    plt.Var("xs"),
                    plt.Lambda(["a", "x"], plt.MkCons(plt.Var("x"), plt.Var("a"))),
                    empty_l,
                ),
            )
        raise NotImplementedError(f"'reversed' is not implemented for type {arg}")

Ancestors

Methods

def impl_from_args(self, args: List[Type]) ‑> pluthon.pluthon_ast.AST
Expand source code
def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
    arg = args[0]
    assert isinstance(arg, InstanceType), "Can only reverse instances"
    if isinstance(arg.typ, ListType):
        empty_l = empty_list(arg.typ.typ)
        return plt.Lambda(
            ["_", "xs"],
            plt.FoldList(
                plt.Var("xs"),
                plt.Lambda(["a", "x"], plt.MkCons(plt.Var("x"), plt.Var("a"))),
                empty_l,
            ),
        )
    raise NotImplementedError(f"'reversed' is not implemented for type {arg}")
def type_from_args(self, args: List[Type]) ‑> FunctionType
Expand source code
def type_from_args(self, args: typing.List[Type]) -> FunctionType:
    assert (
        len(args) == 1
    ), f"'reversed' takes only one argument, but {len(args)} were given"
    typ = args[0]
    assert isinstance(typ, InstanceType), "Can only reverse instances"
    assert isinstance(typ.typ, ListType), "Can only reverse instances of lists"
    # returns list of same type
    return FunctionType(args, typ)