Optimizations for UnionFind (#1334)

Implements ranks & path compression for union find.

---------

Co-authored-by: Alexander McCord <11488393+alexmccord@users.noreply.github.com>
This commit is contained in:
birds3345 2024-07-17 19:19:57 -04:00 committed by GitHub
parent 623e1e30db
commit 2874ca9e86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 8 deletions

View File

@ -13,10 +13,15 @@ struct UnionFind final
{ {
Id makeSet(); Id makeSet();
Id find(Id id) const; Id find(Id id) const;
Id find(Id id);
void merge(Id a, Id b); void merge(Id a, Id b);
private: private:
std::vector<Id> parents; std::vector<Id> parents;
std::vector<int> ranks;
private:
Id canonicalize(Id id) const;
}; };
} // namespace Luau::EqSat } // namespace Luau::EqSat

View File

@ -10,10 +10,51 @@ Id UnionFind::makeSet()
{ {
Id id{parents.size()}; Id id{parents.size()};
parents.push_back(id); parents.push_back(id);
ranks.push_back(0);
return id; return id;
} }
Id UnionFind::find(Id id) const Id UnionFind::find(Id id) const
{
return canonicalize(id);
}
Id UnionFind::find(Id id)
{
Id set = canonicalize(id);
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
while (id != parents[size_t(id)])
{
// Note: we don't update the ranks here since a rank
// represents the upper bound on the maximum depth of a tree
Id parent = parents[size_t(id)];
parents[size_t(id)] = set;
id = parent;
}
return set;
}
void UnionFind::merge(Id a, Id b)
{
Id aSet = find(a);
Id bSet = find(b);
if (aSet == bSet)
return;
// Ensure that the rank of set A is greater than the rank of set B
if (ranks[size_t(aSet)] < ranks[size_t(bSet)])
std::swap(aSet, bSet);
parents[size_t(bSet)] = aSet;
if (ranks[size_t(aSet)] == ranks[size_t(bSet)])
ranks[size_t(aSet)]++;
}
Id UnionFind::canonicalize(Id id) const
{ {
LUAU_ASSERT(size_t(id) < parents.size()); LUAU_ASSERT(size_t(id) < parents.size());
@ -24,12 +65,4 @@ Id UnionFind::find(Id id) const
return id; return id;
} }
void UnionFind::merge(Id a, Id b)
{
LUAU_ASSERT(size_t(a) < parents.size());
LUAU_ASSERT(size_t(b) < parents.size());
parents[size_t(b)] = a;
}
} // namespace Luau::EqSat } // namespace Luau::EqSat