luau/EqSat/include/Luau/Language.h
2024-07-19 10:21:40 -07:00

298 lines
6.1 KiB
C++

// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Id.h"
#include "Luau/LanguageHash.h"
#include "Luau/Slice.h"
#include "Luau/Variant.h"
#include <array>
#include <type_traits>
#include <utility>
#define LUAU_EQSAT_ATOM(name, t) \
struct name : public ::Luau::EqSat::Atom<name, t> \
{ \
static constexpr const char* tag = #name; \
using Atom::Atom; \
}
#define LUAU_EQSAT_NODE_ARRAY(name, ops) \
struct name : public ::Luau::EqSat::NodeVector<name, std::array<::Luau::EqSat::Id, ops>> \
{ \
static constexpr const char* tag = #name; \
using NodeVector::NodeVector; \
}
#define LUAU_EQSAT_NODE_VECTOR(name) \
struct name : public ::Luau::EqSat::NodeVector<name, std::vector<::Luau::EqSat::Id>> \
{ \
static constexpr const char* tag = #name; \
using NodeVector::NodeVector; \
}
#define LUAU_EQSAT_FIELD(name) \
struct name : public ::Luau::EqSat::Field<name> \
{ \
}
#define LUAU_EQSAT_NODE_FIELDS(name, ...) \
struct name : public ::Luau::EqSat::NodeFields<name, __VA_ARGS__> \
{ \
static constexpr const char* tag = #name; \
using NodeFields::NodeFields; \
}
namespace Luau::EqSat
{
template<typename Phantom, typename T>
struct Atom
{
Atom(const T& value)
: _value(value)
{
}
const T& value() const
{
return _value;
}
public:
Slice<Id> operands()
{
return {};
}
Slice<const Id> operands() const
{
return {};
}
bool operator==(const Atom& rhs) const
{
return _value == rhs._value;
}
bool operator!=(const Atom& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const Atom& value) const
{
return languageHash(value._value);
}
};
private:
T _value;
};
template<typename Phantom, typename T>
struct NodeVector
{
template<typename... Args>
NodeVector(Args&&... args)
: vector{std::forward<Args>(args)...}
{
}
Id operator[](size_t i) const
{
return vector[i];
}
public:
Slice<Id> operands()
{
return Slice{vector.data(), vector.size()};
}
Slice<const Id> operands() const
{
return Slice{vector.data(), vector.size()};
}
bool operator==(const NodeVector& rhs) const
{
return vector == rhs.vector;
}
bool operator!=(const NodeVector& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const NodeVector& value) const
{
return languageHash(value.vector);
}
};
private:
T vector;
};
/// Empty base class just for static_asserts.
struct FieldBase
{
FieldBase() = delete;
FieldBase(FieldBase&&) = delete;
FieldBase& operator=(FieldBase&&) = delete;
FieldBase(const FieldBase&) = delete;
FieldBase& operator=(const FieldBase&) = delete;
};
template<typename Phantom>
struct Field : FieldBase
{
};
template<typename Phantom, typename... Fields>
struct NodeFields
{
static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value);
template<typename T>
static constexpr int getIndex()
{
constexpr int N = sizeof...(Fields);
constexpr bool is[N] = {std::is_same_v<std::decay_t<T>, Fields>...};
for (int i = 0; i < N; ++i)
if (is[i])
return i;
return -1;
}
public:
template<typename... Args>
NodeFields(Args&&... args)
: array{std::forward<Args>(args)...}
{
}
Slice<Id> operands()
{
return Slice{array};
}
Slice<const Id> operands() const
{
return Slice{array.data(), array.size()};
}
template<typename T>
Id field() const
{
static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Fields>...>);
return array[getIndex<T>()];
}
bool operator==(const NodeFields& rhs) const
{
return array == rhs.array;
}
bool operator!=(const NodeFields& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const NodeFields& value) const
{
return languageHash(value.array);
}
};
private:
std::array<Id, sizeof...(Fields)> array;
};
template<typename... Ts>
struct Language final
{
template<typename T>
using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>;
template<typename T>
Language(T&& t, std::enable_if_t<WithinDomain<T>::value>* = 0) noexcept
: v(std::forward<T>(t))
{
}
int index() const noexcept
{
return v.index();
}
/// You should never call this function with the intention of mutating the `Id`.
/// Reading is ok, but you should also never assume that these `Id`s are stable.
Slice<Id> operands() noexcept
{
return visit([](auto&& v) -> Slice<Id> {
return v.operands();
}, v);
}
Slice<const Id> operands() const noexcept
{
return visit([](auto&& v) -> Slice<const Id> {
return v.operands();
}, v);
}
template<typename T>
T* get() noexcept
{
static_assert(WithinDomain<T>::value);
return v.template get_if<T>();
}
template<typename T>
const T* get() const noexcept
{
static_assert(WithinDomain<T>::value);
return v.template get_if<T>();
}
bool operator==(const Language& rhs) const noexcept
{
return v == rhs.v;
}
bool operator!=(const Language& rhs) const noexcept
{
return !(*this == rhs);
}
public:
struct Hash
{
size_t operator()(const Language& language) const
{
size_t seed = std::hash<int>{}(language.index());
hashCombine(seed, visit([](auto&& v) {
return typename std::decay_t<decltype(v)>::Hash{}(v);
}, language.v));
return seed;
}
};
private:
Variant<Ts...> v;
};
} // namespace Luau::EqSat