Simplify abstract type structuring

main
Bram van den Heuvel 2026-06-29 15:02:21 +02:00
parent ffc79f4372
commit 90240cf067
1 changed files with 25 additions and 7 deletions

View File

@ -15,6 +15,16 @@ class App(Term):
f : Term f : Term
x : Term x : Term
A0 = lambda fn : Const(fn)
A1 = lambda fn, a : App(f=A0(fn), x=a)
A2 = lambda fn, a, b : App(f=A1(fn, a), x=b)
A3 = lambda fn, a, b, c : App(f=A2(fn, a, b), x=c)
A4 = lambda fn, a, b, c, d : App(f=A3(fn, a, b, c), x=d)
A5 = lambda fn, a, b, c, d, e : App(f=A4(fn, a, b, c, d), x=e)
A6 = lambda fn, a, b, c, d, e, f : App(f=A5(fn, a, b, c, d, e), x=f)
A7 = lambda fn, a, b, c, d, e, f, g : App(f=A6(fn, a, b, c, d, e, f), x=g)
A8 = lambda fn, a, b, c, d, e, f, g, h : App(f=A7(fn, a, b, c, d, e, f, g), x=h)
@dataclass(frozen=True) @dataclass(frozen=True)
class Const(Term): class Const(Term):
name : str name : str
@ -22,18 +32,26 @@ class Const(Term):
# --- Bool constructors --- # --- Bool constructors ---
def Truth() -> Term: def Truth() -> Term:
return Const("Bool.Truth") return A0("Bool.Truth")
def Contradiction() -> Term: def Contradiction() -> Term:
return Const("Bool.Contradiction") return A0("Bool.Contradiction")
# --- List constructors ---
def Nil() -> Term:
return A0("List.Nil")
def Cons(head : Term, tail : Term) -> Term:
return A2("List.Cons", head, tail)
# --- Nat constructors --- # --- Nat constructors ---
def Zero() -> Term: def Zero() -> Term:
return Const("Nat.Zero") return A0("Nat.Zero")
def Succ(n : Term) -> Term: def Succ(n : Term) -> Term:
return App(Const("Nat.Succ"), n) return A1("Nat.Succ", n)
def kernel_from_int(n : int) -> Term: def kernel_from_int(n : int) -> Term:
if n == 0: if n == 0:
@ -174,7 +192,7 @@ if __name__ == "__main__":
# 2 + 3 == 5 # 2 + 3 == 5
print(check( print(check(
Eq( Eq(
lhs=App(f=App(f=Const("Nat.add"), x=kernel_from_int(2)), x=kernel_from_int(3)), lhs=A2("Nat.add", kernel_from_int(2), kernel_from_int(3)),
rhs=kernel_from_int(5), rhs=kernel_from_int(5),
), ),
Refl(), Refl(),
@ -183,7 +201,7 @@ if __name__ == "__main__":
# 0 + x == x # 0 + x == x
print(check( print(check(
Eq( Eq(
lhs=App(f=App(f=Const("Nat.add"), x=Zero()), x=Const("x")), lhs=A2("Nat.add", Zero(), Const("x")),
rhs=Const("x") rhs=Const("x")
), ),
Refl(), Refl(),
@ -192,7 +210,7 @@ if __name__ == "__main__":
# x + 0 == x # x + 0 == x
print(check( print(check(
Eq( Eq(
lhs=App(f=App(f=Const("Nat.add"), x=Const("x")), x=Zero()), lhs=A2("Nat.add", Const("x"), Zero()),
rhs=Const("x") rhs=Const("x")
), ),
Refl(), Refl(),