Add minimal proof checker
parent
d54125daa6
commit
cca866ae62
|
|
@ -0,0 +1,232 @@
|
|||
from __future__ import annotations
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Terms
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class Term:
|
||||
pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class App(Term):
|
||||
f : Term
|
||||
x : Term
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Const(Term):
|
||||
name : str
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Var(Term):
|
||||
name : str
|
||||
|
||||
# --- Bool constructors ---
|
||||
|
||||
Bool = Union["Truth", "Contradiction"]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Truth(Term):
|
||||
pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Contradiction(Term):
|
||||
pass
|
||||
|
||||
# --- Nat constructors ---
|
||||
|
||||
Nat = Union["Zero", "Succ"]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Zero(Term):
|
||||
pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Succ(Term):
|
||||
n : Term
|
||||
|
||||
def kernel_from_int(n : int) -> Nat:
|
||||
if n == 0:
|
||||
return Zero()
|
||||
elif n < 0:
|
||||
raise ValueError("Int is not Nat")
|
||||
else:
|
||||
return Succ(n=kernel_from_int(n=n-1))
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Propositions
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class Prop:
|
||||
pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Eq(Prop):
|
||||
lhs : Term
|
||||
rhs : Term
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Proofs
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class Proof:
|
||||
pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Refl(Proof):
|
||||
pass
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Normalization function
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class F:
|
||||
arity : int
|
||||
func : Callable[..., Term | None]
|
||||
|
||||
def apply_rules() -> dict[str, F]:
|
||||
def __bool_not(x : Term) -> Term | None:
|
||||
match x:
|
||||
case Truth():
|
||||
return Contradiction()
|
||||
|
||||
case Contradiction():
|
||||
return Truth()
|
||||
|
||||
def __nat_add(x : Term, y : Term) -> Term | None:
|
||||
match x:
|
||||
case Zero():
|
||||
return y
|
||||
|
||||
case Succ():
|
||||
return Succ(
|
||||
n=normalize(App(
|
||||
f=App(Const("Nat.add"), x=x.n),
|
||||
x=y,
|
||||
))
|
||||
)
|
||||
|
||||
def __nat_iszero(x : Term) -> Term | None:
|
||||
match x:
|
||||
case Zero():
|
||||
return Truth()
|
||||
|
||||
case Succ():
|
||||
return Contradiction()
|
||||
|
||||
return {
|
||||
"Bool.not": F(arity=1, func=__bool_not),
|
||||
"Nat.add" : F(arity=2, func=__nat_add),
|
||||
"Nat.isZero" : F(arity=1, func=__nat_iszero),
|
||||
}
|
||||
|
||||
|
||||
def normalize(term : Term) -> Term:
|
||||
match term:
|
||||
case App():
|
||||
f = normalize(term.f)
|
||||
x = normalize(term.x)
|
||||
|
||||
cursor = term
|
||||
items = [ term ]
|
||||
|
||||
while True:
|
||||
match cursor.f:
|
||||
case App():
|
||||
cursor = cursor.f
|
||||
items.append(cursor)
|
||||
|
||||
case Const():
|
||||
name = cursor.f.name
|
||||
rule = apply_rules().get(name, None)
|
||||
|
||||
if rule is None:
|
||||
return App(f, x) # Unknown function
|
||||
if len(items) != rule.arity:
|
||||
# If len(items) is too small, we don't have
|
||||
# enough inputs yet to evaluate the function
|
||||
# If len(items) is too large, we already evaluated
|
||||
# the function before and couldn't optimize.
|
||||
return App(f, x)
|
||||
|
||||
out = rule.func(*reversed([i.x for i in items]))
|
||||
|
||||
return out if out is not None else App(f, x)
|
||||
|
||||
case Contradiction():
|
||||
if len(items) != 2:
|
||||
# Either evaluating too early or too late
|
||||
return term
|
||||
|
||||
# Return the "second" item
|
||||
return x
|
||||
|
||||
case Truth():
|
||||
if len(items) != 2:
|
||||
# Either evaluating too early or too late
|
||||
return App(f, x)
|
||||
|
||||
# Return the "first" item
|
||||
return cursor.x
|
||||
|
||||
case Const():
|
||||
return term
|
||||
|
||||
case Succ():
|
||||
return Succ(n=normalize(term.n))
|
||||
|
||||
case Term():
|
||||
return term
|
||||
|
||||
case Var():
|
||||
return term
|
||||
|
||||
case Zero():
|
||||
return term
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Proof checker
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def check(goal: Prop, proof : Proof) -> bool:
|
||||
match ( goal, proof ):
|
||||
case ( Eq(), Refl() ):
|
||||
return normalize(goal.lhs) == normalize(goal.rhs)
|
||||
|
||||
case _:
|
||||
return False
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# -----------------------------------------------------------------------------
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 2 + 3 == 5
|
||||
print(check(
|
||||
Eq(
|
||||
lhs=App(f=App(f=Const("Nat.add"), x=kernel_from_int(2)), x=kernel_from_int(3)),
|
||||
rhs=kernel_from_int(5),
|
||||
),
|
||||
Refl(),
|
||||
))
|
||||
|
||||
# 0 + x == x
|
||||
print(check(
|
||||
Eq(
|
||||
lhs=App(f=App(f=Const("Nat.add"), x=Zero()), x=Var("x")),
|
||||
rhs=Var("x")
|
||||
),
|
||||
Refl(),
|
||||
))
|
||||
|
||||
# x + 0 == x
|
||||
print(check(
|
||||
Eq(
|
||||
lhs=App(f=App(f=Const("Nat.add"), x=Var("x")), x=Zero()),
|
||||
rhs=Var("x")
|
||||
),
|
||||
Refl(),
|
||||
))
|
||||
Loading…
Reference in New Issue