luau/tests/EqSat.propositional.test.cpp
Alexander McCord e1bf6289c7
Equality graphs (#1285)
Working towards a full e-graph implementation as described by the [egg
paper](https://arxiv.org/pdf/2004.03082).

The type system has a couple of places where e-graphs would've been
useful and solved some classes of problems trivially. For example:

1. Normalization and simplification cannot handle cyclic types due to
the nature of their implementation.
2. Normalization can't tell when two tables or functions are equivalent,
but simplification theoretically can albeit not implemented.
3. Normalization requires deep normalization for inhabitance check,
whereas simplification would've returned the `never` type itself
indicating uninhabited.
4. Simplification requires constraint ordering to have perfect timing to
simplify.
5. Adding a rewrite rule requires implementing it twice, once in
simplification and once again in normalization with completely different
code design making it hard to verify that their behavior is materially
equivalent.
6. In cases where we must cache for performance, two different types
that are isomorphic have different cache entries resulting in cache
misses.
7. Type family reduction can handle cyclic type families, but only if
the cycle is not obscured by a different type family instance. (`t1
where t1 = union<number, add<t1, number>>` is irreducible)

I think we're getting the point!

---

Currently the implementation is missing a few features that makes
e-graphs actually useful. Those will be coming in a future PR.

1. Pattern matching,
6. Applying rewrites,
7. Rewrite until saturation, and
8. Extracting the best e-node according to some cost function.
2024-07-16 10:35:20 -07:00

198 lines
4.5 KiB
C++

// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include <doctest.h>
#include "Luau/EGraph.h"
#include "Luau/Id.h"
#include "Luau/Language.h"
#include <optional>
LUAU_EQSAT_ATOM(Var, std::string);
LUAU_EQSAT_ATOM(Bool, bool);
LUAU_EQSAT_NODE_ARRAY(Not, 1);
LUAU_EQSAT_NODE_ARRAY(And, 2);
LUAU_EQSAT_NODE_ARRAY(Or, 2);
LUAU_EQSAT_NODE_ARRAY(Implies, 2);
using namespace Luau;
using PropositionalLogic = EqSat::Language<Var, Bool, Not, And, Or, Implies>;
using EGraph = EqSat::EGraph<PropositionalLogic, struct ConstantFold>;
struct ConstantFold
{
using Data = std::optional<bool>;
Data make(const EGraph& egraph, const Var& var) const
{
return std::nullopt;
}
Data make(const EGraph& egraph, const Bool& b) const
{
return b.value();
}
Data make(const EGraph& egraph, const Not& n) const
{
Data data = egraph[n[0]].data;
if (data)
return !*data;
return std::nullopt;
}
Data make(const EGraph& egraph, const And& a) const
{
Data l = egraph[a[0]].data;
Data r = egraph[a[1]].data;
if (l && r)
return *l && *r;
return std::nullopt;
}
Data make(const EGraph& egraph, const Or& o) const
{
Data l = egraph[o[0]].data;
Data r = egraph[o[1]].data;
if (l && r)
return *l || *r;
return std::nullopt;
}
Data make(const EGraph& egraph, const Implies& i) const
{
Data antecedent = egraph[i[0]].data;
Data consequent = egraph[i[1]].data;
if (antecedent && consequent)
return !*antecedent || *consequent;
return std::nullopt;
}
void join(Data& a, const Data& b) const
{
if (!a && b)
a = b;
}
};
TEST_SUITE_BEGIN("EqSatPropositionalLogic");
TEST_CASE("egraph_hashconsing")
{
EGraph egraph;
EqSat::Id id1 = egraph.add(Bool{true});
EqSat::Id id2 = egraph.add(Bool{true});
EqSat::Id id3 = egraph.add(Bool{false});
CHECK(id1 == id2);
CHECK(id2 != id3);
}
TEST_CASE("egraph_data")
{
EGraph egraph;
EqSat::Id id1 = egraph.add(Bool{true});
EqSat::Id id2 = egraph.add(Bool{false});
CHECK(egraph[id1].data == true);
CHECK(egraph[id2].data == false);
}
TEST_CASE("egraph_merge")
{
EGraph egraph;
EqSat::Id id1 = egraph.add(Var{"a"});
EqSat::Id id2 = egraph.add(Bool{true});
egraph.merge(id1, id2);
CHECK(egraph[id1].data == true);
CHECK(egraph[id2].data == true);
}
TEST_CASE("const_fold_true_and_true")
{
EGraph egraph;
EqSat::Id id1 = egraph.add(Bool{true});
EqSat::Id id2 = egraph.add(Bool{true});
EqSat::Id id3 = egraph.add(And{id1, id2});
CHECK(egraph[id3].data == true);
}
TEST_CASE("const_fold_true_and_false")
{
EGraph egraph;
EqSat::Id id1 = egraph.add(Bool{true});
EqSat::Id id2 = egraph.add(Bool{false});
EqSat::Id id3 = egraph.add(And{id1, id2});
CHECK(egraph[id3].data == false);
}
TEST_CASE("const_fold_false_and_false")
{
EGraph egraph;
EqSat::Id id1 = egraph.add(Bool{false});
EqSat::Id id2 = egraph.add(Bool{false});
EqSat::Id id3 = egraph.add(And{id1, id2});
CHECK(egraph[id3].data == false);
}
TEST_CASE("implications")
{
EGraph egraph;
EqSat::Id t = egraph.add(Bool{true});
EqSat::Id f = egraph.add(Bool{false});
EqSat::Id a = egraph.add(Implies{t, t}); // true
EqSat::Id b = egraph.add(Implies{t, f}); // false
EqSat::Id c = egraph.add(Implies{f, t}); // true
EqSat::Id d = egraph.add(Implies{f, f}); // true
CHECK(egraph[a].data == true);
CHECK(egraph[b].data == false);
CHECK(egraph[c].data == true);
CHECK(egraph[d].data == true);
}
TEST_CASE("merge_x_and_y")
{
EGraph egraph;
EqSat::Id x = egraph.add(Var{"x"});
EqSat::Id y = egraph.add(Var{"y"});
EqSat::Id a = egraph.add(Var{"a"});
EqSat::Id ax = egraph.add(And{a, x});
EqSat::Id ay = egraph.add(And{a, y});
egraph.merge(x, y); // [x y] [ax] [ay] [a]
CHECK_EQ(egraph.size(), 4);
CHECK_EQ(egraph.find(x), egraph.find(y));
CHECK_NE(egraph.find(ax), egraph.find(ay));
CHECK_NE(egraph.find(a), egraph.find(x));
CHECK_NE(egraph.find(a), egraph.find(y));
egraph.rebuild(); // [x y] [ax ay] [a]
CHECK_EQ(egraph.size(), 3);
CHECK_EQ(egraph.find(x), egraph.find(y));
CHECK_EQ(egraph.find(ax), egraph.find(ay));
CHECK_NE(egraph.find(a), egraph.find(x));
CHECK_NE(egraph.find(a), egraph.find(y));
}
TEST_SUITE_END();