Add minimal proof checker

main
Bram van den Heuvel 2026-06-29 11:59:35 +02:00
parent d54125daa6
commit cca866ae62
1 changed files with 232 additions and 0 deletions

232
proof.py Normal file
View File

@ -0,0 +1,232 @@
from __future__ import annotations
from typing import Any, Callable, Union
from dataclasses import dataclass
# -----------------------------------------------------------------------------
# Terms
# -----------------------------------------------------------------------------
class Term:
pass
@dataclass(frozen=True)
class App(Term):
f : Term
x : Term
@dataclass(frozen=True)
class Const(Term):
name : str
@dataclass(frozen=True)
class Var(Term):
name : str
# --- Bool constructors ---
Bool = Union["Truth", "Contradiction"]
@dataclass(frozen=True)
class Truth(Term):
pass
@dataclass(frozen=True)
class Contradiction(Term):
pass
# --- Nat constructors ---
Nat = Union["Zero", "Succ"]
@dataclass(frozen=True)
class Zero(Term):
pass
@dataclass(frozen=True)
class Succ(Term):
n : Term
def kernel_from_int(n : int) -> Nat:
if n == 0:
return Zero()
elif n < 0:
raise ValueError("Int is not Nat")
else:
return Succ(n=kernel_from_int(n=n-1))
# -----------------------------------------------------------------------------
# Propositions
# -----------------------------------------------------------------------------
class Prop:
pass
@dataclass(frozen=True)
class Eq(Prop):
lhs : Term
rhs : Term
# -----------------------------------------------------------------------------
# Proofs
# -----------------------------------------------------------------------------
class Proof:
pass
@dataclass(frozen=True)
class Refl(Proof):
pass
# -----------------------------------------------------------------------------
# Normalization function
# -----------------------------------------------------------------------------
@dataclass(frozen=True)
class F:
arity : int
func : Callable[..., Term | None]
def apply_rules() -> dict[str, F]:
def __bool_not(x : Term) -> Term | None:
match x:
case Truth():
return Contradiction()
case Contradiction():
return Truth()
def __nat_add(x : Term, y : Term) -> Term | None:
match x:
case Zero():
return y
case Succ():
return Succ(
n=normalize(App(
f=App(Const("Nat.add"), x=x.n),
x=y,
))
)
def __nat_iszero(x : Term) -> Term | None:
match x:
case Zero():
return Truth()
case Succ():
return Contradiction()
return {
"Bool.not": F(arity=1, func=__bool_not),
"Nat.add" : F(arity=2, func=__nat_add),
"Nat.isZero" : F(arity=1, func=__nat_iszero),
}
def normalize(term : Term) -> Term:
match term:
case App():
f = normalize(term.f)
x = normalize(term.x)
cursor = term
items = [ term ]
while True:
match cursor.f:
case App():
cursor = cursor.f
items.append(cursor)
case Const():
name = cursor.f.name
rule = apply_rules().get(name, None)
if rule is None:
return App(f, x) # Unknown function
if len(items) != rule.arity:
# If len(items) is too small, we don't have
# enough inputs yet to evaluate the function
# If len(items) is too large, we already evaluated
# the function before and couldn't optimize.
return App(f, x)
out = rule.func(*reversed([i.x for i in items]))
return out if out is not None else App(f, x)
case Contradiction():
if len(items) != 2:
# Either evaluating too early or too late
return term
# Return the "second" item
return x
case Truth():
if len(items) != 2:
# Either evaluating too early or too late
return App(f, x)
# Return the "first" item
return cursor.x
case Const():
return term
case Succ():
return Succ(n=normalize(term.n))
case Term():
return term
case Var():
return term
case Zero():
return term
# -----------------------------------------------------------------------------
# Proof checker
# -----------------------------------------------------------------------------
def check(goal: Prop, proof : Proof) -> bool:
match ( goal, proof ):
case ( Eq(), Refl() ):
return normalize(goal.lhs) == normalize(goal.rhs)
case _:
return False
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# 2 + 3 == 5
print(check(
Eq(
lhs=App(f=App(f=Const("Nat.add"), x=kernel_from_int(2)), x=kernel_from_int(3)),
rhs=kernel_from_int(5),
),
Refl(),
))
# 0 + x == x
print(check(
Eq(
lhs=App(f=App(f=Const("Nat.add"), x=Zero()), x=Var("x")),
rhs=Var("x")
),
Refl(),
))
# x + 0 == x
print(check(
Eq(
lhs=App(f=App(f=Const("Nat.add"), x=Var("x")), x=Zero()),
rhs=Var("x")
),
Refl(),
))