from __future__ import annotations from typing import Any, Callable, Union from dataclasses import dataclass # ----------------------------------------------------------------------------- # Terms # ----------------------------------------------------------------------------- class Term: pass @dataclass(frozen=True) class App(Term): f : Term x : Term A0 = lambda fn : Const(fn) A1 = lambda fn, a : App(f=A0(fn), x=a) A2 = lambda fn, a, b : App(f=A1(fn, a), x=b) A3 = lambda fn, a, b, c : App(f=A2(fn, a, b), x=c) A4 = lambda fn, a, b, c, d : App(f=A3(fn, a, b, c), x=d) A5 = lambda fn, a, b, c, d, e : App(f=A4(fn, a, b, c, d), x=e) A6 = lambda fn, a, b, c, d, e, f : App(f=A5(fn, a, b, c, d, e), x=f) A7 = lambda fn, a, b, c, d, e, f, g : App(f=A6(fn, a, b, c, d, e, f), x=g) A8 = lambda fn, a, b, c, d, e, f, g, h : App(f=A7(fn, a, b, c, d, e, f, g), x=h) @dataclass(frozen=True) class Const(Term): name : str # --- Bool constructors --- def Truth() -> Term: return A0("Bool.Truth") def Contradiction() -> Term: return A0("Bool.Contradiction") # --- List constructors --- def Nil() -> Term: return A0("List.Nil") def Cons(head : Term, tail : Term) -> Term: return A2("List.Cons", head, tail) # --- Nat constructors --- def Zero() -> Term: return A0("Nat.Zero") def Succ(n : Term) -> Term: return A1("Nat.Succ", n) def kernel_from_int(n : int) -> Term: if n == 0: return Zero() elif n < 0: raise ValueError("Int is not Nat") else: return Succ(kernel_from_int(n-1)) # ----------------------------------------------------------------------------- # Propositions # ----------------------------------------------------------------------------- class Prop: pass @dataclass(frozen=True) class Eq(Prop): lhs : Term rhs : Term # ----------------------------------------------------------------------------- # Proofs # ----------------------------------------------------------------------------- class Proof: pass @dataclass(frozen=True) class Refl(Proof): pass # ----------------------------------------------------------------------------- # Normalization function # ----------------------------------------------------------------------------- @dataclass(frozen=True) class F: arity : int func : Callable[..., Term | None] F_ = lambda a : lambda f : F(arity=a, func=f) F0, F1, F2, F3 = F_(0), F_(1), F_(2), F_(3) F4, F5, F6, F7 = F_(4), F_(5), F_(6), F_(7) def apply_rules() -> dict[str, F]: def __bool_not(x : Term) -> Term | None: match x: case Const("Bool.Truth"): return Contradiction() case Const("Bool.Contradiction"): return Truth() def __list_isempty(x : Term) -> Term | None: match x: case Const("List.Nil"): return Truth() case App(f=App(f=Const("List.Cons"))): return Contradiction() def __list_length(x : Term) -> Term | None: match x: case Const("List.Nil"): return Zero() case App(f=App(f=Const("List.Cons"), x=head), x=tail): return Succ(A1("List.length", tail)) def __nat_add(x : Term, y : Term) -> Term | None: match x: case Const("Nat.Zero"): return y case App(f=Const("Nat.Succ"), x=x_): return Succ(A2("Nat.add", x_, y)) def __nat_iszero(x : Term) -> Term | None: match x: case Const("Nat.Zero"): return Truth() case App(f=Const("Nat.Succ")): return Contradiction() return { "Bool.not": F1(__bool_not), "List.isEmpty": F1(__list_isempty), "List.length": F1(__list_length), "Nat.add" : F2(__nat_add), "Nat.isZero" : F1(__nat_iszero), } def normalize(term : Term) -> Term: match term: case App(): f = normalize(term.f) x = normalize(term.x) cursor = term items = [ term ] while True: match cursor.f: case App(): cursor = cursor.f items.append(cursor) case Const(): name = cursor.f.name rule = apply_rules().get(name, None) if rule is None: return App(f, x) # Unknown function if len(items) != rule.arity: # If len(items) is too small, we don't have # enough inputs yet to evaluate the function # If len(items) is too large, we already evaluated # the function before and couldn't optimize. return App(f, x) out = rule.func(*reversed([i.x for i in items])) return normalize(out) if out is not None else App(f, x) case Const(): return term case Term(): return term # ----------------------------------------------------------------------------- # Proof checker # ----------------------------------------------------------------------------- def check(goal: Prop, proof : Proof) -> bool: match ( goal, proof ): case ( Eq(), Refl() ): return normalize(goal.lhs) == normalize(goal.rhs) case _: return False # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- if __name__ == "__main__": # 2 + 3 == 5 print(check( Eq( lhs=A2("Nat.add", kernel_from_int(2), kernel_from_int(3)), rhs=kernel_from_int(5), ), Refl(), )) # 0 + x == x print(check( Eq( lhs=A2("Nat.add", Zero(), Const("x")), rhs=Const("x") ), Refl(), )) # x + 0 == x print(check( Eq( lhs=A2("Nat.add", Const("x"), Zero()), rhs=Const("x") ), Refl(), ))