352 lines
9.9 KiB
Python
352 lines
9.9 KiB
Python
from __future__ import annotations
|
|
from typing import Any, Callable, Union
|
|
|
|
from dataclasses import dataclass
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Terms
|
|
# -----------------------------------------------------------------------------
|
|
|
|
class Term:
|
|
pass
|
|
|
|
@dataclass(frozen=True)
|
|
class App(Term):
|
|
f : Term
|
|
x : Term
|
|
|
|
def __repr__(self):
|
|
match self.f:
|
|
case App(f=Const("Nat.add"), x=a):
|
|
return f"{a} + {self.x}"
|
|
case App(f=Const("List.cons"), x=a):
|
|
return f"{a} :: {self.x}"
|
|
case Const("Nat.Succ"):
|
|
cursor, i = self, 0
|
|
|
|
while True:
|
|
match cursor:
|
|
case App(f=Const("Nat.Succ")):
|
|
cursor = cursor.x
|
|
i += 1
|
|
case Const("Nat.Zero"):
|
|
return str(i)
|
|
case _:
|
|
xs = str(cursor)
|
|
if " " in xs:
|
|
xs = "(" + xs + ")"
|
|
return "{xs} + {i}"
|
|
case _:
|
|
fs = str(self.f)
|
|
xs = str(self.x)
|
|
|
|
if " " in fs:
|
|
fs = "(" + fs + ")"
|
|
if " " in xs:
|
|
xs = "(" + xs + ")"
|
|
|
|
return f"{fs} xs"
|
|
|
|
A0 = lambda fn : Const(fn)
|
|
A1 = lambda fn, a : App(f=A0(fn), x=a)
|
|
A2 = lambda fn, a, b : App(f=A1(fn, a), x=b)
|
|
A3 = lambda fn, a, b, c : App(f=A2(fn, a, b), x=c)
|
|
A4 = lambda fn, a, b, c, d : App(f=A3(fn, a, b, c), x=d)
|
|
A5 = lambda fn, a, b, c, d, e : App(f=A4(fn, a, b, c, d), x=e)
|
|
A6 = lambda fn, a, b, c, d, e, f : App(f=A5(fn, a, b, c, d, e), x=f)
|
|
A7 = lambda fn, a, b, c, d, e, f, g : App(f=A6(fn, a, b, c, d, e, f), x=g)
|
|
A8 = lambda fn, a, b, c, d, e, f, g, h : App(f=A7(fn, a, b, c, d, e, f, g), x=h)
|
|
|
|
@dataclass(frozen=True)
|
|
class Const(Term):
|
|
name : str
|
|
|
|
def __repr__(self) -> str:
|
|
match self.name:
|
|
case "List.Nil":
|
|
return "[]"
|
|
case "Nat.Zero":
|
|
return "0"
|
|
case _:
|
|
return self.name
|
|
|
|
# --- Bool constructors ---
|
|
|
|
def Truth() -> Term:
|
|
return A0("Bool.Truth")
|
|
|
|
def Contradiction() -> Term:
|
|
return A0("Bool.Contradiction")
|
|
|
|
# --- List constructors ---
|
|
|
|
def Nil() -> Term:
|
|
return A0("List.Nil")
|
|
|
|
def Cons(head : Term, tail : Term) -> Term:
|
|
return A2("List.Cons", head, tail)
|
|
|
|
# --- Nat constructors ---
|
|
|
|
def Zero() -> Term:
|
|
return A0("Nat.Zero")
|
|
|
|
def Succ(n : Term) -> Term:
|
|
return A1("Nat.Succ", n)
|
|
|
|
def kernel_from_int(n : int) -> Term:
|
|
if n == 0:
|
|
return Zero()
|
|
elif n < 0:
|
|
raise ValueError("Int is not Nat")
|
|
else:
|
|
return Succ(kernel_from_int(n-1))
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Propositions
|
|
# -----------------------------------------------------------------------------
|
|
|
|
class Prop:
|
|
pass
|
|
|
|
@dataclass(frozen=True)
|
|
class Contradict(Prop):
|
|
s : str
|
|
|
|
@dataclass(frozen=True)
|
|
class Eq(Prop):
|
|
lhs : Term
|
|
rhs : Term
|
|
|
|
@dataclass(frozen=True)
|
|
class Trivial(Prop):
|
|
pass
|
|
# def __repr__(self) -> str:
|
|
# return "<trivial>"
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Proofs
|
|
# -----------------------------------------------------------------------------
|
|
|
|
class Proof:
|
|
def apply(self, goal : Prop, ctx : Context) -> Prop:
|
|
return (
|
|
Contradict("Cannot prove using the base class")
|
|
)
|
|
|
|
@dataclass(frozen=True)
|
|
class Assumption(Proof):
|
|
def apply(self, goal : Prop, ctx : Context) -> Prop:
|
|
return (
|
|
Trivial() if goal in ctx.props else
|
|
Contradict(f"Couldn't find proposition `{goal}` in assumptions")
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Refl(Proof):
|
|
def apply(self, goal : Prop, ctx : Context) -> Prop:
|
|
match goal:
|
|
case Eq():
|
|
l = normalize(goal.lhs)
|
|
r = normalize(goal.rhs)
|
|
|
|
if l == r:
|
|
return Trivial()
|
|
else:
|
|
# Reduce to differing parts
|
|
while True:
|
|
match ( l, r ):
|
|
case ( App(), App() ):
|
|
if l.x == r.x:
|
|
l, r = l.f, r.f
|
|
elif l.f == r.f:
|
|
l, r = l.x, r.x
|
|
else:
|
|
break
|
|
case _:
|
|
break
|
|
|
|
return Contradict(
|
|
f"Couldn't normalize {goal.lhs} = {goal.rhs} any further than to {l} = {r}"
|
|
)
|
|
|
|
case Prop():
|
|
return Contradict(
|
|
"Cannot prove base proposition"
|
|
)
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Context
|
|
# -----------------------------------------------------------------------------
|
|
# The context is a set of information you already know.
|
|
|
|
@dataclass(frozen=True)
|
|
class Context:
|
|
props : list[Prop]
|
|
|
|
def with_prop(self, prop : Prop):
|
|
return Context([ p for p in self.props ] + [ prop ])
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Normalization function
|
|
# -----------------------------------------------------------------------------
|
|
|
|
@dataclass(frozen=True)
|
|
class F:
|
|
arity : int
|
|
func : Callable[..., Term | None]
|
|
|
|
F_ = lambda a : lambda f : F(arity=a, func=f)
|
|
F0, F1, F2, F3 = F_(0), F_(1), F_(2), F_(3)
|
|
F4, F5, F6, F7 = F_(4), F_(5), F_(6), F_(7)
|
|
|
|
def apply_rules() -> dict[str, F]:
|
|
def __bool_not(x : Term) -> Term | None:
|
|
match x:
|
|
case Const("Bool.Truth"):
|
|
return Contradiction()
|
|
|
|
case Const("Bool.Contradiction"):
|
|
return Truth()
|
|
|
|
def __list_isempty(x : Term) -> Term | None:
|
|
match x:
|
|
case Const("List.Nil"):
|
|
return Truth()
|
|
|
|
case App(f=App(f=Const("List.Cons"))):
|
|
return Contradiction()
|
|
|
|
def __list_length(x : Term) -> Term | None:
|
|
match x:
|
|
case Const("List.Nil"):
|
|
return Zero()
|
|
|
|
case App(f=App(f=Const("List.Cons"), x=head), x=tail):
|
|
return Succ(A1("List.length", tail))
|
|
|
|
def __nat_add(x : Term, y : Term) -> Term | None:
|
|
match x:
|
|
case Const("Nat.Zero"):
|
|
return y
|
|
|
|
case App(f=Const("Nat.Succ"), x=x_):
|
|
return Succ(A2("Nat.add", x_, y))
|
|
|
|
def __nat_iszero(x : Term) -> Term | None:
|
|
match x:
|
|
case Const("Nat.Zero"):
|
|
return Truth()
|
|
|
|
case App(f=Const("Nat.Succ")):
|
|
return Contradiction()
|
|
|
|
return {
|
|
"Bool.not": F1(__bool_not),
|
|
"List.isEmpty": F1(__list_isempty),
|
|
"List.length": F1(__list_length),
|
|
"Nat.add" : F2(__nat_add),
|
|
"Nat.isZero" : F1(__nat_iszero),
|
|
}
|
|
|
|
def normalize(term : Term) -> Term:
|
|
match term:
|
|
case App():
|
|
f = normalize(term.f)
|
|
x = normalize(term.x)
|
|
|
|
cursor = term
|
|
items = [ term ]
|
|
|
|
while True:
|
|
match cursor.f:
|
|
case App():
|
|
cursor = cursor.f
|
|
items.append(cursor)
|
|
|
|
case Const():
|
|
name = cursor.f.name
|
|
rule = apply_rules().get(name, None)
|
|
|
|
if rule is None:
|
|
return App(f, x) # Unknown function
|
|
if len(items) != rule.arity:
|
|
# If len(items) is too small, we don't have
|
|
# enough inputs yet to evaluate the function
|
|
# If len(items) is too large, we already evaluated
|
|
# the function before and couldn't optimize.
|
|
return App(f, x)
|
|
|
|
out = rule.func(*reversed([i.x for i in items]))
|
|
|
|
return normalize(out) if out is not None else App(f, x)
|
|
|
|
case Const():
|
|
return term
|
|
|
|
case Term():
|
|
return term
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Proof checker
|
|
# -----------------------------------------------------------------------------
|
|
|
|
ProofCheck = Union["Proven", "ProveFailure"]
|
|
|
|
@dataclass(frozen=True)
|
|
class Proven:
|
|
pass
|
|
|
|
@dataclass(frozen=True)
|
|
class ProveFailure:
|
|
s : str
|
|
|
|
def check(goal: Prop, proof : Proof, ctx : Context) -> Prop:
|
|
return proof.apply(goal, ctx)
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# -----------------------------------------------------------------------------
|
|
# -----------------------------------------------------------------------------
|
|
|
|
if __name__ == "__main__":
|
|
# 2 + 3 == 5
|
|
print(check(
|
|
Eq(
|
|
lhs=A2("Nat.add", kernel_from_int(2), kernel_from_int(3)),
|
|
rhs=kernel_from_int(5),
|
|
),
|
|
Refl(),
|
|
Context(props=[]),
|
|
))
|
|
|
|
# 2 + 3 == 7
|
|
print(check(
|
|
Eq(
|
|
lhs=A2("Nat.add", kernel_from_int(2), kernel_from_int(3)),
|
|
rhs=kernel_from_int(7),
|
|
),
|
|
Refl(),
|
|
Context(props=[]),
|
|
))
|
|
|
|
# 0 + x == x
|
|
print(check(
|
|
Eq(
|
|
lhs=A2("Nat.add", Zero(), Const("x")),
|
|
rhs=Const("x")
|
|
),
|
|
Refl(),
|
|
Context(props=[]),
|
|
))
|
|
|
|
# x + 0 == x
|
|
print(check(
|
|
Eq(
|
|
lhs=A2("Nat.add", Const("x"), Zero()),
|
|
rhs=Const("x")
|
|
),
|
|
Refl(),
|
|
Context(props=[]),
|
|
))
|