From cca866ae624791e6ac7460670e68b535639c9391 Mon Sep 17 00:00:00 2001 From: Bram van den Heuvel Date: Mon, 29 Jun 2026 11:59:35 +0200 Subject: [PATCH] Add minimal proof checker --- proof.py | 232 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 232 insertions(+) create mode 100644 proof.py diff --git a/proof.py b/proof.py new file mode 100644 index 0000000..c3046dd --- /dev/null +++ b/proof.py @@ -0,0 +1,232 @@ +from __future__ import annotations +from typing import Any, Callable, Union + +from dataclasses import dataclass + +# ----------------------------------------------------------------------------- +# Terms +# ----------------------------------------------------------------------------- + +class Term: + pass + +@dataclass(frozen=True) +class App(Term): + f : Term + x : Term + +@dataclass(frozen=True) +class Const(Term): + name : str + +@dataclass(frozen=True) +class Var(Term): + name : str + +# --- Bool constructors --- + +Bool = Union["Truth", "Contradiction"] + +@dataclass(frozen=True) +class Truth(Term): + pass + +@dataclass(frozen=True) +class Contradiction(Term): + pass + +# --- Nat constructors --- + +Nat = Union["Zero", "Succ"] + +@dataclass(frozen=True) +class Zero(Term): + pass + +@dataclass(frozen=True) +class Succ(Term): + n : Term + +def kernel_from_int(n : int) -> Nat: + if n == 0: + return Zero() + elif n < 0: + raise ValueError("Int is not Nat") + else: + return Succ(n=kernel_from_int(n=n-1)) + +# ----------------------------------------------------------------------------- +# Propositions +# ----------------------------------------------------------------------------- + +class Prop: + pass + +@dataclass(frozen=True) +class Eq(Prop): + lhs : Term + rhs : Term + +# ----------------------------------------------------------------------------- +# Proofs +# ----------------------------------------------------------------------------- + +class Proof: + pass + +@dataclass(frozen=True) +class Refl(Proof): + pass + +# ----------------------------------------------------------------------------- +# Normalization function +# ----------------------------------------------------------------------------- + +@dataclass(frozen=True) +class F: + arity : int + func : Callable[..., Term | None] + +def apply_rules() -> dict[str, F]: + def __bool_not(x : Term) -> Term | None: + match x: + case Truth(): + return Contradiction() + + case Contradiction(): + return Truth() + + def __nat_add(x : Term, y : Term) -> Term | None: + match x: + case Zero(): + return y + + case Succ(): + return Succ( + n=normalize(App( + f=App(Const("Nat.add"), x=x.n), + x=y, + )) + ) + + def __nat_iszero(x : Term) -> Term | None: + match x: + case Zero(): + return Truth() + + case Succ(): + return Contradiction() + + return { + "Bool.not": F(arity=1, func=__bool_not), + "Nat.add" : F(arity=2, func=__nat_add), + "Nat.isZero" : F(arity=1, func=__nat_iszero), + } + + +def normalize(term : Term) -> Term: + match term: + case App(): + f = normalize(term.f) + x = normalize(term.x) + + cursor = term + items = [ term ] + + while True: + match cursor.f: + case App(): + cursor = cursor.f + items.append(cursor) + + case Const(): + name = cursor.f.name + rule = apply_rules().get(name, None) + + if rule is None: + return App(f, x) # Unknown function + if len(items) != rule.arity: + # If len(items) is too small, we don't have + # enough inputs yet to evaluate the function + # If len(items) is too large, we already evaluated + # the function before and couldn't optimize. + return App(f, x) + + out = rule.func(*reversed([i.x for i in items])) + + return out if out is not None else App(f, x) + + case Contradiction(): + if len(items) != 2: + # Either evaluating too early or too late + return term + + # Return the "second" item + return x + + case Truth(): + if len(items) != 2: + # Either evaluating too early or too late + return App(f, x) + + # Return the "first" item + return cursor.x + + case Const(): + return term + + case Succ(): + return Succ(n=normalize(term.n)) + + case Term(): + return term + + case Var(): + return term + + case Zero(): + return term + +# ----------------------------------------------------------------------------- +# Proof checker +# ----------------------------------------------------------------------------- + +def check(goal: Prop, proof : Proof) -> bool: + match ( goal, proof ): + case ( Eq(), Refl() ): + return normalize(goal.lhs) == normalize(goal.rhs) + + case _: + return False + +# ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + # 2 + 3 == 5 + print(check( + Eq( + lhs=App(f=App(f=Const("Nat.add"), x=kernel_from_int(2)), x=kernel_from_int(3)), + rhs=kernel_from_int(5), + ), + Refl(), + )) + + # 0 + x == x + print(check( + Eq( + lhs=App(f=App(f=Const("Nat.add"), x=Zero()), x=Var("x")), + rhs=Var("x") + ), + Refl(), + )) + + # x + 0 == x + print(check( + Eq( + lhs=App(f=App(f=Const("Nat.add"), x=Var("x")), x=Zero()), + rhs=Var("x") + ), + Refl(), + ))