Merge branch 'upstream' into merge

This commit is contained in:
Vighnesh 2024-05-10 09:21:07 -07:00
commit f172471b87
45 changed files with 1536 additions and 467 deletions

View File

@ -284,11 +284,13 @@ struct Constraint
std::vector<NotNull<Constraint>> dependencies;
DenseHashSet<TypeId> getFreeTypes() const;
DenseHashSet<TypeId> getMaybeMutatedFreeTypes() const;
};
using ConstraintPtr = std::unique_ptr<Constraint>;
bool isReferenceCountedType(const TypeId typ);
inline Constraint& asMutable(const Constraint& c)
{
return const_cast<Constraint&>(c);

View File

@ -242,6 +242,24 @@ struct ConstraintSolver
void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e);
/**
* Shifts the count of references from `source` to `target`. This should be paired
* with any instance of binding a free type in order to maintain accurate refcounts.
* If `target` is not a free type, this is a noop.
* @param source the free type which is being bound
* @param target the type which the free type is being bound to
*/
void shiftReferences(TypeId source, TypeId target);
/**
* Generalizes the given free type if the reference counting allows it.
* @param the scope to generalize in
* @param type the free type we want to generalize
* @returns a non-free type that generalizes the argument, or `std::nullopt` if one
* does not exist
*/
std::optional<TypeId> generalizeFreeType(NotNull<Scope> scope, TypeId type);
/**
* Checks the existing set of constraints to see if there exist any that contain
* the provided free type, indicating that it is not yet ready to be replaced by

View File

@ -307,6 +307,9 @@ struct NormalizedType
/// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString()
bool isSubtypeOfString() const;
/// Returns true if the type is a subtype of boolean(it could be a singleton). Behaves like Type::isBoolean()
bool isSubtypeOfBooleans() const;
/// Returns true if this type should result in error suppressing behavior.
bool shouldSuppressErrors() const;
@ -360,7 +363,6 @@ public:
Normalizer& operator=(Normalizer&) = delete;
// If this returns null, the typechecker should emit a "too complex" error
const NormalizedType* DEPRECATED_normalize(TypeId ty);
std::shared_ptr<const NormalizedType> normalize(TypeId ty);
void clearNormal(NormalizedType& norm);
@ -395,7 +397,7 @@ public:
TypeId negate(TypeId there);
void subtractPrimitive(NormalizedType& here, TypeId ty);
void subtractSingleton(NormalizedType& here, TypeId ty);
NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect, bool useDeprecated = false);
NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect);
// ------- Normalizing intersections
TypeId intersectionOfTops(TypeId here, TypeId there);
@ -404,8 +406,8 @@ public:
void intersectClassesWithClass(NormalizedClassType& heres, TypeId there);
void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there);
std::optional<TypePackId> intersectionOfTypePacks(TypePackId here, TypePackId there);
std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there);
void intersectTablesWithTable(TypeIds& heres, TypeId there);
std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet);
void intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes);
void intersectTables(TypeIds& heres, const TypeIds& theres);
std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there);
void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
@ -413,7 +415,7 @@ public:
NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes);
NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes);
NormalizationResult normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType);
NormalizationResult normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet);
// Check for inhabitance
NormalizationResult isInhabited(TypeId ty);
@ -423,6 +425,7 @@ public:
// Check for intersections being inhabited
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right);
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet);
// -------- Convert back from a normalized type to a type
TypeId typeFromNormal(const NormalizedType& norm);

View File

@ -4,7 +4,6 @@
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
LUAU_FASTFLAG(LuauFixSetIter)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau
@ -142,13 +141,10 @@ public:
const_iterator(typename Impl::const_iterator impl_, typename Impl::const_iterator end_)
: impl(impl_)
, end(end_)
{
if (FFlag::LuauFixSetIter || FFlag::DebugLuauDeferredConstraintResolution)
{
while (impl != end && impl->second == false)
++impl;
}
}
const T& operator*() const
{

View File

@ -78,6 +78,11 @@ struct Unifier2
bool unify(TableType* subTable, const TableType* superTable);
bool unify(const MetatableType* subMetatable, const MetatableType* superMetatable);
bool unify(const AnyType* subAny, const FunctionType* superFn);
bool unify(const FunctionType* subFn, const AnyType* superAny);
bool unify(const AnyType* subAny, const TableType* superTable);
bool unify(const TableType* subTable, const AnyType* superAny);
// TODO think about this one carefully. We don't do unions or intersections of type packs
bool unify(TypePackId subTp, TypePackId superTp);

View File

@ -13,12 +13,12 @@ Constraint::Constraint(NotNull<Scope> scope, const Location& location, Constrain
{
}
struct FreeTypeCollector : TypeOnceVisitor
struct ReferenceCountInitializer : TypeOnceVisitor
{
DenseHashSet<TypeId>* result;
FreeTypeCollector(DenseHashSet<TypeId>* result)
ReferenceCountInitializer(DenseHashSet<TypeId>* result)
: result(result)
{
}
@ -29,6 +29,18 @@ struct FreeTypeCollector : TypeOnceVisitor
return false;
}
bool visit(TypeId ty, const BlockedType&) override
{
result->insert(ty);
return false;
}
bool visit(TypeId ty, const PendingExpansionType&) override
{
result->insert(ty);
return false;
}
bool visit(TypeId ty, const ClassType&) override
{
// ClassTypes never contain free types.
@ -36,26 +48,92 @@ struct FreeTypeCollector : TypeOnceVisitor
}
};
DenseHashSet<TypeId> Constraint::getFreeTypes() const
bool isReferenceCountedType(const TypeId typ)
{
// n.b. this should match whatever `ReferenceCountInitializer` includes.
return get<FreeType>(typ) || get<BlockedType>(typ) || get<PendingExpansionType>(typ);
}
DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
{
DenseHashSet<TypeId> types{{}};
FreeTypeCollector ftc{&types};
ReferenceCountInitializer rci{&types};
if (auto sc = get<SubtypeConstraint>(*this))
if (auto ec = get<EqualityConstraint>(*this))
{
ftc.traverse(sc->subType);
ftc.traverse(sc->superType);
rci.traverse(ec->resultType);
// `EqualityConstraints` should not mutate `assignmentType`.
}
else if (auto sc = get<SubtypeConstraint>(*this))
{
rci.traverse(sc->subType);
rci.traverse(sc->superType);
}
else if (auto psc = get<PackSubtypeConstraint>(*this))
{
ftc.traverse(psc->subPack);
ftc.traverse(psc->superPack);
rci.traverse(psc->subPack);
rci.traverse(psc->superPack);
}
else if (auto gc = get<GeneralizationConstraint>(*this))
{
rci.traverse(gc->generalizedType);
// `GeneralizationConstraints` should not mutate `sourceType` or `interiorTypes`.
}
else if (auto itc = get<IterableConstraint>(*this))
{
rci.traverse(itc->variables);
// `IterableConstraints` should not mutate `iterator`.
}
else if (auto nc = get<NameConstraint>(*this))
{
rci.traverse(nc->namedType);
}
else if (auto taec = get<TypeAliasExpansionConstraint>(*this))
{
rci.traverse(taec->target);
}
else if (auto ptc = get<PrimitiveTypeConstraint>(*this))
{
// we need to take into account primitive type constraints to prevent type families from reducing on
// primitive whose types we have not yet selected to be singleton or not.
ftc.traverse(ptc->freeType);
rci.traverse(ptc->freeType);
}
else if (auto hpc = get<HasPropConstraint>(*this))
{
rci.traverse(hpc->resultType);
// `HasPropConstraints` should not mutate `subjectType`.
}
else if (auto spc = get<SetPropConstraint>(*this))
{
rci.traverse(spc->resultType);
// `SetPropConstraints` should not mutate `subjectType` or `propType`.
// TODO: is this true? it "unifies" with `propType`, so maybe mutates that one too?
}
else if (auto hic = get<HasIndexerConstraint>(*this))
{
rci.traverse(hic->resultType);
// `HasIndexerConstraint` should not mutate `subjectType` or `indexType`.
}
else if (auto sic = get<SetIndexerConstraint>(*this))
{
rci.traverse(sic->propType);
// `SetIndexerConstraints` should not mutate `subjectType` or `indexType`.
}
else if (auto uc = get<UnpackConstraint>(*this))
{
rci.traverse(uc->resultPack);
// `UnpackConstraint` should not mutate `sourcePack`.
}
else if (auto u1c = get<Unpack1Constraint>(*this))
{
rci.traverse(u1c->resultType);
// `Unpack1Constraint` should not mutate `sourceType`.
}
else if (auto rc = get<ReduceConstraint>(*this))
{
rci.traverse(rc->ty);
}
else if (auto rpc = get<ReducePackConstraint>(*this))
{
rci.traverse(rpc->tp);
}
return types;

View File

@ -27,6 +27,7 @@
#include <utility>
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false);
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies, false)
LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings, false);
LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500);
@ -251,6 +252,15 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts)
auto it = cs->blockedConstraints.find(c);
int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second);
printf("\t%d\t%s\n", blockCount, toString(*c, opts).c_str());
if (FFlag::DebugLuauLogSolverIncludeDependencies)
{
for (NotNull<Constraint> dep : c->dependencies)
{
if (std::find(cs->unsolvedConstraints.begin(), cs->unsolvedConstraints.end(), dep) != cs->unsolvedConstraints.end())
printf("\t\t|\t%s\n", toString(*dep, opts).c_str());
}
}
}
}
@ -305,7 +315,7 @@ ConstraintSolver::ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope
unsolvedConstraints.push_back(c);
// initialize the reference counts for the free types in this constraint.
for (auto ty : c->getFreeTypes())
for (auto ty : c->getMaybeMutatedFreeTypes())
{
// increment the reference count for `ty`
auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0);
@ -394,7 +404,7 @@ void ConstraintSolver::run()
unsolvedConstraints.erase(unsolvedConstraints.begin() + i);
// decrement the referenced free types for this constraint if we dispatched successfully!
for (auto ty : c->getFreeTypes())
for (auto ty : c->getMaybeMutatedFreeTypes())
{
// this is a little weird, but because we're only counting free types in subtyping constraints,
// some constraints (like unpack) might actually produce _more_ references to a free type.
@ -720,8 +730,6 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull<const Constr
{
// nothing (yet)
}
else
return block(c.namedType, constraint);
return true;
}
@ -771,6 +779,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
auto bindResult = [this, &c, constraint](TypeId result) {
LUAU_ASSERT(get<PendingExpansionType>(c.target));
shiftReferences(c.target, result);
emplaceType<BoundType>(asMutable(c.target), result);
unblock(c.target, constraint->location);
};
@ -1190,6 +1199,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
{
if (!lambdaExpr->args.data[j]->annotation && get<FreeType>(follow(lambdaArgTys[j])))
{
shiftReferences(lambdaArgTys[j], expectedLambdaArgTys[j]);
emplaceType<BoundType>(asMutable(lambdaArgTys[j]), expectedLambdaArgTys[j]);
}
}
@ -1242,6 +1252,7 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull<con
else if (expectedType && maybeSingleton(*expectedType))
bindTo = freeType->lowerBound;
shiftReferences(c.freeType, bindTo);
emplaceType<BoundType>(asMutable(c.freeType), bindTo);
return true;
@ -1551,7 +1562,11 @@ bool ConstraintSolver::tryDispatchHasIndexer(
if (0 == results.size())
emplaceType<BoundType>(asMutable(resultType), builtinTypes->errorType);
else if (1 == results.size())
emplaceType<BoundType>(asMutable(resultType), *results.begin());
{
TypeId firstResult = *results.begin();
shiftReferences(resultType, firstResult);
emplaceType<BoundType>(asMutable(resultType), firstResult);
}
else
emplaceType<UnionType>(asMutable(resultType), std::vector(results.begin(), results.end()));
@ -1716,7 +1731,10 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull<const Constraint> constraint,
--lt->blockCount;
if (0 == lt->blockCount)
{
shiftReferences(ty, lt->domain);
emplaceType<BoundType>(asMutable(ty), lt->domain);
}
};
if (auto ut = get<UnionType>(resultTy))
@ -1732,6 +1750,7 @@ bool ConstraintSolver::tryDispatchUnpack1(NotNull<const Constraint> constraint,
// constitute any meaningful constraint, so we replace it
// with a free type.
TypeId f = freshType(arena, builtinTypes, constraint->scope);
shiftReferences(resultTy, f);
emplaceType<BoundType>(asMutable(resultTy), f);
}
else
@ -1798,8 +1817,11 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull<const Cons
--lt->blockCount;
if (0 == lt->blockCount)
{
shiftReferences(resultTy, lt->domain);
emplaceType<BoundType>(asMutable(resultTy), lt->domain);
}
}
else if (get<BlockedType>(resultTy) || get<PendingExpansionType>(resultTy))
{
emplaceType<BoundType>(asMutable(resultTy), builtinTypes->nilType);
@ -1977,10 +1999,13 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
LUAU_ASSERT(0 <= lt->blockCount);
if (0 == lt->blockCount)
{
shiftReferences(ty, lt->domain);
emplaceType<BoundType>(asMutable(ty), lt->domain);
}
}
}
}
else
unpack(builtinTypes->errorType);
}
@ -2395,11 +2420,16 @@ void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId
LUAU_ASSERT(freeScope);
emplaceType<BoundType>(asMutable(blockedTy), arena->freshType(freeScope));
TypeId freeType = arena->freshType(freeScope);
shiftReferences(blockedTy, freeType);
emplaceType<BoundType>(asMutable(blockedTy), freeType);
}
else
{
shiftReferences(blockedTy, resultTy);
emplaceType<BoundType>(asMutable(blockedTy), resultTy);
}
}
bool ConstraintSolver::block_(BlockedConstraintId target, NotNull<const Constraint> constraint)
{
@ -2700,10 +2730,43 @@ void ConstraintSolver::reportError(TypeError e)
errors.back().moduleName = currentModuleName;
}
void ConstraintSolver::shiftReferences(TypeId source, TypeId target)
{
target = follow(target);
// if the target isn't a reference counted type, there's nothing to do.
// this stops us from keeping unnecessary counts for e.g. primitive types.
if (!isReferenceCountedType(target))
return;
auto sourceRefs = unresolvedConstraints.find(source);
if (!sourceRefs)
return;
// we read out the count before proceeding to avoid hash invalidation issues.
size_t count = *sourceRefs;
auto [targetRefs, _] = unresolvedConstraints.try_insert(target, 0);
targetRefs += count;
}
std::optional<TypeId> ConstraintSolver::generalizeFreeType(NotNull<Scope> scope, TypeId type)
{
if (get<FreeType>(type))
{
auto refCount = unresolvedConstraints.find(type);
if (!refCount || *refCount > 1)
return {};
}
Unifier2 u2{NotNull{arena}, builtinTypes, scope, NotNull{&iceReporter}};
return u2.generalize(type);
}
bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty)
{
if (auto refCount = unresolvedConstraints.find(ty))
return *refCount > 0;
return *refCount > 1;
return false;
}

View File

@ -1297,6 +1297,30 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
result->type = sourceModule.type;
result->upperBoundContributors = std::move(cs.upperBoundContributors);
if (result->timeout || result->cancelled)
{
// If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending
// types
ScopePtr moduleScope = result->getModuleScope();
moduleScope->returnType = builtinTypes->errorRecoveryTypePack();
for (auto& [name, ty] : result->declaredGlobals)
ty = builtinTypes->errorRecoveryType();
for (auto& [name, tf] : result->exportedTypeBindings)
tf.type = builtinTypes->errorRecoveryType();
}
else
{
if (mode == Mode::Nonstrict)
Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get());
else
Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get());
}
unfreeze(result->interfaceTypes);
result->clonePublicInterface(builtinTypes, *iceHandler);
if (FFlag::DebugLuauForbidInternalTypes)
{
InternalTypeFinder finder;
@ -1325,30 +1349,6 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
finder.traverse(tp);
}
if (result->timeout || result->cancelled)
{
// If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending
// types
ScopePtr moduleScope = result->getModuleScope();
moduleScope->returnType = builtinTypes->errorRecoveryTypePack();
for (auto& [name, ty] : result->declaredGlobals)
ty = builtinTypes->errorRecoveryType();
for (auto& [name, tf] : result->exportedTypeBindings)
tf.type = builtinTypes->errorRecoveryType();
}
else
{
if (mode == Mode::Nonstrict)
Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get());
else
Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get());
}
unfreeze(result->interfaceTypes);
result->clonePublicInterface(builtinTypes, *iceHandler);
// It would be nice if we could freeze the arenas before doing type
// checking, but we'll have to do some work to get there.
//

View File

@ -17,21 +17,16 @@
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false)
LUAU_FASTFLAGVARIABLE(LuauFixNormalizeCaching, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false);
LUAU_FASTFLAGVARIABLE(LuauFixCyclicUnionsOfIntersections, false);
LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false);
LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false);
// This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
static bool fixNormalizeCaching()
{
return FFlag::LuauFixNormalizeCaching || FFlag::DebugLuauDeferredConstraintResolution;
}
static bool fixCyclicUnionsOfIntersections()
{
return FFlag::LuauFixCyclicUnionsOfIntersections || FFlag::DebugLuauDeferredConstraintResolution;
@ -42,6 +37,11 @@ static bool fixReduceStackPressure()
return FFlag::LuauFixReduceStackPressure || FFlag::DebugLuauDeferredConstraintResolution;
}
static bool fixCyclicTablesBlowingStack()
{
return FFlag::LuauFixCyclicTablesBlowingStack || FFlag::DebugLuauDeferredConstraintResolution;
}
namespace Luau
{
@ -353,6 +353,12 @@ bool NormalizedType::isSubtypeOfString() const
!hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars();
}
bool NormalizedType::isSubtypeOfBooleans() const
{
return hasBooleans() && !hasTops() && !hasClasses() && !hasErrors() && !hasNils() && !hasNumbers() && !hasStrings() && !hasThreads() &&
!hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars();
}
bool NormalizedType::shouldSuppressErrors() const
{
return hasErrors() || get<AnyType>(tops);
@ -561,22 +567,21 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set<TypeId>& seen)
return isInhabited(mtv->metatable, seen);
}
if (fixNormalizeCaching())
{
std::shared_ptr<const NormalizedType> norm = normalize(ty);
return isInhabited(norm.get(), seen);
}
else
{
const NormalizedType* norm = DEPRECATED_normalize(ty);
return isInhabited(norm, seen);
}
}
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right)
{
Set<TypeId> seen{nullptr};
return isIntersectionInhabited(left, right, seen);
}
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet)
{
left = follow(left);
right = follow(right);
// We're asking if intersection is inahbited between left and right but we've already seen them ....
if (cacheInhabitance)
{
@ -584,12 +589,8 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ
return *result ? NormalizationResult::True : NormalizationResult::False;
}
Set<TypeId> seen{nullptr};
seen.insert(left);
seen.insert(right);
NormalizedType norm{builtinTypes};
NormalizationResult res = normalizeIntersections({left, right}, norm);
NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet);
if (res != NormalizationResult::True)
{
if (cacheInhabitance && res == NormalizationResult::False)
@ -598,7 +599,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ
return res;
}
NormalizationResult result = isInhabited(&norm, seen);
NormalizationResult result = isInhabited(&norm, seenSet);
if (cacheInhabitance && result == NormalizationResult::True)
cachedIsInhabitedIntersection[{left, right}] = true;
@ -870,31 +871,6 @@ Normalizer::Normalizer(TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, Not
{
}
const NormalizedType* Normalizer::DEPRECATED_normalize(TypeId ty)
{
if (!arena)
sharedState->iceHandler->ice("Normalizing types outside a module");
auto found = cachedNormals.find(ty);
if (found != cachedNormals.end())
return found->second.get();
NormalizedType norm{builtinTypes};
Set<TypeId> seenSetTypes{nullptr};
NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes);
if (res != NormalizationResult::True)
return nullptr;
if (norm.isUnknown())
{
clearNormal(norm);
norm.tops = builtinTypes->unknownType;
}
std::shared_ptr<NormalizedType> shared = std::make_shared<NormalizedType>(std::move(norm));
const NormalizedType* result = shared.get();
cachedNormals[ty] = std::move(shared);
return result;
}
static bool isCacheable(TypeId ty, Set<TypeId>& seen);
static bool isCacheable(TypePackId tp, Set<TypeId>& seen)
@ -949,9 +925,6 @@ static bool isCacheable(TypeId ty, Set<TypeId>& seen)
static bool isCacheable(TypeId ty)
{
if (!fixNormalizeCaching())
return true;
Set<TypeId> seen{nullptr};
return isCacheable(ty, seen);
}
@ -985,7 +958,7 @@ std::shared_ptr<const NormalizedType> Normalizer::normalize(TypeId ty)
return shared;
}
NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType)
NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet)
{
if (!arena)
sharedState->iceHandler->ice("Normalizing types outside a module");
@ -995,7 +968,7 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>
Set<TypeId> seenSetTypes{nullptr};
for (auto ty : intersections)
{
NormalizationResult res = intersectNormalWithTy(norm, ty, seenSetTypes);
NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet);
if (res != NormalizationResult::True)
return res;
}
@ -1743,20 +1716,13 @@ bool Normalizer::withinResourceLimits()
return true;
}
NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect, bool useDeprecated)
NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect)
{
std::optional<NormalizedType> negated;
if (useDeprecated)
{
const NormalizedType* normal = DEPRECATED_normalize(toNegate);
negated = negateNormal(*normal);
}
else
{
std::shared_ptr<const NormalizedType> normal = normalize(toNegate);
negated = negateNormal(*normal);
}
if (!negated)
return NormalizationResult::False;
@ -1911,16 +1877,8 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
{
std::optional<NormalizedType> tn;
if (fixNormalizeCaching())
{
std::shared_ptr<const NormalizedType> thereNormal = normalize(ntv->ty);
tn = negateNormal(*thereNormal);
}
else
{
const NormalizedType* thereNormal = DEPRECATED_normalize(ntv->ty);
tn = negateNormal(*thereNormal);
}
if (!tn)
return NormalizationResult::False;
@ -2519,7 +2477,7 @@ std::optional<TypePackId> Normalizer::intersectionOfTypePacks(TypePackId here, T
return arena->addTypePack({});
}
std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there)
std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet)
{
if (here == there)
return here;
@ -2600,8 +2558,33 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
// if the intersection of the read types of a property is uninhabited, the whole table is `never`.
if (fixReduceStackPressure())
{
if (normalizeAwayUninhabitableTables() &&
NormalizationResult::True != isIntersectionInhabited(*hprop.readTy, *tprop.readTy))
// We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited
if (fixCyclicTablesBlowingStack())
{
if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy))
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
return {builtinTypes->neverType};
}
else
{
seenSet.insert(*hprop.readTy);
seenSet.insert(*tprop.readTy);
}
}
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenSet);
// Cleanup
if (fixCyclicTablesBlowingStack())
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
}
if (normalizeAwayUninhabitableTables() && NormalizationResult::True != res)
return {builtinTypes->neverType};
}
else
@ -2720,7 +2703,7 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
if (tmtable && hmtable)
{
// NOTE: this assumes metatables are ivariant
if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable))
if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable, seenSet))
{
if (table == htable && *mtable == hmtable)
return here;
@ -2750,12 +2733,12 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
return table;
}
void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there)
void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes)
{
TypeIds tmp;
for (TypeId here : heres)
{
if (std::optional<TypeId> inter = intersectionOfTables(here, there))
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes))
tmp.insert(*inter);
}
heres.retain(tmp);
@ -2769,7 +2752,8 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres)
{
for (TypeId there : theres)
{
if (std::optional<TypeId> inter = intersectionOfTables(here, there))
Set<TypeId> seenSetTypes{nullptr};
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes))
tmp.insert(*inter);
}
}
@ -3137,7 +3121,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{
TypeIds tables = std::move(here.tables);
clearNormal(here);
intersectTablesWithTable(tables, there);
intersectTablesWithTable(tables, there, seenSetTypes);
here.tables = std::move(tables);
}
else if (get<ClassType>(there))
@ -3210,23 +3194,12 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
else if (const SingletonType* stv = get<SingletonType>(t))
subtractSingleton(here, follow(ntv->ty));
else if (get<ClassType>(t))
{
if (fixNormalizeCaching())
{
NormalizationResult res = intersectNormalWithNegationTy(t, here);
if (shouldEarlyExit(res))
return res;
}
else
{
NormalizationResult res = intersectNormalWithNegationTy(t, here, /* useDeprecated */ true);
if (shouldEarlyExit(res))
return res;
}
}
else if (const UnionType* itv = get<UnionType>(t))
{
if (fixNormalizeCaching())
{
for (TypeId part : itv->options)
{
@ -3235,28 +3208,6 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
return res;
}
}
else
{
if (fixNormalizeCaching())
{
for (TypeId part : itv->options)
{
NormalizationResult res = intersectNormalWithNegationTy(part, here);
if (shouldEarlyExit(res))
return res;
}
}
else
{
for (TypeId part : itv->options)
{
NormalizationResult res = intersectNormalWithNegationTy(part, here, /* useDeprecated */ true);
if (shouldEarlyExit(res))
return res;
}
}
}
}
else if (get<AnyType>(t))
{
// HACK: Refinements sometimes intersect with ~any under the

View File

@ -1,5 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Common.h"
LUAU_FASTFLAGVARIABLE(LuauFixSetIter, false)

View File

@ -20,7 +20,6 @@
#include <string>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauToStringiteTypesSingleLine, false)
/*
* Enables increasing levels of verbosity for Luau type names when stringifying.

View File

@ -500,6 +500,15 @@ TypeFamilyReductionResult<TypeId> lenFamilyFn(TypeId instance, NotNull<TypeFamil
if (isPending(operandTy, ctx->solver) || get<LocalType>(operandTy))
return {std::nullopt, false, {operandTy}, {}};
// if the type is free but has only one remaining reference, we can generalize it to its upper bound here.
if (ctx->solver)
{
std::optional<TypeId> maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy);
if (!maybeGeneralized)
return {std::nullopt, false, {operandTy}, {}};
operandTy = *maybeGeneralized;
}
std::shared_ptr<const NormalizedType> normTy = ctx->normalizer->normalize(operandTy);
NormalizationResult inhabited = ctx->normalizer->isInhabited(normTy.get());
@ -576,6 +585,15 @@ TypeFamilyReductionResult<TypeId> unmFamilyFn(TypeId instance, NotNull<TypeFamil
if (isPending(operandTy, ctx->solver))
return {std::nullopt, false, {operandTy}, {}};
// if the type is free but has only one remaining reference, we can generalize it to its upper bound here.
if (ctx->solver)
{
std::optional<TypeId> maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy);
if (!maybeGeneralized)
return {std::nullopt, false, {operandTy}, {}};
operandTy = *maybeGeneralized;
}
std::shared_ptr<const NormalizedType> normTy = ctx->normalizer->normalize(operandTy);
// if the operand failed to normalize, we can't reduce, but know nothing about inhabitance.
@ -674,6 +692,21 @@ TypeFamilyReductionResult<TypeId> numericBinopFamilyFn(TypeId instance, NotNull<
else if (isPending(rhsTy, ctx->solver))
return {std::nullopt, false, {rhsTy}, {}};
// if either type is free but has only one remaining reference, we can generalize it to its upper bound here.
if (ctx->solver)
{
std::optional<TypeId> lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy);
std::optional<TypeId> rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy);
if (!lhsMaybeGeneralized)
return {std::nullopt, false, {lhsTy}, {}};
else if (!rhsMaybeGeneralized)
return {std::nullopt, false, {rhsTy}, {}};
lhsTy = *lhsMaybeGeneralized;
rhsTy = *rhsMaybeGeneralized;
}
// TODO: Normalization needs to remove cyclic type families from a `NormalizedType`.
std::shared_ptr<const NormalizedType> normLhsTy = ctx->normalizer->normalize(lhsTy);
std::shared_ptr<const NormalizedType> normRhsTy = ctx->normalizer->normalize(rhsTy);
@ -895,6 +928,21 @@ TypeFamilyReductionResult<TypeId> concatFamilyFn(TypeId instance, NotNull<TypeFa
else if (isPending(rhsTy, ctx->solver))
return {std::nullopt, false, {rhsTy}, {}};
// if either type is free but has only one remaining reference, we can generalize it to its upper bound here.
if (ctx->solver)
{
std::optional<TypeId> lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy);
std::optional<TypeId> rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy);
if (!lhsMaybeGeneralized)
return {std::nullopt, false, {lhsTy}, {}};
else if (!rhsMaybeGeneralized)
return {std::nullopt, false, {rhsTy}, {}};
lhsTy = *lhsMaybeGeneralized;
rhsTy = *rhsMaybeGeneralized;
}
std::shared_ptr<const NormalizedType> normLhsTy = ctx->normalizer->normalize(lhsTy);
std::shared_ptr<const NormalizedType> normRhsTy = ctx->normalizer->normalize(rhsTy);
@ -982,13 +1030,27 @@ TypeFamilyReductionResult<TypeId> andFamilyFn(TypeId instance, NotNull<TypeFamil
if (follow(lhsTy) == instance && lhsTy != rhsTy)
return {rhsTy, false, {}, {}};
// check to see if both operand types are resolved enough, and wait to reduce if not
if (isPending(lhsTy, ctx->solver))
return {std::nullopt, false, {lhsTy}, {}};
else if (isPending(rhsTy, ctx->solver))
return {std::nullopt, false, {rhsTy}, {}};
// if either type is free but has only one remaining reference, we can generalize it to its upper bound here.
if (ctx->solver)
{
std::optional<TypeId> lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy);
std::optional<TypeId> rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy);
if (!lhsMaybeGeneralized)
return {std::nullopt, false, {lhsTy}, {}};
else if (!rhsMaybeGeneralized)
return {std::nullopt, false, {rhsTy}, {}};
lhsTy = *lhsMaybeGeneralized;
rhsTy = *rhsMaybeGeneralized;
}
// And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is truthy.
SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->falsyType);
SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result);
@ -1025,6 +1087,21 @@ TypeFamilyReductionResult<TypeId> orFamilyFn(TypeId instance, NotNull<TypeFamily
else if (isPending(rhsTy, ctx->solver))
return {std::nullopt, false, {rhsTy}, {}};
// if either type is free but has only one remaining reference, we can generalize it to its upper bound here.
if (ctx->solver)
{
std::optional<TypeId> lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy);
std::optional<TypeId> rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy);
if (!lhsMaybeGeneralized)
return {std::nullopt, false, {lhsTy}, {}};
else if (!rhsMaybeGeneralized)
return {std::nullopt, false, {rhsTy}, {}};
lhsTy = *lhsMaybeGeneralized;
rhsTy = *rhsMaybeGeneralized;
}
// Or evalutes to the LHS type if the LHS is truthy, and the RHS type if LHS is falsy.
SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->truthyType);
SimplifyResult overallResult = simplifyUnion(ctx->builtins, ctx->arena, rhsTy, filteredLhs.result);
@ -1088,6 +1165,21 @@ static TypeFamilyReductionResult<TypeId> comparisonFamilyFn(TypeId instance, Not
lhsTy = follow(lhsTy);
rhsTy = follow(rhsTy);
// if either type is free but has only one remaining reference, we can generalize it to its upper bound here.
if (ctx->solver)
{
std::optional<TypeId> lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy);
std::optional<TypeId> rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy);
if (!lhsMaybeGeneralized)
return {std::nullopt, false, {lhsTy}, {}};
else if (!rhsMaybeGeneralized)
return {std::nullopt, false, {rhsTy}, {}};
lhsTy = *lhsMaybeGeneralized;
rhsTy = *rhsMaybeGeneralized;
}
// check to see if both operand types are resolved enough, and wait to reduce if not
std::shared_ptr<const NormalizedType> normLhsTy = ctx->normalizer->normalize(lhsTy);
@ -1196,6 +1288,21 @@ TypeFamilyReductionResult<TypeId> eqFamilyFn(TypeId instance, NotNull<TypeFamily
else if (isPending(rhsTy, ctx->solver))
return {std::nullopt, false, {rhsTy}, {}};
// if either type is free but has only one remaining reference, we can generalize it to its upper bound here.
if (ctx->solver)
{
std::optional<TypeId> lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy);
std::optional<TypeId> rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy);
if (!lhsMaybeGeneralized)
return {std::nullopt, false, {lhsTy}, {}};
else if (!rhsMaybeGeneralized)
return {std::nullopt, false, {rhsTy}, {}};
lhsTy = *lhsMaybeGeneralized;
rhsTy = *rhsMaybeGeneralized;
}
std::shared_ptr<const NormalizedType> normLhsTy = ctx->normalizer->normalize(lhsTy);
std::shared_ptr<const NormalizedType> normRhsTy = ctx->normalizer->normalize(rhsTy);
NormalizationResult lhsInhabited = ctx->normalizer->isInhabited(normLhsTy.get());
@ -1223,10 +1330,25 @@ TypeFamilyReductionResult<TypeId> eqFamilyFn(TypeId instance, NotNull<TypeFamily
// if neither type has a metatable entry for `__eq`, then we'll check for inhabitance of the intersection!
NormalizationResult intersectInhabited = ctx->normalizer->isIntersectionInhabited(lhsTy, rhsTy);
if (!mmType && intersectInhabited == NormalizationResult::True)
if (!mmType)
{
if (intersectInhabited == NormalizationResult::True)
return {ctx->builtins->booleanType, false, {}, {}}; // if it's inhabited, everything is okay!
else if (!mmType)
// we might be in a case where we still want to accept the comparison...
if (intersectInhabited == NormalizationResult::False)
{
// if they're both subtypes of `string` but have no common intersection, the comparison is allowed but always `false`.
if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString())
return {ctx->builtins->falseType, false, {}, {}};
// if they're both subtypes of `boolean` but have no common intersection, the comparison is allowed but always `false`.
if (normLhsTy->isSubtypeOfBooleans() && normRhsTy->isSubtypeOfBooleans())
return {ctx->builtins->falseType, false, {}, {}};
}
return {std::nullopt, true, {}, {}}; // if it's not, then this family is irreducible!
}
mmType = follow(*mmType);
if (isPending(*mmType, ctx->solver))
@ -1303,6 +1425,21 @@ TypeFamilyReductionResult<TypeId> refineFamilyFn(TypeId instance, NotNull<TypeFa
else if (isPending(discriminantTy, ctx->solver))
return {std::nullopt, false, {discriminantTy}, {}};
// if either type is free but has only one remaining reference, we can generalize it to its upper bound here.
if (ctx->solver)
{
std::optional<TypeId> targetMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, targetTy);
std::optional<TypeId> discriminantMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, discriminantTy);
if (!targetMaybeGeneralized)
return {std::nullopt, false, {targetTy}, {}};
else if (!discriminantMaybeGeneralized)
return {std::nullopt, false, {discriminantTy}, {}};
targetTy = *targetMaybeGeneralized;
discriminantTy = *discriminantMaybeGeneralized;
}
// we need a more complex check for blocking on the discriminant in particular
FindRefinementBlockers frb;
frb.traverse(discriminantTy);
@ -1358,6 +1495,15 @@ TypeFamilyReductionResult<TypeId> singletonFamilyFn(TypeId instance, NotNull<Typ
if (isPending(type, ctx->solver))
return {std::nullopt, false, {type}, {}};
// if the type is free but has only one remaining reference, we can generalize it to its upper bound here.
if (ctx->solver)
{
std::optional<TypeId> maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, type);
if (!maybeGeneralized)
return {std::nullopt, false, {type}, {}};
type = *maybeGeneralized;
}
TypeId followed = type;
// we want to follow through a negation here as well.
if (auto negation = get<NegationType>(followed))

View File

@ -39,7 +39,6 @@ LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)
LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false)
LUAU_FASTFLAGVARIABLE(LuauForbidAliasNamedTypeof, false)
LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false)
LUAU_FASTFLAG(LuauFixNormalizeCaching)
namespace Luau
{
@ -2649,24 +2648,12 @@ static std::optional<bool> areEqComparable(NotNull<TypeArena> arena, NotNull<Nor
NormalizationResult nr;
if (FFlag::LuauFixNormalizeCaching)
{
TypeId c = arena->addType(IntersectionType{{a, b}});
std::shared_ptr<const NormalizedType> n = normalizer->normalize(c);
if (!n)
return std::nullopt;
nr = normalizer->isInhabited(n.get());
}
else
{
TypeId c = arena->addType(IntersectionType{{a, b}});
const NormalizedType* n = normalizer->DEPRECATED_normalize(c);
if (!n)
return std::nullopt;
nr = normalizer->isInhabited(n);
}
switch (nr)
{

View File

@ -23,7 +23,6 @@ LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false)
LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false)
LUAU_FASTFLAG(LuauFixNormalizeCaching)
namespace Luau
{
@ -579,8 +578,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (log.get<AnyType>(subTy))
{
if (normalize)
{
if (FFlag::LuauFixNormalizeCaching)
{
// TODO: there are probably cheaper ways to check if any <: T.
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
@ -591,18 +588,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (!log.get<AnyType>(superNorm->tops))
failure = true;
}
else
{
// TODO: there are probably cheaper ways to check if any <: T.
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!superNorm)
return reportError(location, NormalizationTooComplex{});
if (!log.get<AnyType>(superNorm->tops))
failure = true;
}
}
else
failure = true;
return tryUnifyWithAny(superTy, builtinTypes->anyType);
@ -962,8 +947,6 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
// We deal with this by type normalization.
Unifier innerState = makeChildUnifier();
if (FFlag::LuauFixNormalizeCaching)
{
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
@ -973,19 +956,6 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
innerState.tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
if (!innerState.failure)
log.concat(std::move(innerState.log));
@ -999,8 +969,6 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
// It is possible that T <: A | B even though T </: A and T </:B
// for example boolean <: true | false.
// We deal with this by type normalization.
if (FFlag::LuauFixNormalizeCaching)
{
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
@ -1011,19 +979,6 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
else
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!subNorm || !superNorm)
reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
}
}
else if (!found)
{
if (errorsSuppressed)
@ -1125,24 +1080,12 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
// It is possible that A & B <: T even though A </: T and B </: T
// for example (string?) & ~nil <: string.
// We deal with this by type normalization.
if (FFlag::LuauFixNormalizeCaching)
{
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
return;
}
@ -1192,8 +1135,6 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
// for example string? & number? <: nil.
// We deal with this by type normalization.
if (FFlag::LuauFixNormalizeCaching)
{
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (subNorm && superNorm)
@ -1201,16 +1142,6 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
else
reportError(location, NormalizationTooComplex{});
}
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
}
}
else if (!found)
{
reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible", mismatchContext()});
@ -2712,8 +2643,6 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy)
if (!log.get<NegationType>(subTy) && !log.get<NegationType>(superTy))
ice("tryUnifyNegations superTy or subTy must be a negation type");
if (FFlag::LuauFixNormalizeCaching)
{
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
@ -2725,20 +2654,6 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy)
if (state.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, mismatchContext()});
}
else
{
const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy);
const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy);
if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{});
// T </: ~U iff T <: U
Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, mismatchContext()});
}
}
static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack)
{

View File

@ -204,25 +204,21 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
auto subAny = get<AnyType>(subTy);
auto superAny = get<AnyType>(superTy);
if (subAny && superAny)
return true;
else if (subAny && superFn)
{
// If `any` is the subtype, then we can propagate that inward.
bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack);
bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes);
return argResult && retResult;
}
else if (subFn && superAny)
{
// If `any` is the supertype, then we can propagate that inward.
bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes);
bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack);
return argResult && retResult;
}
auto subTable = getMutable<TableType>(subTy);
auto superTable = get<TableType>(superTy);
if (subAny && superAny)
return true;
else if (subAny && superFn)
return unify(subAny, superFn);
else if (subFn && superAny)
return unify(subFn, superAny);
else if (subAny && superTable)
return unify(subAny, superTable);
else if (subTable && superAny)
return unify(subTable, superAny);
if (subTable && superTable)
{
// `boundTo` works like a bound type, and therefore we'd replace it
@ -451,7 +447,16 @@ bool Unifier2::unify(TableType* subTable, const TableType* superTable)
* an indexer, we therefore conclude that the unsealed table has the
* same indexer.
*/
subTable->indexer = *superTable->indexer;
TypeId indexType = superTable->indexer->indexType;
if (TypeId* subst = genericSubstitutions.find(indexType))
indexType = *subst;
TypeId indexResultType = superTable->indexer->indexResultType;
if (TypeId* subst = genericSubstitutions.find(indexResultType))
indexResultType = *subst;
subTable->indexer = TableIndexer{indexType, indexResultType};
}
return result;
@ -462,6 +467,62 @@ bool Unifier2::unify(const MetatableType* subMetatable, const MetatableType* sup
return unify(subMetatable->metatable, superMetatable->metatable) && unify(subMetatable->table, superMetatable->table);
}
bool Unifier2::unify(const AnyType* subAny, const FunctionType* superFn)
{
// If `any` is the subtype, then we can propagate that inward.
bool argResult = unify(superFn->argTypes, builtinTypes->anyTypePack);
bool retResult = unify(builtinTypes->anyTypePack, superFn->retTypes);
return argResult && retResult;
}
bool Unifier2::unify(const FunctionType* subFn, const AnyType* superAny)
{
// If `any` is the supertype, then we can propagate that inward.
bool argResult = unify(builtinTypes->anyTypePack, subFn->argTypes);
bool retResult = unify(subFn->retTypes, builtinTypes->anyTypePack);
return argResult && retResult;
}
bool Unifier2::unify(const AnyType* subAny, const TableType* superTable)
{
for (const auto& [propName, prop]: superTable->props)
{
if (prop.readTy)
unify(builtinTypes->anyType, *prop.readTy);
if (prop.writeTy)
unify(*prop.writeTy, builtinTypes->anyType);
}
if (superTable->indexer)
{
unify(builtinTypes->anyType, superTable->indexer->indexType);
unify(builtinTypes->anyType, superTable->indexer->indexResultType);
}
return true;
}
bool Unifier2::unify(const TableType* subTable, const AnyType* superAny)
{
for (const auto& [propName, prop]: subTable->props)
{
if (prop.readTy)
unify(*prop.readTy, builtinTypes->anyType);
if (prop.writeTy)
unify(builtinTypes->anyType, *prop.writeTy);
}
if (subTable->indexer)
{
unify(subTable->indexer->indexType, builtinTypes->anyType);
unify(subTable->indexer->indexResultType, builtinTypes->anyType);
}
return true;
}
// FIXME? This should probably return an ErrorVec or an optional<TypeError>
// rather than a boolean to signal an occurs check failure.
bool Unifier2::unify(TypePackId subTp, TypePackId superTp)
@ -596,6 +657,43 @@ struct FreeTypeSearcher : TypeVisitor
}
}
DenseHashSet<const void*> seenPositive{nullptr};
DenseHashSet<const void*> seenNegative{nullptr};
bool seenWithPolarity(const void* ty)
{
switch (polarity)
{
case Positive:
{
if (seenPositive.contains(ty))
return true;
seenPositive.insert(ty);
return false;
}
case Negative:
{
if (seenNegative.contains(ty))
return true;
seenNegative.insert(ty);
return false;
}
case Both:
{
if (seenPositive.contains(ty) && seenNegative.contains(ty))
return true;
seenPositive.insert(ty);
seenNegative.insert(ty);
return false;
}
}
return false;
}
// The keys in these maps are either TypeIds or TypePackIds. It's safe to
// mix them because we only use these pointers as unique keys. We never
// indirect them.
@ -604,12 +702,18 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty) override
{
if (seenWithPolarity(ty))
return false;
LUAU_ASSERT(ty);
return true;
}
bool visit(TypeId ty, const FreeType& ft) override
{
if (seenWithPolarity(ty))
return false;
if (!subsumes(scope, ft.scope))
return true;
@ -632,6 +736,9 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const TableType& tt) override
{
if (seenWithPolarity(ty))
return false;
if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope))
{
switch (polarity)
@ -675,6 +782,9 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const FunctionType& ft) override
{
if (seenWithPolarity(ty))
return false;
flip();
traverse(ft.argTypes);
flip();
@ -691,6 +801,9 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
if (seenWithPolarity(tp))
return false;
if (!subsumes(scope, ftp.scope))
return true;

View File

@ -17,7 +17,6 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
// flag so that we don't break production games by reverting syntax changes.
// See docs/SyntaxChanges.md for an explanation.
LUAU_FASTFLAG(LuauCheckedFunctionSyntax)
LUAU_FASTFLAGVARIABLE(LuauReadWritePropertySyntax, false)
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
namespace Luau
@ -1340,8 +1339,6 @@ AstType* Parser::parseTableType(bool inDeclarationContext)
AstTableAccess access = AstTableAccess::ReadWrite;
std::optional<Location> accessLocation;
if (FFlag::LuauReadWritePropertySyntax || FFlag::DebugLuauDeferredConstraintResolution)
{
if (lexer.current().type == Lexeme::Name && lexer.lookahead().type != ':')
{
if (AstName(lexer.current().name) == "read")
@ -1357,7 +1354,6 @@ AstType* Parser::parseTableType(bool inDeclarationContext)
lexer.next();
}
}
}
if (lexer.current().type == '[' && (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString))
{

View File

@ -144,7 +144,10 @@ static int lua_require(lua_State* L)
if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
if (codegen)
Luau::CodeGen::compile(ML, -1);
{
Luau::CodeGen::CompilationOptions nativeOptions;
Luau::CodeGen::compile(ML, -1, nativeOptions);
}
if (coverageActive())
coverageTrack(ML, -1);
@ -602,7 +605,10 @@ static bool runFile(const char* name, lua_State* GL, bool repl)
if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
if (codegen)
Luau::CodeGen::compile(L, -1);
{
Luau::CodeGen::CompilationOptions nativeOptions;
Luau::CodeGen::compile(L, -1, nativeOptions);
}
if (coverageActive())
coverageTrack(L, -1);

View File

@ -13,10 +13,11 @@ namespace CodeGen
{
struct IrFunction;
struct HostIrHooks;
void loadBytecodeTypeInfo(IrFunction& function);
void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpTargets);
void analyzeBytecodeTypes(IrFunction& function);
void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks);
} // namespace CodeGen
} // namespace Luau

View File

@ -66,6 +66,39 @@ struct CompilationResult
}
};
struct IrBuilder;
using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength);
using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostVectorNamecallHandler = bool (*)(
IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
struct HostIrHooks
{
// Suggest result type of a vector field access
HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr;
// Suggest result type of a vector function namecall
HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr;
// Handle vector value field access
// 'sourceReg' is guaranteed to be a vector
// Guards should take a VM exit to 'pcpos'
HostVectorAccessHandler vectorAccess = nullptr;
// Handle namecalled performed on a vector value
// 'sourceReg' (self argument) is guaranteed to be a vector
// All other arguments can be of any type
// Guards should take a VM exit to 'pcpos'
HostVectorNamecallHandler vectorNamecall = nullptr;
};
struct CompilationOptions
{
unsigned int flags = 0;
HostIrHooks hooks;
};
struct CompilationStats
{
size_t bytecodeSizeBytes = 0;
@ -118,8 +151,11 @@ void setNativeExecutionEnabled(lua_State* L, bool enabled);
using ModuleId = std::array<uint8_t, 16>;
// Builds target function and all inner functions
CompilationResult compile(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr);
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr);
CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats = nullptr);
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats = nullptr);
CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr);
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr);
using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos);
@ -164,7 +200,7 @@ struct AssemblyOptions
Target target = Host;
unsigned int flags = 0;
CompilationOptions compilationOptions;
bool outputBinary = false;

View File

@ -16,11 +16,11 @@ namespace Luau
namespace CodeGen
{
struct AssemblyOptions;
struct HostIrHooks;
struct IrBuilder
{
IrBuilder();
IrBuilder(const HostIrHooks& hostHooks);
void buildFunctionIr(Proto* proto);
@ -64,13 +64,17 @@ struct IrBuilder
IrOp vmExit(uint32_t pcpos);
const HostIrHooks& hostHooks;
bool inTerminatedBlock = false;
bool interruptRequested = false;
bool activeFastcallFallback = false;
IrOp fastcallFallbackReturn;
int fastcallSkipTarget = -1;
// Force builder to skip source commands
int cmdSkipTarget = -1;
IrFunction function;

View File

@ -2,6 +2,7 @@
#include "Luau/BytecodeAnalysis.h"
#include "Luau/BytecodeUtils.h"
#include "Luau/CodeGen.h"
#include "Luau/IrData.h"
#include "Luau/IrUtils.h"
@ -17,6 +18,8 @@ LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo loa
LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately
LUAU_FASTFLAG(LuauTypeInfoLookupImprovement)
LUAU_FASTFLAGVARIABLE(LuauCodegenVectorMispredictFix, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenAnalyzeHostVectorOps, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTypeUpvalCheck, false)
namespace Luau
{
@ -95,7 +98,10 @@ void loadBytecodeTypeInfo(IrFunction& function)
uint32_t upvalCount = readVarInt(data, offset);
uint32_t localCount = readVarInt(data, offset);
if (!FFlag::LuauCodegenLoadTypeUpvalCheck)
{
CODEGEN_ASSERT(upvalCount == unsigned(proto->nups));
}
if (typeSize != 0)
{
@ -114,6 +120,11 @@ void loadBytecodeTypeInfo(IrFunction& function)
if (upvalCount != 0)
{
if (FFlag::LuauCodegenLoadTypeUpvalCheck)
{
CODEGEN_ASSERT(upvalCount == unsigned(proto->nups));
}
typeInfo.upvalueTypes.resize(upvalCount);
uint8_t* types = (uint8_t*)data + offset;
@ -611,7 +622,7 @@ void buildBytecodeBlocks(IrFunction& function, const std::vector<uint8_t>& jumpT
}
}
void analyzeBytecodeTypes(IrFunction& function)
void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
{
Proto* proto = function.proto;
CODEGEN_ASSERT(proto);
@ -662,6 +673,8 @@ void analyzeBytecodeTypes(IrFunction& function)
for (int i = proto->numparams; i < proto->maxstacksize; ++i)
regTags[i] = LBC_TYPE_ANY;
LuauBytecodeType knownNextCallResult = LBC_TYPE_ANY;
for (int i = block.startpc; i <= block.finishpc;)
{
const Instruction* pc = &proto->code[i];
@ -790,6 +803,9 @@ void analyzeBytecodeTypes(IrFunction& function)
if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER;
}
if (FFlag::LuauCodegenAnalyzeHostVectorOps && regTags[ra] == LBC_TYPE_ANY && hostHooks.vectorAccessBytecodeType)
regTags[ra] = hostHooks.vectorAccessBytecodeType(field, str->len);
}
}
else
@ -1161,6 +1177,34 @@ void analyzeBytecodeTypes(IrFunction& function)
regTags[ra + 1] = bcType.a;
bcType.result = LBC_TYPE_FUNCTION;
if (FFlag::LuauCodegenAnalyzeHostVectorOps && bcType.a == LBC_TYPE_VECTOR && hostHooks.vectorNamecallBytecodeType)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);
knownNextCallResult = LuauBytecodeType(hostHooks.vectorNamecallBytecodeType(field, str->len));
}
}
break;
}
case LOP_CALL:
{
if (FFlag::LuauCodegenAnalyzeHostVectorOps)
{
int ra = LUAU_INSN_A(*pc);
if (knownNextCallResult != LBC_TYPE_ANY)
{
bcType.result = knownNextCallResult;
knownNextCallResult = LBC_TYPE_ANY;
regTags[ra] = bcType.result;
}
if (FFlag::LuauCodegenTypeInfo)
refineRegType(bcTypeInfo, ra, i, bcType.result);
}
break;
}
@ -1199,7 +1243,6 @@ void analyzeBytecodeTypes(IrFunction& function)
}
case LOP_GETGLOBAL:
case LOP_SETGLOBAL:
case LOP_CALL:
case LOP_RETURN:
case LOP_JUMP:
case LOP_JUMPBACK:

View File

@ -201,12 +201,12 @@ static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size)
}
template<typename AssemblyBuilder>
static std::optional<OldNativeProto> createNativeFunction(
AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result)
static std::optional<OldNativeProto> createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount,
const HostIrHooks& hooks, CodeGenCompilationResult& result)
{
CODEGEN_ASSERT(!FFlag::LuauCodegenContext);
IrBuilder ir;
IrBuilder ir(hooks);
ir.buildFunctionIr(proto);
unsigned instCount = unsigned(ir.function.instructions.size());
@ -476,7 +476,7 @@ void setNativeExecutionEnabled(lua_State* L, bool enabled)
}
}
static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
static CompilationResult compile_OLD(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{
CompilationResult compilationResult;
@ -485,7 +485,7 @@ static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags,
Proto* root = clvalue(func)->l.p;
if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
{
compilationResult.result = CodeGenCompilationResult::NotNativeModule;
return compilationResult;
@ -500,7 +500,7 @@ static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags,
}
std::vector<Proto*> protos;
gatherFunctions(protos, root, flags);
gatherFunctions(protos, root, options.flags);
// Skip protos that have been compiled during previous invocations of CodeGen::compile
protos.erase(std::remove_if(protos.begin(), protos.end(),
@ -541,7 +541,7 @@ static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags,
{
CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success;
if (std::optional<OldNativeProto> np = createNativeFunction(build, helpers, p, totalIrInstCount, protoResult))
if (std::optional<OldNativeProto> np = createNativeFunction(build, helpers, p, totalIrInstCount, options.hooks, protoResult))
results.push_back(*np);
else
compilationResult.protoFailures.push_back({protoResult, p->debugname ? getstr(p->debugname) : "", p->linedefined});
@ -618,13 +618,15 @@ static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags,
CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{
Luau::CodeGen::CompilationOptions options{flags};
if (FFlag::LuauCodegenContext)
{
return compile_NEW(L, idx, flags, stats);
return compile_NEW(L, idx, options, stats);
}
else
{
return compile_OLD(L, idx, flags, stats);
return compile_OLD(L, idx, options, stats);
}
}
@ -632,7 +634,27 @@ CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsig
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return compile_NEW(moduleId, L, idx, flags, stats);
Luau::CodeGen::CompilationOptions options{flags};
return compile_NEW(moduleId, L, idx, options, stats);
}
CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{
if (FFlag::LuauCodegenContext)
{
return compile_NEW(L, idx, options, stats);
}
else
{
return compile_OLD(L, idx, options, stats);
}
}
CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return compile_NEW(moduleId, L, idx, options, stats);
}
void setPerfLog(void* context, PerfLogFn logFn)

View File

@ -183,11 +183,11 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
{
Proto* root = clvalue(func)->l.p;
if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
if ((options.compilationOptions.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
return std::string();
std::vector<Proto*> protos;
gatherFunctions(protos, root, options.flags);
gatherFunctions(protos, root, options.compilationOptions.flags);
protos.erase(std::remove_if(protos.begin(), protos.end(),
[](Proto* p) {
@ -215,7 +215,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
for (Proto* p : protos)
{
IrBuilder ir;
IrBuilder ir(options.compilationOptions.hooks);
ir.buildFunctionIr(p);
unsigned asmSize = build.getCodeSize();
unsigned asmCount = build.getInstructionCount();

View File

@ -478,12 +478,12 @@ void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext)
}
template<typename AssemblyBuilder>
[[nodiscard]] static NativeProtoExecDataPtr createNativeFunction(
AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result)
[[nodiscard]] static NativeProtoExecDataPtr createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto,
uint32_t& totalIrInstCount, const HostIrHooks& hooks, CodeGenCompilationResult& result)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
IrBuilder ir;
IrBuilder ir(hooks);
ir.buildFunctionIr(proto);
unsigned instCount = unsigned(ir.function.instructions.size());
@ -505,7 +505,7 @@ template<typename AssemblyBuilder>
}
[[nodiscard]] static CompilationResult compileInternal(
const std::optional<ModuleId>& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
const std::optional<ModuleId>& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
CODEGEN_ASSERT(lua_isLfunction(L, idx));
@ -513,7 +513,7 @@ template<typename AssemblyBuilder>
Proto* root = clvalue(func)->l.p;
if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
return CompilationResult{CodeGenCompilationResult::NotNativeModule};
BaseCodeGenContext* codeGenContext = getCodeGenContext(L);
@ -521,7 +521,7 @@ template<typename AssemblyBuilder>
return CompilationResult{CodeGenCompilationResult::CodeGenNotInitialized};
std::vector<Proto*> protos;
gatherFunctions(protos, root, flags);
gatherFunctions(protos, root, options.flags);
// Skip protos that have been compiled during previous invocations of CodeGen::compile
protos.erase(std::remove_if(protos.begin(), protos.end(),
@ -572,7 +572,7 @@ template<typename AssemblyBuilder>
{
CodeGenCompilationResult protoResult = CodeGenCompilationResult::Success;
NativeProtoExecDataPtr nativeExecData = createNativeFunction(build, helpers, protos[i], totalIrInstCount, protoResult);
NativeProtoExecDataPtr nativeExecData = createNativeFunction(build, helpers, protos[i], totalIrInstCount, options.hooks, protoResult);
if (nativeExecData != nullptr)
{
nativeProtos.push_back(std::move(nativeExecData));
@ -639,18 +639,18 @@ template<typename AssemblyBuilder>
return compilationResult;
}
CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return compileInternal(moduleId, L, idx, flags, stats);
return compileInternal(moduleId, L, idx, options, stats);
}
CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
CompilationResult compile_NEW(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats)
{
CODEGEN_ASSERT(FFlag::LuauCodegenContext);
return compileInternal({}, L, idx, flags, stats);
return compileInternal({}, L, idx, options, stats);
}
[[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L)

View File

@ -107,8 +107,8 @@ void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationC
// destroyed via lua_close.
void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext);
CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats);
CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats);
CompilationResult compile_NEW(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats);
CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats);
// Returns true if native execution is currently enabled for this VM
[[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L);

View File

@ -15,6 +15,7 @@
LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used
LUAU_FASTFLAG(LuauTypeInfoLookupImprovement)
LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps)
namespace Luau
{
@ -23,8 +24,9 @@ namespace CodeGen
constexpr unsigned kNoAssociatedBlockIndex = ~0u;
IrBuilder::IrBuilder()
: constantMap({IrConstKind::Tag, ~0ull})
IrBuilder::IrBuilder(const HostIrHooks& hostHooks)
: hostHooks(hostHooks)
, constantMap({IrConstKind::Tag, ~0ull})
{
}
static bool hasTypedParameters_DEPRECATED(Proto* proto)
@ -230,7 +232,7 @@ void IrBuilder::buildFunctionIr(Proto* proto)
rebuildBytecodeBasicBlocks(proto);
// Infer register tags in bytecode
analyzeBytecodeTypes(function);
analyzeBytecodeTypes(function, hostHooks);
function.bcMapping.resize(proto->sizecode, {~0u, ~0u});
@ -283,10 +285,10 @@ void IrBuilder::buildFunctionIr(Proto* proto)
translateInst(op, pc, i);
if (fastcallSkipTarget != -1)
if (cmdSkipTarget != -1)
{
nexti = fastcallSkipTarget;
fastcallSkipTarget = -1;
nexti = cmdSkipTarget;
cmdSkipTarget = -1;
}
}
@ -613,7 +615,15 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
translateInstCapture(*this, pc, i);
break;
case LOP_NAMECALL:
if (FFlag::LuauCodegenAnalyzeHostVectorOps)
{
if (translateInstNamecall(*this, pc, i))
cmdSkipTarget = i + 3;
}
else
{
translateInstNamecall(*this, pc, i);
}
break;
case LOP_PREPVARARGS:
inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc)));
@ -654,7 +664,7 @@ void IrBuilder::handleFastcallFallback(IrOp fallbackOrUndef, const Instruction*
}
else
{
fastcallSkipTarget = i + skip + 2;
cmdSkipTarget = i + skip + 2;
}
}

View File

@ -3,6 +3,7 @@
#include "Luau/Bytecode.h"
#include "Luau/BytecodeUtils.h"
#include "Luau/CodeGen.h"
#include "Luau/IrBuilder.h"
#include "Luau/IrUtils.h"
@ -14,6 +15,7 @@
LUAU_FASTFLAGVARIABLE(LuauCodegenDirectUserdataFlow, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenFixVectorFields, false)
LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps)
namespace Luau
{
@ -1218,6 +1220,10 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
}
else
{
if (FFlag::LuauCodegenAnalyzeHostVectorOps && build.hostHooks.vectorAccess &&
build.hostHooks.vectorAccess(build, field, str->len, ra, rb, pcpos))
return;
build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux));
}
@ -1376,7 +1382,7 @@ void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos)
}
}
void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos)
bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
@ -1388,8 +1394,24 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos)
{
build.loadAndCheckTag(build.vmReg(rb), LUA_TVECTOR, build.vmExit(pcpos));
if (FFlag::LuauCodegenAnalyzeHostVectorOps && build.hostHooks.vectorNamecall)
{
Instruction call = pc[2];
CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int callra = LUAU_INSN_A(call);
int nparams = LUAU_INSN_B(call) - 1;
int nresults = LUAU_INSN_C(call) - 1;
TString* str = gco2ts(build.function.proto->k[aux].value.gc);
const char* field = getstr(str);
if (build.hostHooks.vectorNamecall(build, field, str->len, callra, rb, nparams, nresults, pcpos))
return true;
}
build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux));
return;
return false;
}
if (FFlag::LuauCodegenDirectUserdataFlow && bcTypes.a == LBC_TYPE_USERDATA)
@ -1397,7 +1419,7 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos)
build.loadAndCheckTag(build.vmReg(rb), LUA_TUSERDATA, build.vmExit(pcpos));
build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux));
return;
return false;
}
IrOp next = build.blockAtInst(pcpos + getOpLength(LOP_NAMECALL));
@ -1451,6 +1473,8 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::JUMP, next);
build.beginBlock(next);
return false;
}
void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c)

View File

@ -61,7 +61,7 @@ void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos);
bool translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c);
void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c);
void translateInstNewClosure(IrBuilder& build, const Instruction* pc, int pcpos);

View File

@ -17,5 +17,6 @@ void luau_codegen_create(lua_State* L)
void luau_codegen_compile(lua_State* L, int idx)
{
Luau::CodeGen::compile(L, idx);
Luau::CodeGen::CompilationOptions options;
Luau::CodeGen::compile(L, idx, options);
}

View File

@ -266,7 +266,6 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Refinement.cpp
Analysis/src/RequireTracer.cpp
Analysis/src/Scope.cpp
Analysis/src/Set.cpp
Analysis/src/Simplify.cpp
Analysis/src/Substitution.cpp
Analysis/src/Subtyping.cpp
@ -494,6 +493,7 @@ if(TARGET Luau.Conformance)
target_sources(Luau.Conformance PRIVATE
tests/RegisterCallbacks.h
tests/RegisterCallbacks.cpp
tests/ConformanceIrHooks.h
tests/Conformance.test.cpp
tests/IrLowering.test.cpp
tests/SharedCodeAllocator.test.cpp

View File

@ -3272,9 +3272,9 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement")
// https://github.com/Roblox/luau/issues/858
TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement2")
{
ScopedFastFlag sff[]{
{FFlag::DebugLuauDeferredConstraintResolution, true},
};
// don't run this when the DCR flag isn't set
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
check(R"(
--!strict

View File

@ -16,6 +16,7 @@
#include "doctest.h"
#include "ScopedFlags.h"
#include "ConformanceIrHooks.h"
#include <fstream>
#include <string>
@ -48,6 +49,13 @@ static lua_CompileOptions defaultOptions()
return copts;
}
static Luau::CodeGen::CompilationOptions defaultCodegenOptions()
{
Luau::CodeGen::CompilationOptions opts = {};
opts.flags = Luau::CodeGen::CodeGen_ColdFunctions;
return opts;
}
static int lua_collectgarbage(lua_State* L)
{
static const char* const opts[] = {"stop", "restart", "collect", "count", "isrunning", "step", "setgoal", "setstepmul", "setstepsize", nullptr};
@ -118,6 +126,15 @@ static int lua_vector_dot(lua_State* L)
return 1;
}
static int lua_vector_cross(lua_State* L)
{
const float* a = luaL_checkvector(L, 1);
const float* b = luaL_checkvector(L, 2);
lua_pushvector(L, a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]);
return 1;
}
static int lua_vector_index(lua_State* L)
{
const float* v = luaL_checkvector(L, 1);
@ -129,6 +146,14 @@ static int lua_vector_index(lua_State* L)
return 1;
}
if (strcmp(name, "Unit") == 0)
{
float invSqrt = 1.0f / sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]);
lua_pushvector(L, v[0] * invSqrt, v[1] * invSqrt, v[2] * invSqrt);
return 1;
}
if (strcmp(name, "Dot") == 0)
{
lua_pushcfunction(L, lua_vector_dot, "Dot");
@ -144,6 +169,9 @@ static int lua_vector_namecall(lua_State* L)
{
if (strcmp(str, "Dot") == 0)
return lua_vector_dot(L);
if (strcmp(str, "Cross") == 0)
return lua_vector_cross(L);
}
luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1));
@ -157,7 +185,8 @@ int lua_silence(lua_State* L)
using StateRef = std::unique_ptr<lua_State, void (*)(lua_State*)>;
static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr,
lua_State* initialLuaState = nullptr, lua_CompileOptions* options = nullptr, bool skipCodegen = false)
lua_State* initialLuaState = nullptr, lua_CompileOptions* options = nullptr, bool skipCodegen = false,
Luau::CodeGen::CompilationOptions* codegenOptions = nullptr)
{
#ifdef LUAU_CONFORMANCE_SOURCE_DIR
std::string path = LUAU_CONFORMANCE_SOURCE_DIR;
@ -238,7 +267,11 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n
free(bytecode);
if (result == 0 && codegen && !skipCodegen && luau_codegen_supported())
Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions);
{
Luau::CodeGen::CompilationOptions nativeOpts = codegenOptions ? *codegenOptions : defaultCodegenOptions();
Luau::CodeGen::compile(L, -1, nativeOpts);
}
int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX;
@ -533,12 +566,51 @@ TEST_CASE("Pack")
TEST_CASE("Vector")
{
lua_CompileOptions copts = defaultOptions();
Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions();
SUBCASE("NoIrHooks")
{
SUBCASE("O0")
{
copts.optimizationLevel = 0;
}
SUBCASE("O1")
{
copts.optimizationLevel = 1;
}
SUBCASE("O2")
{
copts.optimizationLevel = 2;
}
}
SUBCASE("IrHooks")
{
nativeOpts.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType;
nativeOpts.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType;
nativeOpts.hooks.vectorAccess = vectorAccess;
nativeOpts.hooks.vectorNamecall = vectorNamecall;
SUBCASE("O0")
{
copts.optimizationLevel = 0;
}
SUBCASE("O1")
{
copts.optimizationLevel = 1;
}
SUBCASE("O2")
{
copts.optimizationLevel = 2;
}
}
runConformance(
"vector.lua",
[](lua_State* L) {
setupVectorHelpers(L);
},
nullptr, nullptr, nullptr);
nullptr, nullptr, &copts, false, &nativeOpts);
}
static void populateRTTI(lua_State* L, Luau::TypeId type)
@ -2141,7 +2213,10 @@ TEST_CASE("HugeFunction")
REQUIRE(result == 0);
if (codegen && luau_codegen_supported())
Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions);
{
Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions};
Luau::CodeGen::compile(L, -1, nativeOptions);
}
int status = lua_resume(L, nullptr, 0);
REQUIRE(status == 0);
@ -2263,8 +2338,9 @@ TEST_CASE("IrInstructionLimit")
REQUIRE(result == 0);
Luau::CodeGen::CompilationOptions nativeOptions{Luau::CodeGen::CodeGen_ColdFunctions};
Luau::CodeGen::CompilationStats nativeStats = {};
Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions, &nativeStats);
Luau::CodeGen::CompilationResult nativeResult = Luau::CodeGen::compile(L, -1, nativeOptions, &nativeStats);
// Limit is not hit immediately, so with some functions compiled it should be a success
CHECK(nativeResult.result == Luau::CodeGen::CodeGenCompilationResult::Success);

151
tests/ConformanceIrHooks.h Normal file
View File

@ -0,0 +1,151 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/IrBuilder.h"
inline uint8_t vectorAccessBytecodeType(const char* member, size_t memberLength)
{
using namespace Luau::CodeGen;
if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0)
return LBC_TYPE_NUMBER;
if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0)
return LBC_TYPE_VECTOR;
return LBC_TYPE_ANY;
}
inline bool vectorAccess(Luau::CodeGen::IrBuilder& build, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos)
{
using namespace Luau::CodeGen;
if (memberLength == strlen("Magnitude") && strcmp(member, "Magnitude") == 0)
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8));
IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);
IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(resultReg), mag);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TNUMBER));
return true;
}
if (memberLength == strlen("Unit") && strcmp(member, "Unit") == 0)
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8));
IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);
IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);
IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv);
IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv);
IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv);
build.inst(IrCmd::STORE_VECTOR, build.vmReg(resultReg), xr, yr, zr);
build.inst(IrCmd::STORE_TAG, build.vmReg(resultReg), build.constTag(LUA_TVECTOR));
return true;
}
return false;
}
inline uint8_t vectorNamecallBytecodeType(const char* member, size_t memberLength)
{
if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0)
return LBC_TYPE_NUMBER;
if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0)
return LBC_TYPE_VECTOR;
return LBC_TYPE_ANY;
}
inline bool vectorNamecall(
Luau::CodeGen::IrBuilder& build, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos)
{
using namespace Luau::CodeGen;
if (memberLength == strlen("Dot") && strcmp(member, "Dot") == 0 && params == 2 && results <= 1)
{
build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos));
IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0));
IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(0));
IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2);
IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4));
IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(4));
IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2);
IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8));
IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(8));
IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2);
IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(argResReg), sum);
build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TNUMBER));
// If the function is called in multi-return context, stack has to be adjusted
if (results == LUA_MULTRET)
build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1));
return true;
}
if (memberLength == strlen("Cross") && strcmp(member, "Cross") == 0 && params == 2 && results <= 1)
{
build.loadAndCheckTag(build.vmReg(argResReg + 2), LUA_TVECTOR, build.vmExit(pcpos));
IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(0));
IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(0));
IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(4));
IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(4));
IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(sourceReg), build.constInt(8));
IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, build.vmReg(argResReg + 2), build.constInt(8));
IrOp y1z2 = build.inst(IrCmd::MUL_NUM, y1, z2);
IrOp z1y2 = build.inst(IrCmd::MUL_NUM, z1, y2);
IrOp xr = build.inst(IrCmd::SUB_NUM, y1z2, z1y2);
IrOp z1x2 = build.inst(IrCmd::MUL_NUM, z1, x2);
IrOp x1z2 = build.inst(IrCmd::MUL_NUM, x1, z2);
IrOp yr = build.inst(IrCmd::SUB_NUM, z1x2, x1z2);
IrOp x1y2 = build.inst(IrCmd::MUL_NUM, x1, y2);
IrOp y1x2 = build.inst(IrCmd::MUL_NUM, y1, x2);
IrOp zr = build.inst(IrCmd::SUB_NUM, x1y2, y1x2);
build.inst(IrCmd::STORE_VECTOR, build.vmReg(argResReg), xr, yr, zr);
build.inst(IrCmd::STORE_TAG, build.vmReg(argResReg), build.constTag(LUA_TVECTOR));
// If the function is called in multi-return context, stack has to be adjusted
if (results == LUA_MULTRET)
build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(argResReg), build.constInt(1));
return true;
}
return false;
}

View File

@ -21,6 +21,11 @@ using namespace Luau::CodeGen;
class IrBuilderFixture
{
public:
IrBuilderFixture()
: build(hooks)
{
}
void constantFold()
{
for (IrBlock& block : build.function.blocks)
@ -109,6 +114,7 @@ public:
computeCfgDominanceTreeChildren(build.function);
}
HostIrHooks hooks;
IrBuilder build;
// Luau.VM headers are not accessible

View File

@ -6,9 +6,11 @@
#include "Luau/CodeGen.h"
#include "Luau/Compiler.h"
#include "Luau/Parser.h"
#include "Luau/IrBuilder.h"
#include "doctest.h"
#include "ScopedFlags.h"
#include "ConformanceIrHooks.h"
#include <memory>
@ -22,11 +24,17 @@ LUAU_FASTFLAG(LuauCodegenIrTypeNames)
LUAU_FASTFLAG(LuauCompileTempTypeInfo)
LUAU_FASTFLAG(LuauCodegenFixVectorFields)
LUAU_FASTFLAG(LuauCodegenVectorMispredictFix)
LUAU_FASTFLAG(LuauCodegenAnalyzeHostVectorOps)
static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1)
{
Luau::CodeGen::AssemblyOptions options;
options.compilationOptions.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType;
options.compilationOptions.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType;
options.compilationOptions.hooks.vectorAccess = vectorAccess;
options.compilationOptions.hooks.vectorNamecall = vectorNamecall;
// For IR, we don't care about assembly, but we want a stable target
options.target = Luau::CodeGen::AssemblyOptions::Target::X64_SystemV;
@ -513,6 +521,277 @@ bb_6:
)");
}
TEST_CASE("VectorCustomAccess")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true};
ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true};
ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function vec3magn(a: vector)
return a.Magnitude * 2
end
)"),
R"(
; function vec3magn($arg0) line 2
bb_0:
CHECK_TAG R0, tvector, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%6 = LOAD_FLOAT R0, 0i
%7 = LOAD_FLOAT R0, 4i
%8 = LOAD_FLOAT R0, 8i
%9 = MUL_NUM %6, %6
%10 = MUL_NUM %7, %7
%11 = MUL_NUM %8, %8
%12 = ADD_NUM %9, %10
%13 = ADD_NUM %12, %11
%14 = SQRT_NUM %13
%20 = MUL_NUM %14, 2
STORE_DOUBLE R1, %20
STORE_TAG R1, tnumber
INTERRUPT 3u
RETURN R1, 1i
)");
}
TEST_CASE("VectorCustomNamecall")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true};
ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true};
ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function vec3dot(a: vector, b: vector)
return (a:Dot(b))
end
)"),
R"(
; function vec3dot($arg0, $arg1) line 2
bb_0:
CHECK_TAG R0, tvector, exit(entry)
CHECK_TAG R1, tvector, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%6 = LOAD_TVALUE R1
STORE_TVALUE R4, %6
%12 = LOAD_FLOAT R0, 0i
%13 = LOAD_FLOAT R4, 0i
%14 = MUL_NUM %12, %13
%15 = LOAD_FLOAT R0, 4i
%16 = LOAD_FLOAT R4, 4i
%17 = MUL_NUM %15, %16
%18 = LOAD_FLOAT R0, 8i
%19 = LOAD_FLOAT R4, 8i
%20 = MUL_NUM %18, %19
%21 = ADD_NUM %14, %17
%22 = ADD_NUM %21, %20
STORE_DOUBLE R2, %22
STORE_TAG R2, tnumber
INTERRUPT 4u
RETURN R2, 1i
)");
}
TEST_CASE("VectorCustomAccessChain")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true};
ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true};
ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true};
ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: vector, b: vector)
return a.Unit * b.Magnitude
end
)"),
R"(
; function foo($arg0, $arg1) line 2
bb_0:
CHECK_TAG R0, tvector, exit(entry)
CHECK_TAG R1, tvector, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%8 = LOAD_FLOAT R0, 0i
%9 = LOAD_FLOAT R0, 4i
%10 = LOAD_FLOAT R0, 8i
%11 = MUL_NUM %8, %8
%12 = MUL_NUM %9, %9
%13 = MUL_NUM %10, %10
%14 = ADD_NUM %11, %12
%15 = ADD_NUM %14, %13
%16 = SQRT_NUM %15
%17 = DIV_NUM 1, %16
%18 = MUL_NUM %8, %17
%19 = MUL_NUM %9, %17
%20 = MUL_NUM %10, %17
STORE_VECTOR R3, %18, %19, %20
STORE_TAG R3, tvector
%25 = LOAD_FLOAT R1, 0i
%26 = LOAD_FLOAT R1, 4i
%27 = LOAD_FLOAT R1, 8i
%28 = MUL_NUM %25, %25
%29 = MUL_NUM %26, %26
%30 = MUL_NUM %27, %27
%31 = ADD_NUM %28, %29
%32 = ADD_NUM %31, %30
%33 = SQRT_NUM %32
%40 = LOAD_TVALUE R3
%42 = NUM_TO_VEC %33
%43 = MUL_VEC %40, %42
%44 = TAG_VECTOR %43
STORE_TVALUE R2, %44
INTERRUPT 5u
RETURN R2, 1i
)");
}
TEST_CASE("VectorCustomNamecallChain")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true};
ScopedFastFlag luauCodegenVectorMispredictFix{FFlag::LuauCodegenVectorMispredictFix, true};
ScopedFastFlag LuauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true};
ScopedFastFlag luauCodegenAnalyzeHostVectorOps{FFlag::LuauCodegenAnalyzeHostVectorOps, true};
CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(n: vector, b: vector, t: vector)
return n:Cross(t):Dot(b) + 1
end
)"),
R"(
; function foo($arg0, $arg1, $arg2) line 2
bb_0:
CHECK_TAG R0, tvector, exit(entry)
CHECK_TAG R1, tvector, exit(entry)
CHECK_TAG R2, tvector, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%8 = LOAD_TVALUE R2
STORE_TVALUE R6, %8
%14 = LOAD_FLOAT R0, 0i
%15 = LOAD_FLOAT R6, 0i
%16 = LOAD_FLOAT R0, 4i
%17 = LOAD_FLOAT R6, 4i
%18 = LOAD_FLOAT R0, 8i
%19 = LOAD_FLOAT R6, 8i
%20 = MUL_NUM %16, %19
%21 = MUL_NUM %18, %17
%22 = SUB_NUM %20, %21
%23 = MUL_NUM %18, %15
%24 = MUL_NUM %14, %19
%25 = SUB_NUM %23, %24
%26 = MUL_NUM %14, %17
%27 = MUL_NUM %16, %15
%28 = SUB_NUM %26, %27
STORE_VECTOR R4, %22, %25, %28
STORE_TAG R4, tvector
%31 = LOAD_TVALUE R1
STORE_TVALUE R6, %31
%37 = LOAD_FLOAT R4, 0i
%38 = LOAD_FLOAT R6, 0i
%39 = MUL_NUM %37, %38
%40 = LOAD_FLOAT R4, 4i
%41 = LOAD_FLOAT R6, 4i
%42 = MUL_NUM %40, %41
%43 = LOAD_FLOAT R4, 8i
%44 = LOAD_FLOAT R6, 8i
%45 = MUL_NUM %43, %44
%46 = ADD_NUM %39, %42
%47 = ADD_NUM %46, %45
%53 = ADD_NUM %47, 1
STORE_DOUBLE R3, %53
STORE_TAG R3, tnumber
INTERRUPT 9u
RETURN R3, 1i
)");
}
TEST_CASE("VectorCustomNamecallChain2")
{
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCompileTempTypeInfo, true},
{FFlag::LuauCodegenVectorMispredictFix, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
type Vertex = {n: vector, b: vector}
local function foo(v: Vertex, t: vector)
return v.n:Cross(t):Dot(v.b) + 1
end
)"),
R"(
; function foo($arg0, $arg1) line 4
bb_0:
CHECK_TAG R0, ttable, exit(entry)
CHECK_TAG R1, tvector, exit(entry)
JUMP bb_2
bb_2:
JUMP bb_bytecode_1
bb_bytecode_1:
%8 = LOAD_POINTER R0
%9 = GET_SLOT_NODE_ADDR %8, 0u, K1
CHECK_SLOT_MATCH %9, K1, bb_fallback_3
%11 = LOAD_TVALUE %9, 0i
STORE_TVALUE R3, %11
JUMP bb_4
bb_4:
%16 = LOAD_TVALUE R1
STORE_TVALUE R5, %16
CHECK_TAG R3, tvector, exit(3)
CHECK_TAG R5, tvector, exit(3)
%22 = LOAD_FLOAT R3, 0i
%23 = LOAD_FLOAT R5, 0i
%24 = LOAD_FLOAT R3, 4i
%25 = LOAD_FLOAT R5, 4i
%26 = LOAD_FLOAT R3, 8i
%27 = LOAD_FLOAT R5, 8i
%28 = MUL_NUM %24, %27
%29 = MUL_NUM %26, %25
%30 = SUB_NUM %28, %29
%31 = MUL_NUM %26, %23
%32 = MUL_NUM %22, %27
%33 = SUB_NUM %31, %32
%34 = MUL_NUM %22, %25
%35 = MUL_NUM %24, %23
%36 = SUB_NUM %34, %35
STORE_VECTOR R3, %30, %33, %36
CHECK_TAG R0, ttable, exit(6)
%41 = LOAD_POINTER R0
%42 = GET_SLOT_NODE_ADDR %41, 6u, K3
CHECK_SLOT_MATCH %42, K3, bb_fallback_5
%44 = LOAD_TVALUE %42, 0i
STORE_TVALUE R5, %44
JUMP bb_6
bb_6:
CHECK_TAG R3, tvector, exit(8)
CHECK_TAG R5, tvector, exit(8)
%53 = LOAD_FLOAT R3, 0i
%54 = LOAD_FLOAT R5, 0i
%55 = MUL_NUM %53, %54
%56 = LOAD_FLOAT R3, 4i
%57 = LOAD_FLOAT R5, 4i
%58 = MUL_NUM %56, %57
%59 = LOAD_FLOAT R3, 8i
%60 = LOAD_FLOAT R5, 8i
%61 = MUL_NUM %59, %60
%62 = ADD_NUM %55, %58
%63 = ADD_NUM %62, %61
%69 = ADD_NUM %63, 1
STORE_DOUBLE R2, %69
STORE_TAG R2, tnumber
INTERRUPT 12u
RETURN R2, 1i
)");
}
TEST_CASE("UserDataGetIndex")
{
ScopedFastFlag luauCodegenDirectUserdataFlow{FFlag::LuauCodegenDirectUserdataFlow, true};
@ -1040,7 +1319,7 @@ TEST_CASE("ResolveVectorNamecalls")
{
ScopedFastFlag sffs[]{{FFlag::LuauLoadTypeInfo, true}, {FFlag::LuauCompileTypeInfo, true}, {FFlag::LuauCodegenTypeInfo, true},
{FFlag::LuauCodegenRemoveDeadStores5, true}, {FFlag::LuauTypeInfoLookupImprovement, true}, {FFlag::LuauCodegenIrTypeNames, true},
{FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}};
{FFlag::LuauCompileTempTypeInfo, true}, {FFlag::LuauCodegenDirectUserdataFlow, true}, {FFlag::LuauCodegenAnalyzeHostVectorOps, true}};
CHECK_EQ("\n" + getCodegenAssembly(R"(
type Vertex = {pos: vector, normal: vector}
@ -1083,10 +1362,20 @@ bb_6:
%31 = LOAD_TVALUE K1, 0i, tvector
STORE_TVALUE R4, %31
CHECK_TAG R2, tvector, exit(4)
FALLBACK_NAMECALL 4u, R2, R2, K2
INTERRUPT 6u
SET_SAVEDPC 7u
CALL R2, 2i, -1i
%37 = LOAD_FLOAT R2, 0i
%38 = LOAD_FLOAT R4, 0i
%39 = MUL_NUM %37, %38
%40 = LOAD_FLOAT R2, 4i
%41 = LOAD_FLOAT R4, 4i
%42 = MUL_NUM %40, %41
%43 = LOAD_FLOAT R2, 8i
%44 = LOAD_FLOAT R4, 8i
%45 = MUL_NUM %43, %44
%46 = ADD_NUM %39, %42
%47 = ADD_NUM %46, %45
STORE_DOUBLE R2, %47
STORE_TAG R2, tnumber
ADJUST_STACK_TO_REG R2, 1i
INTERRUPT 7u
RETURN R2, -1i
)");

View File

@ -11,7 +11,6 @@
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauFixNormalizeCaching)
LUAU_FASTFLAG(LuauNormalizeNotUnknownIntersection)
LUAU_FASTFLAG(LuauFixCyclicUnionsOfIntersections);
LUAU_FASTINT(LuauTypeInferRecursionLimit)
@ -428,7 +427,6 @@ struct NormalizeFixture : Fixture
UnifierSharedState unifierState{&iceHandler};
Normalizer normalizer{&arena, builtinTypes, NotNull{&unifierState}};
Scope globalScope{builtinTypes->anyTypePack};
ScopedFastFlag fixNormalizeCaching{FFlag::LuauFixNormalizeCaching, true};
NormalizeFixture()
{

View File

@ -17,7 +17,6 @@ LUAU_FASTINT(LuauRecursionLimit);
LUAU_FASTINT(LuauTypeLengthLimit);
LUAU_FASTINT(LuauParseErrorLimit);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauReadWritePropertySyntax);
namespace
{
@ -3156,8 +3155,6 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_@_as_variable_name")
TEST_CASE_FIXTURE(Fixture, "read_write_table_properties")
{
ScopedFastFlag sff{FFlag::LuauReadWritePropertySyntax, true};
auto pr = tryParse(R"(
type A = {read x: number}
type B = {write x: number}

View File

@ -7,8 +7,6 @@
#include <string>
#include <vector>
LUAU_FASTFLAG(LuauFixSetIter);
TEST_SUITE_BEGIN("SetTests");
TEST_CASE("empty_set_size_0")
@ -107,8 +105,6 @@ TEST_CASE("iterate_over_set_skips_erased_elements")
TEST_CASE("iterate_over_set_skips_first_element_if_it_is_erased")
{
ScopedFastFlag sff{FFlag::LuauFixSetIter, true};
/*
* As of this writing, in the following set, the key "y" happens to occur
* before "x" in the underlying DenseHashSet. This is important because it

View File

@ -438,10 +438,13 @@ TEST_CASE("SharedAllocation")
const ModuleId moduleId = {0x01};
CompilationOptions options;
options.flags = CodeGen_ColdFunctions;
CompilationStats nativeStats1 = {};
CompilationStats nativeStats2 = {};
const CompilationResult codeGenResult1 = Luau::CodeGen::compile(moduleId, L1.get(), -1, CodeGen_ColdFunctions, &nativeStats1);
const CompilationResult codeGenResult2 = Luau::CodeGen::compile(moduleId, L2.get(), -1, CodeGen_ColdFunctions, &nativeStats2);
const CompilationResult codeGenResult1 = Luau::CodeGen::compile(moduleId, L1.get(), -1, options, &nativeStats1);
const CompilationResult codeGenResult2 = Luau::CodeGen::compile(moduleId, L2.get(), -1, options, &nativeStats2);
REQUIRE(codeGenResult1.result == CodeGenCompilationResult::Success);
REQUIRE(codeGenResult2.result == CodeGenCompilationResult::Success);

View File

@ -354,21 +354,41 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded")
function f2(f) return f or f1 end
function f3(f) return f or f2 end
)");
if (FFlag::DebugLuauDeferredConstraintResolution)
{
LUAU_REQUIRE_ERROR_COUNT(3, result);
auto err = get<ExplicitFunctionAnnotationRecommended>(result.errors[0]);
LUAU_ASSERT(err);
CHECK("(...any) -> ()" == toString(err->recommendedReturn));
REQUIRE(1 == err->recommendedArgs.size());
CHECK("unknown" == toString(err->recommendedArgs[0].second));
err = get<ExplicitFunctionAnnotationRecommended>(result.errors[1]);
LUAU_ASSERT(err);
// FIXME: this recommendation could be better
CHECK("<a>(a) -> or<a, (...any) -> ()>" == toString(err->recommendedReturn));
REQUIRE(1 == err->recommendedArgs.size());
CHECK("unknown" == toString(err->recommendedArgs[0].second));
err = get<ExplicitFunctionAnnotationRecommended>(result.errors[2]);
LUAU_ASSERT(err);
// FIXME: this recommendation could be better
CHECK("<a>(a) -> or<a, <b>(b) -> or<b, (...any) -> ()>>" == toString(err->recommendedReturn));
REQUIRE(1 == err->recommendedArgs.size());
CHECK("unknown" == toString(err->recommendedArgs[0].second));
ToStringOptions o;
o.exhaustive = false;
o.maxTypeLength = 20;
CHECK_EQ(toString(requireType("f0"), o), "() -> ()");
CHECK_EQ(toString(requireType("f1"), o), "<a>(a) -> or<a, () -> ... *TRUNCATED*");
CHECK_EQ(toString(requireType("f2"), o), "<b>(b) -> or<b, <a>(a... *TRUNCATED*");
CHECK_EQ(toString(requireType("f3"), o), "<c>(c) -> or<c, <b>(b... *TRUNCATED*");
}
else
{
LUAU_REQUIRE_NO_ERRORS(result);
ToStringOptions o;
o.exhaustive = false;
if (FFlag::DebugLuauDeferredConstraintResolution)
{
o.maxTypeLength = 30;
CHECK_EQ(toString(requireType("f0"), o), "() -> ()");
CHECK_EQ(toString(requireType("f1"), o), "<a>(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*");
CHECK_EQ(toString(requireType("f2"), o), "<b>(b) -> (<a>(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*");
CHECK_EQ(toString(requireType("f3"), o), "<c>(c) -> (<b>(b) -> (<a>(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*");
}
else
{
o.maxTypeLength = 40;
CHECK_EQ(toString(requireType("f0"), o), "() -> ()");
CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()");
@ -385,20 +405,42 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive")
function f2(f) return f or f1 end
function f3(f) return f or f2 end
)");
if (FFlag::DebugLuauDeferredConstraintResolution)
{
LUAU_REQUIRE_ERROR_COUNT(3, result);
auto err = get<ExplicitFunctionAnnotationRecommended>(result.errors[0]);
LUAU_ASSERT(err);
CHECK("(...any) -> ()" == toString(err->recommendedReturn));
REQUIRE(1 == err->recommendedArgs.size());
CHECK("unknown" == toString(err->recommendedArgs[0].second));
err = get<ExplicitFunctionAnnotationRecommended>(result.errors[1]);
LUAU_ASSERT(err);
// FIXME: this recommendation could be better
CHECK("<a>(a) -> or<a, (...any) -> ()>" == toString(err->recommendedReturn));
REQUIRE(1 == err->recommendedArgs.size());
CHECK("unknown" == toString(err->recommendedArgs[0].second));
err = get<ExplicitFunctionAnnotationRecommended>(result.errors[2]);
LUAU_ASSERT(err);
// FIXME: this recommendation could be better
CHECK("<a>(a) -> or<a, <b>(b) -> or<b, (...any) -> ()>>" == toString(err->recommendedReturn));
REQUIRE(1 == err->recommendedArgs.size());
CHECK("unknown" == toString(err->recommendedArgs[0].second));
ToStringOptions o;
o.exhaustive = true;
o.maxTypeLength = 20;
CHECK_EQ(toString(requireType("f0"), o), "() -> ()");
CHECK_EQ(toString(requireType("f1"), o), "<a>(a) -> or<a, () -> ... *TRUNCATED*");
CHECK_EQ(toString(requireType("f2"), o), "<b>(b) -> or<b, <a>(a... *TRUNCATED*");
CHECK_EQ(toString(requireType("f3"), o), "<c>(c) -> or<c, <b>(b... *TRUNCATED*");
}
else
{
LUAU_REQUIRE_NO_ERRORS(result);
ToStringOptions o;
o.exhaustive = true;
if (FFlag::DebugLuauDeferredConstraintResolution)
{
o.maxTypeLength = 30;
CHECK_EQ(toString(requireType("f0"), o), "() -> ()");
CHECK_EQ(toString(requireType("f1"), o), "<a>(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*");
CHECK_EQ(toString(requireType("f2"), o), "<b>(b) -> (<a>(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*");
CHECK_EQ(toString(requireType("f3"), o), "<c>(c) -> (<b>(b) -> (<a>(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*");
}
else
{
o.maxTypeLength = 40;
CHECK_EQ(toString(requireType("f0"), o), "() -> ()");
CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()");

View File

@ -2351,8 +2351,9 @@ end
LUAU_REQUIRE_ERRORS(result);
auto err = get<ExplicitFunctionAnnotationRecommended>(result.errors.back());
LUAU_ASSERT(err);
CHECK("false | number" == toString(err->recommendedReturn));
CHECK(err->recommendedArgs.size() == 0);
CHECK("number" == toString(err->recommendedReturn));
REQUIRE(1 == err->recommendedArgs.size());
CHECK("number" == toString(err->recommendedArgs[0].second));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type")
@ -2673,4 +2674,17 @@ TEST_CASE_FIXTURE(Fixture, "captured_local_is_assigned_a_function")
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "error_suppression_propagates_through_function_calls")
{
CheckResult result = check(R"(
function first(x: any)
return pairs(x)(x)
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK("(any) -> (any?, any)" == toString(requireType("first")));
}
TEST_SUITE_END();

View File

@ -1010,7 +1010,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict")
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_add_an_indexer")
TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_retroactively_add_an_indexer")
{
CheckResult result = check(R"(
--!strict
@ -1025,7 +1025,12 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_add_an_indexer")
)");
if (FFlag::DebugLuauDeferredConstraintResolution)
LUAU_REQUIRE_ERROR_COUNT(2, result);
{
// We regress a little here: The old solver would typecheck the first
// access to prices.wwwww on a table that had no indexer, and the second
// on a table that does.
LUAU_REQUIRE_ERROR_COUNT(0, result);
}
else
LUAU_REQUIRE_ERROR_COUNT(1, result);
}
@ -1114,4 +1119,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "forin_metatable_iter_mm")
CHECK_EQ("number", toString(requireTypeAtPosition({6, 21})));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_preserves_error_suppression")
{
CheckResult result = check(R"(
function first(x: any)
for k, v in pairs(x) do
print(k, v)
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK("any" == toString(requireTypeAtPosition({3, 22})));
CHECK("any" == toString(requireTypeAtPosition({3, 25})));
}
TEST_SUITE_END();

View File

@ -21,7 +21,6 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping);
LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls);
LUAU_FASTFLAG(LuauFixIndexerSubtypingOrdering);
LUAU_FASTFLAG(DebugLuauSharedSelf);
LUAU_FASTFLAG(LuauReadWritePropertySyntax);
LUAU_FASTFLAG(LuauMetatableInstantiationCloneCheck);
LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError)
@ -2729,7 +2728,9 @@ TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals")
TEST_CASE_FIXTURE(Fixture, "should_not_unblock_table_type_twice")
{
ScopedFastFlag sff(FFlag::DebugLuauDeferredConstraintResolution, true);
// don't run this when the DCR flag isn't set
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
check(R"(
local timer = peek(timerQueue)
@ -4014,7 +4015,6 @@ TEST_CASE_FIXTURE(Fixture, "identify_all_problematic_table_fields")
TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported")
{
ScopedFastFlag sff[] = {
{FFlag::LuauReadWritePropertySyntax, true},
{FFlag::DebugLuauDeferredConstraintResolution, false},
};
@ -4040,8 +4040,6 @@ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported
TEST_CASE_FIXTURE(Fixture, "read_ond_write_only_indexers_are_unsupported")
{
ScopedFastFlag sff{FFlag::LuauReadWritePropertySyntax, true};
CheckResult result = check(R"(
type T = {read [string]: number}
type U = {write [string]: boolean}
@ -4155,7 +4153,9 @@ TEST_CASE_FIXTURE(Fixture, "write_annotations_are_unsupported_even_with_the_new_
TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported")
{
ScopedFastFlag sff[] = {{FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, false}};
ScopedFastFlag sff[] = {
{FFlag::DebugLuauDeferredConstraintResolution, false}
};
CheckResult result = check(R"(
type W = {read x: number}
@ -4179,7 +4179,9 @@ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported
TEST_CASE_FIXTURE(Fixture, "read_ond_write_only_indexers_are_unsupported")
{
ScopedFastFlag sff[] = {{FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, false}};
ScopedFastFlag sff[] = {
{FFlag::DebugLuauDeferredConstraintResolution, false}
};
CheckResult result = check(R"(
type T = {read [string]: number}
@ -4199,7 +4201,9 @@ TEST_CASE_FIXTURE(Fixture, "table_writes_introduce_write_properties")
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
ScopedFastFlag sff[] = {{FFlag::LuauReadWritePropertySyntax, true}, {FFlag::DebugLuauDeferredConstraintResolution, true}};
ScopedFastFlag sff[] = {
{FFlag::DebugLuauDeferredConstraintResolution, true}
};
CheckResult result = check(R"(
function oc(player, speaker)
@ -4439,4 +4443,21 @@ TEST_CASE_FIXTURE(Fixture, "insert_a_and_f_of_a_into_table_res_in_a_loop")
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_adds_an_unbounded_indexer")
{
CheckResult result = check(R"(
--!strict
local a = {}
ipairs(a)
)");
// The old solver erroneously leaves a free type dangling here. The new
// solver does better.
if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK("{unknown}" == toString(requireType("a"), {true}));
else
CHECK("{a}" == toString(requireType("a"), {true}));
}
TEST_SUITE_END();

View File

@ -4,6 +4,12 @@ print('testing vectors')
-- detect vector size
local vector_size = if pcall(function() return vector(0, 0, 0).w end) then 4 else 3
function ecall(fn, ...)
local ok, err = pcall(fn, ...)
assert(not ok)
return err:sub((err:find(": ") or -1) + 2, #err)
end
-- equality
assert(vector(1, 2, 3) == vector(1, 2, 3))
assert(vector(0, 1, 2) == vector(-0, 1, 2))
@ -92,9 +98,29 @@ assert(nanv ~= nanv);
-- __index
assert(vector(1, 2, 2).Magnitude == 3)
assert(vector(0, 0, 0)['Dot'](vector(1, 2, 4), vector(5, 6, 7)) == 45)
assert(vector(2, 0, 0).Unit == vector(1, 0, 0))
-- __namecall
assert(vector(1, 2, 4):Dot(vector(5, 6, 7)) == 45)
assert(ecall(function() vector(1, 2, 4):Dot() end) == "missing argument #2 (vector expected)")
assert(ecall(function() vector(1, 2, 4):Dot("a") end) == "invalid argument #2 (vector expected, got string)")
local function doDot1(a: vector, b)
return a:Dot(b)
end
local function doDot2(a: vector, b)
return (a:Dot(b))
end
local v124 = vector(1, 2, 4)
assert(doDot1(v124, vector(5, 6, 7)) == 45)
assert(doDot2(v124, vector(5, 6, 7)) == 45)
assert(ecall(function() doDot1(v124, "a") end) == "invalid argument #2 (vector expected, got string)")
assert(ecall(function() doDot2(v124, "a") end) == "invalid argument #2 (vector expected, got string)")
assert(select("#", doDot1(v124, vector(5, 6, 7))) == 1)
assert(select("#", doDot2(v124, vector(5, 6, 7))) == 1)
-- can't use vector with NaN components as table key
assert(pcall(function() local t = {} t[vector(0/0, 2, 3)] = 1 end) == false)
@ -102,6 +128,9 @@ assert(pcall(function() local t = {} t[vector(1, 0/0, 3)] = 1 end) == false)
assert(pcall(function() local t = {} t[vector(1, 2, 0/0)] = 1 end) == false)
assert(pcall(function() local t = {} rawset(t, vector(0/0, 2, 3), 1) end) == false)
assert(vector(1, 0, 0):Cross(vector(0, 1, 0)) == vector(0, 0, 1))
assert(vector(0, 1, 0):Cross(vector(1, 0, 0)) == vector(0, 0, -1))
-- make sure we cover both builtin and C impl
assert(vector(1, 2, 4) == vector("1", "2", "4"))

View File

@ -4,6 +4,7 @@ AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg
AutocompleteTest.autocomplete_string_singletons
AutocompleteTest.do_wrong_compatible_nonself_calls
AutocompleteTest.string_singleton_as_table_key
AutocompleteTest.string_singleton_in_if_statement2
AutocompleteTest.suggest_table_keys
AutocompleteTest.type_correct_suggestion_for_overloads
AutocompleteTest.type_correct_suggestion_in_table
@ -133,9 +134,11 @@ RefinementTest.call_an_incompatible_function_after_using_typeguard
RefinementTest.dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never
RefinementTest.discriminate_from_isa_of_x
RefinementTest.discriminate_from_truthiness_of_x
RefinementTest.function_call_with_colon_after_refining_not_to_be_nil
RefinementTest.free_type_is_equal_to_an_lvalue
RefinementTest.globals_can_be_narrowed_too
RefinementTest.isa_type_refinement_must_be_known_ahead_of_time
RefinementTest.luau_polyfill_isindexkey_refine_conjunction
RefinementTest.luau_polyfill_isindexkey_refine_conjunction_variant
RefinementTest.not_t_or_some_prop_of_t
RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage
RefinementTest.refine_a_property_of_some_global
@ -154,6 +157,7 @@ TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible
TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible
TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode
TableTests.array_factory_function
TableTests.cannot_augment_sealed_table
TableTests.casting_tables_with_props_into_table_with_indexer2
TableTests.casting_tables_with_props_into_table_with_indexer3
TableTests.casting_unsealed_tables_with_props_into_table_with_indexer
@ -177,6 +181,7 @@ TableTests.generalize_table_argument
TableTests.generic_table_instantiation_potential_regression
TableTests.indexer_on_sealed_table_must_unify_with_free_table
TableTests.indexers_get_quantified_too
TableTests.inequality_operators_imply_exactly_matching_types
TableTests.infer_array
TableTests.infer_indexer_from_array_like_table
TableTests.infer_indexer_from_its_variable_type_and_unifiable
@ -206,6 +211,7 @@ TableTests.quantify_even_that_table_was_never_exported_at_all
TableTests.quantify_metatables_of_metatables_of_table
TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table
TableTests.recursive_metatable_type_call
TableTests.refined_thing_can_be_an_array
TableTests.right_table_missing_key2
TableTests.scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type
TableTests.scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type
@ -214,6 +220,7 @@ TableTests.setmetatable_has_a_side_effect
TableTests.shared_selfs
TableTests.shared_selfs_from_free_param
TableTests.shared_selfs_through_metatables
TableTests.should_not_unblock_table_type_twice
TableTests.table_call_metamethod_basic
TableTests.table_call_metamethod_must_be_callable
TableTests.table_param_width_subtyping_2
@ -236,6 +243,8 @@ ToString.named_metatable_toStringNamedFunction
ToString.no_parentheses_around_cyclic_function_type_in_intersection
ToString.pick_distinct_names_for_mixed_explicit_and_implicit_generics
ToString.primitive
ToString.quit_stringifying_type_when_length_is_exceeded
ToString.stringifying_type_is_still_capped_when_exhaustive
ToString.toStringDetailed2
ToString.toStringErrorPack
TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType
@ -332,8 +341,10 @@ TypeInferFunctions.occurs_check_failure_in_function_return_type
TypeInferFunctions.other_things_are_not_related_to_function
TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible
TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible_2
TypeInferFunctions.regex_benchmark_string_format_minimization
TypeInferFunctions.report_exiting_without_return_nonstrict
TypeInferFunctions.return_type_by_overload
TypeInferFunctions.tf_suggest_return_type
TypeInferFunctions.too_few_arguments_variadic
TypeInferFunctions.too_few_arguments_variadic_generic
TypeInferFunctions.too_few_arguments_variadic_generic2
@ -406,6 +417,7 @@ TypeSingletons.error_detailed_tagged_union_mismatch_bool
TypeSingletons.error_detailed_tagged_union_mismatch_string
TypeSingletons.overloaded_function_call_with_singletons_mismatch
TypeSingletons.return_type_of_f_is_not_widened
TypeSingletons.singletons_stick_around_under_assignment
TypeSingletons.table_properties_type_error_escapes
TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton
TypeStatesTest.typestates_preserve_error_suppression_properties