from __future__ import annotations from typing import Any, Callable, Union from dataclasses import dataclass # ----------------------------------------------------------------------------- # Terms # ----------------------------------------------------------------------------- class Term: pass @dataclass(frozen=True) class App(Term): f : Term x : Term @dataclass(frozen=True) class Const(Term): name : str @dataclass(frozen=True) class Var(Term): name : str # --- Bool constructors --- Bool = Union["Truth", "Contradiction"] @dataclass(frozen=True) class Truth(Term): pass @dataclass(frozen=True) class Contradiction(Term): pass # --- Nat constructors --- Nat = Union["Zero", "Succ"] @dataclass(frozen=True) class Zero(Term): pass @dataclass(frozen=True) class Succ(Term): n : Term def kernel_from_int(n : int) -> Nat: if n == 0: return Zero() elif n < 0: raise ValueError("Int is not Nat") else: return Succ(n=kernel_from_int(n=n-1)) # ----------------------------------------------------------------------------- # Propositions # ----------------------------------------------------------------------------- class Prop: pass @dataclass(frozen=True) class Eq(Prop): lhs : Term rhs : Term # ----------------------------------------------------------------------------- # Proofs # ----------------------------------------------------------------------------- class Proof: pass @dataclass(frozen=True) class Refl(Proof): pass # ----------------------------------------------------------------------------- # Normalization function # ----------------------------------------------------------------------------- @dataclass(frozen=True) class F: arity : int func : Callable[..., Term | None] def apply_rules() -> dict[str, F]: def __bool_not(x : Term) -> Term | None: match x: case Truth(): return Contradiction() case Contradiction(): return Truth() def __nat_add(x : Term, y : Term) -> Term | None: match x: case Zero(): return y case Succ(): return Succ( n=normalize(App( f=App(Const("Nat.add"), x=x.n), x=y, )) ) def __nat_iszero(x : Term) -> Term | None: match x: case Zero(): return Truth() case Succ(): return Contradiction() return { "Bool.not": F(arity=1, func=__bool_not), "Nat.add" : F(arity=2, func=__nat_add), "Nat.isZero" : F(arity=1, func=__nat_iszero), } def normalize(term : Term) -> Term: match term: case App(): f = normalize(term.f) x = normalize(term.x) cursor = term items = [ term ] while True: match cursor.f: case App(): cursor = cursor.f items.append(cursor) case Const(): name = cursor.f.name rule = apply_rules().get(name, None) if rule is None: return App(f, x) # Unknown function if len(items) != rule.arity: # If len(items) is too small, we don't have # enough inputs yet to evaluate the function # If len(items) is too large, we already evaluated # the function before and couldn't optimize. return App(f, x) out = rule.func(*reversed([i.x for i in items])) return out if out is not None else App(f, x) case Contradiction(): if len(items) != 2: # Either evaluating too early or too late return term # Return the "second" item return x case Truth(): if len(items) != 2: # Either evaluating too early or too late return App(f, x) # Return the "first" item return cursor.x case Const(): return term case Succ(): return Succ(n=normalize(term.n)) case Term(): return term case Var(): return term case Zero(): return term # ----------------------------------------------------------------------------- # Proof checker # ----------------------------------------------------------------------------- def check(goal: Prop, proof : Proof) -> bool: match ( goal, proof ): case ( Eq(), Refl() ): return normalize(goal.lhs) == normalize(goal.rhs) case _: return False # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- if __name__ == "__main__": # 2 + 3 == 5 print(check( Eq( lhs=App(f=App(f=Const("Nat.add"), x=kernel_from_int(2)), x=kernel_from_int(3)), rhs=kernel_from_int(5), ), Refl(), )) # 0 + x == x print(check( Eq( lhs=App(f=App(f=Const("Nat.add"), x=Zero()), x=Var("x")), rhs=Var("x") ), Refl(), )) # x + 0 == x print(check( Eq( lhs=App(f=App(f=Const("Nat.add"), x=Var("x")), x=Zero()), rhs=Var("x") ), Refl(), ))