diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index dd886a44..559ee119 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -13,10 +13,15 @@ struct UnionFind final { Id makeSet(); Id find(Id id) const; + Id find(Id id); void merge(Id a, Id b); private: std::vector parents; + std::vector ranks; + +private: + Id canonicalize(Id id) const; }; } // namespace Luau::EqSat diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 04d9ba74..5c01e968 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -10,10 +10,51 @@ Id UnionFind::makeSet() { Id id{parents.size()}; parents.push_back(id); + ranks.push_back(0); + return id; } Id UnionFind::find(Id id) const +{ + return canonicalize(id); +} + +Id UnionFind::find(Id id) +{ + Id set = canonicalize(id); + + // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. + while (id != parents[size_t(id)]) + { + // Note: we don't update the ranks here since a rank + // represents the upper bound on the maximum depth of a tree + Id parent = parents[size_t(id)]; + parents[size_t(id)] = set; + id = parent; + } + + return set; +} + +void UnionFind::merge(Id a, Id b) +{ + Id aSet = find(a); + Id bSet = find(b); + if (aSet == bSet) + return; + + // Ensure that the rank of set A is greater than the rank of set B + if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) + std::swap(aSet, bSet); + + parents[size_t(bSet)] = aSet; + + if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) + ranks[size_t(aSet)]++; +} + +Id UnionFind::canonicalize(Id id) const { LUAU_ASSERT(size_t(id) < parents.size()); @@ -24,12 +65,4 @@ Id UnionFind::find(Id id) const return id; } -void UnionFind::merge(Id a, Id b) -{ - LUAU_ASSERT(size_t(a) < parents.size()); - LUAU_ASSERT(size_t(b) < parents.size()); - - parents[size_t(b)] = a; -} - } // namespace Luau::EqSat