Sync to upstream/release/523 (#459)

This commit is contained in:
Arseny Kapoulkine 2022-04-14 16:57:43 -07:00 committed by GitHub
parent d37d0c857b
commit 8e7845076b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
76 changed files with 4575 additions and 639 deletions

View File

@ -14,12 +14,15 @@ using SeenTypePacks = std::unordered_map<TypePackId, TypePackId>;
struct CloneState
{
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
int recursionCount = 0;
bool encounteredFreeType = false;
};
TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState);
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
} // namespace Luau

View File

@ -287,12 +287,20 @@ struct TypesAreUnrelated
bool operator==(const TypesAreUnrelated& rhs) const;
};
struct NormalizationTooComplex
{
bool operator==(const NormalizationTooComplex&) const
{
return true;
}
};
using TypeErrorData =
Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods, DuplicateTypeDefinition,
CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire, IncorrectGenericParameterCount, SyntaxError,
CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed,
ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning, DuplicateGenericParameter, CannotInferBinaryOperation,
MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, TypesAreUnrelated>;
MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, TypesAreUnrelated, NormalizationTooComplex>;
struct TypeError
{

View File

@ -70,6 +70,7 @@ struct SourceNode
std::vector<std::pair<ModuleName, Location>> requireLocations;
bool dirty = true;
bool dirtyAutocomplete = true;
double autocompleteLimitsMult = 1.0;
};
struct FrontendOptions

View File

@ -35,8 +35,12 @@ const LValue* baseof(const LValue& lvalue);
std::optional<LValue> tryGetLValue(const class AstExpr& expr);
// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys.
// TODO: remove with FFlagLuauTypecheckOptPass
std::pair<Symbol, std::vector<std::string>> getFullName(const LValue& lvalue);
// Utility function: breaks down an LValue to get at the Symbol
Symbol getBaseSymbol(const LValue& lvalue);
template<typename T>
const T* get(const LValue& lvalue)
{

View File

@ -113,7 +113,7 @@ struct Module
// This helps us to force TypeVar ownership into a DAG rather than a DCG.
// Returns true if there were any free types encountered in the public interface. This
// indicates a bug in the type checker that we want to surface.
bool clonePublicInterface();
bool clonePublicInterface(InternalErrorReporter& ice);
};
} // namespace Luau

View File

@ -0,0 +1,19 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Substitution.h"
#include "Luau/TypeVar.h"
#include "Luau/Module.h"
namespace Luau
{
struct InternalErrorReporter;
bool isSubtype(TypeId superTy, TypeId subTy, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, TypeArena& arena, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, const ModulePtr& module, InternalErrorReporter& ice);
} // namespace Luau

View File

@ -4,10 +4,21 @@
#include "Luau/Common.h"
#include <stdexcept>
#include <exception>
LUAU_FASTFLAG(LuauRecursionLimitException);
namespace Luau
{
struct RecursionLimitException : public std::exception
{
const char* what() const noexcept
{
return "Internal recursion counter limit exceeded";
}
};
struct RecursionCounter
{
RecursionCounter(int* count)
@ -28,11 +39,22 @@ private:
struct RecursionLimiter : RecursionCounter
{
RecursionLimiter(int* count, int limit)
// TODO: remove ctx after LuauRecursionLimitException is removed
RecursionLimiter(int* count, int limit, const char* ctx)
: RecursionCounter(count)
{
LUAU_ASSERT(ctx);
if (limit > 0 && *count > limit)
throw std::runtime_error("Internal recursion counter limit exceeded");
{
if (FFlag::LuauRecursionLimitException)
throw RecursionLimitException();
else
{
std::string m = "Internal recursion counter limit exceeded: ";
m += ctx;
throw std::runtime_error(m);
}
}
}
};

View File

@ -90,6 +90,7 @@ struct Tarjan
std::vector<int> lowlink;
int childCount = 0;
int childLimit = 0;
// This should never be null; ensure you initialize it before calling
// substitution methods.

View File

@ -28,6 +28,7 @@ struct ToStringOptions
bool functionTypeArguments = false; // If true, output function type argument names when they are available
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
bool indent = false;
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
std::optional<ToStringNameMap> nameMap;
@ -73,6 +74,8 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp
std::string dump(TypeId ty);
std::string dump(TypePackId ty);
std::string dump(const std::shared_ptr<Scope>& scope, const char* name);
std::string generateName(size_t n);
} // namespace Luau

View File

@ -7,7 +7,7 @@
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
LUAU_FASTFLAG(LuauShareTxnSeen);
LUAU_FASTFLAG(LuauTypecheckOptPass)
namespace Luau
{
@ -64,13 +64,17 @@ T* getMutable(PendingTypePack* pending)
struct TxnLog
{
TxnLog()
: ownedSeen()
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, ownedSeen()
, sharedSeen(&ownedSeen)
{
}
explicit TxnLog(TxnLog* parent)
: parent(parent)
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, parent(parent)
{
if (parent)
{
@ -83,14 +87,19 @@ struct TxnLog
}
explicit TxnLog(std::vector<std::pair<TypeOrPackId, TypeOrPackId>>* sharedSeen)
: sharedSeen(sharedSeen)
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, sharedSeen(sharedSeen)
{
}
TxnLog(TxnLog* parent, std::vector<std::pair<TypeOrPackId, TypeOrPackId>>* sharedSeen)
: parent(parent)
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, parent(parent)
, sharedSeen(sharedSeen)
{
LUAU_ASSERT(!FFlag::LuauTypecheckOptPass);
}
TxnLog(const TxnLog&) = delete;
@ -243,6 +252,12 @@ struct TxnLog
return Luau::getMutable<T>(ty);
}
template<typename T, typename TID>
const T* get(TID ty) const
{
return this->getMutable<T>(ty);
}
// Returns whether a given type or type pack is a given state, respecting the
// log's pending state.
//
@ -263,11 +278,8 @@ private:
// unique_ptr is used to give us stable pointers across insertions into the
// map. Otherwise, it would be really easy to accidentally invalidate the
// pointers returned from queue/pending.
//
// We can't use a DenseHashMap here because we need a non-const iterator
// over the map when we concatenate.
std::unordered_map<TypeId, std::unique_ptr<PendingType>, DenseHashPointer> typeVarChanges;
std::unordered_map<TypePackId, std::unique_ptr<PendingTypePack>, DenseHashPointer> typePackChanges;
DenseHashMap<TypeId, std::unique_ptr<PendingType>> typeVarChanges;
DenseHashMap<TypePackId, std::unique_ptr<PendingTypePack>> typePackChanges;
TxnLog* parent = nullptr;

View File

@ -76,19 +76,32 @@ struct Instantiation : Substitution
// A substitution which replaces free types by any
struct Anyification : Substitution
{
Anyification(TypeArena* arena, TypeId anyType, TypePackId anyTypePack)
Anyification(TypeArena* arena, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack)
: Substitution(TxnLog::empty(), arena)
, iceHandler(iceHandler)
, anyType(anyType)
, anyTypePack(anyTypePack)
{
}
InternalErrorReporter* iceHandler;
TypeId anyType;
TypePackId anyTypePack;
bool normalizationTooComplex = false;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
bool ignoreChildren(TypeId ty) override
{
return ty->persistent;
}
bool ignoreChildren(TypePackId ty) override
{
return ty->persistent;
}
};
// A substitution which replaces the type parameters of a type function by arguments
@ -139,6 +152,7 @@ struct TypeChecker
TypeChecker& operator=(const TypeChecker&) = delete;
ModulePtr check(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope = std::nullopt);
ModulePtr checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope = std::nullopt);
std::vector<std::pair<Location, ScopePtr>> getScopes() const;
@ -160,6 +174,7 @@ struct TypeChecker
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
void checkBlock(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockTypeAliases(const ScopePtr& scope, std::vector<AstStat*>& sorted);
ExprResult<TypeId> checkExpr(
@ -172,6 +187,7 @@ struct TypeChecker
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional<TypeId> expectedType = std::nullopt);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType = std::nullopt);
ExprResult<TypeId> checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType = std::nullopt);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprUnary& expr);
TypeId checkRelationalOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {});
@ -258,6 +274,8 @@ struct TypeChecker
ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location);
ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location);
void unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location);
@ -395,6 +413,7 @@ private:
void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
bool isNonstrictMode() const;
bool useConstrainedIntersections() const;
public:
/** Extract the types in a type pack, given the assumption that the pack must have some exact length.
@ -421,7 +440,10 @@ public:
std::vector<RequireCycle> requireCycles;
// Type inference limits
std::optional<double> finishTime;
std::optional<int> instantiationChildLimit;
std::optional<int> unifierIterationLimit;
public:
const TypeId nilType;

View File

@ -40,6 +40,7 @@ struct TypePack
struct VariadicTypePack
{
TypeId ty;
bool hidden = false; // if true, we don't display this when toString()ing a pack with this variadic as its tail.
};
struct TypePackVar
@ -109,10 +110,10 @@ private:
};
TypePackIterator begin(TypePackId tp);
TypePackIterator begin(TypePackId tp, TxnLog* log);
TypePackIterator begin(TypePackId tp, const TxnLog* log);
TypePackIterator end(TypePackId tp);
using SeenSet = std::set<std::pair<void*, void*>>;
using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs);
@ -122,7 +123,7 @@ TypePackId follow(TypePackId tp, std::function<TypePackId(TypePackId)> mapper);
size_t size(TypePackId tp, TxnLog* log = nullptr);
bool finite(TypePackId tp, TxnLog* log = nullptr);
size_t size(const TypePack& tp, TxnLog* log = nullptr);
std::optional<TypeId> first(TypePackId tp);
std::optional<TypeId> first(TypePackId tp, bool ignoreHiddenVariadics = true);
TypePackVar* asMutable(TypePackId tp);
TypePack* asMutable(const TypePack* tp);
@ -154,5 +155,12 @@ bool isEmpty(TypePackId tp);
/// Flattens out a type pack. Also returns a valid TypePackId tail if the type pack's full size is not known
std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp);
std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp, const TxnLog& log);
/// Returs true if the type pack arose from a function that is declared to be variadic.
/// Returns *false* for function argument packs that are inferred to be safe to oversaturate!
bool isVariadic(TypePackId tp);
bool isVariadic(TypePackId tp, const TxnLog& log);
} // namespace Luau

View File

@ -109,6 +109,23 @@ struct PrimitiveTypeVar
}
};
struct ConstrainedTypeVar
{
explicit ConstrainedTypeVar(TypeLevel level)
: level(level)
{
}
explicit ConstrainedTypeVar(TypeLevel level, const std::vector<TypeId>& parts)
: parts(parts)
, level(level)
{
}
std::vector<TypeId> parts;
TypeLevel level;
};
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false
struct BooleanSingleton
@ -248,6 +265,7 @@ struct FunctionTypeVar
MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr.
bool hasSelf;
Tags tags;
bool hasNoGenerics = false;
};
enum class TableState
@ -418,8 +436,8 @@ struct LazyTypeVar
using ErrorTypeVar = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar, MetatableTypeVar, ClassTypeVar,
AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar>;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, ConstrainedTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar,
MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar>;
struct TypeVar final
{
@ -436,6 +454,7 @@ struct TypeVar final
TypeVar(const TypeVariant& ty, bool persistent)
: ty(ty)
, persistent(persistent)
, normal(persistent) // We assume that all persistent types are irreducable.
{
}
@ -446,6 +465,10 @@ struct TypeVar final
// Persistent TypeVars do not get cloned.
bool persistent = false;
// Normalization sets this for types that are fully normalized.
// This implies that they are transitively immutable.
bool normal = false;
std::optional<std::string> documentationSymbol;
// Pointer to the type arena that allocated this type.
@ -458,7 +481,7 @@ struct TypeVar final
TypeVar& operator=(TypeVariant&& rhs);
};
using SeenSet = std::set<std::pair<void*, void*>>;
using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs);
// Follow BoundTypeVars until we get to something real
@ -545,6 +568,8 @@ void persist(TypePackId tp);
const TypeLevel* getLevel(TypeId ty);
TypeLevel* getMutableLevel(TypeId ty);
std::optional<TypeLevel> getLevel(TypePackId tp);
const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name);
bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent);

View File

@ -56,6 +56,14 @@ struct TypeLevel
}
};
inline TypeLevel max(const TypeLevel& a, const TypeLevel& b)
{
if (a.subsumes(b))
return b;
else
return a;
}
inline TypeLevel min(const TypeLevel& a, const TypeLevel& b)
{
if (a.subsumes(b))
@ -64,7 +72,9 @@ inline TypeLevel min(const TypeLevel& a, const TypeLevel& b)
return b;
}
namespace Unifiable
} // namespace Luau
namespace Luau::Unifiable
{
using Name = std::string;
@ -125,7 +135,6 @@ private:
};
template<typename Id, typename... Value>
using Variant = Variant<Free, Bound<Id>, Generic, Error, Value...>;
using Variant = Luau::Variant<Free, Bound<Id>, Generic, Error, Value...>;
} // namespace Unifiable
} // namespace Luau
} // namespace Luau::Unifiable

View File

@ -49,14 +49,14 @@ struct Unifier
ErrorVec errors;
Location location;
Variance variance = Covariant;
bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once.
CountMismatch::Context ctx = CountMismatch::Arg;
UnifierSharedState& sharedState;
Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState,
TxnLog* parentLog = nullptr);
Unifier(TypeArena* types, Mode mode, std::vector<std::pair<TypeOrPackId, TypeOrPackId>>* sharedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr);
Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr);
Unifier(TypeArena* types, Mode mode, std::vector<std::pair<TypeOrPackId, TypeOrPackId>>* sharedSeen, const Location& location, Variance variance,
UnifierSharedState& sharedState, TxnLog* parentLog = nullptr);
// Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId subTy, TypeId superTy);
@ -106,7 +106,12 @@ private:
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name);
void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy);
void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy);
public:
void unifyLowerBound(TypePackId subTy, TypePackId superTy);
// Report an "infinite type error" if the type "needle" already occurs within "haystack"
void occursCheck(TypeId needle, TypeId haystack);
void occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);
@ -115,12 +120,7 @@ public:
Unifier makeChildUnifier();
// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
void reportError(TypeError error)
{
errors.push_back(error);
}
void reportError(TypeError err);
private:
bool isNonstrictMode() const;
@ -135,4 +135,6 @@ private:
std::optional<int> firstPackErrorPos;
};
void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, TypePackId tp);
} // namespace Luau

View File

@ -28,7 +28,9 @@ struct TypeIdPairHash
struct UnifierCounters
{
int recursionCount = 0;
int recursionLimit = 0;
int iterationCount = 0;
int iterationLimit = 0;
};
struct UnifierSharedState

View File

@ -82,6 +82,15 @@ void visit(TypeId ty, F& f, Set& seen)
else if (auto etv = get<ErrorTypeVar>(ty))
apply(ty, *etv, seen, f);
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
if (apply(ty, *ctv, seen, f))
{
for (TypeId part : ctv->parts)
visit(part, f, seen);
}
}
else if (auto ptv = get<PrimitiveTypeVar>(ty))
apply(ty, *ptv, seen, f);

View File

@ -151,8 +151,12 @@ static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionTyp
auto idxExpr = nodes.back()->as<AstExprIndexName>();
bool hasImplicitSelf = idxExpr && idxExpr->op == ':';
auto args = Luau::flatten(func->argTypes);
bool noArgFunction = (args.first.empty() || (hasImplicitSelf && args.first.size() == 1)) && !args.second.has_value();
auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes);
if (argVariadicPack.has_value() && isVariadic(*argVariadicPack))
return ParenthesesRecommendation::CursorInside;
bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1);
return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside;
}

View File

@ -6,7 +6,10 @@
#include "Luau/TypePack.h"
#include "Luau/Unifiable.h"
LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing)
LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300)
LUAU_FASTFLAG(LuauTypecheckOptPass)
namespace Luau
{
@ -23,11 +26,11 @@ struct TypePackCloner;
struct TypeCloner
{
TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState)
: dest(dest)
, typeId(typeId)
, seenTypes(seenTypes)
, seenTypePacks(seenTypePacks)
, seenTypes(cloneState.seenTypes)
, seenTypePacks(cloneState.seenTypePacks)
, cloneState(cloneState)
{
}
@ -46,6 +49,7 @@ struct TypeCloner
void operator()(const Unifiable::Bound<TypeId>& t);
void operator()(const Unifiable::Error& t);
void operator()(const PrimitiveTypeVar& t);
void operator()(const ConstrainedTypeVar& t);
void operator()(const SingletonTypeVar& t);
void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t);
@ -65,11 +69,11 @@ struct TypePackCloner
SeenTypePacks& seenTypePacks;
CloneState& cloneState;
TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState)
: dest(dest)
, typePackId(typePackId)
, seenTypes(seenTypes)
, seenTypePacks(seenTypePacks)
, seenTypes(cloneState.seenTypes)
, seenTypePacks(cloneState.seenTypePacks)
, cloneState(cloneState)
{
}
@ -103,13 +107,15 @@ struct TypePackCloner
// We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer.
void operator()(const Unifiable::Bound<TypePackId>& t)
{
TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState);
TypePackId cloned = clone(t.boundTo, dest, cloneState);
if (FFlag::DebugLuauCopyBeforeNormalizing)
cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}});
seenTypePacks[typePackId] = cloned;
}
void operator()(const VariadicTypePack& t)
{
TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}});
TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}});
seenTypePacks[typePackId] = cloned;
}
@ -121,10 +127,10 @@ struct TypePackCloner
seenTypePacks[typePackId] = cloned;
for (TypeId ty : t.head)
destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
destTp->head.push_back(clone(ty, dest, cloneState));
if (t.tail)
destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState);
destTp->tail = clone(*t.tail, dest, cloneState);
}
};
@ -150,7 +156,9 @@ void TypeCloner::operator()(const Unifiable::Generic& t)
void TypeCloner::operator()(const Unifiable::Bound<TypeId>& t)
{
TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState);
TypeId boundTo = clone(t.boundTo, dest, cloneState);
if (FFlag::DebugLuauCopyBeforeNormalizing)
boundTo = dest.addType(BoundTypeVar{boundTo});
seenTypes[typeId] = boundTo;
}
@ -164,6 +172,23 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t)
defaultClone(t);
}
void TypeCloner::operator()(const ConstrainedTypeVar& t)
{
cloneState.encounteredFreeType = true;
TypeId res = dest.addType(ConstrainedTypeVar{t.level});
ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(res);
LUAU_ASSERT(ctv);
seenTypes[typeId] = res;
std::vector<TypeId> parts;
for (TypeId part : t.parts)
parts.push_back(clone(part, dest, cloneState));
ctv->parts = std::move(parts);
}
void TypeCloner::operator()(const SingletonTypeVar& t)
{
defaultClone(t);
@ -178,23 +203,26 @@ void TypeCloner::operator()(const FunctionTypeVar& t)
seenTypes[typeId] = result;
for (TypeId generic : t.generics)
ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState));
ftv->generics.push_back(clone(generic, dest, cloneState));
for (TypePackId genericPack : t.genericPacks)
ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState));
ftv->genericPacks.push_back(clone(genericPack, dest, cloneState));
ftv->tags = t.tags;
ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState);
ftv->argTypes = clone(t.argTypes, dest, cloneState);
ftv->argNames = t.argNames;
ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState);
ftv->retType = clone(t.retType, dest, cloneState);
if (FFlag::LuauTypecheckOptPass)
ftv->hasNoGenerics = t.hasNoGenerics;
}
void TypeCloner::operator()(const TableTypeVar& t)
{
// If table is now bound to another one, we ignore the content of the original
if (t.boundTo)
if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo)
{
TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState);
TypeId boundTo = clone(*t.boundTo, dest, cloneState);
seenTypes[typeId] = boundTo;
return;
}
@ -209,18 +237,20 @@ void TypeCloner::operator()(const TableTypeVar& t)
ttv->level = TypeLevel{0, 0};
if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, cloneState);
for (const auto& [name, prop] : t.props)
ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags};
ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags};
if (t.indexer)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState),
clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)};
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)};
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState);
arg = clone(arg, dest, cloneState);
for (TypePackId& arg : ttv->instantiatedTypePackParams)
arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState);
arg = clone(arg, dest, cloneState);
if (ttv->state == TableState::Free)
{
@ -240,8 +270,8 @@ void TypeCloner::operator()(const MetatableTypeVar& t)
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(result);
seenTypes[typeId] = result;
mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState);
mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState);
mtv->table = clone(t.table, dest, cloneState);
mtv->metatable = clone(t.metatable, dest, cloneState);
}
void TypeCloner::operator()(const ClassTypeVar& t)
@ -252,13 +282,13 @@ void TypeCloner::operator()(const ClassTypeVar& t)
seenTypes[typeId] = result;
for (const auto& [name, prop] : t.props)
ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags};
ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags};
if (t.parent)
ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState);
ctv->parent = clone(*t.parent, dest, cloneState);
if (t.metatable)
ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState);
ctv->metatable = clone(*t.metatable, dest, cloneState);
}
void TypeCloner::operator()(const AnyTypeVar& t)
@ -272,7 +302,7 @@ void TypeCloner::operator()(const UnionTypeVar& t)
options.reserve(t.options.size());
for (TypeId ty : t.options)
options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
options.push_back(clone(ty, dest, cloneState));
TypeId result = dest.addType(UnionTypeVar{std::move(options)});
seenTypes[typeId] = result;
@ -287,7 +317,7 @@ void TypeCloner::operator()(const IntersectionTypeVar& t)
LUAU_ASSERT(option != nullptr);
for (TypeId ty : t.parts)
option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
option->parts.push_back(clone(ty, dest, cloneState));
}
void TypeCloner::operator()(const LazyTypeVar& t)
@ -297,36 +327,36 @@ void TypeCloner::operator()(const LazyTypeVar& t)
} // anonymous namespace
TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
{
if (tp->persistent)
return tp;
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypePackId");
TypePackId& res = seenTypePacks[tp];
TypePackId& res = cloneState.seenTypePacks[tp];
if (res == nullptr)
{
TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState};
TypePackCloner cloner{dest, tp, cloneState};
Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into.
}
return res;
}
TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
{
if (typeId->persistent)
return typeId;
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypeId");
TypeId& res = seenTypes[typeId];
TypeId& res = cloneState.seenTypes[typeId];
if (res == nullptr)
{
TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState};
TypeCloner cloner{dest, typeId, cloneState};
Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into.
// Persistent types are not being cloned and we get the original type back which might be read-only
@ -337,33 +367,33 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks
return res;
}
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
{
TypeFun result;
for (auto param : typeFun.typeParams)
{
TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState);
TypeId ty = clone(param.ty, dest, cloneState);
std::optional<TypeId> defaultValue;
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState);
defaultValue = clone(*param.defaultValue, dest, cloneState);
result.typeParams.push_back({ty, defaultValue});
}
for (auto param : typeFun.typePackParams)
{
TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState);
TypePackId tp = clone(param.tp, dest, cloneState);
std::optional<TypePackId> defaultValue;
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState);
defaultValue = clone(*param.defaultValue, dest, cloneState);
result.typePackParams.push_back({tp, defaultValue});
}
result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState);
result.type = clone(typeFun.type, dest, cloneState);
return result;
}

View File

@ -8,7 +8,6 @@
#include <stdexcept>
LUAU_FASTFLAGVARIABLE(BetterDiagnosticCodesInStudio, false);
LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleName, false);
static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
@ -252,14 +251,7 @@ struct ErrorConverter
std::string operator()(const Luau::SyntaxError& e) const
{
if (FFlag::BetterDiagnosticCodesInStudio)
{
return e.message;
}
else
{
return "Syntax error: " + e.message;
}
return e.message;
}
std::string operator()(const Luau::CodeTooComplex&) const
@ -451,6 +443,11 @@ struct ErrorConverter
{
return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated";
}
std::string operator()(const NormalizationTooComplex&) const
{
return "Code is too complex to typecheck! Consider simplifying the code around this area";
}
};
struct InvalidNameChecker
@ -716,14 +713,14 @@ bool containsParseErrorName(const TypeError& error)
}
template<typename T>
void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState cloneState)
void copyError(T& e, TypeArena& destArena, CloneState cloneState)
{
auto clone = [&](auto&& ty) {
return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks, cloneState);
return ::Luau::clone(ty, destArena, cloneState);
};
auto visitErrorData = [&](auto&& e) {
copyError(e, destArena, seenTypes, seenTypePacks, cloneState);
copyError(e, destArena, cloneState);
};
if constexpr (false)
@ -844,18 +841,19 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks&
e.left = clone(e.left);
e.right = clone(e.right);
}
else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
{
}
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}
void copyErrors(ErrorVec& errors, TypeArena& destArena)
{
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
auto visitErrorData = [&](auto&& e) {
copyError(e, destArena, seenTypes, seenTypePacks, cloneState);
copyError(e, destArena, cloneState);
};
LUAU_ASSERT(!destArena.typeVars.isFrozen());

View File

@ -11,16 +11,18 @@
#include "Luau/TimeTrace.h"
#include "Luau/TypeInfer.h"
#include "Luau/Variant.h"
#include "Luau/Common.h"
#include <algorithm>
#include <chrono>
#include <stdexcept>
LUAU_FASTINT(LuauTypeInferIterationLimit)
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauCyclicModuleTypeSurface)
LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false)
LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 0)
namespace Luau
@ -97,13 +99,11 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
if (checkedModule->errors.size() > 0)
return LoadDefinitionFileResult{false, parseResult, checkedModule};
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState);
TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
@ -113,7 +113,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState);
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;
@ -440,13 +440,42 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
else
typeCheckerForAutocomplete.finishTime = std::nullopt;
if (FFlag::LuauAutocompleteDynamicLimits)
{
// TODO: This is a dirty ad hoc solution for autocomplete timeouts
// We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit
// so that we'll have type information for the whole file at lower quality instead of a full abort in the middle
if (FInt::LuauTarjanChildLimit > 0)
typeCheckerForAutocomplete.instantiationChildLimit =
std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt;
if (FInt::LuauTypeInferIterationLimit > 0)
typeCheckerForAutocomplete.unifierIterationLimit =
std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt;
}
ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict);
moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete;
double duration = getTimestamp() - timestamp;
if (moduleForAutocomplete->timeout)
{
checkResult.timeoutHits.push_back(moduleName);
stats.timeCheck += getTimestamp() - timestamp;
if (FFlag::LuauAutocompleteDynamicLimits)
sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
}
else if (FFlag::LuauAutocompleteDynamicLimits && duration < autocompleteTimeLimit / 2.0)
{
sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0);
}
stats.timeCheck += duration;
stats.filesStrict += 1;
sourceNode.dirtyAutocomplete = false;

View File

@ -184,6 +184,8 @@ static void errorToString(std::ostream& stream, const T& err)
}
else if constexpr (std::is_same_v<T, TypesAreUnrelated>)
stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }";
else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
stream << "NormalizationTooComplex { }";
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}

View File

@ -5,6 +5,8 @@
#include <vector>
LUAU_FASTFLAG(LuauTypecheckOptPass)
namespace Luau
{
@ -79,6 +81,8 @@ std::optional<LValue> tryGetLValue(const AstExpr& node)
std::pair<Symbol, std::vector<std::string>> getFullName(const LValue& lvalue)
{
LUAU_ASSERT(!FFlag::LuauTypecheckOptPass);
const LValue* current = &lvalue;
std::vector<std::string> keys;
while (auto field = get<Field>(*current))
@ -92,6 +96,19 @@ std::pair<Symbol, std::vector<std::string>> getFullName(const LValue& lvalue)
return {*symbol, std::vector<std::string>(keys.rbegin(), keys.rend())};
}
Symbol getBaseSymbol(const LValue& lvalue)
{
LUAU_ASSERT(FFlag::LuauTypecheckOptPass);
const LValue* current = &lvalue;
while (auto field = get<Field>(*current))
current = baseof(*current);
const Symbol* symbol = get<Symbol>(*current);
LUAU_ASSERT(symbol);
return *symbol;
}
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f)
{
for (const auto& [k, a] : r)

View File

@ -14,7 +14,6 @@
LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4)
LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false)
LUAU_FASTFLAGVARIABLE(LuauLintNoRobloxBits, false)
namespace Luau
{
@ -1140,25 +1139,8 @@ private:
Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata.
Kind_Vector, // 'vector' but only used when type is used
Kind_Userdata, // custom userdata type
// TODO: remove these with LuauLintNoRobloxBits
Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc.
Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc.
};
bool containsPropName(TypeId ty, const std::string& propName)
{
LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits);
if (auto ctv = get<ClassTypeVar>(ty))
return lookupClassProp(ctv, propName) != nullptr;
if (auto ttv = get<TableTypeVar>(ty))
return ttv->props.find(propName) != ttv->props.end();
return false;
}
TypeKind getTypeKind(const std::string& name)
{
if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" ||
@ -1168,23 +1150,10 @@ private:
if (name == "vector")
return Kind_Vector;
if (FFlag::LuauLintNoRobloxBits)
{
if (std::optional<TypeFun> maybeTy = context->scope->lookupType(name))
return Kind_Userdata;
if (std::optional<TypeFun> maybeTy = context->scope->lookupType(name))
return Kind_Userdata;
return Kind_Unknown;
}
else
{
if (std::optional<TypeFun> maybeTy = context->scope->lookupType(name))
// Kind_Userdata is probably not 100% precise but is close enough
return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata;
else if (std::optional<TypeFun> maybeTy = context->scope->lookupImportedType("Enum", name))
return Kind_Enum;
return Kind_Unknown;
}
return Kind_Unknown;
}
void validateType(AstExprConstantString* expr, std::initializer_list<TypeKind> expected, const char* expectedString)
@ -1202,67 +1171,11 @@ private:
{
if (kind == ek)
return;
// as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type
if (!FFlag::LuauLintNoRobloxBits && ek == Kind_Userdata && (name == "Instance" || name == "EnumItem"))
return;
}
emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s' (expected %s)", name.c_str(), expectedString);
}
bool acceptsClassName(AstName method)
{
LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits);
return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" ||
method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA");
}
bool visit(AstExprCall* node) override
{
// TODO: Simply remove the override
if (FFlag::LuauLintNoRobloxBits)
return true;
if (AstExprIndexName* index = node->func->as<AstExprIndexName>())
{
AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as<AstExprConstantString>() : NULL;
if (arg0)
{
if (node->self && index->index == "IsA" && node->args.size == 1)
{
validateType(arg0, {Kind_Class, Kind_Enum}, "class or enum type");
}
else if (node->self && (index->index == "GetService" || index->index == "FindService") && node->args.size == 1)
{
AstExprGlobal* g = index->expr->as<AstExprGlobal>();
if (g && (g->name == "game" || g->name == "Game"))
{
validateType(arg0, {Kind_Class}, "class type");
}
}
else if (node->self && acceptsClassName(index->index) && node->args.size == 1)
{
validateType(arg0, {Kind_Class}, "class type");
}
else if (!node->self && index->index == "new" && node->args.size <= 2)
{
AstExprGlobal* g = index->expr->as<AstExprGlobal>();
if (g && g->name == "Instance")
{
validateType(arg0, {Kind_Class}, "class type");
}
}
}
}
return true;
}
bool visit(AstExprBinary* node) override
{
if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq)

View File

@ -1,8 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Module.h"
#include "Luau/Common.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/Normalize.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
@ -14,6 +15,7 @@
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false)
LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false)
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
namespace Luau
{
@ -143,32 +145,51 @@ Module::~Module()
unfreeze(internalTypes);
}
bool Module::clonePublicInterface()
bool Module::clonePublicInterface(InternalErrorReporter& ice)
{
LUAU_ASSERT(interfaceTypes.typeVars.empty());
LUAU_ASSERT(interfaceTypes.typePacks.empty());
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
ScopePtr moduleScope = getModuleScope();
moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, cloneState);
moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, cloneState);
if (moduleScope->varargPack)
moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState);
moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, cloneState);
if (FFlag::LuauLowerBoundsCalculation)
{
normalize(moduleScope->returnType, interfaceTypes, ice);
if (moduleScope->varargPack)
normalize(*moduleScope->varargPack, interfaceTypes, ice);
}
for (auto& [name, tf] : moduleScope->exportedTypeBindings)
tf = clone(tf, interfaceTypes, seenTypes, seenTypePacks, cloneState);
{
tf = clone(tf, interfaceTypes, cloneState);
if (FFlag::LuauLowerBoundsCalculation)
normalize(tf.type, interfaceTypes, ice);
}
for (TypeId ty : moduleScope->returnType)
{
if (get<GenericTypeVar>(follow(ty)))
*asMutable(ty) = AnyTypeVar{};
{
auto t = asMutable(ty);
t->ty = AnyTypeVar{};
t->normal = true;
}
}
if (FFlag::LuauCloneDeclaredGlobals)
{
for (auto& [name, ty] : declaredGlobals)
ty = clone(ty, interfaceTypes, seenTypes, seenTypePacks, cloneState);
{
ty = clone(ty, interfaceTypes, cloneState);
if (FFlag::LuauLowerBoundsCalculation)
normalize(ty, interfaceTypes, ice);
}
}
freeze(internalTypes);

814
Analysis/src/Normalize.cpp Normal file
View File

@ -0,0 +1,814 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Normalize.h"
#include <algorithm>
#include "Luau/Clone.h"
#include "Luau/DenseHash.h"
#include "Luau/Substitution.h"
#include "Luau/Unifier.h"
#include "Luau/VisitTypeVar.h"
LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false)
// This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineIntersectionFix, false);
namespace Luau
{
namespace
{
struct Replacer : Substitution
{
TypeId sourceType;
TypeId replacedType;
DenseHashMap<TypeId, TypeId> replacedTypes{nullptr};
DenseHashMap<TypePackId, TypePackId> replacedPacks{nullptr};
Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType)
: Substitution(TxnLog::empty(), arena)
, sourceType(sourceType)
, replacedType(replacedType)
{
}
bool isDirty(TypeId ty) override
{
if (!sourceType)
return false;
auto vecHasSourceType = [sourceType = sourceType](const auto& vec) {
return end(vec) != std::find(begin(vec), end(vec), sourceType);
};
// Walk every kind of TypeVar and find pointers to sourceType
if (auto t = get<FreeTypeVar>(ty))
return false;
else if (auto t = get<GenericTypeVar>(ty))
return false;
else if (auto t = get<ErrorTypeVar>(ty))
return false;
else if (auto t = get<PrimitiveTypeVar>(ty))
return false;
else if (auto t = get<ConstrainedTypeVar>(ty))
return vecHasSourceType(t->parts);
else if (auto t = get<SingletonTypeVar>(ty))
return false;
else if (auto t = get<FunctionTypeVar>(ty))
{
if (vecHasSourceType(t->generics))
return true;
return false;
}
else if (auto t = get<TableTypeVar>(ty))
{
if (t->boundTo)
return *t->boundTo == sourceType;
for (const auto& [_name, prop] : t->props)
{
if (prop.type == sourceType)
return true;
}
if (auto indexer = t->indexer)
{
if (indexer->indexType == sourceType || indexer->indexResultType == sourceType)
return true;
}
if (vecHasSourceType(t->instantiatedTypeParams))
return true;
return false;
}
else if (auto t = get<MetatableTypeVar>(ty))
return t->table == sourceType || t->metatable == sourceType;
else if (auto t = get<ClassTypeVar>(ty))
return false;
else if (auto t = get<AnyTypeVar>(ty))
return false;
else if (auto t = get<UnionTypeVar>(ty))
return vecHasSourceType(t->options);
else if (auto t = get<IntersectionTypeVar>(ty))
return vecHasSourceType(t->parts);
else if (auto t = get<LazyTypeVar>(ty))
return false;
LUAU_ASSERT(!"Luau::Replacer::isDirty internal error: Unknown TypeVar type");
LUAU_UNREACHABLE();
}
bool isDirty(TypePackId tp) override
{
if (auto it = replacedPacks.find(tp))
return false;
if (auto pack = get<TypePack>(tp))
{
for (TypeId ty : pack->head)
{
if (ty == sourceType)
return true;
}
return false;
}
else if (auto vtp = get<VariadicTypePack>(tp))
return vtp->ty == sourceType;
else
return false;
}
TypeId clean(TypeId ty) override
{
LUAU_ASSERT(sourceType && replacedType);
// Walk every kind of TypeVar and create a copy with sourceType replaced by replacedType
// Before returning, memoize the result for later use.
// Helpfully, Substitution::clone() only shallow-clones the kinds of types that we care to work with. This
// function returns the identity for things like primitives.
TypeId res = clone(ty);
if (auto t = get<FreeTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<GenericTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<ErrorTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<PrimitiveTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = getMutable<ConstrainedTypeVar>(res))
{
for (TypeId& part : t->parts)
{
if (part == sourceType)
part = replacedType;
}
}
else if (auto t = get<SingletonTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = getMutable<FunctionTypeVar>(res))
{
// The constituent typepacks are cleaned separately. We just need to walk the generics array.
for (TypeId& g : t->generics)
{
if (g == sourceType)
g = replacedType;
}
}
else if (auto t = getMutable<TableTypeVar>(res))
{
for (auto& [_key, prop] : t->props)
{
if (prop.type == sourceType)
prop.type = replacedType;
}
}
else if (auto t = getMutable<MetatableTypeVar>(res))
{
if (t->table == sourceType)
t->table = replacedType;
if (t->metatable == sourceType)
t->table = replacedType;
}
else if (auto t = get<ClassTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<AnyTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = getMutable<UnionTypeVar>(res))
{
for (TypeId& option : t->options)
{
if (option == sourceType)
option = replacedType;
}
}
else if (auto t = getMutable<IntersectionTypeVar>(res))
{
for (TypeId& part : t->parts)
{
if (part == sourceType)
part = replacedType;
}
}
else if (auto t = get<LazyTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else
LUAU_ASSERT(!"Luau::Replacer::clean internal error: Unknown TypeVar type");
replacedTypes[ty] = res;
return res;
}
TypePackId clean(TypePackId tp) override
{
TypePackId res = clone(tp);
if (auto pack = getMutable<TypePack>(res))
{
for (TypeId& type : pack->head)
{
if (type == sourceType)
type = replacedType;
}
}
else if (auto vtp = getMutable<VariadicTypePack>(res))
{
if (vtp->ty == sourceType)
vtp->ty = replacedType;
}
replacedPacks[tp] = res;
return res;
}
TypeId smartClone(TypeId t)
{
std::optional<TypeId> res = replace(t);
LUAU_ASSERT(res.has_value()); // TODO think about this
if (*res == t)
return clone(t);
return *res;
}
};
} // anonymous namespace
bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice)
{
UnifierSharedState sharedState{&ice};
TypeArena arena;
Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState};
u.anyIsTop = true;
u.tryUnify(subTy, superTy);
const bool ok = u.errors.empty() && u.log.empty();
return ok;
}
template<typename T>
static bool areNormal_(const T& t, const DenseHashSet<void*>& seen, InternalErrorReporter& ice)
{
int count = 0;
auto isNormal = [&](TypeId ty) {
++count;
if (count >= FInt::LuauNormalizeIterationLimit)
ice.ice("Luau::areNormal hit iteration limit");
return ty->normal || seen.find(asMutable(ty));
};
return std::all_of(begin(t), end(t), isNormal);
}
static bool areNormal(const std::vector<TypeId>& types, const DenseHashSet<void*>& seen, InternalErrorReporter& ice)
{
return areNormal_(types, seen, ice);
}
static bool areNormal(TypePackId tp, const DenseHashSet<void*>& seen, InternalErrorReporter& ice)
{
tp = follow(tp);
if (get<FreeTypePack>(tp))
return false;
auto [head, tail] = flatten(tp);
if (!areNormal_(head, seen, ice))
return false;
if (!tail)
return true;
if (auto vtp = get<VariadicTypePack>(*tail))
return vtp->ty->normal || seen.find(asMutable(vtp->ty));
return true;
}
#define CHECK_ITERATION_LIMIT(...) \
do \
{ \
if (iterationLimit > FInt::LuauNormalizeIterationLimit) \
{ \
limitExceeded = true; \
return __VA_ARGS__; \
} \
++iterationLimit; \
} while (false)
struct Normalize
{
TypeArena& arena;
InternalErrorReporter& ice;
// Debug data. Types being normalized are invalidated but trying to see what's going on is painful.
// To actually see the original type, read it by using the pointer of the type being normalized.
// e.g. in lldb, `e dump(originalTys[ty])`.
SeenTypes originalTys;
SeenTypePacks originalTps;
int iterationLimit = 0;
bool limitExceeded = false;
template<typename T>
bool operator()(TypePackId, const T&)
{
return true;
}
template<typename TID>
void cycle(TID)
{
}
bool operator()(TypeId ty, const FreeTypeVar&)
{
LUAU_ASSERT(!ty->normal);
return false;
}
bool operator()(TypeId ty, const BoundTypeVar& btv)
{
// It should never be the case that this TypeVar is normal, but is bound to a non-normal type.
LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal);
asMutable(ty)->normal = btv.boundTo->normal;
return !ty->normal;
}
bool operator()(TypeId ty, const PrimitiveTypeVar&)
{
LUAU_ASSERT(ty->normal);
return false;
}
bool operator()(TypeId ty, const GenericTypeVar&)
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool operator()(TypeId ty, const ErrorTypeVar&)
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, DenseHashSet<void*>& seen)
{
CHECK_ITERATION_LIMIT(false);
ConstrainedTypeVar* ctv = const_cast<ConstrainedTypeVar*>(&ctvRef);
std::vector<TypeId> parts = std::move(ctv->parts);
// We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar
for (TypeId part : parts)
visit_detail::visit(part, *this, seen);
std::vector<TypeId> newParts = normalizeUnion(parts);
const bool normal = areNormal(newParts, seen, ice);
if (newParts.size() == 1)
*asMutable(ty) = BoundTypeVar{newParts[0]};
else
*asMutable(ty) = UnionTypeVar{std::move(newParts)};
asMutable(ty)->normal = normal;
return false;
}
bool operator()(TypeId ty, const FunctionTypeVar& ftv) = delete;
bool operator()(TypeId ty, const FunctionTypeVar& ftv, DenseHashSet<void*>& seen)
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
visit_detail::visit(ftv.argTypes, *this, seen);
visit_detail::visit(ftv.retType, *this, seen);
asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice);
return false;
}
bool operator()(TypeId ty, const TableTypeVar& ttv, DenseHashSet<void*>& seen)
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
bool normal = true;
auto checkNormal = [&](TypeId t) {
// if t is on the stack, it is possible that this type is normal.
// If t is not normal and it is not on the stack, this type is definitely not normal.
if (!t->normal && !seen.find(asMutable(t)))
normal = false;
};
if (ttv.boundTo)
{
visit_detail::visit(*ttv.boundTo, *this, seen);
asMutable(ty)->normal = (*ttv.boundTo)->normal;
return false;
}
for (const auto& [_name, prop] : ttv.props)
{
visit_detail::visit(prop.type, *this, seen);
checkNormal(prop.type);
}
if (ttv.indexer)
{
visit_detail::visit(ttv.indexer->indexType, *this, seen);
checkNormal(ttv.indexer->indexType);
visit_detail::visit(ttv.indexer->indexResultType, *this, seen);
checkNormal(ttv.indexer->indexResultType);
}
asMutable(ty)->normal = normal;
return false;
}
bool operator()(TypeId ty, const MetatableTypeVar& mtv, DenseHashSet<void*>& seen)
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
visit_detail::visit(mtv.table, *this, seen);
visit_detail::visit(mtv.metatable, *this, seen);
asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal;
return false;
}
bool operator()(TypeId ty, const ClassTypeVar& ctv)
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool operator()(TypeId ty, const AnyTypeVar&)
{
LUAU_ASSERT(ty->normal);
return false;
}
bool operator()(TypeId ty, const UnionTypeVar& utvRef, DenseHashSet<void*>& seen)
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
UnionTypeVar* utv = &const_cast<UnionTypeVar&>(utvRef);
std::vector<TypeId> options = std::move(utv->options);
// We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar
for (TypeId option : options)
visit_detail::visit(option, *this, seen);
std::vector<TypeId> newOptions = normalizeUnion(options);
const bool normal = areNormal(newOptions, seen, ice);
LUAU_ASSERT(!newOptions.empty());
if (newOptions.size() == 1)
*asMutable(ty) = BoundTypeVar{newOptions[0]};
else
utv->options = std::move(newOptions);
asMutable(ty)->normal = normal;
return false;
}
bool operator()(TypeId ty, const IntersectionTypeVar& itvRef, DenseHashSet<void*>& seen)
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
IntersectionTypeVar* itv = &const_cast<IntersectionTypeVar&>(itvRef);
std::vector<TypeId> oldParts = std::move(itv->parts);
for (TypeId part : oldParts)
visit_detail::visit(part, *this, seen);
std::vector<TypeId> tables;
for (TypeId part : oldParts)
{
part = follow(part);
if (get<TableTypeVar>(part))
tables.push_back(part);
else
{
Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD
combineIntoIntersection(replacer, itv, part);
}
}
// Don't allocate a new table if there's just one in the intersection.
if (tables.size() == 1)
itv->parts.push_back(tables[0]);
else if (!tables.empty())
{
const TableTypeVar* first = get<TableTypeVar>(tables[0]);
LUAU_ASSERT(first);
TypeId newTable = arena.addType(TableTypeVar{first->state, first->level});
TableTypeVar* ttv = getMutable<TableTypeVar>(newTable);
for (TypeId part : tables)
{
// Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need
// to be rewritten to point at 'newTable' in the clone.
Replacer replacer{&arena, part, newTable};
combineIntoTable(replacer, ttv, part);
}
itv->parts.push_back(newTable);
}
asMutable(ty)->normal = areNormal(itv->parts, seen, ice);
if (itv->parts.size() == 1)
{
TypeId part = itv->parts[0];
*asMutable(ty) = BoundTypeVar{part};
}
return false;
}
bool operator()(TypeId ty, const LazyTypeVar&)
{
return false;
}
std::vector<TypeId> normalizeUnion(const std::vector<TypeId>& options)
{
if (options.size() == 1)
return options;
std::vector<TypeId> result;
for (TypeId part : options)
combineIntoUnion(result, part);
return result;
}
void combineIntoUnion(std::vector<TypeId>& result, TypeId ty)
{
ty = follow(ty);
if (auto utv = get<UnionTypeVar>(ty))
{
for (TypeId t : utv)
combineIntoUnion(result, t);
return;
}
for (TypeId& part : result)
{
if (isSubtype(ty, part, ice))
return; // no need to do anything
else if (isSubtype(part, ty, ice))
{
part = ty; // replace the less general type by the more general one
return;
}
}
result.push_back(ty);
}
/**
* @param replacer knows how to clone a type such that any recursive references point at the new containing type.
* @param result is an intersection that is safe for us to mutate in-place.
*/
void combineIntoIntersection(Replacer& replacer, IntersectionTypeVar* result, TypeId ty)
{
// Note: this check guards against running out of stack space
// so if you increase the size of a stack frame, you'll need to decrease the limit.
CHECK_ITERATION_LIMIT();
ty = follow(ty);
if (auto itv = get<IntersectionTypeVar>(ty))
{
for (TypeId part : itv->parts)
combineIntoIntersection(replacer, result, part);
return;
}
// Let's say that the last part of our result intersection is always a table, if any table is part of this intersection
if (get<TableTypeVar>(ty))
{
if (result->parts.empty())
result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}}));
TypeId theTable = result->parts.back();
if (!get<TableTypeVar>(FFlag::LuauNormalizeCombineIntersectionFix ? follow(theTable) : theTable))
{
result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}}));
theTable = result->parts.back();
}
TypeId newTable = replacer.smartClone(theTable);
result->parts.back() = newTable;
combineIntoTable(replacer, getMutable<TableTypeVar>(newTable), ty);
}
else if (auto ftv = get<FunctionTypeVar>(ty))
{
bool merged = false;
for (TypeId& part : result->parts)
{
if (isSubtype(part, ty, ice))
{
merged = true;
break; // no need to do anything
}
else if (isSubtype(ty, part, ice))
{
merged = true;
part = ty; // replace the less general type by the more general one
break;
}
}
if (!merged)
result->parts.push_back(ty);
}
else
result->parts.push_back(ty);
}
TableState combineTableStates(TableState lhs, TableState rhs)
{
if (lhs == rhs)
return lhs;
if (lhs == TableState::Free || rhs == TableState::Free)
return TableState::Free;
if (lhs == TableState::Unsealed || rhs == TableState::Unsealed)
return TableState::Unsealed;
return lhs;
}
/**
* @param replacer gives us a way to clone a type such that recursive references are rewritten to the new
* "containing" type.
* @param table always points into a table that is safe for us to mutate.
*/
void combineIntoTable(Replacer& replacer, TableTypeVar* table, TypeId ty)
{
// Note: this check guards against running out of stack space
// so if you increase the size of a stack frame, you'll need to decrease the limit.
CHECK_ITERATION_LIMIT();
LUAU_ASSERT(table);
ty = follow(ty);
TableTypeVar* tyTable = getMutable<TableTypeVar>(ty);
LUAU_ASSERT(tyTable);
for (const auto& [propName, prop] : tyTable->props)
{
if (auto it = table->props.find(propName); it != table->props.end())
{
/**
* If we are going to recursively merge intersections of tables, we need to ensure that we never mutate
* a table that comes from somewhere else in the type graph.
*
* smarClone() does some nice things for us: It will perform a clone that is as shallow as possible
* while still rewriting any cyclic references back to the new 'root' table.
*
* replacer also keeps a mapping of types that have previously been copied, so we have the added
* advantage here of knowing that, whether or not a new copy was actually made, the resulting TypeVar is
* safe for us to mutate in-place.
*/
TypeId clone = replacer.smartClone(it->second.type);
it->second.type = combine(replacer, clone, prop.type);
}
else
table->props.insert({propName, prop});
}
table->state = combineTableStates(table->state, tyTable->state);
table->level = max(table->level, tyTable->level);
}
/**
* @param a is always cloned by the caller. It is safe to mutate in-place.
* @param b will never be mutated.
*/
TypeId combine(Replacer& replacer, TypeId a, TypeId b)
{
if (FFlag::LuauNormalizeCombineTableFix && a == b)
return a;
if (!get<IntersectionTypeVar>(a) && !get<TableTypeVar>(a))
{
if (!FFlag::LuauNormalizeCombineTableFix && a == b)
return a;
else
return arena.addType(IntersectionTypeVar{{a, b}});
}
if (auto itv = getMutable<IntersectionTypeVar>(a))
{
combineIntoIntersection(replacer, itv, b);
return a;
}
else if (auto ttv = getMutable<TableTypeVar>(a))
{
if (FFlag::LuauNormalizeCombineTableFix && !get<TableTypeVar>(follow(b)))
return arena.addType(IntersectionTypeVar{{a, b}});
combineIntoTable(replacer, ttv, b);
return a;
}
LUAU_ASSERT(!"Impossible");
LUAU_UNREACHABLE();
}
};
#undef CHECK_ITERATION_LIMIT
/**
* @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully)
*/
std::pair<TypeId, bool> normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice)
{
CloneState state;
if (FFlag::DebugLuauCopyBeforeNormalizing)
(void)clone(ty, arena, state);
Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)};
DenseHashSet<void*> seen{nullptr};
visitTypeVarOnce(ty, n, seen);
return {ty, !n.limitExceeded};
}
// TODO: Think about using a temporary arena and cloning types out of it so that we
// reclaim memory used by wantonly allocated intermediate types here.
// The main wrinkle here is that we don't want clone() to copy a type if the source and dest
// arena are the same.
std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice)
{
return normalize(ty, module->internalTypes, ice);
}
/**
* @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully)
*/
std::pair<TypePackId, bool> normalize(TypePackId tp, TypeArena& arena, InternalErrorReporter& ice)
{
CloneState state;
if (FFlag::DebugLuauCopyBeforeNormalizing)
(void)clone(tp, arena, state);
Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)};
DenseHashSet<void*> seen{nullptr};
visitTypeVarOnce(tp, n, seen);
return {tp, !n.limitExceeded};
}
std::pair<TypePackId, bool> normalize(TypePackId tp, const ModulePtr& module, InternalErrorReporter& ice)
{
return normalize(tp, module->internalTypes, ice);
}
} // namespace Luau

View File

@ -4,6 +4,8 @@
#include "Luau/VisitTypeVar.h"
LUAU_FASTFLAG(LuauTypecheckOptPass)
namespace Luau
{
@ -12,6 +14,8 @@ struct Quantifier
TypeLevel level;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
bool seenGenericType = false;
bool seenMutableType = false;
Quantifier(TypeLevel level)
: level(level)
@ -23,6 +27,9 @@ struct Quantifier
bool operator()(TypeId ty, const FreeTypeVar& ftv)
{
if (FFlag::LuauTypecheckOptPass)
seenMutableType = true;
if (!level.subsumes(ftv.level))
return false;
@ -44,17 +51,40 @@ struct Quantifier
return true;
}
bool operator()(TypeId ty, const ConstrainedTypeVar&)
{
return true;
}
bool operator()(TypeId ty, const TableTypeVar&)
{
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
if (FFlag::LuauTypecheckOptPass)
{
if (ttv.state == TableState::Generic)
seenGenericType = true;
if (ttv.state == TableState::Free)
seenMutableType = true;
}
if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic)
return false;
if (!level.subsumes(ttv.level))
{
if (FFlag::LuauTypecheckOptPass && ttv.state == TableState::Unsealed)
seenMutableType = true;
return false;
}
if (ttv.state == TableState::Free)
{
ttv.state = TableState::Generic;
if (FFlag::LuauTypecheckOptPass)
seenGenericType = true;
}
else if (ttv.state == TableState::Unsealed)
ttv.state = TableState::Sealed;
@ -65,6 +95,9 @@ struct Quantifier
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
if (FFlag::LuauTypecheckOptPass)
seenMutableType = true;
if (!level.subsumes(ftp.level))
return false;
@ -84,6 +117,9 @@ void quantify(TypeId ty, TypeLevel level)
LUAU_ASSERT(ftv);
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
if (FFlag::LuauTypecheckOptPass && ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType)
ftv->hasNoGenerics = true;
}
} // namespace Luau

View File

@ -7,24 +7,36 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000)
LUAU_FASTFLAG(LuauTypecheckOptPass)
LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false)
namespace Luau
{
void Tarjan::visitChildren(TypeId ty, int index)
{
ty = log->follow(ty);
if (FFlag::LuauTypecheckOptPass)
LUAU_ASSERT(ty == log->follow(ty));
else
ty = log->follow(ty);
if (ignoreChildren(ty))
return;
if (const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty))
if (FFlag::LuauTypecheckOptPass)
{
if (auto pty = log->pending(ty))
ty = &pty->pending;
}
if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get<FunctionTypeVar>(ty) : log->getMutable<FunctionTypeVar>(ty))
{
visitChild(ftv->argTypes);
visitChild(ftv->retType);
}
else if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get<TableTypeVar>(ty) : log->getMutable<TableTypeVar>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
for (const auto& [name, prop] : ttv->props)
@ -41,38 +53,52 @@ void Tarjan::visitChildren(TypeId ty, int index)
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp);
}
else if (const MetatableTypeVar* mtv = log->getMutable<MetatableTypeVar>(ty))
else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get<MetatableTypeVar>(ty) : log->getMutable<MetatableTypeVar>(ty))
{
visitChild(mtv->table);
visitChild(mtv->metatable);
}
else if (const UnionTypeVar* utv = log->getMutable<UnionTypeVar>(ty))
else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get<UnionTypeVar>(ty) : log->getMutable<UnionTypeVar>(ty))
{
for (TypeId opt : utv->options)
visitChild(opt);
}
else if (const IntersectionTypeVar* itv = log->getMutable<IntersectionTypeVar>(ty))
else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get<IntersectionTypeVar>(ty) : log->getMutable<IntersectionTypeVar>(ty))
{
for (TypeId part : itv->parts)
visitChild(part);
}
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
for (TypeId part : ctv->parts)
visitChild(part);
}
}
void Tarjan::visitChildren(TypePackId tp, int index)
{
tp = log->follow(tp);
if (FFlag::LuauTypecheckOptPass)
LUAU_ASSERT(tp == log->follow(tp));
else
tp = log->follow(tp);
if (ignoreChildren(tp))
return;
if (const TypePack* tpp = log->getMutable<TypePack>(tp))
if (FFlag::LuauTypecheckOptPass)
{
if (auto ptp = log->pending(tp))
tp = &ptp->pending;
}
if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get<TypePack>(tp) : log->getMutable<TypePack>(tp))
{
for (TypeId tv : tpp->head)
visitChild(tv);
if (tpp->tail)
visitChild(*tpp->tail);
}
else if (const VariadicTypePack* vtp = log->getMutable<VariadicTypePack>(tp))
else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get<VariadicTypePack>(tp) : log->getMutable<VariadicTypePack>(tp))
{
visitChild(vtp->ty);
}
@ -80,7 +106,10 @@ void Tarjan::visitChildren(TypePackId tp, int index)
std::pair<int, bool> Tarjan::indexify(TypeId ty)
{
ty = log->follow(ty);
if (FFlag::LuauTypecheckOptPass)
LUAU_ASSERT(ty == log->follow(ty));
else
ty = log->follow(ty);
bool fresh = !typeToIndex.contains(ty);
int& index = typeToIndex[ty];
@ -98,7 +127,10 @@ std::pair<int, bool> Tarjan::indexify(TypeId ty)
std::pair<int, bool> Tarjan::indexify(TypePackId tp)
{
tp = log->follow(tp);
if (FFlag::LuauTypecheckOptPass)
LUAU_ASSERT(tp == log->follow(tp));
else
tp = log->follow(tp);
bool fresh = !packToIndex.contains(tp);
int& index = packToIndex[tp];
@ -141,7 +173,7 @@ TarjanResult Tarjan::loop()
if (currEdge == -1)
{
++childCount;
if (FInt::LuauTarjanChildLimit > 0 && FInt::LuauTarjanChildLimit < childCount)
if (childLimit > 0 && childLimit < childCount)
return TarjanResult::TooManyChildren;
stack.push_back(index);
@ -229,6 +261,9 @@ TarjanResult Tarjan::loop()
TarjanResult Tarjan::visitRoot(TypeId ty)
{
childCount = 0;
if (childLimit == 0)
childLimit = FInt::LuauTarjanChildLimit;
ty = log->follow(ty);
auto [index, fresh] = indexify(ty);
@ -239,6 +274,9 @@ TarjanResult Tarjan::visitRoot(TypeId ty)
TarjanResult Tarjan::visitRoot(TypePackId tp)
{
childCount = 0;
if (childLimit == 0)
childLimit = FInt::LuauTarjanChildLimit;
tp = log->follow(tp);
auto [index, fresh] = indexify(tp);
@ -347,7 +385,13 @@ TypeId Substitution::clone(TypeId ty)
TypeId result = ty;
if (const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty))
if (FFlag::LuauTypecheckOptPass)
{
if (auto pty = log->pending(ty))
ty = &pty->pending;
}
if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get<FunctionTypeVar>(ty) : log->getMutable<FunctionTypeVar>(ty))
{
FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf};
clone.generics = ftv->generics;
@ -357,7 +401,7 @@ TypeId Substitution::clone(TypeId ty)
clone.argNames = ftv->argNames;
result = addType(std::move(clone));
}
else if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get<TableTypeVar>(ty) : log->getMutable<TableTypeVar>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state};
@ -370,24 +414,29 @@ TypeId Substitution::clone(TypeId ty)
clone.tags = ttv->tags;
result = addType(std::move(clone));
}
else if (const MetatableTypeVar* mtv = log->getMutable<MetatableTypeVar>(ty))
else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get<MetatableTypeVar>(ty) : log->getMutable<MetatableTypeVar>(ty))
{
MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable};
clone.syntheticName = mtv->syntheticName;
result = addType(std::move(clone));
}
else if (const UnionTypeVar* utv = log->getMutable<UnionTypeVar>(ty))
else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get<UnionTypeVar>(ty) : log->getMutable<UnionTypeVar>(ty))
{
UnionTypeVar clone;
clone.options = utv->options;
result = addType(std::move(clone));
}
else if (const IntersectionTypeVar* itv = log->getMutable<IntersectionTypeVar>(ty))
else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get<IntersectionTypeVar>(ty) : log->getMutable<IntersectionTypeVar>(ty))
{
IntersectionTypeVar clone;
clone.parts = itv->parts;
result = addType(std::move(clone));
}
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
ConstrainedTypeVar clone{ctv->level, ctv->parts};
result = addType(std::move(clone));
}
asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result;
@ -396,14 +445,21 @@ TypeId Substitution::clone(TypeId ty)
TypePackId Substitution::clone(TypePackId tp)
{
tp = log->follow(tp);
if (const TypePack* tpp = log->getMutable<TypePack>(tp))
if (FFlag::LuauTypecheckOptPass)
{
if (auto ptp = log->pending(tp))
tp = &ptp->pending;
}
if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get<TypePack>(tp) : log->getMutable<TypePack>(tp))
{
TypePack clone;
clone.head = tpp->head;
clone.tail = tpp->tail;
return addTypePack(std::move(clone));
}
else if (const VariadicTypePack* vtp = log->getMutable<VariadicTypePack>(tp))
else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get<VariadicTypePack>(tp) : log->getMutable<VariadicTypePack>(tp))
{
VariadicTypePack clone;
clone.ty = vtp->ty;
@ -415,25 +471,34 @@ TypePackId Substitution::clone(TypePackId tp)
void Substitution::foundDirty(TypeId ty)
{
ty = log->follow(ty);
if (isDirty(ty))
newTypes[ty] = clean(ty);
if (FFlag::LuauTypecheckOptPass)
LUAU_ASSERT(ty == log->follow(ty));
else
newTypes[ty] = clone(ty);
ty = log->follow(ty);
if (isDirty(ty))
newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(ty)) : clean(ty);
else
newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(ty)) : clone(ty);
}
void Substitution::foundDirty(TypePackId tp)
{
tp = log->follow(tp);
if (isDirty(tp))
newPacks[tp] = clean(tp);
if (FFlag::LuauTypecheckOptPass)
LUAU_ASSERT(tp == log->follow(tp));
else
newPacks[tp] = clone(tp);
tp = log->follow(tp);
if (isDirty(tp))
newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(tp)) : clean(tp);
else
newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(tp)) : clone(tp);
}
TypeId Substitution::replace(TypeId ty)
{
ty = log->follow(ty);
if (TypeId* prevTy = newTypes.find(ty))
return *prevTy;
else
@ -443,6 +508,7 @@ TypeId Substitution::replace(TypeId ty)
TypePackId Substitution::replace(TypePackId tp)
{
tp = log->follow(tp);
if (TypePackId* prevTp = newPacks.find(tp))
return *prevTp;
else
@ -451,7 +517,13 @@ TypePackId Substitution::replace(TypePackId tp)
void Substitution::replaceChildren(TypeId ty)
{
ty = log->follow(ty);
if (BoundTypeVar* btv = log->getMutable<BoundTypeVar>(ty); FFlag::LuauLowerBoundsCalculation && btv)
btv->boundTo = replace(btv->boundTo);
if (FFlag::LuauTypecheckOptPass)
LUAU_ASSERT(ty == log->follow(ty));
else
ty = log->follow(ty);
if (ignoreChildren(ty))
return;
@ -493,11 +565,19 @@ void Substitution::replaceChildren(TypeId ty)
for (TypeId& part : itv->parts)
part = replace(part);
}
else if (ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty))
{
for (TypeId& part : ctv->parts)
part = replace(part);
}
}
void Substitution::replaceChildren(TypePackId tp)
{
tp = log->follow(tp);
if (FFlag::LuauTypecheckOptPass)
LUAU_ASSERT(tp == log->follow(tp));
else
tp = log->follow(tp);
if (ignoreChildren(tp))
return;

View File

@ -237,6 +237,15 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty);
finishNode();
}
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
formatAppend(result, "ConstrainedTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
for (TypeId part : ctv->parts)
visitChild(part, index);
}
else if (get<ErrorTypeVar>(ty))
{
formatAppend(result, "ErrorTypeVar %d", index);
@ -258,6 +267,28 @@ void StateDot::visitChildren(TypeId ty, int index)
if (ctv->metatable)
visitChild(*ctv->metatable, index, "[metatable]");
}
else if (const SingletonTypeVar* stv = get<SingletonTypeVar>(ty))
{
std::string res;
if (const StringSingleton* ss = get<StringSingleton>(stv))
{
// Don't put in quotes anywhere. If it's outside of the call to escape,
// then it's invalid syntax. If it's inside, then escaping is super noisy.
res = "string: " + escape(ss->value);
}
else if (const BooleanSingleton* bs = get<BooleanSingleton>(stv))
{
res = "boolean: ";
res += bs->value ? "true" : "false";
}
else
LUAU_ASSERT(!"unknown singleton type");
formatAppend(result, "SingletonTypeVar %s", res.c_str());
finishNodeLabel(ty);
finishNode();
}
else
{
LUAU_ASSERT(!"unknown type kind");

View File

@ -10,6 +10,8 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
/*
* Prefix generic typenames with gen-
* Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4
@ -33,8 +35,8 @@ struct FindCyclicTypes
bool exhaustive = false;
std::unordered_set<TypeId> visited;
std::unordered_set<TypePackId> visitedPacks;
std::unordered_set<TypeId> cycles;
std::unordered_set<TypePackId> cycleTPs;
std::set<TypeId> cycles;
std::set<TypePackId> cycleTPs;
void cycle(TypeId ty)
{
@ -86,7 +88,7 @@ struct FindCyclicTypes
};
template<typename TID>
void findCyclicTypes(std::unordered_set<TypeId>& cycles, std::unordered_set<TypePackId>& cycleTPs, TID ty, bool exhaustive)
void findCyclicTypes(std::set<TypeId>& cycles, std::set<TypePackId>& cycleTPs, TID ty, bool exhaustive)
{
FindCyclicTypes fct;
fct.exhaustive = exhaustive;
@ -124,6 +126,7 @@ struct StringifierState
std::unordered_map<TypePackId, std::string> cycleTpNames;
std::unordered_set<void*> seen;
std::unordered_set<std::string> usedNames;
size_t indentation = 0;
bool exhaustive;
@ -216,6 +219,34 @@ struct StringifierState
result.name += s;
}
void indent()
{
indentation += 4;
}
void dedent()
{
indentation -= 4;
}
void newline()
{
if (!opts.useLineBreaks)
return emit(" ");
emit("\n");
emitIndentation();
}
private:
void emitIndentation()
{
if (!opts.indent)
return;
emit(std::string(indentation, ' '));
}
};
struct TypeVarStringifier
@ -321,7 +352,7 @@ struct TypeVarStringifier
stringify(btv.boundTo);
}
void operator()(TypeId ty, const Unifiable::Generic& gtv)
void operator()(TypeId ty, const GenericTypeVar& gtv)
{
if (gtv.explicitName)
{
@ -332,6 +363,26 @@ struct TypeVarStringifier
state.emit(state.getName(ty));
}
void operator()(TypeId, const ConstrainedTypeVar& ctv)
{
state.result.invalid = true;
state.emit("[[");
bool first = true;
for (TypeId ty : ctv.parts)
{
if (first)
first = false;
else
state.emit("|");
stringify(ty);
}
state.emit("]]");
}
void operator()(TypeId, const PrimitiveTypeVar& ptv)
{
switch (ptv.type)
@ -415,10 +466,25 @@ struct TypeVarStringifier
state.emit(") -> ");
bool plural = true;
if (auto retPack = get<TypePack>(follow(ftv.retType)))
if (FFlag::LuauLowerBoundsCalculation)
{
if (retPack->head.size() == 1 && !retPack->tail)
plural = false;
auto retBegin = begin(ftv.retType);
auto retEnd = end(ftv.retType);
if (retBegin != retEnd)
{
++retBegin;
if (retBegin == retEnd && !retBegin.tail())
plural = false;
}
}
else
{
if (auto retPack = get<TypePack>(follow(ftv.retType)))
{
if (retPack->head.size() == 1 && !retPack->tail)
plural = false;
}
}
if (plural)
@ -511,6 +577,7 @@ struct TypeVarStringifier
}
state.emit(openbrace);
state.indent();
bool comma = false;
if (ttv.indexer)
@ -527,7 +594,10 @@ struct TypeVarStringifier
for (const auto& [name, prop] : ttv.props)
{
if (comma)
state.emit(state.opts.useLineBreaks ? ",\n" : ", ");
{
state.emit(",");
state.newline();
}
size_t length = state.result.name.length() - oldLength;
@ -553,6 +623,7 @@ struct TypeVarStringifier
++index;
}
state.dedent();
state.emit(closedbrace);
state.unsee(&ttv);
@ -563,7 +634,8 @@ struct TypeVarStringifier
state.result.invalid = true;
state.emit("{ @metatable ");
stringify(mtv.metatable);
state.emit(state.opts.useLineBreaks ? ",\n" : ", ");
state.emit(",");
state.newline();
stringify(mtv.table);
state.emit(" }");
}
@ -784,13 +856,16 @@ struct TypePackStringifier
if (tp.tail && !isEmpty(*tp.tail))
{
const auto& tail = *tp.tail;
if (first)
first = false;
else
state.emit(", ");
TypePackId tail = follow(*tp.tail);
if (auto vtp = get<VariadicTypePack>(tail); !vtp || (!FFlag::DebugLuauVerboseTypeNames && !vtp->hidden))
{
if (first)
first = false;
else
state.emit(", ");
stringify(tail);
stringify(tail);
}
}
state.unsee(&tp);
@ -805,6 +880,8 @@ struct TypePackStringifier
void operator()(TypePackId, const VariadicTypePack& pack)
{
state.emit("...");
if (FFlag::DebugLuauVerboseTypeNames && pack.hidden)
state.emit("<hidden>");
stringify(pack.ty);
}
@ -858,15 +935,12 @@ void TypeVarStringifier::stringify(TypePackId tpid, const std::vector<std::optio
tps.stringify(tpid);
}
static void assignCycleNames(const std::unordered_set<TypeId>& cycles, const std::unordered_set<TypePackId>& cycleTPs,
static void assignCycleNames(const std::set<TypeId>& cycles, const std::set<TypePackId>& cycleTPs,
std::unordered_map<TypeId, std::string>& cycleNames, std::unordered_map<TypePackId, std::string>& cycleTpNames, bool exhaustive)
{
int nextIndex = 1;
std::vector<TypeId> sortedCycles{cycles.begin(), cycles.end()};
std::sort(sortedCycles.begin(), sortedCycles.end(), std::less<TypeId>{});
for (TypeId cycleTy : sortedCycles)
for (TypeId cycleTy : cycles)
{
std::string name;
@ -888,10 +962,7 @@ static void assignCycleNames(const std::unordered_set<TypeId>& cycles, const std
cycleNames[cycleTy] = std::move(name);
}
std::vector<TypePackId> sortedCycleTps{cycleTPs.begin(), cycleTPs.end()};
std::sort(sortedCycleTps.begin(), sortedCycleTps.end(), std::less<TypePackId>());
for (TypePackId tp : sortedCycleTps)
for (TypePackId tp : cycleTPs)
{
std::string name = "tp" + std::to_string(nextIndex);
++nextIndex;
@ -913,8 +984,8 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
StringifierState state{opts, result, opts.nameMap};
std::unordered_set<TypeId> cycles;
std::unordered_set<TypePackId> cycleTPs;
std::set<TypeId> cycles;
std::set<TypePackId> cycleTPs;
findCyclicTypes(cycles, cycleTPs, ty, opts.exhaustive);
@ -1016,8 +1087,8 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts)
ToStringResult result;
StringifierState state{opts, result, opts.nameMap};
std::unordered_set<TypeId> cycles;
std::unordered_set<TypePackId> cycleTPs;
std::set<TypeId> cycles;
std::set<TypePackId> cycleTPs;
findCyclicTypes(cycles, cycleTPs, tp, opts.exhaustive);
@ -1058,7 +1129,7 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts)
state.emit(name);
state.emit(" = ");
Luau::visit(
[&tvs, cycleTy = cycleTy](auto&& t) {
[&tvs, cycleTy = cycleTy](auto t) {
return tvs(cycleTy, t);
},
cycleTy->ty);
@ -1163,14 +1234,18 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp
if (argPackIter.tail())
{
if (!first)
state.emit(", ");
if (auto vtp = get<VariadicTypePack>(*argPackIter.tail()); !vtp || !vtp->hidden)
{
if (!first)
state.emit(", ");
state.emit("...: ");
if (auto vtp = get<VariadicTypePack>(*argPackIter.tail()))
tvs.stringify(vtp->ty);
else
tvs.stringify(*argPackIter.tail());
state.emit("...: ");
if (vtp)
tvs.stringify(vtp->ty);
else
tvs.stringify(*argPackIter.tail());
}
}
state.emit("): ");
@ -1210,6 +1285,24 @@ std::string dump(TypePackId ty)
return s;
}
std::string dump(const ScopePtr& scope, const char* name)
{
auto binding = scope->linearSearchForBinding(name);
if (!binding)
{
printf("No binding %s\n", name);
return {};
}
TypeId ty = binding->typeId;
ToStringOptions opts;
opts.exhaustive = true;
opts.functionTypeArguments = true;
std::string s = toString(ty, opts);
printf("%s\n", s.c_str());
return s;
}
std::string generateName(size_t i)
{
std::string n;

View File

@ -215,6 +215,7 @@ struct ArcCollector : public AstVisitor
}
}
// Adds a dependency from the current node to the named node.
void add(const Identifier& name)
{
Node** it = map.find(name);

View File

@ -8,6 +8,7 @@
#include <stdexcept>
LUAU_FASTFLAGVARIABLE(LuauTxnLogPreserveOwner, false)
LUAU_FASTFLAGVARIABLE(LuauJustOneCallFrameForHaveSeen, false)
namespace Luau
{
@ -161,18 +162,37 @@ void TxnLog::popSeen(TypePackId lhs, TypePackId rhs)
bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const
{
const std::pair<TypeOrPackId, TypeOrPackId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair))
if (FFlag::LuauJustOneCallFrameForHaveSeen && !FFlag::LuauTypecheckOptPass)
{
return true;
}
// This function will technically work if `this` is nullptr, but this
// indicates a bug, so we explicitly assert.
LUAU_ASSERT(static_cast<const void*>(this) != nullptr);
if (parent)
const std::pair<TypeOrPackId, TypeOrPackId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
for (const TxnLog* current = this; current; current = current->parent)
{
if (current->sharedSeen->end() != std::find(current->sharedSeen->begin(), current->sharedSeen->end(), sortedPair))
return true;
}
return false;
}
else
{
return parent->haveSeen(lhs, rhs);
}
const std::pair<TypeOrPackId, TypeOrPackId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair))
{
return true;
}
return false;
if (!FFlag::LuauTypecheckOptPass && parent)
{
return parent->haveSeen(lhs, rhs);
}
return false;
}
}
void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs)
@ -222,8 +242,8 @@ PendingType* TxnLog::pending(TypeId ty) const
for (const TxnLog* current = this; current; current = current->parent)
{
if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end())
return it->second.get();
if (auto it = current->typeVarChanges.find(ty))
return it->get();
}
return nullptr;
@ -237,8 +257,8 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const
for (const TxnLog* current = this; current; current = current->parent)
{
if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end())
return it->second.get();
if (auto it = current->typePackChanges.find(tp))
return it->get();
}
return nullptr;

View File

@ -94,6 +94,16 @@ public:
}
}
AstType* operator()(const ConstrainedTypeVar& ctv)
{
AstArray<AstType*> types;
types.size = ctv.parts.size();
types.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * ctv.parts.size()));
for (size_t i = 0; i < ctv.parts.size(); ++i)
types.data[i] = Luau::visit(*this, ctv.parts[i]->ty);
return allocator->alloc<AstTypeIntersection>(Location(), types);
}
AstType* operator()(const SingletonTypeVar& stv)
{
if (const BooleanSingleton* bs = get<BooleanSingleton>(&stv))
@ -364,6 +374,9 @@ public:
AstTypePack* operator()(const VariadicTypePack& vtp) const
{
if (vtp.hidden)
return nullptr;
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*typeVisitor, vtp.ty->ty));
}

View File

@ -3,12 +3,15 @@
#include "Luau/Common.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Normalize.h"
#include "Luau/Parser.h"
#include "Luau/Quantify.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/Substitution.h"
#include "Luau/TopoSortStatements.h"
#include "Luau/TypePack.h"
#include "Luau/ToString.h"
#include "Luau/TypeUtils.h"
#include "Luau/ToString.h"
#include "Luau/TypeVar.h"
@ -19,14 +22,17 @@
LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false)
LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500)
LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000)
LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000)
LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAG(LuauSeparateTypechecks)
LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTFLAG(LuauAutocompleteSingletonTypes)
LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false)
LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false)
LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false.
LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false)
LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false)
@ -39,6 +45,7 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false)
LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false)
LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false)
LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false)
LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false)
LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false)
LUAU_FASTFLAG(LuauTypeMismatchModuleName)
LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false)
@ -53,6 +60,8 @@ LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false)
LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional)
LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false)
LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false)
LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false)
LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false);
namespace Luau
{
@ -140,6 +149,34 @@ bool hasBreak(AstStat* node)
}
}
static bool hasReturn(const AstStat* node)
{
struct Searcher : AstVisitor
{
bool result = false;
bool visit(AstStat*) override
{
return !result; // if we've already found a return statement, don't bother to traverse inward anymore
}
bool visit(AstStatReturn*) override
{
result = true;
return false;
}
bool visit(AstExprFunction*) override
{
return false; // We don't care if the function uses a lambda that itself returns
}
};
Searcher searcher;
const_cast<AstStat*>(node)->visit(&searcher);
return searcher.result;
}
// returns the last statement before the block exits, or nullptr if the block never exits
const AstStat* getFallthrough(const AstStat* node)
{
@ -253,6 +290,26 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan
}
ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope)
{
if (FFlag::LuauRecursionLimitException)
{
try
{
return checkWithoutRecursionCheck(module, mode, environmentScope);
}
catch (const RecursionLimitException&)
{
reportErrorCodeTooComplex(module.root->location);
return std::move(currentModule);
}
}
else
{
return checkWithoutRecursionCheck(module, mode, environmentScope);
}
}
ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope)
{
LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
@ -268,6 +325,12 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona
iceHandler->moduleName = module.name;
if (FFlag::LuauAutocompleteDynamicLimits)
{
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit;
}
ScopePtr parentScope = environmentScope.value_or(globalScope);
ScopePtr moduleScope = std::make_shared<Scope>(parentScope);
@ -312,7 +375,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona
prepareErrorsForDisplay(currentModule->errors);
bool encounteredFreeType = currentModule->clonePublicInterface();
bool encounteredFreeType = currentModule->clonePublicInterface(*iceHandler);
if (encounteredFreeType)
{
reportError(TypeError{module.root->location,
@ -415,7 +478,26 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
reportErrorCodeTooComplex(block.location);
return;
}
if (FFlag::LuauRecursionLimitException)
{
try
{
checkBlockWithoutRecursionCheck(scope, block);
}
catch (const RecursionLimitException&)
{
reportErrorCodeTooComplex(block.location);
return;
}
}
else
{
checkBlockWithoutRecursionCheck(scope, block);
}
}
void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block)
{
int subLevel = 0;
std::vector<AstStat*> sorted(block.body.data, block.body.data + block.body.size);
@ -435,6 +517,16 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
std::unordered_map<AstStat*, std::pair<TypeId, ScopePtr>> functionDecls;
auto isLocalLambda = [](AstStat* stat) -> AstStatLocal* {
AstStatLocal* local = stat->as<AstStatLocal>();
if (FFlag::LuauLowerBoundsCalculation && local && local->vars.size == 1 && local->values.size == 1 &&
local->values.data[0]->is<AstExprFunction>())
return local;
else
return nullptr;
};
auto checkBody = [&](AstStat* stat) {
if (auto fun = stat->as<AstStatFunction>())
{
@ -482,7 +574,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
// function f<a>(x:a):a local x: number = g(37) return x end
// function g(x:number):number return f(x) end
// ```
if (containsFunctionCallOrReturn(**protoIter))
if (containsFunctionCallOrReturn(**protoIter) || (FFlag::LuauLowerBoundsCalculation && isLocalLambda(*protoIter)))
{
while (checkIter != protoIter)
{
@ -513,7 +605,8 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
functionDecls[*protoIter] = pair;
++subLevel;
TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level);
TypeId leftType = follow(checkFunctionName(scope, *fun->name, funScope->level));
unify(funTy, leftType, fun->location);
}
else if (auto fun = (*protoIter)->as<AstStatLocalFunction>())
@ -658,6 +751,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement)
checkExpr(repScope, *statement.condition);
}
void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location)
{
Unifier state = mkUnifier(location);
state.unifyLowerBound(subTy, superTy);
state.log.commit();
reportErrors(state.errors);
}
void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_)
{
std::vector<std::optional<TypeId>> expectedTypes;
@ -682,6 +785,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_)
TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type;
if (useConstrainedIntersections())
{
unifyLowerBound(retPack, scope->returnType, return_.location);
return;
}
// HACK: Nonstrict mode gets a bit too smart and strict for us when we
// start typechecking everything across module boundaries.
if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType))
@ -1209,9 +1318,11 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco
else if (tableSelf->state == TableState::Sealed)
reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}});
const bool tableIsExtendable = tableSelf && tableSelf->state != TableState::Sealed;
ty = follow(ty);
if (tableSelf && tableSelf->state != TableState::Sealed)
if (tableIsExtendable)
tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation};
const FunctionTypeVar* funTy = get<FunctionTypeVar>(ty);
@ -1224,7 +1335,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco
checkFunctionBody(funScope, ty, *function.func);
if (tableSelf && tableSelf->state != TableState::Sealed)
if (tableIsExtendable)
tableSelf->props[indexName->index.value] = {
follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation};
}
@ -1372,7 +1483,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
for (auto param : binding->typePackParams)
clone.instantiatedTypePackParams.push_back(param.tp);
bool isNormal = ty->normal;
ty = addType(std::move(clone));
if (FFlag::LuauLowerBoundsCalculation)
asMutable(ty)->normal = isNormal;
}
}
else
@ -1400,6 +1515,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
if (FFlag::LuauTwoPassAliasDefinitionFix && ok)
bindingType = ty;
if (FFlag::LuauLowerBoundsCalculation)
{
auto [t, ok] = normalize(bindingType, currentModule, *iceHandler);
bindingType = t;
if (!ok)
reportError(typealias.location, NormalizationTooComplex{});
}
}
}
@ -1673,10 +1796,11 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa
{
return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)};
}
else if (get<Unifiable::Free>(retPack))
else if (const FreeTypePack* ftp = get<Unifiable::Free>(retPack))
{
TypeId head = freshType(scope);
TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}});
TypeLevel level = FFlag::LuauLowerBoundsCalculation ? ftp->level : scope->level;
TypeId head = freshType(level);
TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(level)}});
unify(pack, retPack, expr.location);
return {head, std::move(result.predicates)};
}
@ -1793,7 +1917,7 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
for (TypeId t : utv)
{
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeForType unions");
// Not needed when we normalize types.
if (get<AnyTypeVar>(follow(t)))
@ -1817,12 +1941,25 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
return std::nullopt;
}
std::vector<TypeId> result = reduceUnion(goodOptions);
if (FFlag::LuauLowerBoundsCalculation)
{
auto [t, ok] = normalize(addType(UnionTypeVar{std::move(goodOptions)}), currentModule,
*iceHandler); // FIXME Inefficient. We craft a UnionTypeVar and immediately throw it away.
if (result.size() == 1)
return result[0];
if (!ok)
reportError(location, NormalizationTooComplex{});
return addType(UnionTypeVar{std::move(result)});
return t;
}
else
{
std::vector<TypeId> result = reduceUnion(goodOptions);
if (result.size() == 1)
return result[0];
return addType(UnionTypeVar{std::move(result)});
}
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type))
{
@ -1830,7 +1967,7 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
for (TypeId t : itv->parts)
{
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeFromType intersections");
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
parts.push_back(*ty);
@ -1982,7 +2119,6 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location)
{
if (!std::any_of(begin(utv), end(utv), isNil))
return ty;
}
if (std::optional<TypeId> strippedUnion = tryStripUnionFromNil(ty))
@ -2124,7 +2260,26 @@ TypeId TypeChecker::checkExprTable(
ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType)
{
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
if (FFlag::LuauTableUseCounterInstead)
{
RecursionCounter _rc(&checkRecursionCount);
if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit)
{
reportErrorCodeTooComplex(expr.location);
return {errorRecoveryType(scope)};
}
return checkExpr_(scope, expr, expectedType);
}
else
{
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "checkExpr for tables");
return checkExpr_(scope, expr, expectedType);
}
}
ExprResult<TypeId> TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType)
{
std::vector<std::pair<TypeId, TypeId>> fieldTypes(expr.items.size);
const TableTypeVar* expectedTable = nullptr;
@ -3176,6 +3331,10 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
funScope->varargPack = anyTypePack;
}
}
else if (FFlag::LuauLowerBoundsCalculation && !isNonstrictMode())
{
funScope->varargPack = addTypePack(TypePackVar{VariadicTypePack{anyType, /*hidden*/ true}});
}
std::vector<TypeId> argTypes;
@ -3311,9 +3470,24 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE
{
check(scope, *function.body);
// We explicitly don't follow here to check if we have a 'true' free type instead of bound one
if (get_if<FreeTypePack>(&funTy->retType->ty))
*asMutable(funTy->retType) = TypePack{{}, std::nullopt};
if (useConstrainedIntersections())
{
TypePackId retPack = follow(funTy->retType);
// It is possible for a function to have no annotation and no return statement, and yet still have an ascribed return type
// if it is expected to conform to some other interface. (eg the function may be a lambda passed as a callback)
if (!hasReturn(function.body) && !function.returnAnnotation.has_value() && get<FreeTypePack>(retPack))
{
auto level = getLevel(retPack);
if (level && scope->level.subsumes(*level))
*asMutable(retPack) = TypePack{{}, std::nullopt};
}
}
else
{
// We explicitly don't follow here to check if we have a 'true' free type instead of bound one
if (get_if<FreeTypePack>(&funTy->retType->ty))
*asMutable(funTy->retType) = TypePack{{}, std::nullopt};
}
bool reachesImplicitReturn = getFallthrough(function.body) != nullptr;
@ -3418,6 +3592,19 @@ void TypeChecker::checkArgumentList(
size_t minParams = FFlag::LuauFixIncorrectLineNumberDuplicateType ? 0 : getMinParameterCount_DEPRECATED(paramPack);
auto reportCountMismatchError = [&state, &argLocations, minParams, paramPack, argPack]() {
// For this case, we want the error span to cover every errant extra parameter
Location location = state.location;
if (!argLocations.empty())
location = {state.location.begin, argLocations.back().end};
size_t mp = minParams;
if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes)
mp = getMinParameterCount(&state.log, paramPack);
state.reportError(TypeError{location, CountMismatch{mp, std::distance(begin(argPack), end(argPack))}});
};
while (true)
{
state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location;
@ -3472,6 +3659,8 @@ void TypeChecker::checkArgumentList(
}
else if (auto vtp = state.log.getMutable<VariadicTypePack>(tail))
{
// Function is variadic and requires that all subsequent parameters
// be compatible with a type.
while (paramIter != endIter)
{
state.tryUnify(vtp->ty, *paramIter);
@ -3506,14 +3695,22 @@ void TypeChecker::checkArgumentList(
else if (state.log.getMutable<ErrorTypeVar>(t))
{
} // ok
else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.getMutable<AnyTypeVar>(t))
else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.get<AnyTypeVar>(t))
{
} // ok
else
{
if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes)
minParams = getMinParameterCount(&state.log, paramPack);
bool isVariadic = FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic && !finite(paramPack, &state.log);
bool isVariadic = false;
if (FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic)
{
std::optional<TypePackId> tail = flatten(paramPack, state.log).second;
if (tail)
isVariadic = Luau::isVariadic(*tail);
}
state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}});
return;
}
@ -3532,14 +3729,7 @@ void TypeChecker::checkArgumentList(
unify(errorRecoveryType(scope), *argIter, state.location);
++argIter;
}
// For this case, we want the error span to cover every errant extra parameter
Location location = state.location;
if (!argLocations.empty())
location = {state.location.begin, argLocations.back().end};
if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes)
minParams = getMinParameterCount(&state.log, paramPack);
state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}});
reportCountMismatchError();
return;
}
TypePackId tail = state.log.follow(*paramIter.tail());
@ -3551,6 +3741,21 @@ void TypeChecker::checkArgumentList(
}
else if (auto vtp = state.log.getMutable<VariadicTypePack>(tail))
{
if (FFlag::LuauLowerBoundsCalculation && vtp->hidden)
{
// We know that this function can technically be oversaturated, but we have its definition and we
// know that it's useless.
TypeId e = errorRecoveryType(scope);
while (argIter != endIter)
{
unify(e, *argIter, state.location);
++argIter;
}
reportCountMismatchError();
return;
}
// Function is variadic and requires that all subsequent parameters
// be compatible with a type.
size_t argIndex = paramIndex;
@ -3595,14 +3800,7 @@ void TypeChecker::checkArgumentList(
}
else if (state.log.getMutable<GenericTypePack>(tail))
{
// For this case, we want the error span to cover every errant extra parameter
Location location = state.location;
if (!argLocations.empty())
location = {state.location.begin, argLocations.back().end};
// TODO: Better error message?
if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes)
minParams = getMinParameterCount(&state.log, paramPack);
state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}});
reportCountMismatchError();
return;
}
}
@ -3661,7 +3859,7 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
actualFunctionType = follow(actualFunctionType);
TypePackId retPack;
if (!FFlag::LuauWidenIfSupertypeIsFree2)
if (FFlag::LuauLowerBoundsCalculation || !FFlag::LuauWidenIfSupertypeIsFree2)
{
retPack = freshTypePack(scope->level);
}
@ -3809,21 +4007,49 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
return {{errorRecoveryTypePack(scope)}};
}
if (get<FreeTypeVar>(fn))
if (auto ftv = get<FreeTypeVar>(fn))
{
// fn is one of the overloads of actualFunctionType, which
// has been instantiated, so is a monotype. We can therefore
// unify it with a monomorphic function.
TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack));
if (FFlag::LuauWidenIfSupertypeIsFree2)
if (useConstrainedIntersections())
{
UnifierOptions options;
options.isFunctionCall = true;
unify(r, fn, expr.location, options);
// This ternary is phrased deliberately. We need ties between sibling scopes to bias toward ftv->level.
const TypeLevel level = scope->level.subsumes(ftv->level) ? scope->level : ftv->level;
std::vector<TypeId> adjustedArgTypes;
auto it = begin(argPack);
auto endIt = end(argPack);
Widen widen{&currentModule->internalTypes};
for (; it != endIt; ++it)
{
TypeId t = *it;
TypeId widened = widen.substitute(t).value_or(t); // Surely widening is infallible
adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widened}}));
}
TypePackId adjustedArgPack = addTypePack(TypePack{std::move(adjustedArgTypes), it.tail()});
TxnLog log;
promoteTypeLevels(log, &currentModule->internalTypes, level, retPack);
log.commit();
*asMutable(fn) = FunctionTypeVar{level, adjustedArgPack, retPack};
return {{retPack}};
}
else
unify(fn, r, expr.location);
return {{retPack}};
{
TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack));
if (FFlag::LuauWidenIfSupertypeIsFree2)
{
UnifierOptions options;
options.isFunctionCall = true;
unify(r, fn, expr.location, options);
}
else
unify(fn, r, expr.location);
return {{retPack}};
}
}
std::vector<Location> metaArgLocations;
@ -4363,10 +4589,17 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s
bool Instantiation::isDirty(TypeId ty)
{
if (log->getMutable<FunctionTypeVar>(ty))
if (const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty))
{
if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics)
return false;
return true;
}
else
{
return false;
}
}
bool Instantiation::isDirty(TypePackId tp)
@ -4414,14 +4647,21 @@ TypePackId Instantiation::clean(TypePackId tp)
bool ReplaceGenerics::ignoreChildren(TypeId ty)
{
if (const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty))
{
if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics)
return true;
// We aren't recursing in the case of a generic function which
// binds the same generics. This can happen if, for example, there's recursive types.
// If T = <a>(a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'.
// It's OK to use vector equality here, since we always generate fresh generics
// whenever we quantify, so the vectors overlap if and only if they are equal.
return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks);
}
else
{
return false;
}
}
bool ReplaceGenerics::isDirty(TypeId ty)
@ -4464,16 +4704,24 @@ TypePackId ReplaceGenerics::clean(TypePackId tp)
bool Anyification::isDirty(TypeId ty)
{
if (ty->persistent)
return false;
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed));
else if (log->getMutable<FreeTypeVar>(ty))
return true;
else if (get<ConstrainedTypeVar>(ty))
return true;
else
return false;
}
bool Anyification::isDirty(TypePackId tp)
{
if (tp->persistent)
return false;
if (log->getMutable<FreeTypePack>(tp))
return true;
else
@ -4494,7 +4742,16 @@ TypeId Anyification::clean(TypeId ty)
clone.syntheticName = ttv->syntheticName;
clone.tags = ttv->tags;
}
return addType(std::move(clone));
TypeId res = addType(std::move(clone));
asMutable(res)->normal = ty->normal;
return res;
}
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
auto [t, ok] = normalize(ty, *arena, *iceHandler);
if (!ok)
normalizationTooComplex = true;
return t;
}
else
return anyType;
@ -4511,16 +4768,34 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location
ty = follow(ty);
const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty);
if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty())
return ty;
if (ftv && ftv->generics.empty() && ftv->genericPacks.empty())
Luau::quantify(ty, scope->level);
if (FFlag::LuauLowerBoundsCalculation && ftv)
{
auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler);
if (!ok)
reportError(location, NormalizationTooComplex{});
return t;
}
Luau::quantify(ty, scope->level);
return ty;
}
TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log)
{
if (FFlag::LuauTypecheckOptPass)
{
const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(ty));
if (ftv && ftv->hasNoGenerics)
return ty;
}
Instantiation instantiation{log, &currentModule->internalTypes, scope->level};
if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit)
instantiation.childLimit = *instantiationChildLimit;
std::optional<TypeId> instantiated = instantiation.substitute(ty);
if (instantiated.has_value())
return *instantiated;
@ -4533,8 +4808,18 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat
TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location)
{
Anyification anyification{&currentModule->internalTypes, anyType, anyTypePack};
if (FFlag::LuauLowerBoundsCalculation)
{
auto [t, ok] = normalize(ty, currentModule, *iceHandler);
if (!ok)
reportError(location, NormalizationTooComplex{});
ty = t;
}
Anyification anyification{&currentModule->internalTypes, iceHandler, anyType, anyTypePack};
std::optional<TypeId> any = anyification.substitute(ty);
if (anyification.normalizationTooComplex)
reportError(location, NormalizationTooComplex{});
if (any.has_value())
return *any;
else
@ -4546,7 +4831,15 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location)
TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location)
{
Anyification anyification{&currentModule->internalTypes, anyType, anyTypePack};
if (FFlag::LuauLowerBoundsCalculation)
{
auto [t, ok] = normalize(ty, currentModule, *iceHandler);
if (!ok)
reportError(location, NormalizationTooComplex{});
ty = t;
}
Anyification anyification{&currentModule->internalTypes, iceHandler, anyType, anyTypePack};
std::optional<TypePackId> any = anyification.substitute(ty);
if (any.has_value())
return *any;
@ -4830,6 +5123,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
ToStringOptions opts;
opts.exhaustive = true;
opts.maxTableLength = 0;
opts.useLineBreaks = true;
TypeId param = resolveType(scope, *lit->parameters.data[0].type);
luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str()));
@ -5487,25 +5781,82 @@ std::optional<TypeId> TypeChecker::resolveLValue(const ScopePtr& scope, const LV
// We need to search in the provided Scope. Find t.x.y first.
// We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x.
// If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate.
const auto& [symbol, keys] = getFullName(lvalue);
if (!FFlag::LuauTypecheckOptPass)
{
const auto& [symbol, keys] = getFullName(lvalue);
ScopePtr currentScope = scope;
while (currentScope)
{
std::optional<TypeId> found;
std::vector<LValue> childKeys;
const LValue* currentLValue = &lvalue;
while (currentLValue)
{
if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end())
{
found = it->second;
break;
}
childKeys.push_back(*currentLValue);
currentLValue = baseof(*currentLValue);
}
if (!found)
{
// Should not be using scope->lookup. This is already recursive.
if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end())
found = it->second.typeId;
else
{
// Nothing exists in this Scope. Just skip and try the parent one.
currentScope = currentScope->parent;
continue;
}
}
for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it)
{
const LValue& key = *it;
// Symbol can happen. Skip.
if (get<Symbol>(key))
continue;
else if (auto field = get<Field>(key))
{
found = getIndexTypeFromType(scope, *found, field->key, Location(), false);
if (!found)
return std::nullopt; // Turns out this type doesn't have the property at all. We're done.
}
else
LUAU_ASSERT(!"New LValue alternative not handled here.");
}
return found;
}
// No entry for it at all. Can happen when LValue root is a global.
return std::nullopt;
}
const Symbol symbol = getBaseSymbol(lvalue);
ScopePtr currentScope = scope;
while (currentScope)
{
std::optional<TypeId> found;
std::vector<LValue> childKeys;
const LValue* currentLValue = &lvalue;
while (currentLValue)
const LValue* topLValue = nullptr;
for (topLValue = &lvalue; topLValue; topLValue = baseof(*topLValue))
{
if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end())
if (auto it = currentScope->refinements.find(*topLValue); it != currentScope->refinements.end())
{
found = it->second;
break;
}
childKeys.push_back(*currentLValue);
currentLValue = baseof(*currentLValue);
}
if (!found)
@ -5521,9 +5872,15 @@ std::optional<TypeId> TypeChecker::resolveLValue(const ScopePtr& scope, const LV
}
}
// We need to walk the l-value path in reverse, so we collect components into a vector
std::vector<const LValue*> childKeys;
for (const LValue* curr = &lvalue; curr != topLValue; curr = baseof(*curr))
childKeys.push_back(curr);
for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it)
{
const LValue& key = *it;
const LValue& key = **it;
// Symbol can happen. Skip.
if (get<Symbol>(key))
@ -5938,6 +6295,11 @@ bool TypeChecker::isNonstrictMode() const
return (currentModule->mode == Mode::Nonstrict) || (currentModule->mode == Mode::NoCheck);
}
bool TypeChecker::useConstrainedIntersections() const
{
return FFlag::LuauLowerBoundsCalculation && !isNonstrictMode();
}
std::vector<TypeId> TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp, size_t expectedLength, const Location& location)
{
TypePackId expectedTypePack = addTypePack({});

View File

@ -104,7 +104,7 @@ TypePackIterator begin(TypePackId tp)
return TypePackIterator{tp};
}
TypePackIterator begin(TypePackId tp, TxnLog* log)
TypePackIterator begin(TypePackId tp, const TxnLog* log)
{
return TypePackIterator{tp, log};
}
@ -256,7 +256,7 @@ size_t size(const TypePack& tp, TxnLog* log)
return result;
}
std::optional<TypeId> first(TypePackId tp)
std::optional<TypeId> first(TypePackId tp, bool ignoreHiddenVariadics)
{
auto it = begin(tp);
auto endIter = end(tp);
@ -266,7 +266,7 @@ std::optional<TypeId> first(TypePackId tp)
if (auto tail = it.tail())
{
if (auto vtp = get<VariadicTypePack>(*tail))
if (auto vtp = get<VariadicTypePack>(*tail); vtp && (!vtp->hidden || !ignoreHiddenVariadics))
return vtp->ty;
}
@ -299,6 +299,46 @@ std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp)
return {res, iter.tail()};
}
std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp, const TxnLog& log)
{
tp = log.follow(tp);
std::vector<TypeId> flattened;
std::optional<TypePackId> tail = std::nullopt;
TypePackIterator it(tp, &log);
for (; it != end(tp); ++it)
{
flattened.push_back(*it);
}
tail = it.tail();
return {flattened, tail};
}
bool isVariadic(TypePackId tp)
{
return isVariadic(tp, *TxnLog::empty());
}
bool isVariadic(TypePackId tp, const TxnLog& log)
{
std::optional<TypePackId> tail = flatten(tp, log).second;
if (!tail)
return false;
if (log.get<GenericTypePack>(*tail))
return true;
if (auto vtp = log.get<VariadicTypePack>(*tail); vtp && !vtp->hidden)
return true;
return false;
}
TypePackVar* asMutable(TypePackId tp)
{
return const_cast<TypePackVar*>(tp);

View File

@ -366,7 +366,7 @@ bool maybeSingleton(TypeId ty)
bool hasLength(TypeId ty, DenseHashSet<TypeId>& seen, int* recursionCount)
{
RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit, "hasLength");
ty = follow(ty);
@ -654,9 +654,9 @@ static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persi
static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true};
static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true};
static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true};
static TypeVar anyType_{AnyTypeVar{}};
static TypeVar errorType_{ErrorTypeVar{}};
static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}};
static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true};
static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true};
static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}, /*persistent*/ true};
static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true};
static TypePackVar errorTypePack_{Unifiable::Error{}};
@ -698,7 +698,7 @@ TypeId SingletonTypes::makeStringMetatable()
{
const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}});
const TypeId optionalString = arena->addType(UnionTypeVar{{nilType, stringType}});
const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, &booleanType_}});
const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, booleanType}});
const TypePackId oneStringPack = arena->addTypePack({stringType});
const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true});
@ -802,6 +802,7 @@ void persist(TypeId ty)
continue;
asMutable(t)->persistent = true;
asMutable(t)->normal = true; // all persistent types are assumed to be normal
if (auto btv = get<BoundTypeVar>(t))
queue.push_back(btv->boundTo);
@ -838,6 +839,11 @@ void persist(TypeId ty)
for (TypeId opt : itv->parts)
queue.push_back(opt);
}
else if (auto ctv = get<ConstrainedTypeVar>(t))
{
for (TypeId opt : ctv->parts)
queue.push_back(opt);
}
else if (auto mtv = get<MetatableTypeVar>(t))
{
queue.push_back(mtv->table);
@ -899,6 +905,16 @@ TypeLevel* getMutableLevel(TypeId ty)
return const_cast<TypeLevel*>(getLevel(ty));
}
std::optional<TypeLevel> getLevel(TypePackId tp)
{
tp = follow(tp);
if (auto ftv = get<Unifiable::Free>(tp))
return ftv->level;
else
return std::nullopt;
}
const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name)
{
while (cls)

View File

@ -14,9 +14,12 @@
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferTypePackLoopLimit);
LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000);
LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000);
LUAU_FASTFLAGVARIABLE(LuauExtendedIndexerError, false);
LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false);
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false)
LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false)
@ -27,6 +30,7 @@ LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false)
LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false)
LUAU_FASTFLAGVARIABLE(LuauUnifierCacheErrors, false)
LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional)
LUAU_FASTFLAG(LuauTypecheckOptPass)
namespace Luau
{
@ -126,7 +130,6 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel
visitTypeVarOnce(ty, ptl, seen);
}
// TODO: use this and make it static.
void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp)
{
// Type levels of types from other modules are already global, so we don't need to promote anything inside
@ -305,8 +308,7 @@ static std::optional<std::pair<Luau::Name, const SingletonTypeVar*>> getTableMat
return std::nullopt;
}
Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState,
TxnLog* parentLog)
Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog)
: types(types)
, mode(mode)
, log(parentLog)
@ -326,6 +328,7 @@ Unifier::Unifier(TypeArena* types, Mode mode, std::vector<std::pair<TypeOrPackId
, variance(variance)
, sharedState(sharedState)
{
LUAU_ASSERT(!FFlag::LuauTypecheckOptPass);
LUAU_ASSERT(sharedState.iceHandler);
}
@ -338,14 +341,26 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i
void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection)
{
RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(&sharedState.counters.recursionCount,
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "TypeId tryUnify_");
++sharedState.counters.iterationCount;
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
if (FFlag::LuauAutocompleteDynamicLimits)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
}
else
{
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
}
superTy = log.follow(superTy);
@ -354,6 +369,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (superTy == subTy)
return;
if (log.get<ConstrainedTypeVar>(superTy))
return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy);
auto superFree = log.getMutable<FreeTypeVar>(superTy);
auto subFree = log.getMutable<FreeTypeVar>(subTy);
@ -442,7 +460,18 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (get<ErrorTypeVar>(superTy) || get<AnyTypeVar>(superTy))
return tryUnifyWithAny(subTy, superTy);
if (get<ErrorTypeVar>(subTy) || get<AnyTypeVar>(subTy))
if (get<AnyTypeVar>(subTy))
{
if (anyIsTop)
{
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
return;
}
else
return tryUnifyWithAny(superTy, subTy);
}
if (get<ErrorTypeVar>(subTy))
return tryUnifyWithAny(superTy, subTy);
bool cacheEnabled;
@ -484,7 +513,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
size_t errorCount = errors.size();
if (const UnionTypeVar* uv = log.getMutable<UnionTypeVar>(subTy))
if (log.get<ConstrainedTypeVar>(subTy))
tryUnifyWithConstrainedSubTypeVar(subTy, superTy);
else if (const UnionTypeVar* uv = log.getMutable<UnionTypeVar>(subTy))
{
tryUnifyUnionWithType(subTy, uv, superTy);
}
@ -946,7 +977,7 @@ struct WeirdIter
LUAU_ASSERT(log.getMutable<TypePack>(newTail));
level = log.getMutable<Unifiable::Free>(packId)->level;
log.replace(packId, Unifiable::Bound<TypePackId>(newTail));
log.replace(packId, BoundTypePack(newTail));
packId = newTail;
pack = log.getMutable<TypePack>(newTail);
index = 0;
@ -994,39 +1025,32 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall
tryUnify_(subTp, superTp, isFunctionCall);
}
static std::pair<std::vector<TypeId>, std::optional<TypePackId>> logAwareFlatten(TypePackId tp, const TxnLog& log)
{
tp = log.follow(tp);
std::vector<TypeId> flattened;
std::optional<TypePackId> tail = std::nullopt;
TypePackIterator it(tp, &log);
for (; it != end(tp); ++it)
{
flattened.push_back(*it);
}
tail = it.tail();
return {flattened, tail};
}
/*
* This is quite tricky: we are walking two rope-like structures and unifying corresponding elements.
* If one is longer than the other, but the short end is free, we grow it to the required length.
*/
void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall)
{
RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(&sharedState.counters.recursionCount,
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "TypePackId tryUnify_");
++sharedState.counters.iterationCount;
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
if (FFlag::LuauAutocompleteDynamicLimits)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
}
else
{
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
}
superTp = log.follow(superTp);
@ -1087,8 +1111,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
// If the size of two heads does not match, but both packs have free tail
// We set the sentinel variable to say so to avoid growing it forever.
auto [superTypes, superTail] = logAwareFlatten(superTp, log);
auto [subTypes, subTail] = logAwareFlatten(subTp, log);
auto [superTypes, superTail] = flatten(superTp, log);
auto [subTypes, subTail] = flatten(subTp, log);
bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable<FreeTypePack>(*superTail)) &&
(subTail && log.getMutable<FreeTypePack>(*subTail));
@ -1165,19 +1189,20 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
else
{
// A union type including nil marks an optional argument
if (superIter.good() && isOptional(*superIter))
if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && superIter.good() && isOptional(*superIter))
{
superIter.advance();
continue;
}
else if (subIter.good() && isOptional(*subIter))
else if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && subIter.good() && isOptional(*subIter))
{
subIter.advance();
continue;
}
// In nonstrict mode, any also marks an optional argument.
else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && log.getMutable<AnyTypeVar>(log.follow(*superIter)))
else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() &&
log.getMutable<AnyTypeVar>(log.follow(*superIter)))
{
superIter.advance();
continue;
@ -1195,7 +1220,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
return;
}
if (!isFunctionCall && subIter.good())
if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && !isFunctionCall && subIter.good())
{
// Sometimes it is ok to pass too many arguments
return;
@ -1418,14 +1443,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (FFlag::LuauAnyInIsOptionalIsOptional)
{
if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type))
if (subIter == subTable->props.end() &&
(!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type))
missingProperties.push_back(propName);
}
else
{
bool isAny = log.getMutable<AnyTypeVar>(log.follow(superProp.type));
if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny)
if (subIter == subTable->props.end() &&
(!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) &&
!isAny)
missingProperties.push_back(propName);
}
}
@ -1438,8 +1466,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
}
// And vice versa if we're invariant
if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed &&
superTable->state != TableState::Free)
if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && superTable->state != TableState::Free)
{
for (const auto& [propName, subProp] : subTable->props)
{
@ -1453,7 +1480,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
else
{
bool isAny = log.is<AnyTypeVar>(log.follow(subProp.type));
if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny)))
if (superIter == superTable->props.end() &&
(FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny)))
extraProperties.push_back(propName);
}
}
@ -1499,13 +1527,15 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (innerState.errors.empty())
log.concat(std::move(innerState.log));
}
else if (FFlag::LuauAnyInIsOptionalIsOptional && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type))
else if (FFlag::LuauAnyInIsOptionalIsOptional &&
(!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type))
// This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }`
// since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`.
// TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?)
{
}
else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get<AnyTypeVar>(follow(prop.type))))
else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) &&
(isOptional(prop.type) || get<AnyTypeVar>(follow(prop.type))))
// This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }`
// since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`.
// TODO: should isOptional(anyType) be true?
@ -1664,9 +1694,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (FFlag::LuauTxnLogDontRetryForIndexers)
{
// Changing the indexer can invalidate the table pointers.
superTable = log.getMutable<TableTypeVar>(superTy);
subTable = log.getMutable<TableTypeVar>(subTy);
// Changing the indexer can invalidate the table pointers.
superTable = log.getMutable<TableTypeVar>(superTy);
subTable = log.getMutable<TableTypeVar>(subTy);
}
else if (FFlag::LuauTxnLogCheckForInvalidation)
{
@ -1921,8 +1951,6 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec
if (!superTable || !subTable)
ice("passed non-table types to unifySealedTables");
Unifier innerState = makeChildUnifier();
std::vector<std::string> missingPropertiesInSuper;
bool isUnnamedTable = subTable->name == std::nullopt && subTable->syntheticName == std::nullopt;
bool errorReported = false;
@ -1944,6 +1972,8 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec
}
}
Unifier innerState = makeChildUnifier();
// Tables must have exactly the same props and their types must all unify
for (const auto& it : superTable->props)
{
@ -2376,6 +2406,180 @@ std::optional<TypeId> Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N
return Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location);
}
void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy)
{
const ConstrainedTypeVar* subConstrained = get<ConstrainedTypeVar>(subTy);
if (!subConstrained)
ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!");
const std::vector<TypeId>& subTyParts = subConstrained->parts;
// A | B <: T if A <: T and B <: T
bool failed = false;
std::optional<TypeError> unificationTooComplex;
const size_t count = subTyParts.size();
for (size_t i = 0; i < count; ++i)
{
TypeId type = subTyParts[i];
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(type, superTy);
if (i == count - 1)
log.concat(std::move(innerState.log));
++i;
if (auto e = hasUnificationTooComplex(innerState.errors))
unificationTooComplex = e;
if (!innerState.errors.empty())
{
failed = true;
break;
}
}
if (unificationTooComplex)
reportError(*unificationTooComplex);
else if (failed)
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
else
log.replace(subTy, BoundTypeVar{superTy});
}
void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy)
{
ConstrainedTypeVar* superC = log.getMutable<ConstrainedTypeVar>(superTy);
if (!superC)
ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!");
// subTy could be a
// table
// metatable
// class
// function
// primitive
// free
// generic
// intersection
// union
// Do we really just tack it on? I think we might!
// We can certainly do some deduplication.
// Is there any point to deducing Player|Instance when we could just reduce to Instance?
// Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type?
// Maybe we do a simplification step during quantification.
auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy);
if (it != superC->parts.end())
return;
superC->parts.push_back(subTy);
}
void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy)
{
// The duplication between this and regular typepack unification is tragic.
auto superIter = begin(superTy, &log);
auto superEndIter = end(superTy);
auto subIter = begin(subTy, &log);
auto subEndIter = end(subTy);
int count = FInt::LuauTypeInferLowerBoundsIterationLimit;
for (; subIter != subEndIter; ++subIter)
{
if (0 >= --count)
ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound");
if (superIter != superEndIter)
{
tryUnify_(*subIter, *superIter);
++superIter;
continue;
}
if (auto t = superIter.tail())
{
TypePackId tailPack = follow(*t);
if (log.get<FreeTypePack>(tailPack))
occursCheck(tailPack, subTy);
FreeTypePack* freeTailPack = log.getMutable<FreeTypePack>(tailPack);
if (!freeTailPack)
return;
TypeLevel level = freeTailPack->level;
TypePack* tp = getMutable<TypePack>(log.replace(tailPack, TypePack{}));
for (; subIter != subEndIter; ++subIter)
{
tp->head.push_back(types->addType(ConstrainedTypeVar{level, {follow(*subIter)}}));
}
tp->tail = subIter.tail();
}
return;
}
if (superIter != superEndIter)
{
if (auto subTail = subIter.tail())
{
TypePackId subTailPack = follow(*subTail);
if (get<FreeTypePack>(subTailPack))
{
TypePack* tp = getMutable<TypePack>(log.replace(subTailPack, TypePack{}));
for (; superIter != superEndIter; ++superIter)
tp->head.push_back(*superIter);
}
}
else
{
while (superIter != superEndIter)
{
if (!isOptional(*superIter))
{
errors.push_back(TypeError{location, CountMismatch{size(superTy), size(subTy), CountMismatch::Return}});
return;
}
++superIter;
}
}
return;
}
// Both iters are at their respective tails
auto subTail = subIter.tail();
auto superTail = superIter.tail();
if (subTail && superTail)
tryUnify(*subTail, *superTail);
else if (subTail)
{
const FreeTypePack* freeSubTail = log.getMutable<FreeTypePack>(*subTail);
if (freeSubTail)
{
log.replace(*subTail, TypePack{});
}
}
else if (superTail)
{
const FreeTypePack* freeSuperTail = log.getMutable<FreeTypePack>(*superTail);
if (freeSuperTail)
{
log.replace(*superTail, TypePack{});
}
}
}
void Unifier::occursCheck(TypeId needle, TypeId haystack)
{
sharedState.tempSeenTy.clear();
@ -2385,7 +2589,8 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack)
void Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack)
{
RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(&sharedState.counters.recursionCount,
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypeId");
auto check = [&](TypeId tv) {
occursCheck(seen, needle, tv);
@ -2425,6 +2630,11 @@ void Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
for (TypeId ty : a->parts)
check(ty);
}
else if (auto a = log.getMutable<ConstrainedTypeVar>(haystack))
{
for (TypeId ty : a->parts)
check(ty);
}
}
void Unifier::occursCheck(TypePackId needle, TypePackId haystack)
@ -2450,7 +2660,8 @@ void Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
if (!log.getMutable<Unifiable::Free>(needle))
ice("Expected needle pack to be free");
RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit);
RecursionLimiter _ra(&sharedState.counters.recursionCount,
FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypePackId");
while (!log.getMutable<ErrorTypeVar>(haystack))
{
@ -2474,7 +2685,23 @@ void Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
Unifier Unifier::makeChildUnifier()
{
return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log};
if (FFlag::LuauTypecheckOptPass)
{
Unifier u = Unifier{types, mode, location, variance, sharedState, &log};
u.anyIsTop = anyIsTop;
return u;
}
Unifier u = Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log};
u.anyIsTop = anyIsTop;
return u;
}
// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
void Unifier::reportError(TypeError err)
{
errors.push_back(std::move(err));
}
bool Unifier::isNonstrictMode() const

View File

@ -32,6 +32,7 @@ class DenseHashTable
{
public:
class const_iterator;
class iterator;
DenseHashTable(const Key& empty_key, size_t buckets = 0)
: count(0)
@ -43,7 +44,7 @@ public:
// don't move this to initializer list! this works around an MSVC codegen issue on AMD CPUs:
// https://developercommunity.visualstudio.com/t/stdvector-constructor-from-size-t-is-25-times-slow/1546547
if (buckets)
data.resize(buckets, ItemInterface::create(empty_key));
resize_data<Item>(buckets);
}
void clear()
@ -125,7 +126,7 @@ public:
if (data.empty() && data.capacity() >= newsize)
{
LUAU_ASSERT(count == 0);
data.resize(newsize, ItemInterface::create(empty_key));
resize_data<Item>(newsize);
return;
}
@ -169,6 +170,21 @@ public:
return const_iterator(this, data.size());
}
iterator begin()
{
size_t start = 0;
while (start < data.size() && eq(ItemInterface::getKey(data[start]), empty_key))
start++;
return iterator(this, start);
}
iterator end()
{
return iterator(this, data.size());
}
size_t size() const
{
return count;
@ -233,7 +249,82 @@ public:
size_t index;
};
class iterator
{
public:
iterator()
: set(0)
, index(0)
{
}
iterator(DenseHashTable<Key, Item, MutableItem, ItemInterface, Hash, Eq>* set, size_t index)
: set(set)
, index(index)
{
}
MutableItem& operator*() const
{
return *reinterpret_cast<MutableItem*>(&set->data[index]);
}
MutableItem* operator->() const
{
return reinterpret_cast<MutableItem*>(&set->data[index]);
}
bool operator==(const iterator& other) const
{
return set == other.set && index == other.index;
}
bool operator!=(const iterator& other) const
{
return set != other.set || index != other.index;
}
iterator& operator++()
{
size_t size = set->data.size();
do
{
index++;
} while (index < size && set->eq(ItemInterface::getKey(set->data[index]), set->empty_key));
return *this;
}
iterator operator++(int)
{
iterator res = *this;
++*this;
return res;
}
private:
DenseHashTable<Key, Item, MutableItem, ItemInterface, Hash, Eq>* set;
size_t index;
};
private:
template<typename T>
void resize_data(size_t count, typename std::enable_if_t<std::is_copy_assignable_v<T>>* dummy = nullptr)
{
data.resize(count, ItemInterface::create(empty_key));
}
template<typename T>
void resize_data(size_t count, typename std::enable_if_t<!std::is_copy_assignable_v<T>>* dummy = nullptr)
{
size_t size = data.size();
data.resize(count);
for (size_t i = size; i < count; i++)
data[i].first = empty_key;
}
std::vector<Item> data;
size_t count;
Key empty_key;
@ -290,6 +381,7 @@ class DenseHashSet
public:
typedef typename Impl::const_iterator const_iterator;
typedef typename Impl::iterator iterator;
DenseHashSet(const Key& empty_key, size_t buckets = 0)
: impl(empty_key, buckets)
@ -336,6 +428,16 @@ public:
{
return impl.end();
}
iterator begin()
{
return impl.begin();
}
iterator end()
{
return impl.end();
}
};
// This is a faster alternative of unordered_map, but it does not implement the same interface (i.e. it does not support erasing and has
@ -348,6 +450,7 @@ class DenseHashMap
public:
typedef typename Impl::const_iterator const_iterator;
typedef typename Impl::iterator iterator;
DenseHashMap(const Key& empty_key, size_t buckets = 0)
: impl(empty_key, buckets)
@ -401,10 +504,21 @@ public:
{
return impl.begin();
}
const_iterator end() const
{
return impl.end();
}
iterator begin()
{
return impl.begin();
}
iterator end()
{
return impl.end();
}
};
} // namespace Luau

View File

@ -173,7 +173,7 @@ public:
}
const Lexeme& next();
const Lexeme& next(bool skipComments);
const Lexeme& next(bool skipComments, bool updatePrevLocation);
void nextline();
Lexeme lookahead();

View File

@ -349,13 +349,11 @@ void Lexer::setReadNames(bool read)
const Lexeme& Lexer::next()
{
return next(this->skipComments);
return next(this->skipComments, true);
}
const Lexeme& Lexer::next(bool skipComments)
const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation)
{
bool first = true;
// in skipComments mode we reject valid comments
do
{
@ -363,11 +361,11 @@ const Lexeme& Lexer::next(bool skipComments)
while (isSpace(peekch()))
consume();
if (!FFlag::LuauParseLocationIgnoreCommentSkip || first)
if (!FFlag::LuauParseLocationIgnoreCommentSkip || updatePrevLocation)
prevLocation = lexeme.location;
lexeme = readNext();
first = false;
updatePrevLocation = false;
} while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment));
return lexeme;

View File

@ -11,6 +11,7 @@
LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000)
LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauParseRecoverUnexpectedPack, false)
LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false)
namespace Luau
{
@ -2789,7 +2790,7 @@ void Parser::nextLexeme()
{
if (options.captureComments)
{
Lexeme::Type type = lexer.next(/* skipComments= */ false).type;
Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type;
while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment)
{
@ -2813,7 +2814,7 @@ void Parser::nextLexeme()
hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)});
}
type = lexer.next(/* skipComments= */ false).type;
type = lexer.next(/* skipComments= */ false, !FFlag::LuauParseLocationIgnoreCommentSkipInCapture).type;
}
}
else

View File

@ -1386,8 +1386,8 @@ struct Compiler
const Constant* cv = constants.find(expr->index);
if (cv && cv->type == Constant::Type_Number && double(int(cv->valueNumber)) == cv->valueNumber && cv->valueNumber >= 1 &&
cv->valueNumber <= 256)
if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 &&
double(int(cv->valueNumber)) == cv->valueNumber)
{
uint8_t rt = compileExprAuto(expr->expr, rs);
uint8_t i = uint8_t(int(cv->valueNumber) - 1);

258
Compiler/src/CostModel.cpp Normal file
View File

@ -0,0 +1,258 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "CostModel.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
namespace Luau
{
namespace Compile
{
inline uint64_t parallelAddSat(uint64_t x, uint64_t y)
{
uint64_t s = x + y;
uint64_t m = s & 0x8080808080808080ull; // saturation mask
return (s ^ m) | (m - (m >> 7));
}
struct Cost
{
static const uint64_t kLiteral = ~0ull;
// cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant
uint64_t model;
// constant mask: 8-byte 0xff mask; equal to all ff's for literals, for variables only byte #i (1+) is set to align with model
uint64_t constant;
Cost(int cost = 0, uint64_t constant = 0)
: model(cost < 0x7f ? cost : 0x7f)
, constant(constant)
{
}
Cost operator+(const Cost& other) const
{
Cost result;
result.model = parallelAddSat(model, other.model);
return result;
}
Cost& operator+=(const Cost& other)
{
model = parallelAddSat(model, other.model);
constant = 0;
return *this;
}
static Cost fold(const Cost& x, const Cost& y)
{
uint64_t newmodel = parallelAddSat(x.model, y.model);
uint64_t newconstant = x.constant & y.constant;
// the extra cost for folding is 1; the discount is 1 for the variable that is shared by x&y (or whichever one is used in x/y if the other is
// literal)
uint64_t extra = (newconstant == kLiteral) ? 0 : (1 | (0x0101010101010101ull & newconstant));
Cost result;
result.model = parallelAddSat(newmodel, extra);
result.constant = newconstant;
return result;
}
};
struct CostVisitor : AstVisitor
{
DenseHashMap<AstLocal*, uint64_t> vars;
Cost result;
CostVisitor()
: vars(nullptr)
{
}
Cost model(AstExpr* node)
{
if (AstExprGroup* expr = node->as<AstExprGroup>())
{
return model(expr->expr);
}
else if (node->is<AstExprConstantNil>() || node->is<AstExprConstantBool>() || node->is<AstExprConstantNumber>() ||
node->is<AstExprConstantString>())
{
return Cost(0, Cost::kLiteral);
}
else if (AstExprLocal* expr = node->as<AstExprLocal>())
{
const uint64_t* i = vars.find(expr->local);
return Cost(0, i ? *i : 0); // locals typically don't require extra instructions to compute
}
else if (node->is<AstExprGlobal>())
{
return 1;
}
else if (node->is<AstExprVarargs>())
{
return 3;
}
else if (AstExprCall* expr = node->as<AstExprCall>())
{
Cost cost = 3;
cost += model(expr->func);
for (size_t i = 0; i < expr->args.size; ++i)
{
Cost ac = model(expr->args.data[i]);
// for constants/locals we still need to copy them to the argument list
cost += ac.model == 0 ? Cost(1) : ac;
}
return cost;
}
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
{
return model(expr->expr) + 1;
}
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
{
return model(expr->expr) + model(expr->index) + 1;
}
else if (AstExprFunction* expr = node->as<AstExprFunction>())
{
return 10; // high baseline cost due to allocation
}
else if (AstExprTable* expr = node->as<AstExprTable>())
{
Cost cost = 10; // high baseline cost due to allocation
for (size_t i = 0; i < expr->items.size; ++i)
{
const AstExprTable::Item& item = expr->items.data[i];
if (item.key)
cost += model(item.key);
cost += model(item.value);
cost += 1;
}
return cost;
}
else if (AstExprUnary* expr = node->as<AstExprUnary>())
{
return Cost::fold(model(expr->expr), Cost(0, Cost::kLiteral));
}
else if (AstExprBinary* expr = node->as<AstExprBinary>())
{
return Cost::fold(model(expr->left), model(expr->right));
}
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
{
return model(expr->expr);
}
else if (AstExprIfElse* expr = node->as<AstExprIfElse>())
{
return model(expr->condition) + model(expr->trueExpr) + model(expr->falseExpr) + 2;
}
else
{
LUAU_ASSERT(!"Unknown expression type");
return {};
}
}
void assign(AstExpr* expr)
{
// variable assignments reset variable mask, so that further uses of this variable aren't discounted
// this doesn't work perfectly with backwards control flow like loops, but is good enough for a single pass
if (AstExprLocal* lv = expr->as<AstExprLocal>())
if (uint64_t* i = vars.find(lv->local))
*i = 0;
}
bool visit(AstExpr* node) override
{
// note: we short-circuit the visitor traversal through any expression trees by returning false
// recursive traversal is happening inside model() which makes it easier to get the resulting value of the subexpression
result += model(node);
return false;
}
bool visit(AstStat* node) override
{
if (node->is<AstStatIf>())
result += 2;
else if (node->is<AstStatWhile>() || node->is<AstStatRepeat>() || node->is<AstStatFor>() || node->is<AstStatForIn>())
result += 2;
else if (node->is<AstStatBreak>() || node->is<AstStatContinue>())
result += 1;
return true;
}
bool visit(AstStatLocal* node) override
{
for (size_t i = 0; i < node->values.size; ++i)
{
Cost arg = model(node->values.data[i]);
// propagate constant mask from expression through variables
if (arg.constant && i < node->vars.size)
vars[node->vars.data[i]] = arg.constant;
result += arg;
}
return false;
}
bool visit(AstStatAssign* node) override
{
for (size_t i = 0; i < node->vars.size; ++i)
assign(node->vars.data[i]);
return true;
}
bool visit(AstStatCompoundAssign* node) override
{
assign(node->var);
// if lhs is not a local, setting it requires an extra table operation
result += node->var->is<AstExprLocal>() ? 1 : 2;
return true;
}
};
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount)
{
CostVisitor visitor;
for (size_t i = 0; i < varCount && i < 7; ++i)
visitor.vars[vars[i]] = 0xffull << (i * 8 + 8);
root->visit(&visitor);
return visitor.result.model;
}
int computeCost(uint64_t model, const bool* varsConst, size_t varCount)
{
int cost = int(model & 0x7f);
// don't apply discounts to what is likely a saturated sum
if (cost == 0x7f)
return cost;
for (size_t i = 0; i < varCount && i < 7; ++i)
cost -= int((model >> (8 * i + 8)) & 0x7f) * varsConst[i];
return cost;
}
} // namespace Compile
} // namespace Luau

18
Compiler/src/CostModel.h Normal file
View File

@ -0,0 +1,18 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
namespace Luau
{
namespace Compile
{
// cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount);
// cost is computed as B - sum(Di * Ci), where B is baseline cost, Di is the discount for each variable and Ci is 1 when variable #i is constant
int computeCost(uint64_t model, const bool* varsConst, size_t varCount);
} // namespace Compile
} // namespace Luau

View File

@ -32,11 +32,13 @@ target_sources(Luau.Compiler PRIVATE
Compiler/src/Compiler.cpp
Compiler/src/Builtins.cpp
Compiler/src/ConstantFolding.cpp
Compiler/src/CostModel.cpp
Compiler/src/TableShape.cpp
Compiler/src/ValueTracking.cpp
Compiler/src/lcode.cpp
Compiler/src/Builtins.h
Compiler/src/ConstantFolding.h
Compiler/src/CostModel.h
Compiler/src/TableShape.h
Compiler/src/ValueTracking.h
)
@ -58,6 +60,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/LValue.h
Analysis/include/Luau/Module.h
Analysis/include/Luau/ModuleResolver.h
Analysis/include/Luau/Normalize.h
Analysis/include/Luau/Predicate.h
Analysis/include/Luau/Quantify.h
Analysis/include/Luau/RecursionCounter.h
@ -94,6 +97,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Linter.cpp
Analysis/src/LValue.cpp
Analysis/src/Module.cpp
Analysis/src/Normalize.cpp
Analysis/src/Quantify.cpp
Analysis/src/RequireTracer.cpp
Analysis/src/Scope.cpp
@ -216,6 +220,7 @@ if(TARGET Luau.UnitTest)
tests/Autocomplete.test.cpp
tests/BuiltinDefinitions.test.cpp
tests/Compiler.test.cpp
tests/CostModel.test.cpp
tests/Config.test.cpp
tests/Error.test.cpp
tests/Frontend.test.cpp
@ -224,6 +229,7 @@ if(TARGET Luau.UnitTest)
tests/LValue.test.cpp
tests/Module.test.cpp
tests/NonstrictMode.test.cpp
tests/Normalize.test.cpp
tests/Parser.test.cpp
tests/RequireTracer.test.cpp
tests/StringUtils.test.cpp

View File

@ -34,7 +34,7 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false)
LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary, false)
LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary2, false)
// max size of both array and hash part is 2^MAXBITS
#define MAXBITS 26
@ -390,6 +390,8 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize)
setarrayvector(L, t, nasize);
/* create new hash part with appropriate size */
setnodevector(L, t, nhsize);
/* used for the migration check at the end */
LuaNode* nnew = t->node;
if (nasize < oldasize)
{ /* array part must shrink? */
t->sizearray = nasize;
@ -413,6 +415,8 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize)
/* shrink array */
luaM_reallocarray(L, t->array, oldasize, nasize, TValue, t->memcat);
}
/* used for the migration check at the end */
TValue* anew = t->array;
/* re-insert elements from hash part */
if (FFlag::LuauTableRehashRework)
{
@ -441,14 +445,30 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize)
}
}
/* make sure we haven't recursively rehashed during element migration */
LUAU_ASSERT(nnew == t->node);
LUAU_ASSERT(anew == t->array);
if (nold != dummynode)
luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); /* free old array */
}
static int adjustasize(Table* t, int size, const TValue* ek)
{
LUAU_ASSERT(FFlag::LuauTableNewBoundary2);
bool tbound = t->node != dummynode || size < t->sizearray;
int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1;
/* move the array size up until the boundary is guaranteed to be inside the array part */
while (size + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, size + 1))))
size++;
return size;
}
void luaH_resizearray(lua_State* L, Table* t, int nasize)
{
int nsize = (t->node == dummynode) ? 0 : sizenode(t);
resize(L, t, nasize, nsize);
int asize = FFlag::LuauTableNewBoundary2 ? adjustasize(t, nasize, NULL) : nasize;
resize(L, t, asize, nsize);
}
void luaH_resizehash(lua_State* L, Table* t, int nhsize)
@ -470,21 +490,12 @@ static void rehash(lua_State* L, Table* t, const TValue* ek)
totaluse++;
/* compute new size for array part */
int na = computesizes(nums, &nasize);
int nh = totaluse - na;
/* enforce the boundary invariant; for performance, only do hash lookups if we must */
if (FFlag::LuauTableNewBoundary)
{
bool tbound = t->node != dummynode || nasize < t->sizearray;
int ekindex = ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1;
/* move the array size up until the boundary is guaranteed to be inside the array part */
while (nasize + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, nasize + 1))))
{
nasize++;
na++;
}
}
if (FFlag::LuauTableNewBoundary2)
nasize = adjustasize(t, nasize, ek);
/* resize the table to new computed sizes */
LUAU_ASSERT(na <= totaluse);
resize(L, t, nasize, totaluse - na);
resize(L, t, nasize, nh);
}
/*
@ -544,7 +555,7 @@ static LuaNode* getfreepos(Table* t)
static TValue* newkey(lua_State* L, Table* t, const TValue* key)
{
/* enforce boundary invariant */
if (FFlag::LuauTableNewBoundary && ttisnumber(key) && nvalue(key) == t->sizearray + 1)
if (FFlag::LuauTableNewBoundary2 && ttisnumber(key) && nvalue(key) == t->sizearray + 1)
{
rehash(L, t, key); /* grow table */
@ -735,7 +746,7 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key)
static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j)
{
LUAU_ASSERT(!FFlag::LuauTableNewBoundary);
LUAU_ASSERT(!FFlag::LuauTableNewBoundary2);
unsigned int i = j; /* i is zero or a present index */
j++;
/* find `i' and `j' such that i is present and j is not */
@ -820,7 +831,7 @@ int luaH_getn(Table* t)
maybesetaboundary(t, boundary);
return boundary;
}
else if (FFlag::LuauTableNewBoundary)
else if (FFlag::LuauTableNewBoundary2)
{
/* validate boundary invariant */
LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1)));

View File

@ -199,7 +199,7 @@ static int tmove(lua_State* L)
int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */
luaL_checktype(L, tt, LUA_TTABLE);
void (*telemetrycb)(lua_State* L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry;
void (*telemetrycb)(lua_State * L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry;
if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb && e >= f)
{

View File

@ -16,7 +16,7 @@
#include <string.h>
LUAU_FASTFLAG(LuauTableNewBoundary)
LUAU_FASTFLAG(LuauTableNewBoundary2)
// Disable c99-designator to avoid the warning in CGOTO dispatch table
#ifdef __clang__
@ -2268,7 +2268,7 @@ static void luau_execute(lua_State* L)
VM_NEXT();
}
}
else if (FFlag::LuauTableNewBoundary || (h->lsizenode == 0 && ttisnil(gval(h->node))))
else if (FFlag::LuauTableNewBoundary2 || (h->lsizenode == 0 && ttisnil(gval(h->node))))
{
// fallthrough to exit
VM_NEXT();

101
tests/CostModel.test.cpp Normal file
View File

@ -0,0 +1,101 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "doctest.h"
using namespace Luau;
namespace Luau
{
namespace Compile
{
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount);
int computeCost(uint64_t model, const bool* varsConst, size_t varCount);
} // namespace Compile
} // namespace Luau
TEST_SUITE_BEGIN("CostModel");
static uint64_t modelFunction(const char* source)
{
Allocator allocator;
AstNameTable names(allocator);
ParseResult result = Parser::parse(source, strlen(source), names, allocator);
REQUIRE(result.root != nullptr);
AstStatFunction* func = result.root->body.data[0]->as<AstStatFunction>();
REQUIRE(func);
return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size);
}
TEST_CASE("Expression")
{
uint64_t model = modelFunction(R"(
function test(a, b, c)
return a + (b + 1) * (b + 1) - c
end
)");
const bool args1[] = {false, false, false};
const bool args2[] = {false, true, false};
CHECK_EQ(5, Luau::Compile::computeCost(model, args1, 3));
CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 3));
}
TEST_CASE("PropagateVariable")
{
uint64_t model = modelFunction(R"(
function test(a)
local b = a * a * a
return b * b
end
)");
const bool args1[] = {false};
const bool args2[] = {true};
CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1));
CHECK_EQ(0, Luau::Compile::computeCost(model, args2, 1));
}
TEST_CASE("LoopAssign")
{
uint64_t model = modelFunction(R"(
function test(a)
for i=1,3 do
a[i] = i
end
end
)");
const bool args1[] = {false};
const bool args2[] = {true};
// loop baseline cost is 2
CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1));
CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1));
}
TEST_CASE("MutableVariable")
{
uint64_t model = modelFunction(R"(
function test(a, b)
local x = a * a
x += b
return x * x
end
)");
const bool args1[] = {false};
const bool args2[] = {true};
CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1));
CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1));
}
TEST_SUITE_END();

View File

@ -231,7 +231,7 @@ ModulePtr Fixture::getMainModule()
SourceModule* Fixture::getMainSourceModule()
{
return frontend.getSourceModule(fromString("MainModule"));
return frontend.getSourceModule(fromString(mainModuleName));
}
std::optional<PrimitiveTypeVar::Type> Fixture::getPrimitiveType(TypeId ty)
@ -259,7 +259,7 @@ std::optional<TypeId> Fixture::getType(const std::string& name)
TypeId Fixture::requireType(const std::string& name)
{
std::optional<TypeId> ty = getType(name);
REQUIRE(bool(ty));
REQUIRE_MESSAGE(bool(ty), "Unable to requireType \"" << name << "\"");
return follow(*ty);
}

View File

@ -68,7 +68,9 @@ TEST_CASE("encode_tables")
REQUIRE(parseResult.errors.size() == 0);
std::string json = toJson(parseResult.root);
CHECK(json == R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})");
CHECK(
json ==
R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})");
}
TEST_SUITE_END();

View File

@ -597,8 +597,6 @@ return foo1
TEST_CASE_FIXTURE(Fixture, "UnknownType")
{
ScopedFastFlag sff("LuauLintNoRobloxBits", true);
unfreeze(typeChecker.globalTypes);
TableTypeVar::Props instanceProps{
{"ClassName", {typeChecker.anyType}},
@ -1439,6 +1437,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi")
{
unfreeze(typeChecker.globalTypes);
TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}});
persist(instanceType);
typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType};
getMutable<ClassTypeVar>(instanceType)->props = {

View File

@ -2,6 +2,7 @@
#include "Luau/Clone.h"
#include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/RecursionCounter.h"
#include "Fixture.h"
@ -9,6 +10,8 @@
using namespace Luau;
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
TEST_SUITE_BEGIN("ModuleTests");
TEST_CASE_FIXTURE(Fixture, "is_within_comment")
@ -42,29 +45,23 @@ TEST_CASE_FIXTURE(Fixture, "is_within_comment")
TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive")
{
TypeArena dest;
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
// numberType is persistent. We leave it as-is.
TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks, cloneState);
TypeId newNumber = clone(typeChecker.numberType, dest, cloneState);
CHECK_EQ(newNumber, typeChecker.numberType);
}
TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive")
{
TypeArena dest;
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
// Create a new number type that isn't persistent
unfreeze(typeChecker.globalTypes);
TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number});
freeze(typeChecker.globalTypes);
TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks, cloneState);
TypeId newNumber = clone(oldNumber, dest, cloneState);
CHECK_NE(newNumber, oldNumber);
CHECK_EQ(*oldNumber, *newNumber);
@ -90,12 +87,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table")
TypeId counterType = requireType("Cyclic");
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
TypeArena dest;
TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks, cloneState);
CloneState cloneState;
TypeId counterCopy = clone(counterType, dest, cloneState);
TableTypeVar* ttv = getMutable<TableTypeVar>(counterCopy);
REQUIRE(ttv != nullptr);
@ -112,8 +106,11 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table")
REQUIRE(methodReturnType);
CHECK_EQ(methodReturnType, counterCopy);
CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type
CHECK_EQ(2, dest.typeVars.size()); // One table and one function
if (FFlag::LuauLowerBoundsCalculation)
CHECK_EQ(3, dest.typePacks.size()); // function args, its return type, and the hidden any... pack
else
CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type
CHECK_EQ(2, dest.typeVars.size()); // One table and one function
}
TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena")
@ -143,15 +140,12 @@ TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena")
TEST_CASE_FIXTURE(Fixture, "deepClone_union")
{
TypeArena dest;
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
unfreeze(typeChecker.globalTypes);
TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}});
freeze(typeChecker.globalTypes);
TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks, cloneState);
TypeId newUnion = clone(oldUnion, dest, cloneState);
CHECK_NE(newUnion, oldUnion);
CHECK_EQ("number | string", toString(newUnion));
@ -161,15 +155,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union")
TEST_CASE_FIXTURE(Fixture, "deepClone_intersection")
{
TypeArena dest;
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
unfreeze(typeChecker.globalTypes);
TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}});
freeze(typeChecker.globalTypes);
TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks, cloneState);
TypeId newIntersection = clone(oldIntersection, dest, cloneState);
CHECK_NE(newIntersection, oldIntersection);
CHECK_EQ("number & string", toString(newIntersection));
@ -191,12 +182,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_class")
std::nullopt, &exampleMetaClass, {}, {}}};
TypeArena dest;
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks, cloneState);
TypeId cloned = clone(&exampleClass, dest, cloneState);
const ClassTypeVar* ctv = get<ClassTypeVar>(cloned);
REQUIRE(ctv != nullptr);
@ -216,16 +204,14 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types")
TypePackVar freeTp(FreeTypePack{TypeLevel{}});
TypeArena dest;
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, cloneState);
TypeId clonedTy = clone(&freeTy, dest, cloneState);
CHECK_EQ("any", toString(clonedTy));
CHECK(cloneState.encounteredFreeType);
cloneState = {};
TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, cloneState);
TypePackId clonedTp = clone(&freeTp, dest, cloneState);
CHECK_EQ("...any", toString(clonedTp));
CHECK(cloneState.encounteredFreeType);
}
@ -237,16 +223,32 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables")
ttv->state = TableState::Free;
TypeArena dest;
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, cloneState);
TypeId cloned = clone(&tableTy, dest, cloneState);
const TableTypeVar* clonedTtv = get<TableTypeVar>(cloned);
CHECK_EQ(clonedTtv->state, TableState::Sealed);
CHECK(cloneState.encounteredFreeType);
}
TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection")
{
TypeArena src;
TypeId constrained = src.addType(ConstrainedTypeVar{TypeLevel{}, {getSingletonTypes().numberType, getSingletonTypes().stringType}});
TypeArena dest;
CloneState cloneState;
TypeId cloned = clone(constrained, dest, cloneState);
CHECK_NE(constrained, cloned);
const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(cloned);
REQUIRE_EQ(2, ctv->parts.size());
CHECK_EQ(getSingletonTypes().numberType, ctv->parts[0]);
CHECK_EQ(getSingletonTypes().stringType, ctv->parts[1]);
}
TEST_CASE_FIXTURE(Fixture, "clone_self_property")
{
ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true};
@ -284,6 +286,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit")
int limit = 400;
#endif
ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit};
ScopedFastFlag sff{"LuauRecursionLimitException", true};
TypeArena src;
@ -299,11 +302,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit")
}
TypeArena dest;
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
CHECK_THROWS_AS(clone(table, dest, seenTypes, seenTypePacks, cloneState), std::runtime_error);
CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException);
}
TEST_SUITE_END();

View File

@ -275,4 +275,38 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok")
REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType));
}
TEST_CASE_FIXTURE(Fixture, "returning_insufficient_return_values")
{
CheckResult result = check(R"(
--!nonstrict
function foo(): (boolean, string?)
if true then
return true, "hello"
else
return false
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "returning_too_many_values")
{
CheckResult result = check(R"(
--!nonstrict
function foo(): boolean
if true then
return true, "hello"
else
return false
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

967
tests/Normalize.test.cpp Normal file
View File

@ -0,0 +1,967 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Fixture.h"
#include "doctest.h"
#include "Luau/Normalize.h"
#include "Luau/BuiltinDefinitions.h"
using namespace Luau;
struct NormalizeFixture : Fixture
{
ScopedFastFlag sff1{"LuauLowerBoundsCalculation", true};
ScopedFastFlag sff2{"LuauTableSubtypingVariance2", true};
};
void createSomeClasses(TypeChecker& typeChecker)
{
auto& arena = typeChecker.globalTypes;
unfreeze(arena);
TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr});
ClassTypeVar* parentClass = getMutable<ClassTypeVar>(parentType);
parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})};
parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})};
addGlobalBinding(typeChecker, "Parent", {parentType});
typeChecker.globalScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType};
TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr});
ClassTypeVar* childClass = getMutable<ClassTypeVar>(childType);
childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})};
addGlobalBinding(typeChecker, "Child", {childType});
typeChecker.globalScope->exportedTypeBindings["Child"] = TypeFun{{}, childType};
TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr});
addGlobalBinding(typeChecker, "Unrelated", {unrelatedType});
typeChecker.globalScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType};
freeze(arena);
}
static bool isSubtype(TypeId a, TypeId b)
{
InternalErrorReporter ice;
return isSubtype(a, b, ice);
}
TEST_SUITE_BEGIN("isSubtype");
TEST_CASE_FIXTURE(NormalizeFixture, "primitives")
{
check(R"(
local a = 41
local b = 32
local c = "hello"
local d = "world"
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
TypeId d = requireType("d");
CHECK(isSubtype(b, a));
CHECK(isSubtype(d, c));
CHECK(!isSubtype(d, a));
}
TEST_CASE_FIXTURE(NormalizeFixture, "functions")
{
check(R"(
function a(x: number): number return x end
function b(x: number): number return x end
function c(x: number?): number return x end
function d(x: number): number? return x end
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
TypeId d = requireType("d");
CHECK(isSubtype(b, a));
CHECK(isSubtype(c, a));
CHECK(!isSubtype(d, a));
CHECK(isSubtype(a, d));
}
TEST_CASE_FIXTURE(NormalizeFixture, "functions_and_any")
{
check(R"(
function a(n: number) return "string" end
function b(q: any) return 5 :: any end
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
// Intuition:
// We cannot use b where a is required because we cannot rely on b to return a string.
// We cannot use a where b is required because we cannot rely on a to accept non-number arguments.
CHECK(!isSubtype(b, a));
CHECK(!isSubtype(a, b));
}
TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arities")
{
check(R"(
type A = (any) -> ()
type B = (any, any) -> ()
type T = A & B
local a: A
local b: B
local t: T
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
CHECK(!isSubtype(a, b)); // !!
CHECK(!isSubtype(b, a));
CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity")
{
check(R"(
local a: (number) -> ()
local b: () -> ()
local c: () -> number
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
CHECK(!isSubtype(b, a));
CHECK(!isSubtype(c, a));
CHECK(!isSubtype(a, b));
CHECK(!isSubtype(c, b));
CHECK(!isSubtype(a, c));
CHECK(!isSubtype(b, c));
}
TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_optional_parameters")
{
/*
* (T0..TN) <: (T0..TN, A?)
* (T0..TN) <: (T0..TN, any)
* (T0..TN, A?) </: (T0..TN) We don't technically need to spell this out, but it's quite important.
* T <: T
* if A <: B and B <: C then A <: C
* T -> R <: U -> S if U <: T and R <: S
* A | B <: T if A <: T and B <: T
* T <: A | B if T <: A or T <: B
*/
check(R"(
local a: (number?) -> ()
local b: (number) -> ()
local c: (number, number?) -> ()
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
/*
* (number) -> () </: (number?) -> ()
* because number? </: number (because number <: number, but nil </: number)
*/
CHECK(!isSubtype(b, a));
/*
* (number, number?) </: (number?) -> ()
* because number? </: number (as above)
*/
CHECK(!isSubtype(c, a));
/*
* (number?) -> () <: (number) -> ()
* because number <: number? (because number <: number)
*/
CHECK(isSubtype(a, b));
/*
* (number, number?) -> () <: (number) -> (number)
* The packs have inequal lengths, but (number) <: (number, number?)
* and number <: number
*/
CHECK(!isSubtype(c, b));
/*
* (number?) -> () </: (number, number?) -> ()
* because (number, number?) </: (number)
*/
CHECK(!isSubtype(a, c));
/*
* (number) -> () </: (number, number?) -> ()
* because (number, number?) </: (number)
*/
CHECK(!isSubtype(b, c));
}
TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_any_is_an_optional_param")
{
check(R"(
local a: (number?) -> ()
local b: (number) -> ()
local c: (number, any) -> ()
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
/*
* (number) -> () </: (number?) -> ()
* because number? </: number (because number <: number, but nil </: number)
*/
CHECK(!isSubtype(b, a));
/*
* (number, any) </: (number?) -> ()
* because number? </: number (as above)
*/
CHECK(!isSubtype(c, a));
/*
* (number?) -> () <: (number) -> ()
* because number <: number? (because number <: number)
*/
CHECK(isSubtype(a, b));
/*
* (number, any) -> () </: (number) -> (number)
* The packs have inequal lengths
*/
CHECK(!isSubtype(c, b));
/*
* (number?) -> () </: (number, any) -> ()
* The packs have inequal lengths
*/
CHECK(!isSubtype(a, c));
/*
* (number) -> () </: (number, any) -> ()
* The packs have inequal lengths
*/
CHECK(!isSubtype(b, c));
}
TEST_CASE_FIXTURE(NormalizeFixture, "variadic_functions_with_no_head")
{
check(R"(
local a: (...number) -> ()
local b: (...number?) -> ()
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
CHECK(isSubtype(b, a));
CHECK(!isSubtype(a, b));
}
#if 0
TEST_CASE_FIXTURE(NormalizeFixture, "variadic_function_with_head")
{
check(R"(
local a: (...number) -> ()
local b: (number, number) -> ()
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
CHECK(!isSubtype(b, a));
CHECK(isSubtype(a, b));
}
#endif
TEST_CASE_FIXTURE(NormalizeFixture, "union")
{
check(R"(
local a: number | string
local b: number
local c: string
local d: number?
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
TypeId d = requireType("d");
CHECK(isSubtype(b, a));
CHECK(!isSubtype(a, b));
CHECK(isSubtype(c, a));
CHECK(!isSubtype(a, c));
CHECK(!isSubtype(d, a));
CHECK(!isSubtype(a, d));
CHECK(isSubtype(b, d));
CHECK(!isSubtype(d, b));
}
TEST_CASE_FIXTURE(NormalizeFixture, "table_with_union_prop")
{
check(R"(
local a: {x: number}
local b: {x: number?}
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
CHECK(isSubtype(a, b));
CHECK(!isSubtype(b, a));
}
TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop")
{
check(R"(
local a: {x: number}
local b: {x: any}
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
CHECK(isSubtype(a, b));
CHECK(!isSubtype(b, a));
}
TEST_CASE_FIXTURE(NormalizeFixture, "intersection")
{
check(R"(
local a: number & string
local b: number
local c: string
local d: number & nil
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
TypeId d = requireType("d");
CHECK(!isSubtype(b, a));
CHECK(isSubtype(a, b));
CHECK(!isSubtype(c, a));
CHECK(isSubtype(a, c));
CHECK(!isSubtype(d, a));
CHECK(!isSubtype(a, d));
}
TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection")
{
check(R"(
local a: number & string
local b: number | nil
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
CHECK(!isSubtype(b, a));
CHECK(isSubtype(a, b));
}
TEST_CASE_FIXTURE(NormalizeFixture, "table_with_table_prop")
{
check(R"(
type T = {x: {y: number}} & {x: {y: string}}
local a: T
)");
CHECK_EQ("{| x: {| y: number & string |} |}", toString(requireType("a")));
}
#if 0
TEST_CASE_FIXTURE(NormalizeFixture, "tables")
{
check(R"(
local a: {x: number}
local b: {x: any}
local c: {y: number}
local d: {x: number, y: number}
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
TypeId d = requireType("d");
CHECK(isSubtype(a, b));
CHECK(!isSubtype(b, a));
CHECK(!isSubtype(c, a));
CHECK(!isSubtype(a, c));
CHECK(isSubtype(d, a));
CHECK(!isSubtype(a, d));
CHECK(isSubtype(d, b));
CHECK(!isSubtype(b, d));
}
TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant")
{
check(R"(
local a: {[string]: number}
local b: {[string]: any}
local c: {[string]: number}
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
CHECK(!isSubtype(b, a));
CHECK(!isSubtype(a, b));
CHECK(isSubtype(c, a));
CHECK(isSubtype(a, c));
}
TEST_CASE_FIXTURE(NormalizeFixture, "mismatched_indexers")
{
check(R"(
local a: {x: number}
local b: {[string]: number}
local c: {}
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
CHECK(isSubtype(b, a));
CHECK(!isSubtype(a, b));
CHECK(!isSubtype(c, b));
CHECK(isSubtype(b, c));
}
TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_table")
{
check(R"(
type A = {method: (A) -> ()}
local a: A
type B = {method: (any) -> ()}
local b: B
type C = {method: (C) -> ()}
local c: C
type D = {method: (D) -> (), another: (D) -> ()}
local d: D
type E = {method: (A) -> (), another: (E) -> ()}
local e: E
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
TypeId c = requireType("c");
TypeId d = requireType("d");
TypeId e = requireType("e");
CHECK(isSubtype(b, a));
CHECK(!isSubtype(a, b));
CHECK(isSubtype(c, a));
CHECK(isSubtype(a, c));
CHECK(!isSubtype(d, a));
CHECK(!isSubtype(a, d));
CHECK(isSubtype(e, a));
CHECK(!isSubtype(a, e));
}
#endif
TEST_CASE_FIXTURE(NormalizeFixture, "classes")
{
createSomeClasses(typeChecker);
TypeId p = typeChecker.globalScope->lookupType("Parent")->type;
TypeId c = typeChecker.globalScope->lookupType("Child")->type;
TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type;
CHECK(isSubtype(c, p));
CHECK(!isSubtype(p, c));
CHECK(!isSubtype(u, p));
CHECK(!isSubtype(p, u));
}
#if 0
TEST_CASE_FIXTURE(NormalizeFixture, "metatable" * doctest::expected_failures{1})
{
check(R"(
local T = {}
T.__index = T
function T.new()
return setmetatable({}, T)
end
function T:method() end
local a: typeof(T.new)
local b: {method: (any) -> ()}
)");
TypeId a = requireType("a");
TypeId b = requireType("b");
CHECK(isSubtype(a, b));
}
#endif
TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_tables")
{
check(R"(
type T = {x: number} & ({x: number} & {y: string?})
local t: T
)");
CHECK("{| x: number, y: string? |}" == toString(requireType("t")));
}
TEST_SUITE_END();
TEST_SUITE_BEGIN("Normalize");
TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_disjoint_tables")
{
check(R"(
type T = {a: number} & {b: number}
local t: T
)");
CHECK_EQ("{| a: number, b: number |}", toString(requireType("t")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_overlapping_tables")
{
check(R"(
type T = {a: number, b: string} & {b: number, c: string}
local t: T
)");
CHECK_EQ("{| a: number, b: number & string, c: string |}", toString(requireType("t")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_confluent_overlapping_tables")
{
check(R"(
type T = {a: number, b: string} & {b: string, c: string}
local t: T
)");
CHECK_EQ("{| a: number, b: string, c: string |}", toString(requireType("t")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_subtype_relationship")
{
check(R"(
local t: {x: number} | {x: number?}
)");
ModulePtr tempModule{new Module};
// HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze
// the arena that the type lives in.
ModulePtr mainModule = getMainModule();
unfreeze(mainModule->internalTypes);
TypeId tType = requireType("t");
normalize(tType, tempModule, *typeChecker.iceHandler);
CHECK_EQ("{| x: number? |}", toString(tType, {true}));
}
TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions")
{
check(R"(
type T = ((any) -> string) & ((number) -> string)
local t: T
)");
CHECK_EQ("(any) -> string", toString(requireType("t")));
}
TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", true},
};
check(R"(
--!nonstrict
if Math.random() then
return function(initialState, handlers)
return function(state, action)
return state
end
end
else
return function(initialState, handlers)
return function(state, action)
return state
end
end
end
)");
CHECK_EQ("(any, any) -> (...any)", toString(getMainModule()->getModuleScope()->returnType));
}
TEST_CASE_FIXTURE(Fixture, "return_type_is_not_a_constrained_intersection")
{
check(R"(
function foo(x:number, y:number)
return x + y
end
)");
CHECK_EQ("(number, number) -> number", toString(requireType("foo")));
}
TEST_CASE_FIXTURE(Fixture, "higher_order_function")
{
check(R"(
function apply(f, x)
return f(x)
end
local a = apply(function(x: number) return x + x end, 5)
)");
TypeId aType = requireType("a");
CHECK_MESSAGE(isNumber(follow(aType)), "Expected a number but got ", toString(aType));
}
TEST_CASE_FIXTURE(Fixture, "higher_order_function_with_annotation")
{
check(R"(
function apply<a, b>(f: (a) -> b, x)
return f(x)
end
)");
CHECK_EQ("<a, b>((a) -> b, a) -> b", toString(requireType("apply")));
}
TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_marked_normal")
{
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
check(R"(
type Fiber = {
return_: Fiber?
}
local f: Fiber
)");
TypeId t = requireType("f");
CHECK(t->normal);
}
TEST_CASE_FIXTURE(Fixture, "variadic_tail_is_marked_normal")
{
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
CheckResult result = check(R"(
type Weirdo = (...{x: number}) -> ()
local w: Weirdo
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId t = requireType("w");
auto ftv = get<FunctionTypeVar>(t);
REQUIRE(ftv);
auto [argHead, argTail] = flatten(ftv->argTypes);
CHECK(argHead.empty());
REQUIRE(argTail.has_value());
auto vtp = get<VariadicTypePack>(*argTail);
REQUIRE(vtp);
CHECK(vtp->ty->normal);
}
TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly")
{
CheckResult result = check(R"(
local Cyclic = {}
function Cyclic.get()
return Cyclic
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId ty = requireType("Cyclic");
CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true}));
}
TEST_CASE_FIXTURE(Fixture, "union_of_distinct_free_types")
{
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
CheckResult result = check(R"(
function fussy(a, b)
if math.random() > 0.5 then
return a
else
return b
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK("<a, b>(a, b) -> a | b" == toString(requireType("fussy")));
}
TEST_CASE_FIXTURE(Fixture, "constrained_intersection_of_intersections")
{
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
CheckResult result = check(R"(
local f : (() -> number) | ((number) -> number)
local g : (() -> number) | ((string) -> number)
function h()
if math.random() then
return f
else
return g
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId h = requireType("h");
CHECK("() -> (() -> number) | ((number) -> number) | ((string) -> number)" == toString(h));
}
TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection")
{
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
CheckResult result = check(R"(
type X = {}
type Y = {y: number}
type Z = {z: string}
type W = {w: boolean}
type T = {x: Y & X} & {x:Z & W}
local x: X
local y: Y
local z: Z
local w: W
local t: T
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK("{| |}" == toString(requireType("x"), {true}));
CHECK("{| y: number |}" == toString(requireType("y"), {true}));
CHECK("{| z: string |}" == toString(requireType("z"), {true}));
CHECK("{| w: boolean |}" == toString(requireType("w"), {true}));
CHECK("{| x: {| w: boolean, y: number, z: string |} |}" == toString(requireType("t"), {true}));
}
TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_2")
{
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
// We use a function and inferred parameter types to prevent intermediate normalizations from being performed.
// This exposes a bug where the type of y is mutated.
CheckResult result = check(R"(
function strange(w, x, y, z)
y.y = 5
z.z = "five"
w.w = true
type Z = {x: typeof(x) & typeof(y)} & {x: typeof(w) & typeof(z)}
return ((nil :: any) :: Z)
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId t = requireType("strange");
auto ftv = get<FunctionTypeVar>(t);
REQUIRE(ftv != nullptr);
std::vector<TypeId> args = flatten(ftv->argTypes).first;
REQUIRE(4 == args.size());
CHECK("{+ w: boolean +}" == toString(args[0]));
CHECK("a" == toString(args[1]));
CHECK("{+ y: number +}" == toString(args[2]));
CHECK("{+ z: string +}" == toString(args[3]));
std::vector<TypeId> ret = flatten(ftv->retType).first;
REQUIRE(1 == ret.size());
CHECK("{| x: a & {- w: boolean, y: number, z: string -} |}" == toString(ret[0]));
}
TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_3")
{
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
// We use a function and inferred parameter types to prevent intermediate normalizations from being performed.
// This exposes a bug where the type of y is mutated.
CheckResult result = check(R"(
function strange(x, y, z)
x.x = true
y.y = y
z.z = "five"
type Z = {x: typeof(y)} & {x: typeof(x) & typeof(z)}
return ((nil :: any) :: Z)
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId t = requireType("strange");
auto ftv = get<FunctionTypeVar>(t);
REQUIRE(ftv != nullptr);
std::vector<TypeId> args = flatten(ftv->argTypes).first;
REQUIRE(3 == args.size());
CHECK("{+ x: boolean +}" == toString(args[0]));
CHECK("t1 where t1 = {+ y: t1 +}" == toString(args[1]));
CHECK("{+ z: string +}" == toString(args[2]));
std::vector<TypeId> ret = flatten(ftv->retType).first;
REQUIRE(1 == ret.size());
CHECK("{| x: {- x: boolean, y: t1, z: string -} |} where t1 = {+ y: t1 +}" == toString(ret[0]));
}
TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_4")
{
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
// We use a function and inferred parameter types to prevent intermediate normalizations from being performed.
// This exposes a bug where the type of y is mutated.
CheckResult result = check(R"(
function strange(x, y, z)
x.x = true
z.z = "five"
type R = {x: typeof(y)} & {x: typeof(x) & typeof(z)}
local r: R
y.y = r
return r
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId t = requireType("strange");
auto ftv = get<FunctionTypeVar>(t);
REQUIRE(ftv != nullptr);
std::vector<TypeId> args = flatten(ftv->argTypes).first;
REQUIRE(3 == args.size());
CHECK("{+ x: boolean +}" == toString(args[0]));
CHECK("{+ y: t1 +} where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(args[1]));
CHECK("{+ z: string +}" == toString(args[2]));
std::vector<TypeId> ret = flatten(ftv->retType).first;
REQUIRE(1 == ret.size());
CHECK("t1 where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(ret[0]));
}
TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice")
{
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
{"LuauNormalizeCombineTableFix", true},
};
// CLI-52787
// ends up combining {_:any} with any, recursively
// which used to ICE because this combines a table with a non-table.
CheckResult result = check(R"(
export type t0 = any & { _: {_:any} } & { _:any }
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow")
{
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
{"LuauNormalizeCombineIntersectionFix", true},
};
CheckResult result = check(R"(
export type t0 = {_:{_:any} & {_:any|string}} & {_:{_:{}}}
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View File

@ -1618,6 +1618,26 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments")
CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end);
}
TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture")
{
ScopedFastFlag luauParseLocationIgnoreCommentSkip{"LuauParseLocationIgnoreCommentSkip", true};
ScopedFastFlag luauParseLocationIgnoreCommentSkipInCapture{"LuauParseLocationIgnoreCommentSkipInCapture", true};
// Same should hold when comments are captured
ParseOptions opts;
opts.captureComments = true;
AstStatBlock* block = parse(R"(
type F = number
--comment
print('hello')
)",
opts);
REQUIRE_EQ(2, block->body.size);
CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end);
}
TEST_CASE_FIXTURE(Fixture, "parse_error_loop_control")
{
matchParseError("break", "break statement must be inside a loop");

View File

@ -7,6 +7,8 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
using namespace Luau;
struct ToDotClassFixture : Fixture
@ -101,9 +103,34 @@ local function f(a, ...: string) return a end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("<a>(a, ...string) -> a", toString(requireType("f")));
ToDotOptions opts;
opts.showPointers = false;
CHECK_EQ(R"(digraph graphname {
if (FFlag::LuauLowerBoundsCalculation)
{
CHECK_EQ(R"(digraph graphname {
n1 [label="FunctionTypeVar 1"];
n1 -> n2 [label="arg"];
n2 [label="TypePack 2"];
n2 -> n3;
n3 [label="GenericTypeVar 3"];
n2 -> n4 [label="tail"];
n4 [label="VariadicTypePack 4"];
n4 -> n5;
n5 [label="string"];
n1 -> n6 [label="ret"];
n6 [label="TypePack 6"];
n6 -> n7;
n7 [label="BoundTypeVar 7"];
n7 -> n3;
})",
toDot(requireType("f"), opts));
}
else
{
CHECK_EQ(R"(digraph graphname {
n1 [label="FunctionTypeVar 1"];
n1 -> n2 [label="arg"];
n2 [label="TypePack 2"];
@ -119,7 +146,8 @@ n6 -> n7;
n7 [label="TypePack 7"];
n7 -> n3;
})",
toDot(requireType("f"), opts));
toDot(requireType("f"), opts));
}
}
TEST_CASE_FIXTURE(Fixture, "union")
@ -361,4 +389,49 @@ n3 [label="number"];
toDot(*ty, opts));
}
TEST_CASE_FIXTURE(Fixture, "constrained")
{
// ConstrainedTypeVars never appear in the final type graph, so we have to create one directly
// to dotify it.
TypeVar t{ConstrainedTypeVar{TypeLevel{}, {typeChecker.numberType, typeChecker.stringType, typeChecker.nilType}}};
ToDotOptions opts;
opts.showPointers = false;
CHECK_EQ(R"(digraph graphname {
n1 [label="ConstrainedTypeVar 1"];
n1 -> n2;
n2 [label="number"];
n1 -> n3;
n3 [label="string"];
n1 -> n4;
n4 [label="nil"];
})",
toDot(&t, opts));
}
TEST_CASE_FIXTURE(Fixture, "singletontypes")
{
CheckResult result = check(R"(
local x: "hi" | "\"hello\"" | true | false
)");
ToDotOptions opts;
opts.showPointers = false;
CHECK_EQ(R"(digraph graphname {
n1 [label="UnionTypeVar 1"];
n1 -> n2;
n2 [label="SingletonTypeVar string: hi"];
n1 -> n3;
)"
"n3 [label=\"SingletonTypeVar string: \\\"hello\\\"\"];"
R"(
n1 -> n4;
n4 [label="SingletonTypeVar boolean: true"];
n1 -> n5;
n5 [label="SingletonTypeVar boolean: false"];
})", toDot(requireType("x"), opts));
}
TEST_SUITE_END();

View File

@ -9,6 +9,8 @@
using namespace Luau;
LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction);
TEST_SUITE_BEGIN("ToString");
TEST_CASE_FIXTURE(Fixture, "primitive")

View File

@ -340,26 +340,28 @@ TEST_CASE_FIXTURE(Fixture, "nested_type_annotations_depends_on_later_typealiases
TEST_CASE_FIXTURE(Fixture, "return_comes_last")
{
CheckResult result = check(R"(
export type Module = { bar: (number) -> boolean, foo: () -> string }
AstStatBlock* program = parse(R"(
local module = {}
return function() : Module
local module = {}
local function confuseCompiler() return module.foo() end
local function confuseCompiler() return module.foo() end
module.foo = function() return "" end
module.foo = function() return "" end
function module.bar(x:number)
confuseCompiler()
return true
end
function module.bar(x:number)
confuseCompiler()
return true
end
return module
end
return module
)");
LUAU_REQUIRE_NO_ERRORS(result);
auto sorted = toposort(*program);
CHECK_EQ(sorted[0], program->body.data[0]);
CHECK_EQ(sorted[2], program->body.data[1]);
CHECK_EQ(sorted[1], program->body.data[2]);
CHECK_EQ(sorted[3], program->body.data[3]);
CHECK_EQ(sorted[4], program->body.data[4]);
}
TEST_CASE_FIXTURE(Fixture, "break_comes_last")

View File

@ -388,7 +388,7 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly")
std::string actual = decorateWithTypes(code);
CHECK_EQ(expected, decorateWithTypes(code));
CHECK_EQ(expected, actual);
}
TEST_CASE_FIXTURE(Fixture, "function_type_location")

View File

@ -753,4 +753,14 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_intersection_typevar")
REQUIRE(ocf);
}
TEST_CASE_FIXTURE(Fixture, "instantiation_clone_has_to_follow")
{
CheckResult result = check(R"(
export type t8<t8> = (t0)&(<t0...>((true)|(any))->"")
export type t0<t0> = ({})&({_:{[any]:number},})
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_SUITE_END();

View File

@ -8,6 +8,8 @@
using namespace Luau;
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
TEST_SUITE_BEGIN("BuiltinTests");
TEST_CASE_FIXTURE(Fixture, "math_things_are_defined")
@ -557,9 +559,9 @@ TEST_CASE_FIXTURE(Fixture, "xpcall")
)");
LUAU_REQUIRE_NO_ERRORS(result);
REQUIRE_EQ("boolean", toString(requireType("a")));
REQUIRE_EQ("number", toString(requireType("b")));
REQUIRE_EQ("boolean", toString(requireType("c")));
CHECK_EQ("boolean", toString(requireType("a")));
CHECK_EQ("number", toString(requireType("b")));
CHECK_EQ("boolean", toString(requireType("c")));
}
TEST_CASE_FIXTURE(Fixture, "see_thru_select")
@ -881,7 +883,10 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types")
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f")));
if (FFlag::LuauLowerBoundsCalculation)
CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f")));
else
CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f")));
}
TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2")

View File

@ -91,6 +91,9 @@ struct ClassFixture : Fixture
typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType};
addGlobalBinding(typeChecker, "Vector2", vector2Type, "@test");
for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings)
persist(tf.type);
freeze(arena);
}
};

View File

@ -13,6 +13,8 @@
using namespace Luau;
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
TEST_SUITE_BEGIN("TypeInferFunctions");
TEST_CASE_FIXTURE(Fixture, "tc_function")
@ -98,7 +100,7 @@ TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified")
end
return result
end
end
return T
)");
@ -274,6 +276,10 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets")
TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", true},
};
CheckResult result = check(R"(
function f(g)
return f(f)
@ -281,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args")
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f")));
CHECK_EQ("t1 where t1 = <a...>(t1) -> (a...)", toString(requireType("f")));
}
TEST_CASE_FIXTURE(Fixture, "another_higher_order_function")
@ -481,10 +487,10 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function")
std::vector<TypeId> fArgs = flatten(fType->argTypes).first;
TypeId xType = argVec[1];
TypeId xType = follow(argVec[1]);
CHECK_EQ(1, fArgs.size());
CHECK_EQ(xType, fArgs[0]);
CHECK_EQ(xType, follow(fArgs[0]));
}
TEST_CASE_FIXTURE(Fixture, "higher_order_function_2")
@ -1043,13 +1049,16 @@ f(function(x) return x * 2 end)
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0]));
// Return type doesn't inference 'nil'
result = check(R"(
function f(a: (number) -> nil) return a(4) end
f(function(x) print(x) end)
)");
if (!FFlag::LuauLowerBoundsCalculation)
{
// Return type doesn't inference 'nil'
result = check(R"(
function f(a: (number) -> nil) return a(4) end
f(function(x) print(x) end)
)");
LUAU_REQUIRE_NO_ERRORS(result);
LUAU_REQUIRE_NO_ERRORS(result);
}
}
TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments")
@ -1142,13 +1151,16 @@ f(function(x) return x * 2 end)
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0]));
// Return type doesn't inference 'nil'
result = check(R"(
function f(a: (number) -> nil) return a(4) end
f(function(x) print(x) end)
)");
if (!FFlag::LuauLowerBoundsCalculation)
{
// Return type doesn't inference 'nil'
result = check(R"(
function f(a: (number) -> nil) return a(4) end
f(function(x) print(x) end)
)");
LUAU_REQUIRE_NO_ERRORS(result);
LUAU_REQUIRE_NO_ERRORS(result);
}
}
TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call")
@ -1338,6 +1350,126 @@ end
CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')");
}
TEST_CASE_FIXTURE(Fixture, "inconsistent_return_types")
{
const ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
CheckResult result = check(R"(
function foo(a: boolean, b: number)
if a then
return nil
else
return b
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("(boolean, number) -> number?", toString(requireType("foo")));
// TODO: Test multiple returns
// Think of various cases where typepacks need to grow. maybe consult other tests
// Basic normalization of ConstrainedTypeVars during quantification
}
TEST_CASE_FIXTURE(Fixture, "inconsistent_higher_order_function")
{
const ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
CheckResult result = check(R"(
function foo(f)
f(5)
f("six")
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("<a...>((number | string) -> (a...)) -> ()", toString(requireType("foo")));
}
/* The bug here is that we are using the same level 2.0 for both the body of resolveDispatcher and the
* lambda useCallback.
*
* I think what we want to do is, at each scope level, never reuse the same sublevel.
*
* We also adjust checkBlock to consider the syntax `local x = function() ... end` to be sortable
* in the same way as `local function x() ... end`. This causes the function `resolveDispatcher` to be
* checked before the lambda.
*/
TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", true},
};
CheckResult result = check(R"(
--!strict
local function resolveDispatcher()
return (nil :: any) :: {useCallback: (any) -> any}
end
local useCallback = function(deps: any)
return resolveDispatcher().useCallback(deps)
end
)");
// LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken.
// You get a TypeMismatch error where both types stringify the same.
CHECK(result.errors.empty());
if (!result.errors.empty())
{
for (const auto& e : result.errors)
printf("%s: %s\n", toString(e.location).c_str(), toString(e).c_str());
}
}
TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time2")
{
CheckResult result = check(R"(
--!strict
local function resolveDispatcher()
return (nil :: any) :: {useContext: (number?) -> any}
end
local useContext
useContext = function(unstable_observedBits: number?)
resolveDispatcher().useContext(unstable_observedBits)
end
)");
// LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken.
// You get a TypeMismatch error where both types stringify the same.
CHECK(result.errors.empty());
if (!result.errors.empty())
{
for (const auto& e : result.errors)
printf("%s %s: %s\n", e.moduleName.c_str(), toString(e.location).c_str(), toString(e).c_str());
}
}
TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time3")
{
CheckResult result = check(R"(
local foo
foo():bar(function()
return foo()
end)
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_unsealed_overwrite")
{
ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true};
@ -1471,4 +1603,17 @@ pcall(wrapper, test)
CHECK(acm->isVariadic);
}
TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type")
{
CheckResult result = check(R"(
function f()
return 5, f()
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(nullptr != get<OccursCheckFailed>(result.errors[0]));
}
TEST_SUITE_END();

View File

@ -230,8 +230,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function")
CHECK_EQ(idFun->generics.size(), 1);
CHECK_EQ(idFun->genericPacks.size(), 0);
CHECK_EQ(args[0], idFun->generics[0]);
CHECK_EQ(rets[0], idFun->generics[0]);
CHECK_EQ(follow(args[0]), follow(idFun->generics[0]));
CHECK_EQ(follow(rets[0]), follow(idFun->generics[0]));
}
TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function")
@ -253,8 +253,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function")
CHECK_EQ(idFun->generics.size(), 1);
CHECK_EQ(idFun->genericPacks.size(), 0);
CHECK_EQ(args[0], idFun->generics[0]);
CHECK_EQ(rets[0], idFun->generics[0]);
CHECK_EQ(follow(args[0]), follow(idFun->generics[0]));
CHECK_EQ(follow(rets[0]), follow(idFun->generics[0]));
}
TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function")
@ -705,10 +705,10 @@ end
TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe")
{
ScopedFastFlag sffs[] = {
{ "LuauTableSubtypingVariance2", true },
{ "LuauUnsealedTableLiteral", true },
{ "LuauPropertiesGetExpectedType", true },
{ "LuauRecursiveTypeParameterRestriction", true },
{"LuauTableSubtypingVariance2", true},
{"LuauUnsealedTableLiteral", true},
{"LuauPropertiesGetExpectedType", true},
{"LuauRecursiveTypeParameterRestriction", true},
};
CheckResult result = check(R"(
@ -843,6 +843,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_function")
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("<a>(a) -> a", toString(requireType("id")));
CHECK_EQ(*typeChecker.numberType, *requireType("a"));
CHECK_EQ(*typeChecker.nilType, *requireType("b"));
}
@ -1037,25 +1038,39 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument")
ScopedFastFlag sff{"LuauUnsealedTableLiteral", true};
CheckResult result = check(R"(
local function sum<a>(x: a, y: a, f: (a, a) -> a) return f(x, y) end
return sum(2, 3, function(a, b) return a + b end)
local function sum<a>(x: a, y: a, f: (a, a) -> a)
return f(x, y)
end
return sum(2, 3, function(a, b) return a + b end)
)");
LUAU_REQUIRE_NO_ERRORS(result);
result = check(R"(
local function map<a, b>(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end
local a = {1, 2, 3}
local r = map(a, function(a) return a + a > 100 end)
local function map<a, b>(arr: {a}, f: (a) -> b)
local r = {}
for i,v in ipairs(arr) do
table.insert(r, f(v))
end
return r
end
local a = {1, 2, 3}
local r = map(a, function(a) return a + a > 100 end)
)");
LUAU_REQUIRE_NO_ERRORS(result);
REQUIRE_EQ("{boolean}", toString(requireType("r")));
check(R"(
local function foldl<a, b>(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end
local a = {1, 2, 3}
local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end)
local function foldl<a, b>(arr: {a}, init: b, f: (b, a) -> b)
local r = init
for i,v in ipairs(arr) do
r = f(r, v)
end
return r
end
local a = {1, 2, 3}
local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end)
)");
LUAU_REQUIRE_NO_ERRORS(result);
@ -1065,25 +1080,19 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e
TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded")
{
CheckResult result = check(R"(
local function g1<T>(a: T, f: (T) -> T) return f(a) end
local function g2<T>(a: T, b: T, f: (T, T) -> T) return f(a, b) end
local g12: (<T>(T, (T) -> T) -> T) & (<T>(T, T, (T, T) -> T) -> T)
local g12: typeof(g1) & typeof(g2)
g12(1, function(x) return x + x end)
g12(1, 2, function(x, y) return x + y end)
g12(1, function(x) return x + x end)
g12(1, 2, function(x, y) return x + y end)
)");
LUAU_REQUIRE_NO_ERRORS(result);
result = check(R"(
local function g1<T>(a: T, f: (T) -> T) return f(a) end
local function g2<T>(a: T, b: T, f: (T, T) -> T) return f(a, b) end
local g12: (<T>(T, (T) -> T) -> T) & (<T>(T, T, (T, T) -> T) -> T)
local g12: typeof(g1) & typeof(g2)
g12({x=1}, function(x) return {x=-x.x} end)
g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end)
g12({x=1}, function(x) return {x=-x.x} end)
g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end)
)");
LUAU_REQUIRE_NO_ERRORS(result);
@ -1121,12 +1130,12 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i
TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table")
{
CheckResult result = check(R"(
type A = { x: number }
local a: A = { x = 1 }
local b = a
type B = typeof(b)
type X<T> = T
local c: X<B>
type A = { x: number }
local a: A = { x = 1 }
local b = a
type B = typeof(b)
type X<T> = T
local c: X<B>
)");
LUAU_REQUIRE_NO_ERRORS(result);

View File

@ -8,6 +8,8 @@
using namespace Luau;
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
TEST_SUITE_BEGIN("IntersectionTypes");
TEST_CASE_FIXTURE(Fixture, "select_correct_union_fn")
@ -306,7 +308,10 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed")
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'");
if (FFlag::LuauLowerBoundsCalculation)
CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table '{| x: number, y: number |}'");
else
CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'");
}
TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect")
@ -314,27 +319,34 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect")
ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true};
CheckResult result = check(R"(
type X = { x: (number) -> number }
type Y = { y: (string) -> string }
type X = { x: (number) -> number }
type Y = { y: (string) -> string }
type XY = X & Y
type XY = X & Y
local xy : XY = {
x = function(a: number) return -a end,
y = function(a: string) return a .. "b" end
}
function xy.z(a:number) return a * 10 end
function xy:y(a:number) return a * 10 end
function xy:w(a:number) return a * 10 end
local xy : XY = {
x = function(a: number) return -a end,
y = function(a: string) return a .. "b" end
}
function xy.z(a:number) return a * 10 end
function xy:y(a:number) return a * 10 end
function xy:w(a:number) return a * 10 end
)");
LUAU_REQUIRE_ERROR_COUNT(4, result);
CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string'
caused by:
Argument count mismatch. Function expects 2 arguments, but only 1 is specified)");
CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'");
if (FFlag::LuauLowerBoundsCalculation)
CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table '{| x: (number) -> number, y: (string) -> string |}'");
else
CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'");
CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'");
CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'");
if (FFlag::LuauLowerBoundsCalculation)
CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table '{| x: (number) -> number, y: (string) -> string |}'");
else
CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'");
}
TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect")
@ -375,6 +387,8 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable")
TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part")
{
ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}};
CheckResult result = check(R"(
type X = { x: number }
type Y = { y: number }
@ -393,6 +407,8 @@ caused by:
TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all")
{
ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}};
CheckResult result = check(R"(
type X = { x: number }
type Y = { y: number }
@ -427,8 +443,8 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection")
repeat
type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0)
function _(l0):(t0)&(t0)
while nil do
end
while nil do
end
end
until _(_)(_)._
)");

View File

@ -199,16 +199,16 @@ end
TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail")
{
CheckResult result = check(R"(
--!nonstrict
local f = {}
function f:foo(a: number, b: number) end
--!nonstrict
local f = {}
function f:foo(a: number, b: number) end
function bar(...)
f.foo(f, 1, ...)
end
function bar(...)
f.foo(f, 1, ...)
end
bar(2)
)");
bar(2)
)");
LUAU_REQUIRE_NO_ERRORS(result);
}

View File

@ -91,7 +91,8 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable")
const FunctionTypeVar* functionType = get<FunctionTypeVar>(requireType("add"));
std::optional<TypeId> retType = first(functionType->retType);
CHECK_EQ(std::optional<TypeId>(typeChecker.numberType), retType);
REQUIRE(retType.has_value());
CHECK_EQ(typeChecker.numberType, follow(*retType));
CHECK_EQ(requireType("n"), typeChecker.numberType);
CHECK_EQ(requireType("s"), typeChecker.stringType);
}

View File

@ -8,6 +8,7 @@
#include <algorithm>
LUAU_FASTFLAG(LuauEqConstraint)
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
using namespace Luau;
@ -527,6 +528,7 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table
LUAU_REQUIRE_NO_ERRORS(result);
}
// FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation
TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack")
{
ScopedFastFlag sff[]{
@ -556,10 +558,19 @@ TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type
LUAU_REQUIRE_NO_ERRORS(result);
// f and g should have the type () -> ()
CHECK_EQ("() -> (a...)", toString(requireType("f")));
CHECK_EQ("<a...>() -> (a...)", toString(requireType("g")));
CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now
if (FFlag::LuauLowerBoundsCalculation)
{
CHECK_EQ("() -> ()", toString(requireType("f")));
CHECK_EQ("() -> ()", toString(requireType("g")));
CHECK_EQ("nil", toString(requireType("x")));
}
else
{
// f and g should have the type () -> ()
CHECK_EQ("() -> (a...)", toString(requireType("f")));
CHECK_EQ("<a...>() -> (a...)", toString(requireType("g")));
CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now
}
}
TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early")
@ -575,6 +586,10 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early")
TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", false},
};
CheckResult result = check(R"(
local function f() return end
local g = function() return f() end
@ -585,6 +600,10 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack")
TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", false},
};
CheckResult result = check(R"(
--!strict
local function f(...) return ... end
@ -594,4 +613,112 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack")
LUAU_REQUIRE_ERRORS(result); // Should not have any errors.
}
TEST_CASE_FIXTURE(Fixture, "lower_bounds_calculation_is_too_permissive_with_overloaded_higher_order_functions")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", true},
};
CheckResult result = check(R"(
function foo(f)
f(5, 'a')
f('b', 6)
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
// We incorrectly infer that the argument to foo could be called with (number, number) or (string, string)
// even though that is strictly more permissive than the actual source text shows.
CHECK("<a...>((number | string, number | string) -> (a...)) -> ()" == toString(requireType("foo")));
}
// Once fixed, move this to Normalize.test.cpp
TEST_CASE_FIXTURE(Fixture, "normalization_fails_on_certain_kinds_of_cyclic_tables")
{
#if defined(_DEBUG) || defined(_NOOPT)
ScopedFastInt sfi("LuauNormalizeIterationLimit", 500);
#endif
ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true},
};
// We use a function and inferred parameter types to prevent intermediate normalizations from being performed.
// This exposes a bug where the type of y is mutated.
CheckResult result = check(R"(
function strange(x, y)
x.x = y
y.x = x
type R = {x: typeof(x)} & {x: typeof(y)}
local r: R
return r
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(nullptr != get<NormalizationTooComplex>(result.errors[0]));
}
// Belongs in TypeInfer.builtins.test.cpp.
TEST_CASE_FIXTURE(Fixture, "pcall_returns_at_least_two_value_but_function_returns_nothing")
{
CheckResult result = check(R"(
local function f(): () end
local ok, res = pcall(f)
)");
LUAU_REQUIRE_ERRORS(result);
// LUAU_REQUIRE_NO_ERRORS(result);
// CHECK_EQ("boolean", toString(requireType("ok")));
// CHECK_EQ("any", toString(requireType("res")));
}
// Belongs in TypeInfer.builtins.test.cpp.
TEST_CASE_FIXTURE(Fixture, "choose_the_right_overload_for_pcall")
{
CheckResult result = check(R"(
local function f(): number
if math.random() > 0.5 then
return 5
else
error("something")
end
end
local ok, res = pcall(f)
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("boolean", toString(requireType("ok")));
CHECK_EQ("number", toString(requireType("res")));
// CHECK_EQ("any", toString(requireType("res")));
}
// Belongs in TypeInfer.builtins.test.cpp.
TEST_CASE_FIXTURE(Fixture, "function_returns_many_things_but_first_of_it_is_forgotten")
{
CheckResult result = check(R"(
local function f(): (number, string, boolean)
if math.random() > 0.5 then
return 5, "hello", true
else
error("something")
end
end
local ok, res, s, b = pcall(f)
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("boolean", toString(requireType("ok")));
CHECK_EQ("number", toString(requireType("res")));
// CHECK_EQ("any", toString(requireType("res")));
CHECK_EQ("string", toString(requireType("s")));
CHECK_EQ("boolean", toString(requireType("b")));
}
TEST_SUITE_END();

View File

@ -1,4 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Normalize.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
@ -8,6 +9,7 @@
LUAU_FASTFLAG(LuauDiscriminableUnions2)
LUAU_FASTFLAG(LuauWeakEqConstraint)
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
using namespace Luau;
@ -48,6 +50,7 @@ struct RefinementClassFixture : Fixture
{"Y", Property{typeChecker.numberType}},
{"Z", Property{typeChecker.numberType}},
};
normalize(vec3, arena, *typeChecker.iceHandler);
TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr});
@ -55,17 +58,21 @@ struct RefinementClassFixture : Fixture
TypePackId isARets = arena.addTypePack({typeChecker.booleanType});
TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets});
getMutable<FunctionTypeVar>(isA)->magicFunction = magicFunctionInstanceIsA;
normalize(isA, arena, *typeChecker.iceHandler);
getMutable<ClassTypeVar>(inst)->props = {
{"Name", Property{typeChecker.stringType}},
{"IsA", Property{isA}},
};
normalize(inst, arena, *typeChecker.iceHandler);
TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr});
normalize(folder, arena, *typeChecker.iceHandler);
TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr});
getMutable<ClassTypeVar>(part)->props = {
{"Position", Property{vec3}},
};
normalize(part, arena, *typeChecker.iceHandler);
typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3};
typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst};
@ -697,7 +704,10 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables")
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28})));
if (FFlag::LuauLowerBoundsCalculation)
CHECK_EQ("{| x: number, y: number |}", toString(requireTypeAtPosition({4, 28})));
else
CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28})));
CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28})));
}

View File

@ -5,8 +5,6 @@
#include "doctest.h"
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAG(BetterDiagnosticCodesInStudio)
using namespace Luau;
TEST_SUITE_BEGIN("TypeSingletons");
@ -261,14 +259,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer")
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::BetterDiagnosticCodesInStudio)
{
CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0]));
}
else
{
CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0]));
}
CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes")

View File

@ -11,6 +11,8 @@
using namespace Luau;
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
TEST_SUITE_BEGIN("TableTests");
TEST_CASE_FIXTURE(Fixture, "basic")
@ -1211,7 +1213,10 @@ TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_c
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(get<TypeMismatch>(result.errors[0]));
if (FFlag::LuauLowerBoundsCalculation)
CHECK(get<MissingProperties>(result.errors[0]));
else
CHECK(get<TypeMismatch>(result.errors[0]));
}
// This unit test could be flaky if the fix has regressed.
@ -2922,6 +2927,60 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the
LUAU_REQUIRE_NO_ERRORS(result);
}
// The real bug here was that we weren't always uncondionally typechecking a trailing return statement last.
TEST_CASE_FIXTURE(Fixture, "dont_leak_free_table_props")
{
CheckResult result = check(R"(
local function a(state)
print(state.blah)
end
local function b(state) -- The bug was that we inferred state: {blah: any, gwar: any}
print(state.gwar)
end
return function()
return function(state)
a(state)
b(state)
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("<a>({+ blah: a +}) -> ()", toString(requireType("a")));
CHECK_EQ("<a>({+ gwar: a +}) -> ()", toString(requireType("b")));
CHECK_EQ("() -> <a, b>({+ blah: a, gwar: b +}) -> ()", toString(getMainModule()->getModuleScope()->returnType));
}
TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", true},
};
check(R"(
function Base64FileReader(data)
local reader = {}
local index: number
function reader:PeekByte()
return data:byte(index)
end
function reader:Byte()
return data:byte(index - 1)
end
return reader
end
)");
CHECK_EQ("<a...>(t1) -> {| Byte: <b>(b) -> (a...), PeekByte: <c>(c) -> (a...) |} where t1 = {+ byte: (t1, number) -> (a...) +}",
toString(requireType("Base64FileReader")));
}
TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys")
{
ScopedFastFlag sff{"LuauCheckImplicitNumbericKeys", true};

View File

@ -13,6 +13,7 @@
#include <algorithm>
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr)
LUAU_FASTFLAG(LuauEqConstraint)
@ -177,7 +178,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_case")
)");
LUAU_REQUIRE_NO_ERRORS(result);
dumpErrors(result);
}
TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check")
@ -293,7 +293,7 @@ TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types")
// In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type
// checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead.
TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit")
TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count")
{
#if defined(LUAU_ENABLE_ASAN)
int limit = 250;
@ -302,12 +302,14 @@ TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit")
#else
int limit = 600;
#endif
ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100};
ScopedFastInt luauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", limit - 100};
ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", 0};
CHECK_NOTHROW(check("print('Hello!')"));
CHECK_THROWS_AS(check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"), std::runtime_error);
ScopedFastFlag sff{"LuauTableUseCounterInstead", true};
ScopedFastInt sfi{"LuauCheckRecursionLimit", limit};
CheckResult result = check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(nullptr != get<CodeTooComplex>(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "check_block_recursion_limit")
@ -721,9 +723,9 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error")
local l0
do end
while _ do
function _:_()
_ += _(_._(_:n0(xpcall,_)))
end
function _:_()
_ += _(_._(_:n0(xpcall,_)))
end
end
)");
@ -978,4 +980,48 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error")
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice")
{
ScopedFastInt sfi("LuauTypeInferRecursionLimit", 2);
ScopedFastFlag sff{"LuauRecursionLimitException", true};
CheckResult result = check(R"(
function complex()
function _(l0:t0): (any, ()->())
return 0,_
end
type t0 = t0 | {}
_(nil)
end
)");
LUAU_REQUIRE_ERRORS(result);
CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution")
{
ScopedFastFlag substituteFollowNewTypes{"LuauSubstituteFollowNewTypes", true};
CheckResult result = check(R"(
local obj = {}
function obj:Method()
self.fieldA = function(object)
if object.a then
self.arr[object] = true
elseif object.b then
self.fieldB[object] = object:Connect(function(arg)
self.arr[arg] = nil
end)
end
end
end
return obj
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View File

@ -9,6 +9,8 @@
using namespace Luau;
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
TEST_SUITE_BEGIN("TypePackTests");
TEST_CASE_FIXTURE(Fixture, "infer_multi_return")
@ -27,8 +29,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return")
const auto& [returns, tail] = flatten(takeTwoType->retType);
CHECK_EQ(2, returns.size());
CHECK_EQ(typeChecker.numberType, returns[0]);
CHECK_EQ(typeChecker.numberType, returns[1]);
CHECK_EQ(typeChecker.numberType, follow(returns[0]));
CHECK_EQ(typeChecker.numberType, follow(returns[1]));
CHECK(!tail);
}
@ -74,9 +76,9 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac
const auto& [rets, tail] = flatten(takeOneMoreType->retType);
REQUIRE_EQ(3, rets.size());
CHECK_EQ(typeChecker.numberType, rets[0]);
CHECK_EQ(typeChecker.numberType, rets[1]);
CHECK_EQ(typeChecker.numberType, rets[2]);
CHECK_EQ(typeChecker.numberType, follow(rets[0]));
CHECK_EQ(typeChecker.numberType, follow(rets[1]));
CHECK_EQ(typeChecker.numberType, follow(rets[2]));
CHECK(!tail);
}
@ -91,26 +93,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function")
LUAU_REQUIRE_NO_ERRORS(result);
const FunctionTypeVar* applyType = get<FunctionTypeVar>(requireType("apply"));
REQUIRE(applyType != nullptr);
std::vector<TypeId> applyArgs = flatten(applyType->argTypes).first;
REQUIRE_EQ(3, applyArgs.size());
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(applyArgs[0]));
REQUIRE(fType != nullptr);
const FunctionTypeVar* gType = get<FunctionTypeVar>(follow(applyArgs[1]));
REQUIRE(gType != nullptr);
std::vector<TypeId> gArgs = flatten(gType->argTypes).first;
REQUIRE_EQ(1, gArgs.size());
// function(function(t1, T2...): (t3, T4...), function(t5): (t1, T2...), t5): (t3, T4...)
REQUIRE_EQ(*gArgs[0], *applyArgs[2]);
REQUIRE_EQ(toString(fType->argTypes), toString(gType->retType));
REQUIRE_EQ(toString(fType->retType), toString(applyType->retType));
CHECK_EQ("<a, b..., c...>((b...) -> (c...), (a) -> (b...), a) -> (c...)", toString(requireType("apply")));
}
TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned")
@ -328,7 +311,10 @@ local c: Packed<string, number, boolean>
auto ttvA = get<TableTypeVar>(requireType("a"));
REQUIRE(ttvA);
CHECK_EQ(toString(requireType("a")), "Packed<number>");
CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}");
if (FFlag::LuauLowerBoundsCalculation)
CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}");
else
CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}");
REQUIRE(ttvA->instantiatedTypeParams.size() == 1);
REQUIRE(ttvA->instantiatedTypePackParams.size() == 1);
CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number");

View File

@ -6,6 +6,7 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
LUAU_FASTFLAG(LuauEqConstraint)
using namespace Luau;
@ -254,11 +255,11 @@ local c = bf.a.y
TEST_CASE_FIXTURE(Fixture, "optional_union_functions")
{
CheckResult result = check(R"(
local a = {}
function a.foo(x:number, y:number) return x + y end
type A = typeof(a)
local b: A? = a
local c = b.foo(1, 2)
local a = {}
function a.foo(x:number, y:number) return x + y end
type A = typeof(a)
local b: A? = a
local c = b.foo(1, 2)
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
@ -356,7 +357,10 @@ a.x = 2
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0]));
if (FFlag::LuauLowerBoundsCalculation)
CHECK_EQ("Value of type '{| x: number, y: number |}?' could be nil", toString(result.errors[0]));
else
CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "optional_length_error")
@ -533,8 +537,13 @@ TEST_CASE_FIXTURE(Fixture, "table_union_write_indirect")
LUAU_REQUIRE_ERROR_COUNT(1, result);
// NOTE: union normalization will improve this message
CHECK_EQ(toString(result.errors[0]),
R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)");
if (FFlag::LuauLowerBoundsCalculation)
CHECK_EQ(toString(result.errors[0]), "Type '(string) -> number' could not be converted into '(number) -> string'\n"
"caused by:\n"
" Argument #1 type is not compatible. Type 'number' could not be converted into 'string'");
else
CHECK_EQ(toString(result.errors[0]),
R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)");
}

View File

@ -581,4 +581,19 @@ do
assert(#arr == 5)
end
-- test boundary invariant maintenance when table is filled using SETLIST opcode
do
local arr = {[2]=2,1}
assert(#arr == 2)
end
-- test boundary invariant maintenance when table is filled using table.move
do
local t1 = {1, 2, 3, 4, 5}
local t2 = {[6] = 6}
table.move(t1, 1, 5, 1, t2)
assert(#t2 == 6)
end
return"OK"