new-lang/proof.py

230 lines
6.3 KiB
Python

from __future__ import annotations
from typing import Any, Callable, Union
from dataclasses import dataclass
# -----------------------------------------------------------------------------
# Terms
# -----------------------------------------------------------------------------
class Term:
pass
@dataclass(frozen=True)
class App(Term):
f : Term
x : Term
A0 = lambda fn : Const(fn)
A1 = lambda fn, a : App(f=A0(fn), x=a)
A2 = lambda fn, a, b : App(f=A1(fn, a), x=b)
A3 = lambda fn, a, b, c : App(f=A2(fn, a, b), x=c)
A4 = lambda fn, a, b, c, d : App(f=A3(fn, a, b, c), x=d)
A5 = lambda fn, a, b, c, d, e : App(f=A4(fn, a, b, c, d), x=e)
A6 = lambda fn, a, b, c, d, e, f : App(f=A5(fn, a, b, c, d, e), x=f)
A7 = lambda fn, a, b, c, d, e, f, g : App(f=A6(fn, a, b, c, d, e, f), x=g)
A8 = lambda fn, a, b, c, d, e, f, g, h : App(f=A7(fn, a, b, c, d, e, f, g), x=h)
@dataclass(frozen=True)
class Const(Term):
name : str
# --- Bool constructors ---
def Truth() -> Term:
return A0("Bool.Truth")
def Contradiction() -> Term:
return A0("Bool.Contradiction")
# --- List constructors ---
def Nil() -> Term:
return A0("List.Nil")
def Cons(head : Term, tail : Term) -> Term:
return A2("List.Cons", head, tail)
# --- Nat constructors ---
def Zero() -> Term:
return A0("Nat.Zero")
def Succ(n : Term) -> Term:
return A1("Nat.Succ", n)
def kernel_from_int(n : int) -> Term:
if n == 0:
return Zero()
elif n < 0:
raise ValueError("Int is not Nat")
else:
return Succ(kernel_from_int(n-1))
# -----------------------------------------------------------------------------
# Propositions
# -----------------------------------------------------------------------------
class Prop:
pass
@dataclass(frozen=True)
class Eq(Prop):
lhs : Term
rhs : Term
# -----------------------------------------------------------------------------
# Proofs
# -----------------------------------------------------------------------------
class Proof:
pass
@dataclass(frozen=True)
class Refl(Proof):
pass
# -----------------------------------------------------------------------------
# Normalization function
# -----------------------------------------------------------------------------
@dataclass(frozen=True)
class F:
arity : int
func : Callable[..., Term | None]
F_ = lambda a : lambda f : F(arity=a, func=f)
F0, F1, F2, F3 = F_(0), F_(1), F_(2), F_(3)
F4, F5, F6, F7 = F_(4), F_(5), F_(6), F_(7)
def apply_rules() -> dict[str, F]:
def __bool_not(x : Term) -> Term | None:
match x:
case Const("Bool.Truth"):
return Contradiction()
case Const("Bool.Contradiction"):
return Truth()
def __list_isempty(x : Term) -> Term | None:
match x:
case Const("List.Nil"):
return Truth()
case App(f=App(f=Const("List.Cons"))):
return Contradiction()
def __list_length(x : Term) -> Term | None:
match x:
case Const("List.Nil"):
return Zero()
case App(f=App(f=Const("List.Cons"), x=head), x=tail):
return Succ(A1("List.length", tail))
def __nat_add(x : Term, y : Term) -> Term | None:
match x:
case Const("Nat.Zero"):
return y
case App(f=Const("Nat.Succ"), x=x_):
return Succ(A2("Nat.add", x_, y))
def __nat_iszero(x : Term) -> Term | None:
match x:
case Const("Nat.Zero"):
return Truth()
case App(f=Const("Nat.Succ")):
return Contradiction()
return {
"Bool.not": F1(__bool_not),
"List.isEmpty": F1(__list_isempty),
"List.length": F1(__list_length),
"Nat.add" : F2(__nat_add),
"Nat.isZero" : F1(__nat_iszero),
}
def normalize(term : Term) -> Term:
match term:
case App():
f = normalize(term.f)
x = normalize(term.x)
cursor = term
items = [ term ]
while True:
match cursor.f:
case App():
cursor = cursor.f
items.append(cursor)
case Const():
name = cursor.f.name
rule = apply_rules().get(name, None)
if rule is None:
return App(f, x) # Unknown function
if len(items) != rule.arity:
# If len(items) is too small, we don't have
# enough inputs yet to evaluate the function
# If len(items) is too large, we already evaluated
# the function before and couldn't optimize.
return App(f, x)
out = rule.func(*reversed([i.x for i in items]))
return normalize(out) if out is not None else App(f, x)
case Const():
return term
case Term():
return term
# -----------------------------------------------------------------------------
# Proof checker
# -----------------------------------------------------------------------------
def check(goal: Prop, proof : Proof) -> bool:
match ( goal, proof ):
case ( Eq(), Refl() ):
return normalize(goal.lhs) == normalize(goal.rhs)
case _:
return False
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# 2 + 3 == 5
print(check(
Eq(
lhs=A2("Nat.add", kernel_from_int(2), kernel_from_int(3)),
rhs=kernel_from_int(5),
),
Refl(),
))
# 0 + x == x
print(check(
Eq(
lhs=A2("Nat.add", Zero(), Const("x")),
rhs=Const("x")
),
Refl(),
))
# x + 0 == x
print(check(
Eq(
lhs=A2("Nat.add", Const("x"), Zero()),
rhs=Const("x")
),
Refl(),
))