mirror of
https://github.com/luau-lang/luau.git
synced 2024-11-15 06:15:44 +08:00
Equality graphs (#1285)
Working towards a full e-graph implementation as described by the [egg paper](https://arxiv.org/pdf/2004.03082). The type system has a couple of places where e-graphs would've been useful and solved some classes of problems trivially. For example: 1. Normalization and simplification cannot handle cyclic types due to the nature of their implementation. 2. Normalization can't tell when two tables or functions are equivalent, but simplification theoretically can albeit not implemented. 3. Normalization requires deep normalization for inhabitance check, whereas simplification would've returned the `never` type itself indicating uninhabited. 4. Simplification requires constraint ordering to have perfect timing to simplify. 5. Adding a rewrite rule requires implementing it twice, once in simplification and once again in normalization with completely different code design making it hard to verify that their behavior is materially equivalent. 6. In cases where we must cache for performance, two different types that are isomorphic have different cache entries resulting in cache misses. 7. Type family reduction can handle cyclic type families, but only if the cycle is not obscured by a different type family instance. (`t1 where t1 = union<number, add<t1, number>>` is irreducible) I think we're getting the point! --- Currently the implementation is missing a few features that makes e-graphs actually useful. Those will be coming in a future PR. 1. Pattern matching, 6. Applying rewrites, 7. Rewrite until saturation, and 8. Extracting the best e-node according to some cost function.
This commit is contained in:
parent
b6b74b4425
commit
e1bf6289c7
2
.gitignore
vendored
2
.gitignore
vendored
@ -14,3 +14,5 @@
|
||||
/luau-analyze
|
||||
/luau-compile
|
||||
__pycache__
|
||||
.cache
|
||||
.clangd
|
||||
|
@ -28,6 +28,7 @@ add_library(Luau.Ast STATIC)
|
||||
add_library(Luau.Compiler STATIC)
|
||||
add_library(Luau.Config STATIC)
|
||||
add_library(Luau.Analysis STATIC)
|
||||
add_library(Luau.EqSat STATIC)
|
||||
add_library(Luau.CodeGen STATIC)
|
||||
add_library(Luau.VM STATIC)
|
||||
add_library(isocline STATIC)
|
||||
@ -83,7 +84,11 @@ target_link_libraries(Luau.Config PUBLIC Luau.Ast)
|
||||
|
||||
target_compile_features(Luau.Analysis PUBLIC cxx_std_17)
|
||||
target_include_directories(Luau.Analysis PUBLIC Analysis/include)
|
||||
target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.Config)
|
||||
target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.EqSat Luau.Config)
|
||||
|
||||
target_compile_features(Luau.EqSat PUBLIC cxx_std_17)
|
||||
target_include_directories(Luau.EqSat PUBLIC EqSat/include)
|
||||
target_link_libraries(Luau.EqSat PUBLIC Luau.Common)
|
||||
|
||||
target_compile_features(Luau.CodeGen PRIVATE cxx_std_17)
|
||||
target_include_directories(Luau.CodeGen PUBLIC CodeGen/include)
|
||||
@ -141,6 +146,7 @@ endif()
|
||||
|
||||
target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS})
|
||||
target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS})
|
||||
target_compile_options(Luau.EqSat PRIVATE ${LUAU_OPTIONS})
|
||||
target_compile_options(Luau.CLI.lib PRIVATE ${LUAU_OPTIONS})
|
||||
target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS})
|
||||
target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS})
|
||||
@ -263,13 +269,13 @@ endif()
|
||||
add_subdirectory(fuzz)
|
||||
|
||||
# validate dependencies for internal libraries
|
||||
foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.CodeGen Luau.VM)
|
||||
foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.EqSat Luau.CodeGen Luau.VM)
|
||||
if(TARGET ${LIB})
|
||||
get_target_property(DEPENDS ${LIB} LINK_LIBRARIES)
|
||||
if(LIB MATCHES "CodeGen|VM" AND DEPENDS MATCHES "Ast|Analysis|Config|Compiler")
|
||||
message(FATAL_ERROR ${LIB} " is a runtime component but it depends on one of the offline components")
|
||||
endif()
|
||||
if(LIB MATCHES "Ast|Analysis|Compiler" AND DEPENDS MATCHES "CodeGen|VM")
|
||||
if(LIB MATCHES "Ast|Analysis|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM")
|
||||
message(FATAL_ERROR ${LIB} " is an offline component but it depends on one of the runtime components")
|
||||
endif()
|
||||
if(LIB MATCHES "Ast|Compiler" AND DEPENDS MATCHES "Analysis|Config")
|
||||
|
228
EqSat/include/Luau/EGraph.h
Normal file
228
EqSat/include/Luau/EGraph.h
Normal file
@ -0,0 +1,228 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#pragma once
|
||||
|
||||
#include "Luau/Common.h"
|
||||
#include "Luau/Id.h"
|
||||
#include "Luau/Language.h"
|
||||
#include "Luau/UnionFind.h"
|
||||
#include "Luau/VecDeque.h"
|
||||
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace Luau::EqSat
|
||||
{
|
||||
|
||||
template<typename L, typename N>
|
||||
struct EGraph;
|
||||
|
||||
template<typename L, typename N>
|
||||
struct Analysis final
|
||||
{
|
||||
N analysis;
|
||||
|
||||
using D = typename N::Data;
|
||||
|
||||
template<typename T>
|
||||
static D fnMake(const N& analysis, const EGraph<L, N>& egraph, const L& enode)
|
||||
{
|
||||
return analysis.make(egraph, *enode.template get<T>());
|
||||
}
|
||||
|
||||
template<typename... Ts>
|
||||
D make(const EGraph<L, N>& egraph, const Language<Ts...>& enode) const
|
||||
{
|
||||
using FnMake = D (*)(const N&, const EGraph<L, N>&, const L&);
|
||||
static constexpr FnMake tableMake[sizeof...(Ts)] = {&fnMake<Ts>...};
|
||||
|
||||
return tableMake[enode.index()](analysis, egraph, enode);
|
||||
}
|
||||
|
||||
void join(D& a, const D& b) const
|
||||
{
|
||||
return analysis.join(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
/// Each e-class is a set of e-nodes representing equivalent terms from a given language,
|
||||
/// and an e-node is a function symbol paired with a list of children e-classes.
|
||||
template<typename L, typename D>
|
||||
struct EClass final
|
||||
{
|
||||
Id id;
|
||||
std::vector<L> nodes;
|
||||
D data;
|
||||
std::vector<std::pair<L, Id>> parents;
|
||||
};
|
||||
|
||||
/// See <https://arxiv.org/pdf/2004.03082>.
|
||||
template<typename L, typename N>
|
||||
struct EGraph final
|
||||
{
|
||||
Id find(Id id) const
|
||||
{
|
||||
return unionfind.find(id);
|
||||
}
|
||||
|
||||
std::optional<Id> lookup(const L& enode) const
|
||||
{
|
||||
LUAU_ASSERT(isCanonical(enode));
|
||||
|
||||
if (auto it = hashcons.find(enode); it != hashcons.end())
|
||||
return it->second;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Id add(L enode)
|
||||
{
|
||||
canonicalize(enode);
|
||||
|
||||
if (auto id = lookup(enode))
|
||||
return *id;
|
||||
|
||||
Id id = makeEClass(enode);
|
||||
return id;
|
||||
}
|
||||
|
||||
void merge(Id id1, Id id2)
|
||||
{
|
||||
id1 = find(id1);
|
||||
id2 = find(id2);
|
||||
if (id1 == id2)
|
||||
return;
|
||||
|
||||
unionfind.merge(id1, id2);
|
||||
|
||||
EClass<L, typename N::Data>& eclass1 = get(id1);
|
||||
EClass<L, typename N::Data> eclass2 = std::move(get(id2));
|
||||
classes.erase(id2);
|
||||
|
||||
worklist.reserve(worklist.size() + eclass2.parents.size());
|
||||
for (auto [enode, id] : eclass2.parents)
|
||||
worklist.push_back({std::move(enode), id});
|
||||
|
||||
analysis.join(eclass1.data, eclass2.data);
|
||||
}
|
||||
|
||||
void rebuild()
|
||||
{
|
||||
while (!worklist.empty())
|
||||
{
|
||||
auto [enode, id] = worklist.back();
|
||||
worklist.pop_back();
|
||||
repair(get(find(id)));
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
return classes.size();
|
||||
}
|
||||
|
||||
EClass<L, typename N::Data>& operator[](Id id)
|
||||
{
|
||||
return get(find(id));
|
||||
}
|
||||
|
||||
const EClass<L, typename N::Data>& operator[](Id id) const
|
||||
{
|
||||
return const_cast<EGraph*>(this)->get(find(id));
|
||||
}
|
||||
|
||||
private:
|
||||
Analysis<L, N> analysis;
|
||||
|
||||
/// A union-find data structure 𝑈 stores an equivalence relation over e-class ids.
|
||||
UnionFind unionfind;
|
||||
|
||||
/// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same
|
||||
/// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the
|
||||
/// e-class 𝑀[find(𝑎)].
|
||||
std::unordered_map<Id, EClass<L, typename N::Data>> classes;
|
||||
|
||||
/// The hashcons 𝐻 is a map from e-nodes to e-class ids.
|
||||
std::unordered_map<L, Id, typename L::Hash> hashcons;
|
||||
|
||||
VecDeque<std::pair<L, Id>> worklist;
|
||||
|
||||
private:
|
||||
void canonicalize(L& enode)
|
||||
{
|
||||
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
|
||||
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
|
||||
for (Id& id : enode.operands())
|
||||
id = find(id);
|
||||
}
|
||||
|
||||
bool isCanonical(const L& enode) const
|
||||
{
|
||||
bool canonical = true;
|
||||
for (Id id : enode.operands())
|
||||
canonical &= (id == find(id));
|
||||
return canonical;
|
||||
}
|
||||
|
||||
Id makeEClass(const L& enode)
|
||||
{
|
||||
LUAU_ASSERT(isCanonical(enode));
|
||||
|
||||
Id id = unionfind.makeSet();
|
||||
|
||||
classes.insert_or_assign(id, EClass<L, typename N::Data>{
|
||||
id,
|
||||
{enode},
|
||||
analysis.make(*this, enode),
|
||||
{},
|
||||
});
|
||||
|
||||
for (Id operand : enode.operands())
|
||||
get(operand).parents.push_back({enode, id});
|
||||
|
||||
worklist.push_back({enode, id});
|
||||
hashcons.insert_or_assign(enode, id);
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
// Looks up for an eclass from a given non-canonicalized `id`.
|
||||
// For a canonicalized eclass, use `get(find(id))` or `egraph[id]`.
|
||||
EClass<L, typename N::Data>& get(Id id)
|
||||
{
|
||||
return classes.at(id);
|
||||
}
|
||||
|
||||
void repair(EClass<L, typename N::Data>& eclass)
|
||||
{
|
||||
// In the egg paper, the `repair` function makes use of two loops over the `eclass.parents`
|
||||
// by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id.
|
||||
// And then in another loop that follows, deduplicate it.
|
||||
//
|
||||
// Here, we unify the two loops. I think it's equivalent?
|
||||
|
||||
// After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent.
|
||||
std::unordered_map<L, Id, typename L::Hash> map;
|
||||
for (auto& [enode, id] : eclass.parents)
|
||||
{
|
||||
// By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id.
|
||||
hashcons.erase(enode);
|
||||
canonicalize(enode);
|
||||
hashcons.insert_or_assign(enode, find(id));
|
||||
|
||||
if (auto it = map.find(enode); it != map.end())
|
||||
merge(id, it->second);
|
||||
|
||||
map.insert_or_assign(enode, find(id));
|
||||
}
|
||||
|
||||
eclass.parents.clear();
|
||||
for (auto it = map.begin(); it != map.end();)
|
||||
{
|
||||
auto node = map.extract(it++);
|
||||
eclass.parents.emplace_back(std::move(node.key()), node.mapped());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace Luau::EqSat
|
29
EqSat/include/Luau/Id.h
Normal file
29
EqSat/include/Luau/Id.h
Normal file
@ -0,0 +1,29 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
|
||||
namespace Luau::EqSat
|
||||
{
|
||||
|
||||
struct Id final
|
||||
{
|
||||
explicit Id(size_t id);
|
||||
|
||||
explicit operator size_t() const;
|
||||
|
||||
bool operator==(Id rhs) const;
|
||||
bool operator!=(Id rhs) const;
|
||||
|
||||
private:
|
||||
size_t id;
|
||||
};
|
||||
|
||||
} // namespace Luau::EqSat
|
||||
|
||||
template<>
|
||||
struct std::hash<Luau::EqSat::Id>
|
||||
{
|
||||
size_t operator()(Luau::EqSat::Id id) const;
|
||||
};
|
304
EqSat/include/Luau/Language.h
Normal file
304
EqSat/include/Luau/Language.h
Normal file
@ -0,0 +1,304 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#pragma once
|
||||
|
||||
#include "Luau/Id.h"
|
||||
#include "Luau/LanguageHash.h"
|
||||
#include "Luau/Slice.h"
|
||||
#include "Luau/Variant.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#define LUAU_EQSAT_ATOM(name, t) \
|
||||
struct name : public ::Luau::EqSat::Atom<name, t> \
|
||||
{ \
|
||||
static constexpr const char* tag = #name; \
|
||||
using Atom::Atom; \
|
||||
}
|
||||
|
||||
#define LUAU_EQSAT_NODE_ARRAY(name, ops) \
|
||||
struct name : public ::Luau::EqSat::NodeVector<name, std::array<::Luau::EqSat::Id, ops>> \
|
||||
{ \
|
||||
static constexpr const char* tag = #name; \
|
||||
using NodeVector::NodeVector; \
|
||||
}
|
||||
|
||||
#define LUAU_EQSAT_NODE_VECTOR(name) \
|
||||
struct name : public ::Luau::EqSat::NodeVector<name, std::vector<::Luau::EqSat::Id>> \
|
||||
{ \
|
||||
static constexpr const char* tag = #name; \
|
||||
using NodeVector::NodeVector; \
|
||||
}
|
||||
|
||||
#define LUAU_EQSAT_FIELD(name) \
|
||||
struct name : public ::Luau::EqSat::Field<name> \
|
||||
{ \
|
||||
}
|
||||
|
||||
#define LUAU_EQSAT_NODE_FIELDS(name, ...) \
|
||||
struct name : public ::Luau::EqSat::NodeFields<name, __VA_ARGS__> \
|
||||
{ \
|
||||
static constexpr const char* tag = #name; \
|
||||
using NodeFields::NodeFields; \
|
||||
}
|
||||
|
||||
namespace Luau::EqSat
|
||||
{
|
||||
|
||||
template<typename Phantom, typename T>
|
||||
struct Atom
|
||||
{
|
||||
Atom(const T& value)
|
||||
: _value(value)
|
||||
{
|
||||
}
|
||||
|
||||
const T& value() const
|
||||
{
|
||||
return _value;
|
||||
}
|
||||
|
||||
public:
|
||||
Slice<Id> operands()
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
Slice<const Id> operands() const
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
bool operator==(const Atom& rhs) const
|
||||
{
|
||||
return _value == rhs._value;
|
||||
}
|
||||
|
||||
bool operator!=(const Atom& rhs) const
|
||||
{
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
struct Hash
|
||||
{
|
||||
size_t operator()(const Atom& value) const
|
||||
{
|
||||
return languageHash(value._value);
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
T _value;
|
||||
};
|
||||
|
||||
template<typename Phantom, typename T>
|
||||
struct NodeVector
|
||||
{
|
||||
template<typename... Args>
|
||||
NodeVector(Args&&... args)
|
||||
: vector{std::forward<Args>(args)...}
|
||||
{
|
||||
}
|
||||
|
||||
Id operator[](size_t i) const
|
||||
{
|
||||
return vector[i];
|
||||
}
|
||||
|
||||
public:
|
||||
Slice<Id> operands()
|
||||
{
|
||||
return Slice{vector.data(), vector.size()};
|
||||
}
|
||||
|
||||
Slice<const Id> operands() const
|
||||
{
|
||||
return Slice{vector.data(), vector.size()};
|
||||
}
|
||||
|
||||
bool operator==(const NodeVector& rhs) const
|
||||
{
|
||||
return vector == rhs.vector;
|
||||
}
|
||||
|
||||
bool operator!=(const NodeVector& rhs) const
|
||||
{
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
struct Hash
|
||||
{
|
||||
size_t operator()(const NodeVector& value) const
|
||||
{
|
||||
return languageHash(value.vector);
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
T vector;
|
||||
};
|
||||
|
||||
/// Empty base class just for static_asserts.
|
||||
struct FieldBase
|
||||
{
|
||||
FieldBase() = delete;
|
||||
|
||||
FieldBase(FieldBase&&) = delete;
|
||||
FieldBase& operator=(FieldBase&&) = delete;
|
||||
|
||||
FieldBase(const FieldBase&) = delete;
|
||||
FieldBase& operator=(const FieldBase&) = delete;
|
||||
};
|
||||
|
||||
template<typename Phantom>
|
||||
struct Field : FieldBase
|
||||
{
|
||||
};
|
||||
|
||||
template<typename Phantom, typename... Fields>
|
||||
struct NodeFields
|
||||
{
|
||||
static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value);
|
||||
|
||||
template<typename T>
|
||||
static constexpr int getIndex()
|
||||
{
|
||||
constexpr int N = sizeof...(Fields);
|
||||
constexpr bool is[N] = {std::is_same_v<std::decay_t<T>, Fields>...};
|
||||
|
||||
for (int i = 0; i < N; ++i)
|
||||
if (is[i])
|
||||
return i;
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
public:
|
||||
template<typename... Args>
|
||||
NodeFields(Args&&... args)
|
||||
: array{std::forward<Args>(args)...}
|
||||
{
|
||||
}
|
||||
|
||||
Slice<Id> operands()
|
||||
{
|
||||
return Slice{array};
|
||||
}
|
||||
|
||||
Slice<const Id> operands() const
|
||||
{
|
||||
return Slice{array.data(), array.size()};
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Id field() const
|
||||
{
|
||||
static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Fields>...>);
|
||||
return array[getIndex<T>()];
|
||||
}
|
||||
|
||||
bool operator==(const NodeFields& rhs) const
|
||||
{
|
||||
return array == rhs.array;
|
||||
}
|
||||
|
||||
bool operator!=(const NodeFields& rhs) const
|
||||
{
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
struct Hash
|
||||
{
|
||||
size_t operator()(const NodeFields& value) const
|
||||
{
|
||||
return languageHash(value.array);
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
std::array<Id, sizeof...(Fields)> array;
|
||||
};
|
||||
|
||||
template<typename... Ts>
|
||||
struct Language final
|
||||
{
|
||||
template<typename T>
|
||||
using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>;
|
||||
|
||||
template<typename T>
|
||||
Language(T&& t, std::enable_if_t<WithinDomain<T>::value>* = 0) noexcept
|
||||
: v(std::forward<T>(t))
|
||||
{
|
||||
}
|
||||
|
||||
Language(const Language&) noexcept = default;
|
||||
Language& operator=(const Language&) noexcept = default;
|
||||
|
||||
Language(Language&&) noexcept = default;
|
||||
Language& operator=(Language&&) noexcept = default;
|
||||
|
||||
int index() const noexcept
|
||||
{
|
||||
return v.index();
|
||||
}
|
||||
|
||||
/// You should never call this function with the intention of mutating the `Id`.
|
||||
/// Reading is ok, but you should also never assume that these `Id`s are stable.
|
||||
Slice<Id> operands() noexcept
|
||||
{
|
||||
return visit([](auto&& v) -> Slice<Id> {
|
||||
return v.operands();
|
||||
}, v);
|
||||
}
|
||||
|
||||
Slice<const Id> operands() const noexcept
|
||||
{
|
||||
return visit([](auto&& v) -> Slice<const Id> {
|
||||
return v.operands();
|
||||
}, v);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* get() noexcept
|
||||
{
|
||||
static_assert(WithinDomain<T>::value);
|
||||
return v.template get_if<T>();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
const T* get() const noexcept
|
||||
{
|
||||
static_assert(WithinDomain<T>::value);
|
||||
return v.template get_if<T>();
|
||||
}
|
||||
|
||||
bool operator==(const Language& rhs) const noexcept
|
||||
{
|
||||
return v == rhs.v;
|
||||
}
|
||||
|
||||
bool operator!=(const Language& rhs) const noexcept
|
||||
{
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
public:
|
||||
struct Hash
|
||||
{
|
||||
size_t operator()(const Language& language) const
|
||||
{
|
||||
size_t seed = std::hash<int>{}(language.index());
|
||||
hashCombine(seed, visit([](auto&& v) {
|
||||
return typename std::decay_t<decltype(v)>::Hash{}(v);
|
||||
}, language.v));
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
Variant<Ts...> v;
|
||||
};
|
||||
|
||||
} // namespace Luau::EqSat
|
56
EqSat/include/Luau/LanguageHash.h
Normal file
56
EqSat/include/Luau/LanguageHash.h
Normal file
@ -0,0 +1,56 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
|
||||
namespace Luau::EqSat
|
||||
{
|
||||
|
||||
template<typename T>
|
||||
struct LanguageHash
|
||||
{
|
||||
size_t operator()(const T& t, decltype(std::hash<T>{}(std::declval<T>()))* = 0) const
|
||||
{
|
||||
return std::hash<T>{}(t);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
size_t languageHash(const T& lang)
|
||||
{
|
||||
return LanguageHash<T>{}(lang);
|
||||
}
|
||||
|
||||
inline void hashCombine(size_t& seed, size_t hash)
|
||||
{
|
||||
// Golden Ratio constant used for better hash scattering
|
||||
// See https://softwareengineering.stackexchange.com/a/402543
|
||||
seed ^= hash + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
template<typename T, size_t I>
|
||||
struct LanguageHash<std::array<T, I>>
|
||||
{
|
||||
size_t operator()(const std::array<T, I>& array) const
|
||||
{
|
||||
size_t seed = 0;
|
||||
for (const T& t : array)
|
||||
hashCombine(seed, languageHash(t));
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct LanguageHash<std::vector<T>>
|
||||
{
|
||||
size_t operator()(const std::vector<T>& vector) const
|
||||
{
|
||||
size_t seed = 0;
|
||||
for (const T& t : vector)
|
||||
hashCombine(seed, languageHash(t));
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace Luau::EqSat
|
78
EqSat/include/Luau/Slice.h
Normal file
78
EqSat/include/Luau/Slice.h
Normal file
@ -0,0 +1,78 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#pragma once
|
||||
|
||||
#include "Luau/Common.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
|
||||
namespace Luau::EqSat
|
||||
{
|
||||
|
||||
template<typename T>
|
||||
struct Slice final
|
||||
{
|
||||
Slice()
|
||||
: _data(nullptr)
|
||||
, _size(0)
|
||||
{
|
||||
}
|
||||
|
||||
/// Use this constructor if you have a dynamically sized vector.
|
||||
/// The slice is valid for as long as the backing vector has not moved
|
||||
/// elsewhere in memory.
|
||||
///
|
||||
/// In general, a slice should never be used from vectors except for
|
||||
/// any vectors whose size are statically unknown, but remains fixed
|
||||
/// upon the construction of such a slice over a vector.
|
||||
Slice(T* first, size_t last)
|
||||
: _data(first)
|
||||
, _size(last)
|
||||
{
|
||||
}
|
||||
|
||||
template<size_t I>
|
||||
explicit Slice(std::array<T, I>& array)
|
||||
: _data(array.data())
|
||||
, _size(array.size())
|
||||
{
|
||||
}
|
||||
|
||||
T* data() const
|
||||
{
|
||||
return _data;
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
return _size;
|
||||
}
|
||||
|
||||
bool empty() const
|
||||
{
|
||||
return _size == 0;
|
||||
}
|
||||
|
||||
T& operator[](size_t i) const
|
||||
{
|
||||
LUAU_ASSERT(i < _size);
|
||||
return _data[i];
|
||||
}
|
||||
|
||||
public:
|
||||
T* _data;
|
||||
size_t _size;
|
||||
|
||||
public:
|
||||
T* begin() const
|
||||
{
|
||||
return _data;
|
||||
}
|
||||
|
||||
T* end() const
|
||||
{
|
||||
return _data + _size;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace Luau::EqSat
|
22
EqSat/include/Luau/UnionFind.h
Normal file
22
EqSat/include/Luau/UnionFind.h
Normal file
@ -0,0 +1,22 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#pragma once
|
||||
|
||||
#include "Luau/Id.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace Luau::EqSat
|
||||
{
|
||||
|
||||
/// See <https://dl.acm.org/doi/pdf/10.1145/321879.321884>.
|
||||
struct UnionFind final
|
||||
{
|
||||
Id makeSet();
|
||||
Id find(Id id) const;
|
||||
void merge(Id a, Id b);
|
||||
|
||||
private:
|
||||
std::vector<Id> parents;
|
||||
};
|
||||
|
||||
} // namespace Luau::EqSat
|
32
EqSat/src/Id.cpp
Normal file
32
EqSat/src/Id.cpp
Normal file
@ -0,0 +1,32 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#include "Luau/Id.h"
|
||||
|
||||
namespace Luau::EqSat
|
||||
{
|
||||
|
||||
Id::Id(size_t id)
|
||||
: id(id)
|
||||
{
|
||||
}
|
||||
|
||||
Id::operator size_t() const
|
||||
{
|
||||
return id;
|
||||
}
|
||||
|
||||
bool Id::operator==(Id rhs) const
|
||||
{
|
||||
return id == rhs.id;
|
||||
}
|
||||
|
||||
bool Id::operator!=(Id rhs) const
|
||||
{
|
||||
return id != rhs.id;
|
||||
}
|
||||
|
||||
} // namespace Luau::EqSat
|
||||
|
||||
size_t std::hash<Luau::EqSat::Id>::operator()(Luau::EqSat::Id id) const
|
||||
{
|
||||
return std::hash<size_t>()(size_t(id));
|
||||
}
|
35
EqSat/src/UnionFind.cpp
Normal file
35
EqSat/src/UnionFind.cpp
Normal file
@ -0,0 +1,35 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#include "Luau/UnionFind.h"
|
||||
|
||||
#include "Luau/Common.h"
|
||||
|
||||
namespace Luau::EqSat
|
||||
{
|
||||
|
||||
Id UnionFind::makeSet()
|
||||
{
|
||||
Id id{parents.size()};
|
||||
parents.push_back(id);
|
||||
return id;
|
||||
}
|
||||
|
||||
Id UnionFind::find(Id id) const
|
||||
{
|
||||
LUAU_ASSERT(size_t(id) < parents.size());
|
||||
|
||||
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
|
||||
while (id != parents[size_t(id)])
|
||||
id = parents[size_t(id)];
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
void UnionFind::merge(Id a, Id b)
|
||||
{
|
||||
LUAU_ASSERT(size_t(a) < parents.size());
|
||||
LUAU_ASSERT(size_t(b) < parents.size());
|
||||
|
||||
parents[size_t(b)] = a;
|
||||
}
|
||||
|
||||
} // namespace Luau::EqSat
|
28
Makefile
28
Makefile
@ -26,6 +26,10 @@ ANALYSIS_SOURCES=$(wildcard Analysis/src/*.cpp)
|
||||
ANALYSIS_OBJECTS=$(ANALYSIS_SOURCES:%=$(BUILD)/%.o)
|
||||
ANALYSIS_TARGET=$(BUILD)/libluauanalysis.a
|
||||
|
||||
EQSAT_SOURCES=$(wildcard EqSat/src/*.cpp)
|
||||
EQSAT_OBJECTS=$(EQSAT_SOURCES:%=$(BUILD)/%.o)
|
||||
EQSAT_TARGET=$(BUILD)/libluaueqsat.a
|
||||
|
||||
CODEGEN_SOURCES=$(wildcard CodeGen/src/*.cpp)
|
||||
CODEGEN_OBJECTS=$(CODEGEN_SOURCES:%=$(BUILD)/%.o)
|
||||
CODEGEN_TARGET=$(BUILD)/libluaucodegen.a
|
||||
@ -69,7 +73,7 @@ ifneq ($(opt),)
|
||||
TESTS_ARGS+=-O$(opt)
|
||||
endif
|
||||
|
||||
OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(CONFIG_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(COMPILE_CLI_OBJECTS) $(BYTECODE_CLI_OBJECTS) $(FUZZ_OBJECTS)
|
||||
OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(CONFIG_OBJECTS) $(ANALYSIS_OBJECTS) $(EQSAT_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(COMPILE_CLI_OBJECTS) $(BYTECODE_CLI_OBJECTS) $(FUZZ_OBJECTS)
|
||||
EXECUTABLE_ALIASES = luau luau-analyze luau-compile luau-bytecode luau-tests
|
||||
|
||||
# common flags
|
||||
@ -138,16 +142,17 @@ endif
|
||||
$(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include
|
||||
$(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include
|
||||
$(CONFIG_OBJECTS): CXXFLAGS+=-std=c++17 -IConfig/include -ICommon/include -IAst/include
|
||||
$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IConfig/include
|
||||
$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include
|
||||
$(EQSAT_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IEqSat/include
|
||||
$(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals
|
||||
$(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include
|
||||
$(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include
|
||||
$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY
|
||||
$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -IEqSat/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY
|
||||
$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -Iextern -Iextern/isocline/include
|
||||
$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IConfig/include -Iextern
|
||||
$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -Iextern
|
||||
$(COMPILE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include
|
||||
$(BYTECODE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include
|
||||
$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICodeGen/include -IConfig/include
|
||||
$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IEqSat/include -IVM/include -ICodeGen/include -IConfig/include
|
||||
|
||||
$(TESTS_TARGET): LDFLAGS+=-lpthread
|
||||
$(REPL_CLI_TARGET): LDFLAGS+=-lpthread
|
||||
@ -218,9 +223,9 @@ luau-tests: $(TESTS_TARGET)
|
||||
ln -fs $^ $@
|
||||
|
||||
# executable targets
|
||||
$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
|
||||
$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
|
||||
$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
|
||||
$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(AST_TARGET) $(CONFIG_TARGET)
|
||||
$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET)
|
||||
$(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET)
|
||||
$(BYTECODE_CLI_TARGET): $(BYTECODE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET)
|
||||
|
||||
@ -228,22 +233,23 @@ $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(COMPILE_CLI_TARGET) $
|
||||
$(CXX) $^ $(LDFLAGS) -o $@
|
||||
|
||||
# executable targets for fuzzing
|
||||
fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(CODEGEN_TARGET) $(VM_TARGET)
|
||||
fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(CODEGEN_TARGET) $(VM_TARGET)
|
||||
$(CXX) $^ $(LDFLAGS) -o $@
|
||||
|
||||
fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator
|
||||
fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator
|
||||
fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator
|
||||
fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator
|
||||
|
||||
# static library targets
|
||||
$(AST_TARGET): $(AST_OBJECTS)
|
||||
$(COMPILER_TARGET): $(COMPILER_OBJECTS)
|
||||
$(CONFIG_TARGET): $(CONFIG_OBJECTS)
|
||||
$(ANALYSIS_TARGET): $(ANALYSIS_OBJECTS)
|
||||
$(EQSAT_TARGET): $(EQSAT_OBJECTS)
|
||||
$(CODEGEN_TARGET): $(CODEGEN_OBJECTS)
|
||||
$(VM_TARGET): $(VM_OBJECTS)
|
||||
$(ISOCLINE_TARGET): $(ISOCLINE_OBJECTS)
|
||||
|
||||
$(AST_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(ANALYSIS_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET):
|
||||
$(AST_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET):
|
||||
ar rcs $@ $^
|
||||
|
||||
# object file targets
|
||||
|
@ -7,6 +7,7 @@ if(NOT ${CMAKE_VERSION} VERSION_LESS "3.19")
|
||||
Common/include/Luau/BytecodeUtils.h
|
||||
Common/include/Luau/DenseHash.h
|
||||
Common/include/Luau/ExperimentalFlags.h
|
||||
Common/include/Luau/Variant.h
|
||||
Common/include/Luau/VecDeque.h
|
||||
)
|
||||
endif()
|
||||
@ -232,7 +233,6 @@ target_sources(Luau.Analysis PRIVATE
|
||||
Analysis/include/Luau/Unifier.h
|
||||
Analysis/include/Luau/Unifier2.h
|
||||
Analysis/include/Luau/UnifierSharedState.h
|
||||
Analysis/include/Luau/Variant.h
|
||||
Analysis/include/Luau/VisitType.h
|
||||
|
||||
Analysis/src/Anyification.cpp
|
||||
@ -295,6 +295,19 @@ target_sources(Luau.Analysis PRIVATE
|
||||
Analysis/src/Unifier2.cpp
|
||||
)
|
||||
|
||||
# Luau.Analysis Sources
|
||||
target_sources(Luau.EqSat PRIVATE
|
||||
EqSat/include/Luau/EGraph.h
|
||||
EqSat/include/Luau/Id.h
|
||||
EqSat/include/Luau/Language.h
|
||||
EqSat/include/Luau/LanguageHash.h
|
||||
EqSat/include/Luau/Slice.h
|
||||
EqSat/include/Luau/UnionFind.h
|
||||
|
||||
EqSat/src/Id.cpp
|
||||
EqSat/src/UnionFind.cpp
|
||||
)
|
||||
|
||||
# Luau.VM Sources
|
||||
target_sources(Luau.VM PRIVATE
|
||||
VM/include/lua.h
|
||||
@ -418,6 +431,9 @@ if(TARGET Luau.UnitTest)
|
||||
tests/DiffAsserts.cpp
|
||||
tests/DiffAsserts.h
|
||||
tests/Differ.test.cpp
|
||||
tests/EqSat.language.test.cpp
|
||||
tests/EqSat.propositional.test.cpp
|
||||
tests/EqSat.slice.test.cpp
|
||||
tests/Error.test.cpp
|
||||
tests/Fixture.cpp
|
||||
tests/Fixture.h
|
||||
|
144
tests/EqSat.language.test.cpp
Normal file
144
tests/EqSat.language.test.cpp
Normal file
@ -0,0 +1,144 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#include <doctest.h>
|
||||
|
||||
#include "Luau/Id.h"
|
||||
#include "Luau/Language.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
LUAU_EQSAT_ATOM(I32, int);
|
||||
LUAU_EQSAT_ATOM(Bool, bool);
|
||||
LUAU_EQSAT_ATOM(Str, std::string);
|
||||
|
||||
LUAU_EQSAT_FIELD(Left);
|
||||
LUAU_EQSAT_FIELD(Right);
|
||||
LUAU_EQSAT_NODE_FIELDS(Add, Left, Right);
|
||||
|
||||
using namespace Luau;
|
||||
|
||||
using Value = EqSat::Language<I32, Bool, Str, Add>;
|
||||
|
||||
TEST_SUITE_BEGIN("EqSatLanguage");
|
||||
|
||||
TEST_CASE("atom_equality")
|
||||
{
|
||||
CHECK(I32{0} == I32{0});
|
||||
CHECK(I32{0} != I32{1});
|
||||
}
|
||||
|
||||
TEST_CASE("node_equality")
|
||||
{
|
||||
CHECK(Add{EqSat::Id{0}, EqSat::Id{0}} == Add{EqSat::Id{0}, EqSat::Id{0}});
|
||||
CHECK(Add{EqSat::Id{1}, EqSat::Id{0}} != Add{EqSat::Id{0}, EqSat::Id{0}});
|
||||
}
|
||||
|
||||
TEST_CASE("language_get")
|
||||
{
|
||||
Value v{I32{5}};
|
||||
|
||||
auto i = v.get<I32>();
|
||||
REQUIRE(i);
|
||||
CHECK(i->value());
|
||||
|
||||
CHECK(!v.get<Bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("language_copy_ctor")
|
||||
{
|
||||
Value v1{I32{5}};
|
||||
Value v2 = v1;
|
||||
|
||||
auto i1 = v1.get<I32>();
|
||||
auto i2 = v2.get<I32>();
|
||||
REQUIRE(i1);
|
||||
REQUIRE(i2);
|
||||
CHECK(i1->value() == i2->value());
|
||||
}
|
||||
|
||||
TEST_CASE("language_move_ctor")
|
||||
{
|
||||
Value v1{Str{"hello"}};
|
||||
{
|
||||
auto s1 = v1.get<Str>();
|
||||
REQUIRE(s1);
|
||||
CHECK(s1->value() == "hello");
|
||||
}
|
||||
|
||||
Value v2 = std::move(v1);
|
||||
|
||||
auto s1 = v1.get<Str>();
|
||||
REQUIRE(s1);
|
||||
CHECK(s1->value() == ""); // this also tests the dtor.
|
||||
|
||||
auto s2 = v2.get<Str>();
|
||||
REQUIRE(s2);
|
||||
CHECK(s2->value() == "hello");
|
||||
}
|
||||
|
||||
TEST_CASE("language_equality")
|
||||
{
|
||||
Value v1{I32{0}};
|
||||
Value v2{I32{0}};
|
||||
Value v3{I32{1}};
|
||||
Value v4{Bool{true}};
|
||||
Value v5{Add{EqSat::Id{0}, EqSat::Id{1}}};
|
||||
|
||||
CHECK(v1 == v2);
|
||||
CHECK(v2 != v3);
|
||||
CHECK(v3 != v4);
|
||||
CHECK(v4 != v5);
|
||||
}
|
||||
|
||||
TEST_CASE("language_is_mappable")
|
||||
{
|
||||
std::unordered_map<Value, int, Value::Hash> map;
|
||||
|
||||
Value v1{I32{5}};
|
||||
Value v2{I32{5}};
|
||||
Value v3{Bool{true}};
|
||||
Value v4{Add{EqSat::Id{0}, EqSat::Id{1}}};
|
||||
|
||||
map[v1] = 1;
|
||||
map[v2] = 2;
|
||||
map[v3] = 42;
|
||||
map[v4] = 37;
|
||||
|
||||
CHECK(map[v1] == 2);
|
||||
CHECK(map[v2] == 2);
|
||||
CHECK(map[v3] == 42);
|
||||
CHECK(map[v4] == 37);
|
||||
}
|
||||
|
||||
TEST_CASE("node_field")
|
||||
{
|
||||
EqSat::Id left{0};
|
||||
EqSat::Id right{1};
|
||||
|
||||
Add add{left, right};
|
||||
|
||||
EqSat::Id left2 = add.field<Left>();
|
||||
EqSat::Id right2 = add.field<Right>();
|
||||
|
||||
CHECK(left == left2);
|
||||
CHECK(left != right2);
|
||||
CHECK(right == right2);
|
||||
CHECK(right != left2);
|
||||
}
|
||||
|
||||
TEST_CASE("language_operands")
|
||||
{
|
||||
Value v1{I32{0}};
|
||||
CHECK(v1.operands().empty());
|
||||
|
||||
Value v2{Add{EqSat::Id{0}, EqSat::Id{1}}};
|
||||
const Add* add = v2.get<Add>();
|
||||
REQUIRE(add);
|
||||
|
||||
EqSat::Slice<EqSat::Id> actual = v2.operands();
|
||||
CHECK(actual.size() == 2);
|
||||
CHECK(actual[0] == add->field<Left>());
|
||||
CHECK(actual[1] == add->field<Right>());
|
||||
}
|
||||
|
||||
TEST_SUITE_END();
|
197
tests/EqSat.propositional.test.cpp
Normal file
197
tests/EqSat.propositional.test.cpp
Normal file
@ -0,0 +1,197 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#include <doctest.h>
|
||||
|
||||
#include "Luau/EGraph.h"
|
||||
#include "Luau/Id.h"
|
||||
#include "Luau/Language.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
LUAU_EQSAT_ATOM(Var, std::string);
|
||||
LUAU_EQSAT_ATOM(Bool, bool);
|
||||
LUAU_EQSAT_NODE_ARRAY(Not, 1);
|
||||
LUAU_EQSAT_NODE_ARRAY(And, 2);
|
||||
LUAU_EQSAT_NODE_ARRAY(Or, 2);
|
||||
LUAU_EQSAT_NODE_ARRAY(Implies, 2);
|
||||
|
||||
using namespace Luau;
|
||||
|
||||
using PropositionalLogic = EqSat::Language<Var, Bool, Not, And, Or, Implies>;
|
||||
|
||||
using EGraph = EqSat::EGraph<PropositionalLogic, struct ConstantFold>;
|
||||
|
||||
struct ConstantFold
|
||||
{
|
||||
using Data = std::optional<bool>;
|
||||
|
||||
Data make(const EGraph& egraph, const Var& var) const
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Data make(const EGraph& egraph, const Bool& b) const
|
||||
{
|
||||
return b.value();
|
||||
}
|
||||
|
||||
Data make(const EGraph& egraph, const Not& n) const
|
||||
{
|
||||
Data data = egraph[n[0]].data;
|
||||
if (data)
|
||||
return !*data;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Data make(const EGraph& egraph, const And& a) const
|
||||
{
|
||||
Data l = egraph[a[0]].data;
|
||||
Data r = egraph[a[1]].data;
|
||||
if (l && r)
|
||||
return *l && *r;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Data make(const EGraph& egraph, const Or& o) const
|
||||
{
|
||||
Data l = egraph[o[0]].data;
|
||||
Data r = egraph[o[1]].data;
|
||||
if (l && r)
|
||||
return *l || *r;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Data make(const EGraph& egraph, const Implies& i) const
|
||||
{
|
||||
Data antecedent = egraph[i[0]].data;
|
||||
Data consequent = egraph[i[1]].data;
|
||||
if (antecedent && consequent)
|
||||
return !*antecedent || *consequent;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void join(Data& a, const Data& b) const
|
||||
{
|
||||
if (!a && b)
|
||||
a = b;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_SUITE_BEGIN("EqSatPropositionalLogic");
|
||||
|
||||
TEST_CASE("egraph_hashconsing")
|
||||
{
|
||||
EGraph egraph;
|
||||
|
||||
EqSat::Id id1 = egraph.add(Bool{true});
|
||||
EqSat::Id id2 = egraph.add(Bool{true});
|
||||
EqSat::Id id3 = egraph.add(Bool{false});
|
||||
|
||||
CHECK(id1 == id2);
|
||||
CHECK(id2 != id3);
|
||||
}
|
||||
|
||||
TEST_CASE("egraph_data")
|
||||
{
|
||||
EGraph egraph;
|
||||
|
||||
EqSat::Id id1 = egraph.add(Bool{true});
|
||||
EqSat::Id id2 = egraph.add(Bool{false});
|
||||
|
||||
CHECK(egraph[id1].data == true);
|
||||
CHECK(egraph[id2].data == false);
|
||||
}
|
||||
|
||||
TEST_CASE("egraph_merge")
|
||||
{
|
||||
EGraph egraph;
|
||||
|
||||
EqSat::Id id1 = egraph.add(Var{"a"});
|
||||
EqSat::Id id2 = egraph.add(Bool{true});
|
||||
egraph.merge(id1, id2);
|
||||
|
||||
CHECK(egraph[id1].data == true);
|
||||
CHECK(egraph[id2].data == true);
|
||||
}
|
||||
|
||||
TEST_CASE("const_fold_true_and_true")
|
||||
{
|
||||
EGraph egraph;
|
||||
|
||||
EqSat::Id id1 = egraph.add(Bool{true});
|
||||
EqSat::Id id2 = egraph.add(Bool{true});
|
||||
EqSat::Id id3 = egraph.add(And{id1, id2});
|
||||
|
||||
CHECK(egraph[id3].data == true);
|
||||
}
|
||||
|
||||
TEST_CASE("const_fold_true_and_false")
|
||||
{
|
||||
EGraph egraph;
|
||||
|
||||
EqSat::Id id1 = egraph.add(Bool{true});
|
||||
EqSat::Id id2 = egraph.add(Bool{false});
|
||||
EqSat::Id id3 = egraph.add(And{id1, id2});
|
||||
|
||||
CHECK(egraph[id3].data == false);
|
||||
}
|
||||
|
||||
TEST_CASE("const_fold_false_and_false")
|
||||
{
|
||||
EGraph egraph;
|
||||
|
||||
EqSat::Id id1 = egraph.add(Bool{false});
|
||||
EqSat::Id id2 = egraph.add(Bool{false});
|
||||
EqSat::Id id3 = egraph.add(And{id1, id2});
|
||||
|
||||
CHECK(egraph[id3].data == false);
|
||||
}
|
||||
|
||||
TEST_CASE("implications")
|
||||
{
|
||||
EGraph egraph;
|
||||
|
||||
EqSat::Id t = egraph.add(Bool{true});
|
||||
EqSat::Id f = egraph.add(Bool{false});
|
||||
|
||||
EqSat::Id a = egraph.add(Implies{t, t}); // true
|
||||
EqSat::Id b = egraph.add(Implies{t, f}); // false
|
||||
EqSat::Id c = egraph.add(Implies{f, t}); // true
|
||||
EqSat::Id d = egraph.add(Implies{f, f}); // true
|
||||
|
||||
CHECK(egraph[a].data == true);
|
||||
CHECK(egraph[b].data == false);
|
||||
CHECK(egraph[c].data == true);
|
||||
CHECK(egraph[d].data == true);
|
||||
}
|
||||
|
||||
TEST_CASE("merge_x_and_y")
|
||||
{
|
||||
EGraph egraph;
|
||||
|
||||
EqSat::Id x = egraph.add(Var{"x"});
|
||||
EqSat::Id y = egraph.add(Var{"y"});
|
||||
|
||||
EqSat::Id a = egraph.add(Var{"a"});
|
||||
EqSat::Id ax = egraph.add(And{a, x});
|
||||
EqSat::Id ay = egraph.add(And{a, y});
|
||||
|
||||
egraph.merge(x, y); // [x y] [ax] [ay] [a]
|
||||
CHECK_EQ(egraph.size(), 4);
|
||||
CHECK_EQ(egraph.find(x), egraph.find(y));
|
||||
CHECK_NE(egraph.find(ax), egraph.find(ay));
|
||||
CHECK_NE(egraph.find(a), egraph.find(x));
|
||||
CHECK_NE(egraph.find(a), egraph.find(y));
|
||||
|
||||
egraph.rebuild(); // [x y] [ax ay] [a]
|
||||
CHECK_EQ(egraph.size(), 3);
|
||||
CHECK_EQ(egraph.find(x), egraph.find(y));
|
||||
CHECK_EQ(egraph.find(ax), egraph.find(ay));
|
||||
CHECK_NE(egraph.find(a), egraph.find(x));
|
||||
CHECK_NE(egraph.find(a), egraph.find(y));
|
||||
}
|
||||
|
||||
TEST_SUITE_END();
|
58
tests/EqSat.slice.test.cpp
Normal file
58
tests/EqSat.slice.test.cpp
Normal file
@ -0,0 +1,58 @@
|
||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#include <doctest.h>
|
||||
|
||||
#include "Luau/Slice.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
using namespace Luau;
|
||||
|
||||
TEST_SUITE_BEGIN("EqSatSlice");
|
||||
|
||||
TEST_CASE("slice_is_a_view_over_array")
|
||||
{
|
||||
std::array<int, 8> a{1, 2, 3, 4, 5, 6, 7, 8};
|
||||
|
||||
EqSat::Slice<int> slice{a};
|
||||
|
||||
CHECK(slice.data() == a.data());
|
||||
CHECK(slice.size() == a.size());
|
||||
|
||||
for (size_t i = 0; i < a.size(); ++i)
|
||||
{
|
||||
CHECK(slice[i] == a[i]);
|
||||
CHECK(&slice[i] == &a[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("slice_is_a_view_over_vector")
|
||||
{
|
||||
std::vector<int> vector{1, 2, 3, 4, 5, 6, 7, 8};
|
||||
|
||||
EqSat::Slice<int> slice{vector.data(), vector.size()};
|
||||
|
||||
CHECK(slice.data() == vector.data());
|
||||
CHECK(slice.size() == vector.size());
|
||||
|
||||
for (size_t i = 0; i < vector.size(); ++i)
|
||||
{
|
||||
CHECK(slice[i] == vector[i]);
|
||||
CHECK(&slice[i] == &vector[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("mutate_via_slice")
|
||||
{
|
||||
std::array<int, 2> a{1, 2};
|
||||
CHECK(a[0] == 1);
|
||||
CHECK(a[1] == 2);
|
||||
|
||||
EqSat::Slice<int> slice{a};
|
||||
slice[0] = 42;
|
||||
slice[1] = 37;
|
||||
|
||||
CHECK(a[0] == 42);
|
||||
CHECK(a[1] == 37);
|
||||
}
|
||||
|
||||
TEST_SUITE_END();
|
Loading…
Reference in New Issue
Block a user