From 0612e320b0dfc65fd0fb1f2fe1f95e2e1a99568e Mon Sep 17 00:00:00 2001 From: Bram van den Heuvel Date: Tue, 30 Jun 2026 10:36:58 +0200 Subject: [PATCH] Improve proof readability --- proof.py | 140 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 131 insertions(+), 9 deletions(-) diff --git a/proof.py b/proof.py index b8f5240..b3bfe22 100644 --- a/proof.py +++ b/proof.py @@ -15,6 +15,38 @@ class App(Term): f : Term x : Term + def __repr__(self): + match self.f: + case App(f=Const("Nat.add"), x=a): + return f"{a} + {self.x}" + case App(f=Const("List.cons"), x=a): + return f"{a} :: {self.x}" + case Const("Nat.Succ"): + cursor, i = self, 0 + + while True: + match cursor: + case App(f=Const("Nat.Succ")): + cursor = cursor.x + i += 1 + case Const("Nat.Zero"): + return str(i) + case _: + xs = str(cursor) + if " " in xs: + xs = "(" + xs + ")" + return "{xs} + {i}" + case _: + fs = str(self.f) + xs = str(self.x) + + if " " in fs: + fs = "(" + fs + ")" + if " " in xs: + xs = "(" + xs + ")" + + return f"{fs} xs" + A0 = lambda fn : Const(fn) A1 = lambda fn, a : App(f=A0(fn), x=a) A2 = lambda fn, a, b : App(f=A1(fn, a), x=b) @@ -29,6 +61,15 @@ A8 = lambda fn, a, b, c, d, e, f, g, h : App(f=A7(fn, a, b, c, d, e, f, g), x=h) class Const(Term): name : str + def __repr__(self) -> str: + match self.name: + case "List.Nil": + return "[]" + case "Nat.Zero": + return "0" + case _: + return self.name + # --- Bool constructors --- def Truth() -> Term: @@ -68,21 +109,84 @@ def kernel_from_int(n : int) -> Term: class Prop: pass +@dataclass(frozen=True) +class Contradict(Prop): + s : str + @dataclass(frozen=True) class Eq(Prop): lhs : Term rhs : Term +@dataclass(frozen=True) +class Trivial(Prop): + pass + # def __repr__(self) -> str: + # return "" + # ----------------------------------------------------------------------------- # Proofs # ----------------------------------------------------------------------------- class Proof: - pass + def apply(self, goal : Prop, ctx : Context) -> Prop: + return ( + Contradict("Cannot prove using the base class") + ) + +@dataclass(frozen=True) +class Assumption(Proof): + def apply(self, goal : Prop, ctx : Context) -> Prop: + return ( + Trivial() if goal in ctx.props else + Contradict(f"Couldn't find proposition `{goal}` in assumptions") + ) + @dataclass(frozen=True) class Refl(Proof): - pass + def apply(self, goal : Prop, ctx : Context) -> Prop: + match goal: + case Eq(): + l = normalize(goal.lhs) + r = normalize(goal.rhs) + + if l == r: + return Trivial() + else: + # Reduce to differing parts + while True: + match ( l, r ): + case ( App(), App() ): + if l.x == r.x: + l, r = l.f, r.f + elif l.f == r.f: + l, r = l.x, r.x + else: + break + case _: + break + + return Contradict( + f"Couldn't normalize {goal.lhs} = {goal.rhs} any further than to {l} = {r}" + ) + + case Prop(): + return Contradict( + "Cannot prove base proposition" + ) + +# ----------------------------------------------------------------------------- +# Context +# ----------------------------------------------------------------------------- +# The context is a set of information you already know. + +@dataclass(frozen=True) +class Context: + props : list[Prop] + + def with_prop(self, prop : Prop): + return Context([ p for p in self.props ] + [ prop ]) # ----------------------------------------------------------------------------- # Normalization function @@ -188,13 +292,18 @@ def normalize(term : Term) -> Term: # Proof checker # ----------------------------------------------------------------------------- -def check(goal: Prop, proof : Proof) -> bool: - match ( goal, proof ): - case ( Eq(), Refl() ): - return normalize(goal.lhs) == normalize(goal.rhs) - - case _: - return False +ProofCheck = Union["Proven", "ProveFailure"] + +@dataclass(frozen=True) +class Proven: + pass + +@dataclass(frozen=True) +class ProveFailure: + s : str + +def check(goal: Prop, proof : Proof, ctx : Context) -> Prop: + return proof.apply(goal, ctx) # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- @@ -208,6 +317,17 @@ if __name__ == "__main__": rhs=kernel_from_int(5), ), Refl(), + Context(props=[]), + )) + + # 2 + 3 == 7 + print(check( + Eq( + lhs=A2("Nat.add", kernel_from_int(2), kernel_from_int(3)), + rhs=kernel_from_int(7), + ), + Refl(), + Context(props=[]), )) # 0 + x == x @@ -217,6 +337,7 @@ if __name__ == "__main__": rhs=Const("x") ), Refl(), + Context(props=[]), )) # x + 0 == x @@ -226,4 +347,5 @@ if __name__ == "__main__": rhs=Const("x") ), Refl(), + Context(props=[]), ))