From 8e7845076b240cd10ea73680d48cad9525b31c20 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 14 Apr 2022 16:57:43 -0700 Subject: [PATCH] Sync to upstream/release/523 (#459) --- Analysis/include/Luau/Clone.h | 9 +- Analysis/include/Luau/Error.h | 10 +- Analysis/include/Luau/Frontend.h | 1 + Analysis/include/Luau/LValue.h | 4 + Analysis/include/Luau/Module.h | 2 +- Analysis/include/Luau/Normalize.h | 19 + Analysis/include/Luau/RecursionCounter.h | 26 +- Analysis/include/Luau/Substitution.h | 1 + Analysis/include/Luau/ToString.h | 3 + Analysis/include/Luau/TxnLog.h | 32 +- Analysis/include/Luau/TypeInfer.h | 24 +- Analysis/include/Luau/TypePack.h | 14 +- Analysis/include/Luau/TypeVar.h | 31 +- Analysis/include/Luau/Unifiable.h | 17 +- Analysis/include/Luau/Unifier.h | 22 +- Analysis/include/Luau/UnifierSharedState.h | 2 + Analysis/include/Luau/VisitTypeVar.h | 9 + Analysis/src/Autocomplete.cpp | 8 +- Analysis/src/Clone.cpp | 116 ++- Analysis/src/Error.cpp | 28 +- Analysis/src/Frontend.cpp | 41 +- Analysis/src/IostreamHelpers.cpp | 2 + Analysis/src/LValue.cpp | 17 + Analysis/src/Linter.cpp | 93 +- Analysis/src/Module.cpp | 39 +- Analysis/src/Normalize.cpp | 814 +++++++++++++++++ Analysis/src/Quantify.cpp | 36 + Analysis/src/Substitution.cpp | 138 ++- Analysis/src/ToDot.cpp | 31 + Analysis/src/ToString.cpp | 165 +++- Analysis/src/TopoSortStatements.cpp | 1 + Analysis/src/TxnLog.cpp | 44 +- Analysis/src/TypeAttach.cpp | 13 + Analysis/src/TypeInfer.cpp | 488 +++++++++-- Analysis/src/TypePack.cpp | 46 +- Analysis/src/TypeVar.cpp | 28 +- Analysis/src/Unifier.cpp | 337 +++++-- Ast/include/Luau/DenseHash.h | 118 ++- Ast/include/Luau/Lexer.h | 2 +- Ast/src/Lexer.cpp | 10 +- Ast/src/Parser.cpp | 5 +- Compiler/src/Compiler.cpp | 4 +- Compiler/src/CostModel.cpp | 258 ++++++ Compiler/src/CostModel.h | 18 + Sources.cmake | 6 + VM/src/ltable.cpp | 47 +- VM/src/ltablib.cpp | 2 +- VM/src/lvmexecute.cpp | 4 +- tests/CostModel.test.cpp | 101 +++ tests/Fixture.cpp | 4 +- tests/JsonEncoder.test.cpp | 4 +- tests/Linter.test.cpp | 3 +- tests/Module.test.cpp | 75 +- tests/NonstrictMode.test.cpp | 34 + tests/Normalize.test.cpp | 967 +++++++++++++++++++++ tests/Parser.test.cpp | 20 + tests/ToDot.test.cpp | 77 +- tests/ToString.test.cpp | 2 + tests/TopoSort.test.cpp | 32 +- tests/Transpiler.test.cpp | 2 +- tests/TypeInfer.annotations.test.cpp | 10 + tests/TypeInfer.builtins.test.cpp | 13 +- tests/TypeInfer.classes.test.cpp | 3 + tests/TypeInfer.functions.test.cpp | 177 +++- tests/TypeInfer.generics.test.cpp | 77 +- tests/TypeInfer.intersectionTypes.test.cpp | 46 +- tests/TypeInfer.oop.test.cpp | 16 +- tests/TypeInfer.operators.test.cpp | 3 +- tests/TypeInfer.provisional.test.cpp | 135 ++- tests/TypeInfer.refinements.test.cpp | 12 +- tests/TypeInfer.singletons.test.cpp | 11 +- tests/TypeInfer.tables.test.cpp | 61 +- tests/TypeInfer.test.cpp | 66 +- tests/TypeInfer.typePacks.cpp | 38 +- tests/TypeInfer.unionTypes.test.cpp | 25 +- tests/conformance/nextvar.lua | 15 + 76 files changed, 4575 insertions(+), 639 deletions(-) create mode 100644 Analysis/include/Luau/Normalize.h create mode 100644 Analysis/src/Normalize.cpp create mode 100644 Compiler/src/CostModel.cpp create mode 100644 Compiler/src/CostModel.h create mode 100644 tests/CostModel.test.cpp create mode 100644 tests/Normalize.test.cpp diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 917ef801..78aa92c7 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -14,12 +14,15 @@ using SeenTypePacks = std::unordered_map; struct CloneState { + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + int recursionCount = 0; bool encounteredFreeType = false; }; -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); +TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); } // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 53b946a0..70683141 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -287,12 +287,20 @@ struct TypesAreUnrelated bool operator==(const TypesAreUnrelated& rhs) const; }; +struct NormalizationTooComplex +{ + bool operator==(const NormalizationTooComplex&) const + { + return true; + } +}; + using TypeErrorData = Variant; + MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, TypesAreUnrelated, NormalizationTooComplex>; struct TypeError { diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 2266f548..e24e433c 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -70,6 +70,7 @@ struct SourceNode std::vector> requireLocations; bool dirty = true; bool dirtyAutocomplete = true; + double autocompleteLimitsMult = 1.0; }; struct FrontendOptions diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 3d510d5f..afb71415 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -35,8 +35,12 @@ const LValue* baseof(const LValue& lvalue); std::optional tryGetLValue(const class AstExpr& expr); // Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. +// TODO: remove with FFlagLuauTypecheckOptPass std::pair> getFullName(const LValue& lvalue); +// Utility function: breaks down an LValue to get at the Symbol +Symbol getBaseSymbol(const LValue& lvalue); + template const T* get(const LValue& lvalue) { diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 9a32f614..0dd44188 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -113,7 +113,7 @@ struct Module // This helps us to force TypeVar ownership into a DAG rather than a DCG. // Returns true if there were any free types encountered in the public interface. This // indicates a bug in the type checker that we want to surface. - bool clonePublicInterface(); + bool clonePublicInterface(InternalErrorReporter& ice); }; } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h new file mode 100644 index 00000000..262b54b2 --- /dev/null +++ b/Analysis/include/Luau/Normalize.h @@ -0,0 +1,19 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Substitution.h" +#include "Luau/TypeVar.h" +#include "Luau/Module.h" + +namespace Luau +{ + +struct InternalErrorReporter; + +bool isSubtype(TypeId superTy, TypeId subTy, InternalErrorReporter& ice); + +std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, const ModulePtr& module, InternalErrorReporter& ice); + +} // namespace Luau diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index 89632cea..03ae2c83 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -4,10 +4,21 @@ #include "Luau/Common.h" #include +#include + +LUAU_FASTFLAG(LuauRecursionLimitException); namespace Luau { +struct RecursionLimitException : public std::exception +{ + const char* what() const noexcept + { + return "Internal recursion counter limit exceeded"; + } +}; + struct RecursionCounter { RecursionCounter(int* count) @@ -28,11 +39,22 @@ private: struct RecursionLimiter : RecursionCounter { - RecursionLimiter(int* count, int limit) + // TODO: remove ctx after LuauRecursionLimitException is removed + RecursionLimiter(int* count, int limit, const char* ctx) : RecursionCounter(count) { + LUAU_ASSERT(ctx); if (limit > 0 && *count > limit) - throw std::runtime_error("Internal recursion counter limit exceeded"); + { + if (FFlag::LuauRecursionLimitException) + throw RecursionLimitException(); + else + { + std::string m = "Internal recursion counter limit exceeded: "; + m += ctx; + throw std::runtime_error(m); + } + } } }; diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 9662d5b3..6f5931e1 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -90,6 +90,7 @@ struct Tarjan std::vector lowlink; int childCount = 0; + int childLimit = 0; // This should never be null; ensure you initialize it before calling // substitution methods. diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 49ee82fe..f4db5e35 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -28,6 +28,7 @@ struct ToStringOptions bool functionTypeArguments = false; // If true, output function type argument names when they are available bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. + bool indent = false; size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; @@ -73,6 +74,8 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp std::string dump(TypeId ty); std::string dump(TypePackId ty); +std::string dump(const std::shared_ptr& scope, const char* name); + std::string generateName(size_t n); } // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index c8ebaaeb..995ed6c6 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -7,7 +7,7 @@ #include "Luau/TypeVar.h" #include "Luau/TypePack.h" -LUAU_FASTFLAG(LuauShareTxnSeen); +LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -64,13 +64,17 @@ T* getMutable(PendingTypePack* pending) struct TxnLog { TxnLog() - : ownedSeen() + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , ownedSeen() , sharedSeen(&ownedSeen) { } explicit TxnLog(TxnLog* parent) - : parent(parent) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , parent(parent) { if (parent) { @@ -83,14 +87,19 @@ struct TxnLog } explicit TxnLog(std::vector>* sharedSeen) - : sharedSeen(sharedSeen) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , sharedSeen(sharedSeen) { } TxnLog(TxnLog* parent, std::vector>* sharedSeen) - : parent(parent) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , parent(parent) , sharedSeen(sharedSeen) { + LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); } TxnLog(const TxnLog&) = delete; @@ -243,6 +252,12 @@ struct TxnLog return Luau::getMutable(ty); } + template + const T* get(TID ty) const + { + return this->getMutable(ty); + } + // Returns whether a given type or type pack is a given state, respecting the // log's pending state. // @@ -263,11 +278,8 @@ private: // unique_ptr is used to give us stable pointers across insertions into the // map. Otherwise, it would be really easy to accidentally invalidate the // pointers returned from queue/pending. - // - // We can't use a DenseHashMap here because we need a non-const iterator - // over the map when we concatenate. - std::unordered_map, DenseHashPointer> typeVarChanges; - std::unordered_map, DenseHashPointer> typePackChanges; + DenseHashMap> typeVarChanges; + DenseHashMap> typePackChanges; TxnLog* parent = nullptr; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 215da67f..ac880135 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -76,19 +76,32 @@ struct Instantiation : Substitution // A substitution which replaces free types by any struct Anyification : Substitution { - Anyification(TypeArena* arena, TypeId anyType, TypePackId anyTypePack) + Anyification(TypeArena* arena, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) : Substitution(TxnLog::empty(), arena) + , iceHandler(iceHandler) , anyType(anyType) , anyTypePack(anyTypePack) { } + InternalErrorReporter* iceHandler; + TypeId anyType; TypePackId anyTypePack; + bool normalizationTooComplex = false; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; TypeId clean(TypeId ty) override; TypePackId clean(TypePackId tp) override; + + bool ignoreChildren(TypeId ty) override + { + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } }; // A substitution which replaces the type parameters of a type function by arguments @@ -139,6 +152,7 @@ struct TypeChecker TypeChecker& operator=(const TypeChecker&) = delete; ModulePtr check(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); + ModulePtr checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); std::vector> getScopes() const; @@ -160,6 +174,7 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); + void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); ExprResult checkExpr( @@ -172,6 +187,7 @@ struct TypeChecker ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); + ExprResult checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); @@ -258,6 +274,8 @@ struct TypeChecker ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location); + std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); @@ -395,6 +413,7 @@ private: void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; + bool useConstrainedIntersections() const; public: /** Extract the types in a type pack, given the assumption that the pack must have some exact length. @@ -421,7 +440,10 @@ public: std::vector requireCycles; + // Type inference limits std::optional finishTime; + std::optional instantiationChildLimit; + std::optional unifierIterationLimit; public: const TypeId nilType; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 85fa467f..bbc65f94 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -40,6 +40,7 @@ struct TypePack struct VariadicTypePack { TypeId ty; + bool hidden = false; // if true, we don't display this when toString()ing a pack with this variadic as its tail. }; struct TypePackVar @@ -109,10 +110,10 @@ private: }; TypePackIterator begin(TypePackId tp); -TypePackIterator begin(TypePackId tp, TxnLog* log); +TypePackIterator begin(TypePackId tp, const TxnLog* log); TypePackIterator end(TypePackId tp); -using SeenSet = std::set>; +using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); @@ -122,7 +123,7 @@ TypePackId follow(TypePackId tp, std::function mapper); size_t size(TypePackId tp, TxnLog* log = nullptr); bool finite(TypePackId tp, TxnLog* log = nullptr); size_t size(const TypePack& tp, TxnLog* log = nullptr); -std::optional first(TypePackId tp); +std::optional first(TypePackId tp, bool ignoreHiddenVariadics = true); TypePackVar* asMutable(TypePackId tp); TypePack* asMutable(const TypePack* tp); @@ -154,5 +155,12 @@ bool isEmpty(TypePackId tp); /// Flattens out a type pack. Also returns a valid TypePackId tail if the type pack's full size is not known std::pair, std::optional> flatten(TypePackId tp); +std::pair, std::optional> flatten(TypePackId tp, const TxnLog& log); + +/// Returs true if the type pack arose from a function that is declared to be variadic. +/// Returns *false* for function argument packs that are inferred to be safe to oversaturate! +bool isVariadic(TypePackId tp); +bool isVariadic(TypePackId tp, const TxnLog& log); + } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index f61e4044..ae7d1377 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -109,6 +109,23 @@ struct PrimitiveTypeVar } }; +struct ConstrainedTypeVar +{ + explicit ConstrainedTypeVar(TypeLevel level) + : level(level) + { + } + + explicit ConstrainedTypeVar(TypeLevel level, const std::vector& parts) + : parts(parts) + , level(level) + { + } + + std::vector parts; + TypeLevel level; +}; + // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Types for true and false struct BooleanSingleton @@ -248,6 +265,7 @@ struct FunctionTypeVar MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. bool hasSelf; Tags tags; + bool hasNoGenerics = false; }; enum class TableState @@ -418,8 +436,8 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -436,6 +454,7 @@ struct TypeVar final TypeVar(const TypeVariant& ty, bool persistent) : ty(ty) , persistent(persistent) + , normal(persistent) // We assume that all persistent types are irreducable. { } @@ -446,6 +465,10 @@ struct TypeVar final // Persistent TypeVars do not get cloned. bool persistent = false; + // Normalization sets this for types that are fully normalized. + // This implies that they are transitively immutable. + bool normal = false; + std::optional documentationSymbol; // Pointer to the type arena that allocated this type. @@ -458,7 +481,7 @@ struct TypeVar final TypeVar& operator=(TypeVariant&& rhs); }; -using SeenSet = std::set>; +using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); // Follow BoundTypeVars until we get to something real @@ -545,6 +568,8 @@ void persist(TypePackId tp); const TypeLevel* getLevel(TypeId ty); TypeLevel* getMutableLevel(TypeId ty); +std::optional getLevel(TypePackId tp); + const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index e8eafe68..64fa131d 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -56,6 +56,14 @@ struct TypeLevel } }; +inline TypeLevel max(const TypeLevel& a, const TypeLevel& b) +{ + if (a.subsumes(b)) + return b; + else + return a; +} + inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) { if (a.subsumes(b)) @@ -64,7 +72,9 @@ inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) return b; } -namespace Unifiable +} // namespace Luau + +namespace Luau::Unifiable { using Name = std::string; @@ -125,7 +135,6 @@ private: }; template -using Variant = Variant, Generic, Error, Value...>; +using Variant = Luau::Variant, Generic, Error, Value...>; -} // namespace Unifiable -} // namespace Luau +} // namespace Luau::Unifiable diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 474af50c..340feb7f 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -49,14 +49,14 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; + bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. CountMismatch::Context ctx = CountMismatch::Arg; UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, - TxnLog* parentLog = nullptr); - Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, + UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); @@ -106,7 +106,12 @@ private: std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); + void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy); + void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy); + public: + void unifyLowerBound(TypePackId subTy, TypePackId superTy); + // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); void occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); @@ -115,12 +120,7 @@ public: Unifier makeChildUnifier(); - // A utility function that appends the given error to the unifier's error log. - // This allows setting a breakpoint wherever the unifier reports an error. - void reportError(TypeError error) - { - errors.push_back(error); - } + void reportError(TypeError err); private: bool isNonstrictMode() const; @@ -135,4 +135,6 @@ private: std::optional firstPackErrorPos; }; +void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, TypePackId tp); + } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index 9a3ba56d..1a0b8b76 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -28,7 +28,9 @@ struct TypeIdPairHash struct UnifierCounters { int recursionCount = 0; + int recursionLimit = 0; int iterationCount = 0; + int iterationLimit = 0; }; struct UnifierSharedState diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 740854b3..d11cbd0d 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -82,6 +82,15 @@ void visit(TypeId ty, F& f, Set& seen) else if (auto etv = get(ty)) apply(ty, *etv, seen, f); + else if (auto ctv = get(ty)) + { + if (apply(ty, *ctv, seen, f)) + { + for (TypeId part : ctv->parts) + visit(part, f, seen); + } + } + else if (auto ptv = get(ty)) apply(ty, *ptv, seen, f); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index b7201ab3..e0e79cb4 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -151,8 +151,12 @@ static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionTyp auto idxExpr = nodes.back()->as(); bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; - auto args = Luau::flatten(func->argTypes); - bool noArgFunction = (args.first.empty() || (hasImplicitSelf && args.first.size() == 1)) && !args.second.has_value(); + auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); + + if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) + return ParenthesesRecommendation::CursorInside; + + bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; } diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index ac9705a7..8e7f7c07 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -6,7 +6,10 @@ #include "Luau/TypePack.h" #include "Luau/Unifiable.h" +LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) + LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -23,11 +26,11 @@ struct TypePackCloner; struct TypeCloner { - TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) + TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState) : dest(dest) , typeId(typeId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) + , seenTypes(cloneState.seenTypes) + , seenTypePacks(cloneState.seenTypePacks) , cloneState(cloneState) { } @@ -46,6 +49,7 @@ struct TypeCloner void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); void operator()(const PrimitiveTypeVar& t); + void operator()(const ConstrainedTypeVar& t); void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); @@ -65,11 +69,11 @@ struct TypePackCloner SeenTypePacks& seenTypePacks; CloneState& cloneState; - TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) + TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState) : dest(dest) , typePackId(typePackId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) + , seenTypes(cloneState.seenTypes) + , seenTypePacks(cloneState.seenTypePacks) , cloneState(cloneState) { } @@ -103,13 +107,15 @@ struct TypePackCloner // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. void operator()(const Unifiable::Bound& t) { - TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + TypePackId cloned = clone(t.boundTo, dest, cloneState); + if (FFlag::DebugLuauCopyBeforeNormalizing) + cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}}); seenTypePacks[typePackId] = cloned; } void operator()(const VariadicTypePack& t) { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}}); seenTypePacks[typePackId] = cloned; } @@ -121,10 +127,10 @@ struct TypePackCloner seenTypePacks[typePackId] = cloned; for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + destTp->head.push_back(clone(ty, dest, cloneState)); if (t.tail) - destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); + destTp->tail = clone(*t.tail, dest, cloneState); } }; @@ -150,7 +156,9 @@ void TypeCloner::operator()(const Unifiable::Generic& t) void TypeCloner::operator()(const Unifiable::Bound& t) { - TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + TypeId boundTo = clone(t.boundTo, dest, cloneState); + if (FFlag::DebugLuauCopyBeforeNormalizing) + boundTo = dest.addType(BoundTypeVar{boundTo}); seenTypes[typeId] = boundTo; } @@ -164,6 +172,23 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) defaultClone(t); } +void TypeCloner::operator()(const ConstrainedTypeVar& t) +{ + cloneState.encounteredFreeType = true; + + TypeId res = dest.addType(ConstrainedTypeVar{t.level}); + ConstrainedTypeVar* ctv = getMutable(res); + LUAU_ASSERT(ctv); + + seenTypes[typeId] = res; + + std::vector parts; + for (TypeId part : t.parts) + parts.push_back(clone(part, dest, cloneState)); + + ctv->parts = std::move(parts); +} + void TypeCloner::operator()(const SingletonTypeVar& t) { defaultClone(t); @@ -178,23 +203,26 @@ void TypeCloner::operator()(const FunctionTypeVar& t) seenTypes[typeId] = result; for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); + ftv->generics.push_back(clone(generic, dest, cloneState)); for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); + ftv->genericPacks.push_back(clone(genericPack, dest, cloneState)); ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); + ftv->argTypes = clone(t.argTypes, dest, cloneState); ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); + ftv->retType = clone(t.retType, dest, cloneState); + + if (FFlag::LuauTypecheckOptPass) + ftv->hasNoGenerics = t.hasNoGenerics; } void TypeCloner::operator()(const TableTypeVar& t) { // If table is now bound to another one, we ignore the content of the original - if (t.boundTo) + if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) { - TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + TypeId boundTo = clone(*t.boundTo, dest, cloneState); seenTypes[typeId] = boundTo; return; } @@ -209,18 +237,20 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->level = TypeLevel{0, 0}; + if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) + ttv->boundTo = clone(*t.boundTo, dest, cloneState); + for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), - clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); + arg = clone(arg, dest, cloneState); for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); + arg = clone(arg, dest, cloneState); if (ttv->state == TableState::Free) { @@ -240,8 +270,8 @@ void TypeCloner::operator()(const MetatableTypeVar& t) MetatableTypeVar* mtv = getMutable(result); seenTypes[typeId] = result; - mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); - mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); + mtv->table = clone(t.table, dest, cloneState); + mtv->metatable = clone(t.metatable, dest, cloneState); } void TypeCloner::operator()(const ClassTypeVar& t) @@ -252,13 +282,13 @@ void TypeCloner::operator()(const ClassTypeVar& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.parent) - ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); + ctv->parent = clone(*t.parent, dest, cloneState); if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); + ctv->metatable = clone(*t.metatable, dest, cloneState); } void TypeCloner::operator()(const AnyTypeVar& t) @@ -272,7 +302,7 @@ void TypeCloner::operator()(const UnionTypeVar& t) options.reserve(t.options.size()); for (TypeId ty : t.options) - options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + options.push_back(clone(ty, dest, cloneState)); TypeId result = dest.addType(UnionTypeVar{std::move(options)}); seenTypes[typeId] = result; @@ -287,7 +317,7 @@ void TypeCloner::operator()(const IntersectionTypeVar& t) LUAU_ASSERT(option != nullptr); for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + option->parts.push_back(clone(ty, dest, cloneState)); } void TypeCloner::operator()(const LazyTypeVar& t) @@ -297,36 +327,36 @@ void TypeCloner::operator()(const LazyTypeVar& t) } // anonymous namespace -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) { if (tp->persistent) return tp; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypePackId"); - TypePackId& res = seenTypePacks[tp]; + TypePackId& res = cloneState.seenTypePacks[tp]; if (res == nullptr) { - TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; + TypePackCloner cloner{dest, tp, cloneState}; Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. } return res; } -TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) { if (typeId->persistent) return typeId; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypeId"); - TypeId& res = seenTypes[typeId]; + TypeId& res = cloneState.seenTypes[typeId]; if (res == nullptr) { - TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; + TypeCloner cloner{dest, typeId, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. // Persistent types are not being cloned and we get the original type back which might be read-only @@ -337,33 +367,33 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks return res; } -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) { TypeFun result; for (auto param : typeFun.typeParams) { - TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); + TypeId ty = clone(param.ty, dest, cloneState); std::optional defaultValue; if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + defaultValue = clone(*param.defaultValue, dest, cloneState); result.typeParams.push_back({ty, defaultValue}); } for (auto param : typeFun.typePackParams) { - TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); + TypePackId tp = clone(param.tp, dest, cloneState); std::optional defaultValue; if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + defaultValue = clone(*param.defaultValue, dest, cloneState); result.typePackParams.push_back({tp, defaultValue}); } - result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); + result.type = clone(typeFun.type, dest, cloneState); return result; } diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 5eb2ea2a..cbec0b15 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -8,7 +8,6 @@ #include -LUAU_FASTFLAGVARIABLE(BetterDiagnosticCodesInStudio, false); LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleName, false); static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) @@ -252,14 +251,7 @@ struct ErrorConverter std::string operator()(const Luau::SyntaxError& e) const { - if (FFlag::BetterDiagnosticCodesInStudio) - { - return e.message; - } - else - { - return "Syntax error: " + e.message; - } + return e.message; } std::string operator()(const Luau::CodeTooComplex&) const @@ -451,6 +443,11 @@ struct ErrorConverter { return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated"; } + + std::string operator()(const NormalizationTooComplex&) const + { + return "Code is too complex to typecheck! Consider simplifying the code around this area"; + } }; struct InvalidNameChecker @@ -716,14 +713,14 @@ bool containsParseErrorName(const TypeError& error) } template -void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState cloneState) +void copyError(T& e, TypeArena& destArena, CloneState cloneState) { auto clone = [&](auto&& ty) { - return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks, cloneState); + return ::Luau::clone(ty, destArena, cloneState); }; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks, cloneState); + copyError(e, destArena, cloneState); }; if constexpr (false) @@ -844,18 +841,19 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& e.left = clone(e.left); e.right = clone(e.right); } + else if constexpr (std::is_same_v) + { + } else static_assert(always_false_v, "Non-exhaustive type switch"); } void copyErrors(ErrorVec& errors, TypeArena& destArena) { - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks, cloneState); + copyError(e, destArena, cloneState); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 000769fe..8b0b2210 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -11,16 +11,18 @@ #include "Luau/TimeTrace.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" -#include "Luau/Common.h" #include #include #include +LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauCyclicModuleTypeSurface) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 0) namespace Luau @@ -97,13 +99,11 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, checkedModule}; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; @@ -113,7 +113,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; @@ -440,13 +440,42 @@ CheckResult Frontend::check(const ModuleName& name, std::optional 0) + typeCheckerForAutocomplete.instantiationChildLimit = + std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckerForAutocomplete.unifierIterationLimit = + std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; + } + ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; + double duration = getTimestamp() - timestamp; + if (moduleForAutocomplete->timeout) + { checkResult.timeoutHits.push_back(moduleName); - stats.timeCheck += getTimestamp() - timestamp; + if (FFlag::LuauAutocompleteDynamicLimits) + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + } + else if (FFlag::LuauAutocompleteDynamicLimits && duration < autocompleteTimeLimit / 2.0) + { + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); + } + + stats.timeCheck += duration; stats.filesStrict += 1; sourceNode.dirtyAutocomplete = false; diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index a8f67589..0eaa485e 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -184,6 +184,8 @@ static void errorToString(std::ostream& stream, const T& err) } else if constexpr (std::is_same_v) stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; + else if constexpr (std::is_same_v) + stream << "NormalizationTooComplex { }"; else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/LValue.cpp b/Analysis/src/LValue.cpp index c9466a40..72555ab4 100644 --- a/Analysis/src/LValue.cpp +++ b/Analysis/src/LValue.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAG(LuauTypecheckOptPass) + namespace Luau { @@ -79,6 +81,8 @@ std::optional tryGetLValue(const AstExpr& node) std::pair> getFullName(const LValue& lvalue) { + LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); + const LValue* current = &lvalue; std::vector keys; while (auto field = get(*current)) @@ -92,6 +96,19 @@ std::pair> getFullName(const LValue& lvalue) return {*symbol, std::vector(keys.rbegin(), keys.rend())}; } +Symbol getBaseSymbol(const LValue& lvalue) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOptPass); + + const LValue* current = &lvalue; + while (auto field = get(*current)) + current = baseof(*current); + + const Symbol* symbol = get(*current); + LUAU_ASSERT(symbol); + return *symbol; +} + void merge(RefinementMap& l, const RefinementMap& r, std::function f) { for (const auto& [k, a] : r) diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index b7480e34..5608e4b3 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,7 +14,6 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) -LUAU_FASTFLAGVARIABLE(LuauLintNoRobloxBits, false) namespace Luau { @@ -1140,25 +1139,8 @@ private: Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata. Kind_Vector, // 'vector' but only used when type is used Kind_Userdata, // custom userdata type - - // TODO: remove these with LuauLintNoRobloxBits - Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc. - Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc. }; - bool containsPropName(TypeId ty, const std::string& propName) - { - LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); - - if (auto ctv = get(ty)) - return lookupClassProp(ctv, propName) != nullptr; - - if (auto ttv = get(ty)) - return ttv->props.find(propName) != ttv->props.end(); - - return false; - } - TypeKind getTypeKind(const std::string& name) { if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" || @@ -1168,23 +1150,10 @@ private: if (name == "vector") return Kind_Vector; - if (FFlag::LuauLintNoRobloxBits) - { - if (std::optional maybeTy = context->scope->lookupType(name)) - return Kind_Userdata; + if (std::optional maybeTy = context->scope->lookupType(name)) + return Kind_Userdata; - return Kind_Unknown; - } - else - { - if (std::optional maybeTy = context->scope->lookupType(name)) - // Kind_Userdata is probably not 100% precise but is close enough - return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata; - else if (std::optional maybeTy = context->scope->lookupImportedType("Enum", name)) - return Kind_Enum; - - return Kind_Unknown; - } + return Kind_Unknown; } void validateType(AstExprConstantString* expr, std::initializer_list expected, const char* expectedString) @@ -1202,67 +1171,11 @@ private: { if (kind == ek) return; - - // as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type - if (!FFlag::LuauLintNoRobloxBits && ek == Kind_Userdata && (name == "Instance" || name == "EnumItem")) - return; } emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s' (expected %s)", name.c_str(), expectedString); } - bool acceptsClassName(AstName method) - { - LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); - - return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" || - method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA"); - } - - bool visit(AstExprCall* node) override - { - // TODO: Simply remove the override - if (FFlag::LuauLintNoRobloxBits) - return true; - - if (AstExprIndexName* index = node->func->as()) - { - AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as() : NULL; - - if (arg0) - { - if (node->self && index->index == "IsA" && node->args.size == 1) - { - validateType(arg0, {Kind_Class, Kind_Enum}, "class or enum type"); - } - else if (node->self && (index->index == "GetService" || index->index == "FindService") && node->args.size == 1) - { - AstExprGlobal* g = index->expr->as(); - - if (g && (g->name == "game" || g->name == "Game")) - { - validateType(arg0, {Kind_Class}, "class type"); - } - } - else if (node->self && acceptsClassName(index->index) && node->args.size == 1) - { - validateType(arg0, {Kind_Class}, "class type"); - } - else if (!node->self && index->index == "new" && node->args.size <= 2) - { - AstExprGlobal* g = index->expr->as(); - - if (g && g->name == "Instance") - { - validateType(arg0, {Kind_Class}, "class type"); - } - } - } - } - - return true; - } - bool visit(AstExprBinary* node) override { if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 6bb45245..e2e3b436 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,8 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" -#include "Luau/Common.h" #include "Luau/Clone.h" +#include "Luau/Common.h" +#include "Luau/Normalize.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -14,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) namespace Luau { @@ -143,32 +145,51 @@ Module::~Module() unfreeze(internalTypes); } -bool Module::clonePublicInterface() +bool Module::clonePublicInterface(InternalErrorReporter& ice) { LUAU_ASSERT(interfaceTypes.typeVars.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; ScopePtr moduleScope = getModuleScope(); - moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, cloneState); + moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, cloneState); if (moduleScope->varargPack) - moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); + moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, cloneState); + + if (FFlag::LuauLowerBoundsCalculation) + { + normalize(moduleScope->returnType, interfaceTypes, ice); + if (moduleScope->varargPack) + normalize(*moduleScope->varargPack, interfaceTypes, ice); + } for (auto& [name, tf] : moduleScope->exportedTypeBindings) - tf = clone(tf, interfaceTypes, seenTypes, seenTypePacks, cloneState); + { + tf = clone(tf, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) + normalize(tf.type, interfaceTypes, ice); + } for (TypeId ty : moduleScope->returnType) + { if (get(follow(ty))) - *asMutable(ty) = AnyTypeVar{}; + { + auto t = asMutable(ty); + t->ty = AnyTypeVar{}; + t->normal = true; + } + } if (FFlag::LuauCloneDeclaredGlobals) { for (auto& [name, ty] : declaredGlobals) - ty = clone(ty, interfaceTypes, seenTypes, seenTypePacks, cloneState); + { + ty = clone(ty, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) + normalize(ty, interfaceTypes, ice); + } } freeze(internalTypes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp new file mode 100644 index 00000000..40341ac1 --- /dev/null +++ b/Analysis/src/Normalize.cpp @@ -0,0 +1,814 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Normalize.h" + +#include + +#include "Luau/Clone.h" +#include "Luau/DenseHash.h" +#include "Luau/Substitution.h" +#include "Luau/Unifier.h" +#include "Luau/VisitTypeVar.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) + +// This could theoretically be 2000 on amd64, but x86 requires this. +LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineIntersectionFix, false); + +namespace Luau +{ + +namespace +{ + +struct Replacer : Substitution +{ + TypeId sourceType; + TypeId replacedType; + DenseHashMap replacedTypes{nullptr}; + DenseHashMap replacedPacks{nullptr}; + + Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) + : Substitution(TxnLog::empty(), arena) + , sourceType(sourceType) + , replacedType(replacedType) + { + } + + bool isDirty(TypeId ty) override + { + if (!sourceType) + return false; + + auto vecHasSourceType = [sourceType = sourceType](const auto& vec) { + return end(vec) != std::find(begin(vec), end(vec), sourceType); + }; + + // Walk every kind of TypeVar and find pointers to sourceType + if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return vecHasSourceType(t->parts); + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + { + if (vecHasSourceType(t->generics)) + return true; + + return false; + } + else if (auto t = get(ty)) + { + if (t->boundTo) + return *t->boundTo == sourceType; + + for (const auto& [_name, prop] : t->props) + { + if (prop.type == sourceType) + return true; + } + + if (auto indexer = t->indexer) + { + if (indexer->indexType == sourceType || indexer->indexResultType == sourceType) + return true; + } + + if (vecHasSourceType(t->instantiatedTypeParams)) + return true; + + return false; + } + else if (auto t = get(ty)) + return t->table == sourceType || t->metatable == sourceType; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return vecHasSourceType(t->options); + else if (auto t = get(ty)) + return vecHasSourceType(t->parts); + else if (auto t = get(ty)) + return false; + + LUAU_ASSERT(!"Luau::Replacer::isDirty internal error: Unknown TypeVar type"); + LUAU_UNREACHABLE(); + } + + bool isDirty(TypePackId tp) override + { + if (auto it = replacedPacks.find(tp)) + return false; + + if (auto pack = get(tp)) + { + for (TypeId ty : pack->head) + { + if (ty == sourceType) + return true; + } + return false; + } + else if (auto vtp = get(tp)) + return vtp->ty == sourceType; + else + return false; + } + + TypeId clean(TypeId ty) override + { + LUAU_ASSERT(sourceType && replacedType); + + // Walk every kind of TypeVar and create a copy with sourceType replaced by replacedType + // Before returning, memoize the result for later use. + + // Helpfully, Substitution::clone() only shallow-clones the kinds of types that we care to work with. This + // function returns the identity for things like primitives. + TypeId res = clone(ty); + + if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + for (TypeId& part : t->parts) + { + if (part == sourceType) + part = replacedType; + } + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + // The constituent typepacks are cleaned separately. We just need to walk the generics array. + for (TypeId& g : t->generics) + { + if (g == sourceType) + g = replacedType; + } + } + else if (auto t = getMutable(res)) + { + for (auto& [_key, prop] : t->props) + { + if (prop.type == sourceType) + prop.type = replacedType; + } + } + else if (auto t = getMutable(res)) + { + if (t->table == sourceType) + t->table = replacedType; + if (t->metatable == sourceType) + t->table = replacedType; + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + for (TypeId& option : t->options) + { + if (option == sourceType) + option = replacedType; + } + } + else if (auto t = getMutable(res)) + { + for (TypeId& part : t->parts) + { + if (part == sourceType) + part = replacedType; + } + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else + LUAU_ASSERT(!"Luau::Replacer::clean internal error: Unknown TypeVar type"); + + replacedTypes[ty] = res; + return res; + } + + TypePackId clean(TypePackId tp) override + { + TypePackId res = clone(tp); + + if (auto pack = getMutable(res)) + { + for (TypeId& type : pack->head) + { + if (type == sourceType) + type = replacedType; + } + } + else if (auto vtp = getMutable(res)) + { + if (vtp->ty == sourceType) + vtp->ty = replacedType; + } + + replacedPacks[tp] = res; + return res; + } + + TypeId smartClone(TypeId t) + { + std::optional res = replace(t); + LUAU_ASSERT(res.has_value()); // TODO think about this + if (*res == t) + return clone(t); + return *res; + } +}; + +} // anonymous namespace + +bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) +{ + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(subTy, superTy); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; +} + +template +static bool areNormal_(const T& t, const DenseHashSet& seen, InternalErrorReporter& ice) +{ + int count = 0; + auto isNormal = [&](TypeId ty) { + ++count; + if (count >= FInt::LuauNormalizeIterationLimit) + ice.ice("Luau::areNormal hit iteration limit"); + + return ty->normal || seen.find(asMutable(ty)); + }; + + return std::all_of(begin(t), end(t), isNormal); +} + +static bool areNormal(const std::vector& types, const DenseHashSet& seen, InternalErrorReporter& ice) +{ + return areNormal_(types, seen, ice); +} + +static bool areNormal(TypePackId tp, const DenseHashSet& seen, InternalErrorReporter& ice) +{ + tp = follow(tp); + if (get(tp)) + return false; + + auto [head, tail] = flatten(tp); + + if (!areNormal_(head, seen, ice)) + return false; + + if (!tail) + return true; + + if (auto vtp = get(*tail)) + return vtp->ty->normal || seen.find(asMutable(vtp->ty)); + + return true; +} + +#define CHECK_ITERATION_LIMIT(...) \ + do \ + { \ + if (iterationLimit > FInt::LuauNormalizeIterationLimit) \ + { \ + limitExceeded = true; \ + return __VA_ARGS__; \ + } \ + ++iterationLimit; \ + } while (false) + +struct Normalize +{ + TypeArena& arena; + InternalErrorReporter& ice; + + // Debug data. Types being normalized are invalidated but trying to see what's going on is painful. + // To actually see the original type, read it by using the pointer of the type being normalized. + // e.g. in lldb, `e dump(originalTys[ty])`. + SeenTypes originalTys; + SeenTypePacks originalTps; + + int iterationLimit = 0; + bool limitExceeded = false; + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + template + void cycle(TID) + { + } + + bool operator()(TypeId ty, const FreeTypeVar&) + { + LUAU_ASSERT(!ty->normal); + return false; + } + + bool operator()(TypeId ty, const BoundTypeVar& btv) + { + // It should never be the case that this TypeVar is normal, but is bound to a non-normal type. + LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); + + asMutable(ty)->normal = btv.boundTo->normal; + return !ty->normal; + } + + bool operator()(TypeId ty, const PrimitiveTypeVar&) + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool operator()(TypeId ty, const GenericTypeVar&) + { + if (!ty->normal) + asMutable(ty)->normal = true; + + return false; + } + + bool operator()(TypeId ty, const ErrorTypeVar&) + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + ConstrainedTypeVar* ctv = const_cast(&ctvRef); + + std::vector parts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId part : parts) + visit_detail::visit(part, *this, seen); + + std::vector newParts = normalizeUnion(parts); + + const bool normal = areNormal(newParts, seen, ice); + + if (newParts.size() == 1) + *asMutable(ty) = BoundTypeVar{newParts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + + asMutable(ty)->normal = normal; + + return false; + } + + bool operator()(TypeId ty, const FunctionTypeVar& ftv) = delete; + bool operator()(TypeId ty, const FunctionTypeVar& ftv, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + visit_detail::visit(ftv.argTypes, *this, seen); + visit_detail::visit(ftv.retType, *this, seen); + + asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice); + + return false; + } + + bool operator()(TypeId ty, const TableTypeVar& ttv, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + bool normal = true; + + auto checkNormal = [&](TypeId t) { + // if t is on the stack, it is possible that this type is normal. + // If t is not normal and it is not on the stack, this type is definitely not normal. + if (!t->normal && !seen.find(asMutable(t))) + normal = false; + }; + + if (ttv.boundTo) + { + visit_detail::visit(*ttv.boundTo, *this, seen); + asMutable(ty)->normal = (*ttv.boundTo)->normal; + return false; + } + + for (const auto& [_name, prop] : ttv.props) + { + visit_detail::visit(prop.type, *this, seen); + checkNormal(prop.type); + } + + if (ttv.indexer) + { + visit_detail::visit(ttv.indexer->indexType, *this, seen); + checkNormal(ttv.indexer->indexType); + visit_detail::visit(ttv.indexer->indexResultType, *this, seen); + checkNormal(ttv.indexer->indexResultType); + } + + asMutable(ty)->normal = normal; + + return false; + } + + bool operator()(TypeId ty, const MetatableTypeVar& mtv, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + visit_detail::visit(mtv.table, *this, seen); + visit_detail::visit(mtv.metatable, *this, seen); + + asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; + + return false; + } + + bool operator()(TypeId ty, const ClassTypeVar& ctv) + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool operator()(TypeId ty, const AnyTypeVar&) + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool operator()(TypeId ty, const UnionTypeVar& utvRef, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + UnionTypeVar* utv = &const_cast(utvRef); + std::vector options = std::move(utv->options); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId option : options) + visit_detail::visit(option, *this, seen); + + std::vector newOptions = normalizeUnion(options); + + const bool normal = areNormal(newOptions, seen, ice); + + LUAU_ASSERT(!newOptions.empty()); + + if (newOptions.size() == 1) + *asMutable(ty) = BoundTypeVar{newOptions[0]}; + else + utv->options = std::move(newOptions); + + asMutable(ty)->normal = normal; + + return false; + } + + bool operator()(TypeId ty, const IntersectionTypeVar& itvRef, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + IntersectionTypeVar* itv = &const_cast(itvRef); + + std::vector oldParts = std::move(itv->parts); + + for (TypeId part : oldParts) + visit_detail::visit(part, *this, seen); + + std::vector tables; + for (TypeId part : oldParts) + { + part = follow(part); + if (get(part)) + tables.push_back(part); + else + { + Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD + combineIntoIntersection(replacer, itv, part); + } + } + + // Don't allocate a new table if there's just one in the intersection. + if (tables.size() == 1) + itv->parts.push_back(tables[0]); + else if (!tables.empty()) + { + const TableTypeVar* first = get(tables[0]); + LUAU_ASSERT(first); + + TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); + TableTypeVar* ttv = getMutable(newTable); + for (TypeId part : tables) + { + // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need + // to be rewritten to point at 'newTable' in the clone. + Replacer replacer{&arena, part, newTable}; + combineIntoTable(replacer, ttv, part); + } + + itv->parts.push_back(newTable); + } + + asMutable(ty)->normal = areNormal(itv->parts, seen, ice); + + if (itv->parts.size() == 1) + { + TypeId part = itv->parts[0]; + *asMutable(ty) = BoundTypeVar{part}; + } + + return false; + } + + bool operator()(TypeId ty, const LazyTypeVar&) + { + return false; + } + + std::vector normalizeUnion(const std::vector& options) + { + if (options.size() == 1) + return options; + + std::vector result; + + for (TypeId part : options) + combineIntoUnion(result, part); + + return result; + } + + void combineIntoUnion(std::vector& result, TypeId ty) + { + ty = follow(ty); + if (auto utv = get(ty)) + { + for (TypeId t : utv) + combineIntoUnion(result, t); + return; + } + + for (TypeId& part : result) + { + if (isSubtype(ty, part, ice)) + return; // no need to do anything + else if (isSubtype(part, ty, ice)) + { + part = ty; // replace the less general type by the more general one + return; + } + } + + result.push_back(ty); + } + + /** + * @param replacer knows how to clone a type such that any recursive references point at the new containing type. + * @param result is an intersection that is safe for us to mutate in-place. + */ + void combineIntoIntersection(Replacer& replacer, IntersectionTypeVar* result, TypeId ty) + { + // Note: this check guards against running out of stack space + // so if you increase the size of a stack frame, you'll need to decrease the limit. + CHECK_ITERATION_LIMIT(); + + ty = follow(ty); + if (auto itv = get(ty)) + { + for (TypeId part : itv->parts) + combineIntoIntersection(replacer, result, part); + return; + } + + // Let's say that the last part of our result intersection is always a table, if any table is part of this intersection + if (get(ty)) + { + if (result->parts.empty()) + result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); + + TypeId theTable = result->parts.back(); + + if (!get(FFlag::LuauNormalizeCombineIntersectionFix ? follow(theTable) : theTable)) + { + result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); + theTable = result->parts.back(); + } + + TypeId newTable = replacer.smartClone(theTable); + result->parts.back() = newTable; + + combineIntoTable(replacer, getMutable(newTable), ty); + } + else if (auto ftv = get(ty)) + { + bool merged = false; + for (TypeId& part : result->parts) + { + if (isSubtype(part, ty, ice)) + { + merged = true; + break; // no need to do anything + } + else if (isSubtype(ty, part, ice)) + { + merged = true; + part = ty; // replace the less general type by the more general one + break; + } + } + + if (!merged) + result->parts.push_back(ty); + } + else + result->parts.push_back(ty); + } + + TableState combineTableStates(TableState lhs, TableState rhs) + { + if (lhs == rhs) + return lhs; + + if (lhs == TableState::Free || rhs == TableState::Free) + return TableState::Free; + + if (lhs == TableState::Unsealed || rhs == TableState::Unsealed) + return TableState::Unsealed; + + return lhs; + } + + /** + * @param replacer gives us a way to clone a type such that recursive references are rewritten to the new + * "containing" type. + * @param table always points into a table that is safe for us to mutate. + */ + void combineIntoTable(Replacer& replacer, TableTypeVar* table, TypeId ty) + { + // Note: this check guards against running out of stack space + // so if you increase the size of a stack frame, you'll need to decrease the limit. + CHECK_ITERATION_LIMIT(); + + LUAU_ASSERT(table); + + ty = follow(ty); + + TableTypeVar* tyTable = getMutable(ty); + LUAU_ASSERT(tyTable); + + for (const auto& [propName, prop] : tyTable->props) + { + if (auto it = table->props.find(propName); it != table->props.end()) + { + /** + * If we are going to recursively merge intersections of tables, we need to ensure that we never mutate + * a table that comes from somewhere else in the type graph. + * + * smarClone() does some nice things for us: It will perform a clone that is as shallow as possible + * while still rewriting any cyclic references back to the new 'root' table. + * + * replacer also keeps a mapping of types that have previously been copied, so we have the added + * advantage here of knowing that, whether or not a new copy was actually made, the resulting TypeVar is + * safe for us to mutate in-place. + */ + TypeId clone = replacer.smartClone(it->second.type); + it->second.type = combine(replacer, clone, prop.type); + } + else + table->props.insert({propName, prop}); + } + + table->state = combineTableStates(table->state, tyTable->state); + table->level = max(table->level, tyTable->level); + } + + /** + * @param a is always cloned by the caller. It is safe to mutate in-place. + * @param b will never be mutated. + */ + TypeId combine(Replacer& replacer, TypeId a, TypeId b) + { + if (FFlag::LuauNormalizeCombineTableFix && a == b) + return a; + + if (!get(a) && !get(a)) + { + if (!FFlag::LuauNormalizeCombineTableFix && a == b) + return a; + else + return arena.addType(IntersectionTypeVar{{a, b}}); + } + + if (auto itv = getMutable(a)) + { + combineIntoIntersection(replacer, itv, b); + return a; + } + else if (auto ttv = getMutable(a)) + { + if (FFlag::LuauNormalizeCombineTableFix && !get(follow(b))) + return arena.addType(IntersectionTypeVar{{a, b}}); + combineIntoTable(replacer, ttv, b); + return a; + } + + LUAU_ASSERT(!"Impossible"); + LUAU_UNREACHABLE(); + } +}; + +#undef CHECK_ITERATION_LIMIT + +/** + * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) + */ +std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice) +{ + CloneState state; + if (FFlag::DebugLuauCopyBeforeNormalizing) + (void)clone(ty, arena, state); + + Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, n, seen); + + return {ty, !n.limitExceeded}; +} + +// TODO: Think about using a temporary arena and cloning types out of it so that we +// reclaim memory used by wantonly allocated intermediate types here. +// The main wrinkle here is that we don't want clone() to copy a type if the source and dest +// arena are the same. +std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice) +{ + return normalize(ty, module->internalTypes, ice); +} + +/** + * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) + */ +std::pair normalize(TypePackId tp, TypeArena& arena, InternalErrorReporter& ice) +{ + CloneState state; + if (FFlag::DebugLuauCopyBeforeNormalizing) + (void)clone(tp, arena, state); + + Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(tp, n, seen); + + return {tp, !n.limitExceeded}; +} + +std::pair normalize(TypePackId tp, const ModulePtr& module, InternalErrorReporter& ice) +{ + return normalize(tp, module->internalTypes, ice); +} + +} // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 94e169f1..305f83ce 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,6 +4,8 @@ #include "Luau/VisitTypeVar.h" +LUAU_FASTFLAG(LuauTypecheckOptPass) + namespace Luau { @@ -12,6 +14,8 @@ struct Quantifier TypeLevel level; std::vector generics; std::vector genericPacks; + bool seenGenericType = false; + bool seenMutableType = false; Quantifier(TypeLevel level) : level(level) @@ -23,6 +27,9 @@ struct Quantifier bool operator()(TypeId ty, const FreeTypeVar& ftv) { + if (FFlag::LuauTypecheckOptPass) + seenMutableType = true; + if (!level.subsumes(ftv.level)) return false; @@ -44,17 +51,40 @@ struct Quantifier return true; } + bool operator()(TypeId ty, const ConstrainedTypeVar&) + { + return true; + } + bool operator()(TypeId ty, const TableTypeVar&) { TableTypeVar& ttv = *getMutable(ty); + if (FFlag::LuauTypecheckOptPass) + { + if (ttv.state == TableState::Generic) + seenGenericType = true; + + if (ttv.state == TableState::Free) + seenMutableType = true; + } + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) return false; if (!level.subsumes(ttv.level)) + { + if (FFlag::LuauTypecheckOptPass && ttv.state == TableState::Unsealed) + seenMutableType = true; return false; + } if (ttv.state == TableState::Free) + { ttv.state = TableState::Generic; + + if (FFlag::LuauTypecheckOptPass) + seenGenericType = true; + } else if (ttv.state == TableState::Unsealed) ttv.state = TableState::Sealed; @@ -65,6 +95,9 @@ struct Quantifier bool operator()(TypePackId tp, const FreeTypePack& ftp) { + if (FFlag::LuauTypecheckOptPass) + seenMutableType = true; + if (!level.subsumes(ftp.level)) return false; @@ -84,6 +117,9 @@ void quantify(TypeId ty, TypeLevel level) LUAU_ASSERT(ftv); ftv->generics = q.generics; ftv->genericPacks = q.genericPacks; + + if (FFlag::LuauTypecheckOptPass && ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + ftv->hasNoGenerics = true; } } // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 770c7a47..8648b21e 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -7,24 +7,36 @@ #include #include +LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) +LUAU_FASTFLAG(LuauTypecheckOptPass) +LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) namespace Luau { void Tarjan::visitChildren(TypeId ty, int index) { - ty = log->follow(ty); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); + else + ty = log->follow(ty); if (ignoreChildren(ty)) return; - if (const FunctionTypeVar* ftv = log->getMutable(ty)) + if (FFlag::LuauTypecheckOptPass) + { + if (auto pty = log->pending(ty)) + ty = &pty->pending; + } + + if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { visitChild(ftv->argTypes); visitChild(ftv->retType); } - else if (const TableTypeVar* ttv = log->getMutable(ty)) + else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) @@ -41,38 +53,52 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId itp : ttv->instantiatedTypePackParams) visitChild(itp); } - else if (const MetatableTypeVar* mtv = log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { visitChild(mtv->table); visitChild(mtv->metatable); } - else if (const UnionTypeVar* utv = log->getMutable(ty)) + else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { for (TypeId opt : utv->options) visitChild(opt); } - else if (const IntersectionTypeVar* itv = log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { for (TypeId part : itv->parts) visitChild(part); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + for (TypeId part : ctv->parts) + visitChild(part); + } } void Tarjan::visitChildren(TypePackId tp, int index) { - tp = log->follow(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); + else + tp = log->follow(tp); if (ignoreChildren(tp)) return; - if (const TypePack* tpp = log->getMutable(tp)) + if (FFlag::LuauTypecheckOptPass) + { + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; + } + + if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { for (TypeId tv : tpp->head) visitChild(tv); if (tpp->tail) visitChild(*tpp->tail); } - else if (const VariadicTypePack* vtp = log->getMutable(tp)) + else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { visitChild(vtp->ty); } @@ -80,7 +106,10 @@ void Tarjan::visitChildren(TypePackId tp, int index) std::pair Tarjan::indexify(TypeId ty) { - ty = log->follow(ty); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); + else + ty = log->follow(ty); bool fresh = !typeToIndex.contains(ty); int& index = typeToIndex[ty]; @@ -98,7 +127,10 @@ std::pair Tarjan::indexify(TypeId ty) std::pair Tarjan::indexify(TypePackId tp) { - tp = log->follow(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); + else + tp = log->follow(tp); bool fresh = !packToIndex.contains(tp); int& index = packToIndex[tp]; @@ -141,7 +173,7 @@ TarjanResult Tarjan::loop() if (currEdge == -1) { ++childCount; - if (FInt::LuauTarjanChildLimit > 0 && FInt::LuauTarjanChildLimit < childCount) + if (childLimit > 0 && childLimit < childCount) return TarjanResult::TooManyChildren; stack.push_back(index); @@ -229,6 +261,9 @@ TarjanResult Tarjan::loop() TarjanResult Tarjan::visitRoot(TypeId ty) { childCount = 0; + if (childLimit == 0) + childLimit = FInt::LuauTarjanChildLimit; + ty = log->follow(ty); auto [index, fresh] = indexify(ty); @@ -239,6 +274,9 @@ TarjanResult Tarjan::visitRoot(TypeId ty) TarjanResult Tarjan::visitRoot(TypePackId tp) { childCount = 0; + if (childLimit == 0) + childLimit = FInt::LuauTarjanChildLimit; + tp = log->follow(tp); auto [index, fresh] = indexify(tp); @@ -347,7 +385,13 @@ TypeId Substitution::clone(TypeId ty) TypeId result = ty; - if (const FunctionTypeVar* ftv = log->getMutable(ty)) + if (FFlag::LuauTypecheckOptPass) + { + if (auto pty = log->pending(ty)) + ty = &pty->pending; + } + + if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; @@ -357,7 +401,7 @@ TypeId Substitution::clone(TypeId ty) clone.argNames = ftv->argNames; result = addType(std::move(clone)); } - else if (const TableTypeVar* ttv = log->getMutable(ty)) + else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -370,24 +414,29 @@ TypeId Substitution::clone(TypeId ty) clone.tags = ttv->tags; result = addType(std::move(clone)); } - else if (const MetatableTypeVar* mtv = log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; clone.syntheticName = mtv->syntheticName; result = addType(std::move(clone)); } - else if (const UnionTypeVar* utv = log->getMutable(ty)) + else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { UnionTypeVar clone; clone.options = utv->options; result = addType(std::move(clone)); } - else if (const IntersectionTypeVar* itv = log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { IntersectionTypeVar clone; clone.parts = itv->parts; result = addType(std::move(clone)); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + ConstrainedTypeVar clone{ctv->level, ctv->parts}; + result = addType(std::move(clone)); + } asMutable(result)->documentationSymbol = ty->documentationSymbol; return result; @@ -396,14 +445,21 @@ TypeId Substitution::clone(TypeId ty) TypePackId Substitution::clone(TypePackId tp) { tp = log->follow(tp); - if (const TypePack* tpp = log->getMutable(tp)) + + if (FFlag::LuauTypecheckOptPass) + { + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; + } + + if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { TypePack clone; clone.head = tpp->head; clone.tail = tpp->tail; return addTypePack(std::move(clone)); } - else if (const VariadicTypePack* vtp = log->getMutable(tp)) + else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { VariadicTypePack clone; clone.ty = vtp->ty; @@ -415,25 +471,34 @@ TypePackId Substitution::clone(TypePackId tp) void Substitution::foundDirty(TypeId ty) { - ty = log->follow(ty); - if (isDirty(ty)) - newTypes[ty] = clean(ty); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); else - newTypes[ty] = clone(ty); + ty = log->follow(ty); + + if (isDirty(ty)) + newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(ty)) : clean(ty); + else + newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(ty)) : clone(ty); } void Substitution::foundDirty(TypePackId tp) { - tp = log->follow(tp); - if (isDirty(tp)) - newPacks[tp] = clean(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); else - newPacks[tp] = clone(tp); + tp = log->follow(tp); + + if (isDirty(tp)) + newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(tp)) : clean(tp); + else + newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(tp)) : clone(tp); } TypeId Substitution::replace(TypeId ty) { ty = log->follow(ty); + if (TypeId* prevTy = newTypes.find(ty)) return *prevTy; else @@ -443,6 +508,7 @@ TypeId Substitution::replace(TypeId ty) TypePackId Substitution::replace(TypePackId tp) { tp = log->follow(tp); + if (TypePackId* prevTp = newPacks.find(tp)) return *prevTp; else @@ -451,7 +517,13 @@ TypePackId Substitution::replace(TypePackId tp) void Substitution::replaceChildren(TypeId ty) { - ty = log->follow(ty); + if (BoundTypeVar* btv = log->getMutable(ty); FFlag::LuauLowerBoundsCalculation && btv) + btv->boundTo = replace(btv->boundTo); + + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); + else + ty = log->follow(ty); if (ignoreChildren(ty)) return; @@ -493,11 +565,19 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& part : itv->parts) part = replace(part); } + else if (ConstrainedTypeVar* ctv = getMutable(ty)) + { + for (TypeId& part : ctv->parts) + part = replace(part); + } } void Substitution::replaceChildren(TypePackId tp) { - tp = log->follow(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); + else + tp = log->follow(tp); if (ignoreChildren(tp)) return; diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index df9d4188..cb54bfc1 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -237,6 +237,15 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + formatAppend(result, "ConstrainedTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : ctv->parts) + visitChild(part, index); + } else if (get(ty)) { formatAppend(result, "ErrorTypeVar %d", index); @@ -258,6 +267,28 @@ void StateDot::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable, index, "[metatable]"); } + else if (const SingletonTypeVar* stv = get(ty)) + { + std::string res; + + if (const StringSingleton* ss = get(stv)) + { + // Don't put in quotes anywhere. If it's outside of the call to escape, + // then it's invalid syntax. If it's inside, then escaping is super noisy. + res = "string: " + escape(ss->value); + } + else if (const BooleanSingleton* bs = get(stv)) + { + res = "boolean: "; + res += bs->value ? "true" : "false"; + } + else + LUAU_ASSERT(!"unknown singleton type"); + + formatAppend(result, "SingletonTypeVar %s", res.c_str()); + finishNodeLabel(ty); + finishNode(); + } else { LUAU_ASSERT(!"unknown type kind"); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 59ee6de2..610842da 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_FASTFLAG(LuauLowerBoundsCalculation) + /* * Prefix generic typenames with gen- * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 @@ -33,8 +35,8 @@ struct FindCyclicTypes bool exhaustive = false; std::unordered_set visited; std::unordered_set visitedPacks; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; void cycle(TypeId ty) { @@ -86,7 +88,7 @@ struct FindCyclicTypes }; template -void findCyclicTypes(std::unordered_set& cycles, std::unordered_set& cycleTPs, TID ty, bool exhaustive) +void findCyclicTypes(std::set& cycles, std::set& cycleTPs, TID ty, bool exhaustive) { FindCyclicTypes fct; fct.exhaustive = exhaustive; @@ -124,6 +126,7 @@ struct StringifierState std::unordered_map cycleTpNames; std::unordered_set seen; std::unordered_set usedNames; + size_t indentation = 0; bool exhaustive; @@ -216,6 +219,34 @@ struct StringifierState result.name += s; } + + void indent() + { + indentation += 4; + } + + void dedent() + { + indentation -= 4; + } + + void newline() + { + if (!opts.useLineBreaks) + return emit(" "); + + emit("\n"); + emitIndentation(); + } + +private: + void emitIndentation() + { + if (!opts.indent) + return; + + emit(std::string(indentation, ' ')); + } }; struct TypeVarStringifier @@ -321,7 +352,7 @@ struct TypeVarStringifier stringify(btv.boundTo); } - void operator()(TypeId ty, const Unifiable::Generic& gtv) + void operator()(TypeId ty, const GenericTypeVar& gtv) { if (gtv.explicitName) { @@ -332,6 +363,26 @@ struct TypeVarStringifier state.emit(state.getName(ty)); } + void operator()(TypeId, const ConstrainedTypeVar& ctv) + { + state.result.invalid = true; + + state.emit("[["); + + bool first = true; + for (TypeId ty : ctv.parts) + { + if (first) + first = false; + else + state.emit("|"); + + stringify(ty); + } + + state.emit("]]"); + } + void operator()(TypeId, const PrimitiveTypeVar& ptv) { switch (ptv.type) @@ -415,10 +466,25 @@ struct TypeVarStringifier state.emit(") -> "); bool plural = true; - if (auto retPack = get(follow(ftv.retType))) + + if (FFlag::LuauLowerBoundsCalculation) { - if (retPack->head.size() == 1 && !retPack->tail) - plural = false; + auto retBegin = begin(ftv.retType); + auto retEnd = end(ftv.retType); + if (retBegin != retEnd) + { + ++retBegin; + if (retBegin == retEnd && !retBegin.tail()) + plural = false; + } + } + else + { + if (auto retPack = get(follow(ftv.retType))) + { + if (retPack->head.size() == 1 && !retPack->tail) + plural = false; + } } if (plural) @@ -511,6 +577,7 @@ struct TypeVarStringifier } state.emit(openbrace); + state.indent(); bool comma = false; if (ttv.indexer) @@ -527,7 +594,10 @@ struct TypeVarStringifier for (const auto& [name, prop] : ttv.props) { if (comma) - state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + { + state.emit(","); + state.newline(); + } size_t length = state.result.name.length() - oldLength; @@ -553,6 +623,7 @@ struct TypeVarStringifier ++index; } + state.dedent(); state.emit(closedbrace); state.unsee(&ttv); @@ -563,7 +634,8 @@ struct TypeVarStringifier state.result.invalid = true; state.emit("{ @metatable "); stringify(mtv.metatable); - state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + state.emit(","); + state.newline(); stringify(mtv.table); state.emit(" }"); } @@ -784,13 +856,16 @@ struct TypePackStringifier if (tp.tail && !isEmpty(*tp.tail)) { - const auto& tail = *tp.tail; - if (first) - first = false; - else - state.emit(", "); + TypePackId tail = follow(*tp.tail); + if (auto vtp = get(tail); !vtp || (!FFlag::DebugLuauVerboseTypeNames && !vtp->hidden)) + { + if (first) + first = false; + else + state.emit(", "); - stringify(tail); + stringify(tail); + } } state.unsee(&tp); @@ -805,6 +880,8 @@ struct TypePackStringifier void operator()(TypePackId, const VariadicTypePack& pack) { state.emit("..."); + if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) + state.emit(""); stringify(pack.ty); } @@ -858,15 +935,12 @@ void TypeVarStringifier::stringify(TypePackId tpid, const std::vector& cycles, const std::unordered_set& cycleTPs, +static void assignCycleNames(const std::set& cycles, const std::set& cycleTPs, std::unordered_map& cycleNames, std::unordered_map& cycleTpNames, bool exhaustive) { int nextIndex = 1; - std::vector sortedCycles{cycles.begin(), cycles.end()}; - std::sort(sortedCycles.begin(), sortedCycles.end(), std::less{}); - - for (TypeId cycleTy : sortedCycles) + for (TypeId cycleTy : cycles) { std::string name; @@ -888,10 +962,7 @@ static void assignCycleNames(const std::unordered_set& cycles, const std cycleNames[cycleTy] = std::move(name); } - std::vector sortedCycleTps{cycleTPs.begin(), cycleTPs.end()}; - std::sort(sortedCycleTps.begin(), sortedCycleTps.end(), std::less()); - - for (TypePackId tp : sortedCycleTps) + for (TypePackId tp : cycleTPs) { std::string name = "tp" + std::to_string(nextIndex); ++nextIndex; @@ -913,8 +984,8 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) StringifierState state{opts, result, opts.nameMap}; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; findCyclicTypes(cycles, cycleTPs, ty, opts.exhaustive); @@ -1016,8 +1087,8 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) ToStringResult result; StringifierState state{opts, result, opts.nameMap}; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; findCyclicTypes(cycles, cycleTPs, tp, opts.exhaustive); @@ -1058,7 +1129,7 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) { + [&tvs, cycleTy = cycleTy](auto t) { return tvs(cycleTy, t); }, cycleTy->ty); @@ -1163,14 +1234,18 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp if (argPackIter.tail()) { - if (!first) - state.emit(", "); + if (auto vtp = get(*argPackIter.tail()); !vtp || !vtp->hidden) + { + if (!first) + state.emit(", "); - state.emit("...: "); - if (auto vtp = get(*argPackIter.tail())) - tvs.stringify(vtp->ty); - else - tvs.stringify(*argPackIter.tail()); + state.emit("...: "); + + if (vtp) + tvs.stringify(vtp->ty); + else + tvs.stringify(*argPackIter.tail()); + } } state.emit("): "); @@ -1210,6 +1285,24 @@ std::string dump(TypePackId ty) return s; } +std::string dump(const ScopePtr& scope, const char* name) +{ + auto binding = scope->linearSearchForBinding(name); + if (!binding) + { + printf("No binding %s\n", name); + return {}; + } + + TypeId ty = binding->typeId; + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; +} + std::string generateName(size_t i) { std::string n; diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 678001bf..1ea2e27d 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -215,6 +215,7 @@ struct ArcCollector : public AstVisitor } } + // Adds a dependency from the current node to the named node. void add(const Identifier& name) { Node** it = map.find(name); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 5fbb596d..a5f9d26c 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauTxnLogPreserveOwner, false) +LUAU_FASTFLAGVARIABLE(LuauJustOneCallFrameForHaveSeen, false) namespace Luau { @@ -161,18 +162,37 @@ void TxnLog::popSeen(TypePackId lhs, TypePackId rhs) bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const { - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) + if (FFlag::LuauJustOneCallFrameForHaveSeen && !FFlag::LuauTypecheckOptPass) { - return true; - } + // This function will technically work if `this` is nullptr, but this + // indicates a bug, so we explicitly assert. + LUAU_ASSERT(static_cast(this) != nullptr); - if (parent) + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (current->sharedSeen->end() != std::find(current->sharedSeen->begin(), current->sharedSeen->end(), sortedPair)) + return true; + } + + return false; + } + else { - return parent->haveSeen(lhs, rhs); - } + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) + { + return true; + } - return false; + if (!FFlag::LuauTypecheckOptPass && parent) + { + return parent->haveSeen(lhs, rhs); + } + + return false; + } } void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) @@ -222,8 +242,8 @@ PendingType* TxnLog::pending(TypeId ty) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) - return it->second.get(); + if (auto it = current->typeVarChanges.find(ty)) + return it->get(); } return nullptr; @@ -237,8 +257,8 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) - return it->second.get(); + if (auto it = current->typePackChanges.find(tp)) + return it->get(); } return nullptr; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index d575e023..bc8d0d4e 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -94,6 +94,16 @@ public: } } + AstType* operator()(const ConstrainedTypeVar& ctv) + { + AstArray types; + types.size = ctv.parts.size(); + types.data = static_cast(allocator->allocate(sizeof(AstType*) * ctv.parts.size())); + for (size_t i = 0; i < ctv.parts.size(); ++i) + types.data[i] = Luau::visit(*this, ctv.parts[i]->ty); + return allocator->alloc(Location(), types); + } + AstType* operator()(const SingletonTypeVar& stv) { if (const BooleanSingleton* bs = get(&stv)) @@ -364,6 +374,9 @@ public: AstTypePack* operator()(const VariadicTypePack& vtp) const { + if (vtp.hidden) + return nullptr; + return allocator->alloc(Location(), Luau::visit(*typeVisitor, vtp.ty->ty)); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 10930248..af42a4e6 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -3,12 +3,15 @@ #include "Luau/Common.h" #include "Luau/ModuleResolver.h" +#include "Luau/Normalize.h" +#include "Luau/Parser.h" #include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" #include "Luau/TypePack.h" +#include "Luau/ToString.h" #include "Luau/TypeUtils.h" #include "Luau/ToString.h" #include "Luau/TypeVar.h" @@ -19,14 +22,17 @@ LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauSeparateTypechecks) +LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauAutocompleteSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. +LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) @@ -39,6 +45,7 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) +LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) @@ -53,6 +60,8 @@ LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) +LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) +LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); namespace Luau { @@ -140,6 +149,34 @@ bool hasBreak(AstStat* node) } } +static bool hasReturn(const AstStat* node) +{ + struct Searcher : AstVisitor + { + bool result = false; + + bool visit(AstStat*) override + { + return !result; // if we've already found a return statement, don't bother to traverse inward anymore + } + + bool visit(AstStatReturn*) override + { + result = true; + return false; + } + + bool visit(AstExprFunction*) override + { + return false; // We don't care if the function uses a lambda that itself returns + } + }; + + Searcher searcher; + const_cast(node)->visit(&searcher); + return searcher.result; +} + // returns the last statement before the block exits, or nullptr if the block never exits const AstStat* getFallthrough(const AstStat* node) { @@ -253,6 +290,26 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) +{ + if (FFlag::LuauRecursionLimitException) + { + try + { + return checkWithoutRecursionCheck(module, mode, environmentScope); + } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(module.root->location); + return std::move(currentModule); + } + } + else + { + return checkWithoutRecursionCheck(module, mode, environmentScope); + } +} + +ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) { LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); @@ -268,6 +325,12 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona iceHandler->moduleName = module.name; + if (FFlag::LuauAutocompleteDynamicLimits) + { + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; + } + ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); @@ -312,7 +375,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona prepareErrorsForDisplay(currentModule->errors); - bool encounteredFreeType = currentModule->clonePublicInterface(); + bool encounteredFreeType = currentModule->clonePublicInterface(*iceHandler); if (encounteredFreeType) { reportError(TypeError{module.root->location, @@ -415,7 +478,26 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } + if (FFlag::LuauRecursionLimitException) + { + try + { + checkBlockWithoutRecursionCheck(scope, block); + } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(block.location); + return; + } + } + else + { + checkBlockWithoutRecursionCheck(scope, block); + } +} +void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) +{ int subLevel = 0; std::vector sorted(block.body.data, block.body.data + block.body.size); @@ -435,6 +517,16 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) std::unordered_map> functionDecls; + auto isLocalLambda = [](AstStat* stat) -> AstStatLocal* { + AstStatLocal* local = stat->as(); + + if (FFlag::LuauLowerBoundsCalculation && local && local->vars.size == 1 && local->values.size == 1 && + local->values.data[0]->is()) + return local; + else + return nullptr; + }; + auto checkBody = [&](AstStat* stat) { if (auto fun = stat->as()) { @@ -482,7 +574,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (containsFunctionCallOrReturn(**protoIter)) + if (containsFunctionCallOrReturn(**protoIter) || (FFlag::LuauLowerBoundsCalculation && isLocalLambda(*protoIter))) { while (checkIter != protoIter) { @@ -513,7 +605,8 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) functionDecls[*protoIter] = pair; ++subLevel; - TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); + TypeId leftType = follow(checkFunctionName(scope, *fun->name, funScope->level)); + unify(funTy, leftType, fun->location); } else if (auto fun = (*protoIter)->as()) @@ -658,6 +751,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } +void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location) +{ + Unifier state = mkUnifier(location); + state.unifyLowerBound(subTy, superTy); + + state.log.commit(); + + reportErrors(state.errors); +} + void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; @@ -682,6 +785,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; + if (useConstrainedIntersections()) + { + unifyLowerBound(retPack, scope->returnType, return_.location); + return; + } + // HACK: Nonstrict mode gets a bit too smart and strict for us when we // start typechecking everything across module boundaries. if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) @@ -1209,9 +1318,11 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco else if (tableSelf->state == TableState::Sealed) reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); + const bool tableIsExtendable = tableSelf && tableSelf->state != TableState::Sealed; + ty = follow(ty); - if (tableSelf && tableSelf->state != TableState::Sealed) + if (tableIsExtendable) tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; const FunctionTypeVar* funTy = get(ty); @@ -1224,7 +1335,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (tableSelf && tableSelf->state != TableState::Sealed) + if (tableIsExtendable) tableSelf->props[indexName->index.value] = { follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; } @@ -1372,7 +1483,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias for (auto param : binding->typePackParams) clone.instantiatedTypePackParams.push_back(param.tp); + bool isNormal = ty->normal; ty = addType(std::move(clone)); + + if (FFlag::LuauLowerBoundsCalculation) + asMutable(ty)->normal = isNormal; } } else @@ -1400,6 +1515,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (FFlag::LuauTwoPassAliasDefinitionFix && ok) bindingType = ty; + + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(bindingType, currentModule, *iceHandler); + bindingType = t; + if (!ok) + reportError(typealias.location, NormalizationTooComplex{}); + } } } @@ -1673,10 +1796,11 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (get(retPack)) + else if (const FreeTypePack* ftp = get(retPack)) { - TypeId head = freshType(scope); - TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); + TypeLevel level = FFlag::LuauLowerBoundsCalculation ? ftp->level : scope->level; + TypeId head = freshType(level); + TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(level)}}); unify(pack, retPack, expr.location); return {head, std::move(result.predicates)}; } @@ -1793,7 +1917,7 @@ std::optional TypeChecker::getIndexTypeFromType( for (TypeId t : utv) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeForType unions"); // Not needed when we normalize types. if (get(follow(t))) @@ -1817,12 +1941,25 @@ std::optional TypeChecker::getIndexTypeFromType( return std::nullopt; } - std::vector result = reduceUnion(goodOptions); + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(addType(UnionTypeVar{std::move(goodOptions)}), currentModule, + *iceHandler); // FIXME Inefficient. We craft a UnionTypeVar and immediately throw it away. - if (result.size() == 1) - return result[0]; + if (!ok) + reportError(location, NormalizationTooComplex{}); - return addType(UnionTypeVar{std::move(result)}); + return t; + } + else + { + std::vector result = reduceUnion(goodOptions); + + if (result.size() == 1) + return result[0]; + + return addType(UnionTypeVar{std::move(result)}); + } } else if (const IntersectionTypeVar* itv = get(type)) { @@ -1830,7 +1967,7 @@ std::optional TypeChecker::getIndexTypeFromType( for (TypeId t : itv->parts) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeFromType intersections"); if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) parts.push_back(*ty); @@ -1982,7 +2119,6 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) { if (!std::any_of(begin(utv), end(utv), isNil)) return ty; - } if (std::optional strippedUnion = tryStripUnionFromNil(ty)) @@ -2124,7 +2260,26 @@ TypeId TypeChecker::checkExprTable( ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + if (FFlag::LuauTableUseCounterInstead) + { + RecursionCounter _rc(&checkRecursionCount); + if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) + { + reportErrorCodeTooComplex(expr.location); + return {errorRecoveryType(scope)}; + } + + return checkExpr_(scope, expr, expectedType); + } + else + { + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "checkExpr for tables"); + return checkExpr_(scope, expr, expectedType); + } +} + +ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) +{ std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; @@ -3176,6 +3331,10 @@ std::pair TypeChecker::checkFunctionSignature( funScope->varargPack = anyTypePack; } } + else if (FFlag::LuauLowerBoundsCalculation && !isNonstrictMode()) + { + funScope->varargPack = addTypePack(TypePackVar{VariadicTypePack{anyType, /*hidden*/ true}}); + } std::vector argTypes; @@ -3311,9 +3470,24 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE { check(scope, *function.body); - // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (get_if(&funTy->retType->ty)) - *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + if (useConstrainedIntersections()) + { + TypePackId retPack = follow(funTy->retType); + // It is possible for a function to have no annotation and no return statement, and yet still have an ascribed return type + // if it is expected to conform to some other interface. (eg the function may be a lambda passed as a callback) + if (!hasReturn(function.body) && !function.returnAnnotation.has_value() && get(retPack)) + { + auto level = getLevel(retPack); + if (level && scope->level.subsumes(*level)) + *asMutable(retPack) = TypePack{{}, std::nullopt}; + } + } + else + { + // We explicitly don't follow here to check if we have a 'true' free type instead of bound one + if (get_if(&funTy->retType->ty)) + *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + } bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; @@ -3418,6 +3592,19 @@ void TypeChecker::checkArgumentList( size_t minParams = FFlag::LuauFixIncorrectLineNumberDuplicateType ? 0 : getMinParameterCount_DEPRECATED(paramPack); + auto reportCountMismatchError = [&state, &argLocations, minParams, paramPack, argPack]() { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + + size_t mp = minParams; + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + mp = getMinParameterCount(&state.log, paramPack); + + state.reportError(TypeError{location, CountMismatch{mp, std::distance(begin(argPack), end(argPack))}}); + }; + while (true) { state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; @@ -3472,6 +3659,8 @@ void TypeChecker::checkArgumentList( } else if (auto vtp = state.log.getMutable(tail)) { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. while (paramIter != endIter) { state.tryUnify(vtp->ty, *paramIter); @@ -3506,14 +3695,22 @@ void TypeChecker::checkArgumentList( else if (state.log.getMutable(t)) { } // ok - else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.getMutable(t)) + else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.get(t)) { } // ok else { if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) minParams = getMinParameterCount(&state.log, paramPack); - bool isVariadic = FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic && !finite(paramPack, &state.log); + + bool isVariadic = false; + if (FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic) + { + std::optional tail = flatten(paramPack, state.log).second; + if (tail) + isVariadic = Luau::isVariadic(*tail); + } + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}}); return; } @@ -3532,14 +3729,7 @@ void TypeChecker::checkArgumentList( unify(errorRecoveryType(scope), *argIter, state.location); ++argIter; } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + reportCountMismatchError(); return; } TypePackId tail = state.log.follow(*paramIter.tail()); @@ -3551,6 +3741,21 @@ void TypeChecker::checkArgumentList( } else if (auto vtp = state.log.getMutable(tail)) { + if (FFlag::LuauLowerBoundsCalculation && vtp->hidden) + { + // We know that this function can technically be oversaturated, but we have its definition and we + // know that it's useless. + + TypeId e = errorRecoveryType(scope); + while (argIter != endIter) + { + unify(e, *argIter, state.location); + ++argIter; + } + + reportCountMismatchError(); + return; + } // Function is variadic and requires that all subsequent parameters // be compatible with a type. size_t argIndex = paramIndex; @@ -3595,14 +3800,7 @@ void TypeChecker::checkArgumentList( } else if (state.log.getMutable(tail)) { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + reportCountMismatchError(); return; } } @@ -3661,7 +3859,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = follow(actualFunctionType); TypePackId retPack; - if (!FFlag::LuauWidenIfSupertypeIsFree2) + if (FFlag::LuauLowerBoundsCalculation || !FFlag::LuauWidenIfSupertypeIsFree2) { retPack = freshTypePack(scope->level); } @@ -3809,21 +4007,49 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {{errorRecoveryTypePack(scope)}}; } - if (get(fn)) + if (auto ftv = get(fn)) { // fn is one of the overloads of actualFunctionType, which // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. - TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - if (FFlag::LuauWidenIfSupertypeIsFree2) + if (useConstrainedIntersections()) { - UnifierOptions options; - options.isFunctionCall = true; - unify(r, fn, expr.location, options); + // This ternary is phrased deliberately. We need ties between sibling scopes to bias toward ftv->level. + const TypeLevel level = scope->level.subsumes(ftv->level) ? scope->level : ftv->level; + + std::vector adjustedArgTypes; + auto it = begin(argPack); + auto endIt = end(argPack); + Widen widen{¤tModule->internalTypes}; + for (; it != endIt; ++it) + { + TypeId t = *it; + TypeId widened = widen.substitute(t).value_or(t); // Surely widening is infallible + adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widened}})); + } + + TypePackId adjustedArgPack = addTypePack(TypePack{std::move(adjustedArgTypes), it.tail()}); + + TxnLog log; + promoteTypeLevels(log, ¤tModule->internalTypes, level, retPack); + log.commit(); + + *asMutable(fn) = FunctionTypeVar{level, adjustedArgPack, retPack}; + return {{retPack}}; } else - unify(fn, r, expr.location); - return {{retPack}}; + { + TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); + if (FFlag::LuauWidenIfSupertypeIsFree2) + { + UnifierOptions options; + options.isFunctionCall = true; + unify(r, fn, expr.location, options); + } + else + unify(fn, r, expr.location); + return {{retPack}}; + } } std::vector metaArgLocations; @@ -4363,10 +4589,17 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s bool Instantiation::isDirty(TypeId ty) { - if (log->getMutable(ty)) + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics) + return false; + return true; + } else + { return false; + } } bool Instantiation::isDirty(TypePackId tp) @@ -4414,14 +4647,21 @@ TypePackId Instantiation::clean(TypePackId tp) bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics) + return true; + // We aren't recursing in the case of a generic function which // binds the same generics. This can happen if, for example, there's recursive types. // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. // It's OK to use vector equality here, since we always generate fresh generics // whenever we quantify, so the vectors overlap if and only if they are equal. return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); + } else + { return false; + } } bool ReplaceGenerics::isDirty(TypeId ty) @@ -4464,16 +4704,24 @@ TypePackId ReplaceGenerics::clean(TypePackId tp) bool Anyification::isDirty(TypeId ty) { + if (ty->persistent) + return false; + if (const TableTypeVar* ttv = log->getMutable(ty)) return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); else if (log->getMutable(ty)) return true; + else if (get(ty)) + return true; else return false; } bool Anyification::isDirty(TypePackId tp) { + if (tp->persistent) + return false; + if (log->getMutable(tp)) return true; else @@ -4494,7 +4742,16 @@ TypeId Anyification::clean(TypeId ty) clone.syntheticName = ttv->syntheticName; clone.tags = ttv->tags; } - return addType(std::move(clone)); + TypeId res = addType(std::move(clone)); + asMutable(res)->normal = ty->normal; + return res; + } + else if (auto ctv = get(ty)) + { + auto [t, ok] = normalize(ty, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; } else return anyType; @@ -4511,16 +4768,34 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location ty = follow(ty); const FunctionTypeVar* ftv = get(ty); - if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) - return ty; + if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) + Luau::quantify(ty, scope->level); + + if (FFlag::LuauLowerBoundsCalculation && ftv) + { + auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + return t; + } - Luau::quantify(ty, scope->level); return ty; } TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { + if (FFlag::LuauTypecheckOptPass) + { + const FunctionTypeVar* ftv = get(follow(ty)); + if (ftv && ftv->hasNoGenerics) + return ty; + } + Instantiation instantiation{log, ¤tModule->internalTypes, scope->level}; + + if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit) + instantiation.childLimit = *instantiationChildLimit; + std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; @@ -4533,8 +4808,18 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { - Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + ty = t; + } + + Anyification anyification{¤tModule->internalTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); + if (anyification.normalizationTooComplex) + reportError(location, NormalizationTooComplex{}); if (any.has_value()) return *any; else @@ -4546,7 +4831,15 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { - Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + ty = t; + } + + Anyification anyification{¤tModule->internalTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4830,6 +5123,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation ToStringOptions opts; opts.exhaustive = true; opts.maxTableLength = 0; + opts.useLineBreaks = true; TypeId param = resolveType(scope, *lit->parameters.data[0].type); luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); @@ -5283,7 +5577,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, bool needsClone = follow(tf.type) == target; bool shouldMutate = (!FFlag::LuauOnlyMutateInstantiatedTables || getTableType(tf.type)); TableTypeVar* ttv = getMutableTableType(target); - + if (shouldMutate && ttv && needsClone) { // Substitution::clone is a shallow clone. If this is a metatable type, we @@ -5487,25 +5781,82 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV // We need to search in the provided Scope. Find t.x.y first. // We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x. // If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate. - const auto& [symbol, keys] = getFullName(lvalue); + if (!FFlag::LuauTypecheckOptPass) + { + const auto& [symbol, keys] = getFullName(lvalue); + + ScopePtr currentScope = scope; + while (currentScope) + { + std::optional found; + + std::vector childKeys; + const LValue* currentLValue = &lvalue; + while (currentLValue) + { + if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) + { + found = it->second; + break; + } + + childKeys.push_back(*currentLValue); + currentLValue = baseof(*currentLValue); + } + + if (!found) + { + // Should not be using scope->lookup. This is already recursive. + if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) + found = it->second.typeId; + else + { + // Nothing exists in this Scope. Just skip and try the parent one. + currentScope = currentScope->parent; + continue; + } + } + + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) + { + const LValue& key = *it; + + // Symbol can happen. Skip. + if (get(key)) + continue; + else if (auto field = get(key)) + { + found = getIndexTypeFromType(scope, *found, field->key, Location(), false); + if (!found) + return std::nullopt; // Turns out this type doesn't have the property at all. We're done. + } + else + LUAU_ASSERT(!"New LValue alternative not handled here."); + } + + return found; + } + + // No entry for it at all. Can happen when LValue root is a global. + return std::nullopt; + } + + const Symbol symbol = getBaseSymbol(lvalue); ScopePtr currentScope = scope; while (currentScope) { std::optional found; - std::vector childKeys; - const LValue* currentLValue = &lvalue; - while (currentLValue) + const LValue* topLValue = nullptr; + + for (topLValue = &lvalue; topLValue; topLValue = baseof(*topLValue)) { - if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) + if (auto it = currentScope->refinements.find(*topLValue); it != currentScope->refinements.end()) { found = it->second; break; } - - childKeys.push_back(*currentLValue); - currentLValue = baseof(*currentLValue); } if (!found) @@ -5521,9 +5872,15 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV } } + // We need to walk the l-value path in reverse, so we collect components into a vector + std::vector childKeys; + + for (const LValue* curr = &lvalue; curr != topLValue; curr = baseof(*curr)) + childKeys.push_back(curr); + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) { - const LValue& key = *it; + const LValue& key = **it; // Symbol can happen. Skip. if (get(key)) @@ -5938,6 +6295,11 @@ bool TypeChecker::isNonstrictMode() const return (currentModule->mode == Mode::Nonstrict) || (currentModule->mode == Mode::NoCheck); } +bool TypeChecker::useConstrainedIntersections() const +{ + return FFlag::LuauLowerBoundsCalculation && !isNonstrictMode(); +} + std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp, size_t expectedLength, const Location& location) { TypePackId expectedTypePack = addTypePack({}); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 5bb05234..30503233 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -104,7 +104,7 @@ TypePackIterator begin(TypePackId tp) return TypePackIterator{tp}; } -TypePackIterator begin(TypePackId tp, TxnLog* log) +TypePackIterator begin(TypePackId tp, const TxnLog* log) { return TypePackIterator{tp, log}; } @@ -256,7 +256,7 @@ size_t size(const TypePack& tp, TxnLog* log) return result; } -std::optional first(TypePackId tp) +std::optional first(TypePackId tp, bool ignoreHiddenVariadics) { auto it = begin(tp); auto endIter = end(tp); @@ -266,7 +266,7 @@ std::optional first(TypePackId tp) if (auto tail = it.tail()) { - if (auto vtp = get(*tail)) + if (auto vtp = get(*tail); vtp && (!vtp->hidden || !ignoreHiddenVariadics)) return vtp->ty; } @@ -299,6 +299,46 @@ std::pair, std::optional> flatten(TypePackId tp) return {res, iter.tail()}; } +std::pair, std::optional> flatten(TypePackId tp, const TxnLog& log) +{ + tp = log.follow(tp); + + std::vector flattened; + std::optional tail = std::nullopt; + + TypePackIterator it(tp, &log); + + for (; it != end(tp); ++it) + { + flattened.push_back(*it); + } + + tail = it.tail(); + + return {flattened, tail}; +} + +bool isVariadic(TypePackId tp) +{ + return isVariadic(tp, *TxnLog::empty()); +} + +bool isVariadic(TypePackId tp, const TxnLog& log) +{ + std::optional tail = flatten(tp, log).second; + + if (!tail) + return false; + + if (log.get(*tail)) + return true; + + if (auto vtp = log.get(*tail); vtp && !vtp->hidden) + return true; + + return false; +} + TypePackVar* asMutable(TypePackId tp) { return const_cast(tp); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index dbc412fc..0fbfdbf0 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -177,7 +177,7 @@ bool maybeString(TypeId ty) if (FFlag::LuauSubtypingAddOptPropsToUnsealedTables) { ty = follow(ty); - + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) return true; @@ -366,7 +366,7 @@ bool maybeSingleton(TypeId ty) bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) { - RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit, "hasLength"); ty = follow(ty); @@ -654,9 +654,9 @@ static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persi static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}; static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}; static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; -static TypeVar anyType_{AnyTypeVar{}}; -static TypeVar errorType_{ErrorTypeVar{}}; -static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}}; +static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; +static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; +static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}, /*persistent*/ true}; static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; static TypePackVar errorTypePack_{Unifiable::Error{}}; @@ -698,7 +698,7 @@ TypeId SingletonTypes::makeStringMetatable() { const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); const TypeId optionalString = arena->addType(UnionTypeVar{{nilType, stringType}}); - const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, &booleanType_}}); + const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, booleanType}}); const TypePackId oneStringPack = arena->addTypePack({stringType}); const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); @@ -802,6 +802,7 @@ void persist(TypeId ty) continue; asMutable(t)->persistent = true; + asMutable(t)->normal = true; // all persistent types are assumed to be normal if (auto btv = get(t)) queue.push_back(btv->boundTo); @@ -838,6 +839,11 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } + else if (auto ctv = get(t)) + { + for (TypeId opt : ctv->parts) + queue.push_back(opt); + } else if (auto mtv = get(t)) { queue.push_back(mtv->table); @@ -899,6 +905,16 @@ TypeLevel* getMutableLevel(TypeId ty) return const_cast(getLevel(ty)); } +std::optional getLevel(TypePackId tp) +{ + tp = follow(tp); + + if (auto ftv = get(tp)) + return ftv->level; + else + return std::nullopt; +} + const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) { while (cls) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 398dc9e2..f9ea58cc 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,9 +14,12 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); +LUAU_FASTINT(LuauTypeInferIterationLimit); +LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauExtendedIndexerError, false); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); +LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) @@ -27,6 +30,7 @@ LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) LUAU_FASTFLAGVARIABLE(LuauUnifierCacheErrors, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) +LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -126,7 +130,6 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel visitTypeVarOnce(ty, ptl, seen); } -// TODO: use this and make it static. void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -305,8 +308,7 @@ static std::optional> getTableMat return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, - TxnLog* parentLog) +Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) , log(parentLog) @@ -326,6 +328,7 @@ Unifier::Unifier(TypeArena* types, Mode mode, std::vector 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + if (FFlag::LuauAutocompleteDynamicLimits) { - reportError(TypeError{location, UnificationTooComplex{}}); - return; + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } + } + else + { + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } } superTy = log.follow(superTy); @@ -354,6 +369,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; + if (log.get(superTy)) + return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy); + auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); @@ -442,7 +460,18 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (get(superTy) || get(superTy)) return tryUnifyWithAny(subTy, superTy); - if (get(subTy) || get(subTy)) + if (get(subTy)) + { + if (anyIsTop) + { + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + return; + } + else + return tryUnifyWithAny(superTy, subTy); + } + + if (get(subTy)) return tryUnifyWithAny(superTy, subTy); bool cacheEnabled; @@ -484,7 +513,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (const UnionTypeVar* uv = log.getMutable(subTy)) + if (log.get(subTy)) + tryUnifyWithConstrainedSubTypeVar(subTy, superTy); + else if (const UnionTypeVar* uv = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, uv, superTy); } @@ -946,7 +977,7 @@ struct WeirdIter LUAU_ASSERT(log.getMutable(newTail)); level = log.getMutable(packId)->level; - log.replace(packId, Unifiable::Bound(newTail)); + log.replace(packId, BoundTypePack(newTail)); packId = newTail; pack = log.getMutable(newTail); index = 0; @@ -994,39 +1025,32 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall tryUnify_(subTp, superTp, isFunctionCall); } -static std::pair, std::optional> logAwareFlatten(TypePackId tp, const TxnLog& log) -{ - tp = log.follow(tp); - - std::vector flattened; - std::optional tail = std::nullopt; - - TypePackIterator it(tp, &log); - - for (; it != end(tp); ++it) - { - flattened.push_back(*it); - } - - tail = it.tail(); - - return {flattened, tail}; -} - /* * This is quite tricky: we are walking two rope-like structures and unifying corresponding elements. * If one is longer than the other, but the short end is free, we grow it to the required length. */ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "TypePackId tryUnify_"); ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + if (FFlag::LuauAutocompleteDynamicLimits) { - reportError(TypeError{location, UnificationTooComplex{}}); - return; + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } + } + else + { + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } } superTp = log.follow(superTp); @@ -1087,8 +1111,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If the size of two heads does not match, but both packs have free tail // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = logAwareFlatten(superTp, log); - auto [subTypes, subTail] = logAwareFlatten(subTp, log); + auto [superTypes, superTail] = flatten(superTp, log); + auto [subTypes, subTail] = flatten(subTp, log); bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && (subTail && log.getMutable(*subTail)); @@ -1165,19 +1189,20 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal else { // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) + if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && superIter.good() && isOptional(*superIter)) { superIter.advance(); continue; } - else if (subIter.good() && isOptional(*subIter)) + else if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && subIter.good() && isOptional(*subIter)) { subIter.advance(); continue; } // In nonstrict mode, any also marks an optional argument. - else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) + else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && + log.getMutable(log.follow(*superIter))) { superIter.advance(); continue; @@ -1195,7 +1220,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal return; } - if (!isFunctionCall && subIter.good()) + if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && !isFunctionCall && subIter.good()) { // Sometimes it is ok to pass too many arguments return; @@ -1418,14 +1443,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (FFlag::LuauAnyInIsOptionalIsOptional) { - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) + if (subIter == subTable->props.end() && + (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) missingProperties.push_back(propName); } else { bool isAny = log.getMutable(log.follow(superProp.type)); - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) + if (subIter == subTable->props.end() && + (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && + !isAny) missingProperties.push_back(propName); } } @@ -1438,8 +1466,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // And vice versa if we're invariant - if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && - superTable->state != TableState::Free) + if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && superTable->state != TableState::Free) { for (const auto& [propName, subProp] : subTable->props) { @@ -1453,7 +1480,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else { bool isAny = log.is(log.follow(subProp.type)); - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) + if (superIter == superTable->props.end() && + (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) extraProperties.push_back(propName); } } @@ -1499,13 +1527,15 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); } - else if (FFlag::LuauAnyInIsOptionalIsOptional && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) + else if (FFlag::LuauAnyInIsOptionalIsOptional && + (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) { } - else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get(follow(prop.type)))) + else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && + (isOptional(prop.type) || get(follow(prop.type)))) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: should isOptional(anyType) be true? @@ -1664,9 +1694,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (FFlag::LuauTxnLogDontRetryForIndexers) { - // Changing the indexer can invalidate the table pointers. - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); + // Changing the indexer can invalidate the table pointers. + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); } else if (FFlag::LuauTxnLogCheckForInvalidation) { @@ -1921,8 +1951,6 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!superTable || !subTable) ice("passed non-table types to unifySealedTables"); - Unifier innerState = makeChildUnifier(); - std::vector missingPropertiesInSuper; bool isUnnamedTable = subTable->name == std::nullopt && subTable->syntheticName == std::nullopt; bool errorReported = false; @@ -1944,6 +1972,8 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec } } + Unifier innerState = makeChildUnifier(); + // Tables must have exactly the same props and their types must all unify for (const auto& it : superTable->props) { @@ -2376,6 +2406,180 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); } +void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy) +{ + const ConstrainedTypeVar* subConstrained = get(subTy); + if (!subConstrained) + ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!"); + + const std::vector& subTyParts = subConstrained->parts; + + // A | B <: T if A <: T and B <: T + bool failed = false; + std::optional unificationTooComplex; + + const size_t count = subTyParts.size(); + + for (size_t i = 0; i < count; ++i) + { + TypeId type = subTyParts[i]; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, superTy); + + if (i == count - 1) + log.concat(std::move(innerState.log)); + + ++i; + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + + if (!innerState.errors.empty()) + { + failed = true; + break; + } + } + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (failed) + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + else + log.replace(subTy, BoundTypeVar{superTy}); +} + +void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy) +{ + ConstrainedTypeVar* superC = log.getMutable(superTy); + if (!superC) + ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!"); + + // subTy could be a + // table + // metatable + // class + // function + // primitive + // free + // generic + // intersection + // union + // Do we really just tack it on? I think we might! + // We can certainly do some deduplication. + // Is there any point to deducing Player|Instance when we could just reduce to Instance? + // Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type? + // Maybe we do a simplification step during quantification. + + auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy); + if (it != superC->parts.end()) + return; + + superC->parts.push_back(subTy); +} + +void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy) +{ + // The duplication between this and regular typepack unification is tragic. + + auto superIter = begin(superTy, &log); + auto superEndIter = end(superTy); + + auto subIter = begin(subTy, &log); + auto subEndIter = end(subTy); + + int count = FInt::LuauTypeInferLowerBoundsIterationLimit; + + for (; subIter != subEndIter; ++subIter) + { + if (0 >= --count) + ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound"); + + if (superIter != superEndIter) + { + tryUnify_(*subIter, *superIter); + ++superIter; + continue; + } + + if (auto t = superIter.tail()) + { + TypePackId tailPack = follow(*t); + + if (log.get(tailPack)) + occursCheck(tailPack, subTy); + + FreeTypePack* freeTailPack = log.getMutable(tailPack); + if (!freeTailPack) + return; + + TypeLevel level = freeTailPack->level; + + TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); + + for (; subIter != subEndIter; ++subIter) + { + tp->head.push_back(types->addType(ConstrainedTypeVar{level, {follow(*subIter)}})); + } + + tp->tail = subIter.tail(); + } + + return; + } + + if (superIter != superEndIter) + { + if (auto subTail = subIter.tail()) + { + TypePackId subTailPack = follow(*subTail); + if (get(subTailPack)) + { + TypePack* tp = getMutable(log.replace(subTailPack, TypePack{})); + + for (; superIter != superEndIter; ++superIter) + tp->head.push_back(*superIter); + } + } + else + { + while (superIter != superEndIter) + { + if (!isOptional(*superIter)) + { + errors.push_back(TypeError{location, CountMismatch{size(superTy), size(subTy), CountMismatch::Return}}); + return; + } + ++superIter; + } + } + + return; + } + + // Both iters are at their respective tails + auto subTail = subIter.tail(); + auto superTail = superIter.tail(); + if (subTail && superTail) + tryUnify(*subTail, *superTail); + else if (subTail) + { + const FreeTypePack* freeSubTail = log.getMutable(*subTail); + if (freeSubTail) + { + log.replace(*subTail, TypePack{}); + } + } + else if (superTail) + { + const FreeTypePack* freeSuperTail = log.getMutable(*superTail); + if (freeSuperTail) + { + log.replace(*superTail, TypePack{}); + } + } +} + void Unifier::occursCheck(TypeId needle, TypeId haystack) { sharedState.tempSeenTy.clear(); @@ -2385,7 +2589,8 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypeId"); auto check = [&](TypeId tv) { occursCheck(seen, needle, tv); @@ -2425,6 +2630,11 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays for (TypeId ty : a->parts) check(ty); } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->parts) + check(ty); + } } void Unifier::occursCheck(TypePackId needle, TypePackId haystack) @@ -2450,7 +2660,8 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (!log.getMutable(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypePackId"); while (!log.getMutable(haystack)) { @@ -2474,7 +2685,23 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; + if (FFlag::LuauTypecheckOptPass) + { + Unifier u = Unifier{types, mode, location, variance, sharedState, &log}; + u.anyIsTop = anyIsTop; + return u; + } + + Unifier u = Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; + u.anyIsTop = anyIsTop; + return u; +} + +// A utility function that appends the given error to the unifier's error log. +// This allows setting a breakpoint wherever the unifier reports an error. +void Unifier::reportError(TypeError err) +{ + errors.push_back(std::move(err)); } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index 65939bee..f8543111 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -32,6 +32,7 @@ class DenseHashTable { public: class const_iterator; + class iterator; DenseHashTable(const Key& empty_key, size_t buckets = 0) : count(0) @@ -43,7 +44,7 @@ public: // don't move this to initializer list! this works around an MSVC codegen issue on AMD CPUs: // https://developercommunity.visualstudio.com/t/stdvector-constructor-from-size-t-is-25-times-slow/1546547 if (buckets) - data.resize(buckets, ItemInterface::create(empty_key)); + resize_data(buckets); } void clear() @@ -125,7 +126,7 @@ public: if (data.empty() && data.capacity() >= newsize) { LUAU_ASSERT(count == 0); - data.resize(newsize, ItemInterface::create(empty_key)); + resize_data(newsize); return; } @@ -169,6 +170,21 @@ public: return const_iterator(this, data.size()); } + iterator begin() + { + size_t start = 0; + + while (start < data.size() && eq(ItemInterface::getKey(data[start]), empty_key)) + start++; + + return iterator(this, start); + } + + iterator end() + { + return iterator(this, data.size()); + } + size_t size() const { return count; @@ -233,7 +249,82 @@ public: size_t index; }; + class iterator + { + public: + iterator() + : set(0) + , index(0) + { + } + + iterator(DenseHashTable* set, size_t index) + : set(set) + , index(index) + { + } + + MutableItem& operator*() const + { + return *reinterpret_cast(&set->data[index]); + } + + MutableItem* operator->() const + { + return reinterpret_cast(&set->data[index]); + } + + bool operator==(const iterator& other) const + { + return set == other.set && index == other.index; + } + + bool operator!=(const iterator& other) const + { + return set != other.set || index != other.index; + } + + iterator& operator++() + { + size_t size = set->data.size(); + + do + { + index++; + } while (index < size && set->eq(ItemInterface::getKey(set->data[index]), set->empty_key)); + + return *this; + } + + iterator operator++(int) + { + iterator res = *this; + ++*this; + return res; + } + + private: + DenseHashTable* set; + size_t index; + }; + private: + template + void resize_data(size_t count, typename std::enable_if_t>* dummy = nullptr) + { + data.resize(count, ItemInterface::create(empty_key)); + } + + template + void resize_data(size_t count, typename std::enable_if_t>* dummy = nullptr) + { + size_t size = data.size(); + data.resize(count); + + for (size_t i = size; i < count; i++) + data[i].first = empty_key; + } + std::vector data; size_t count; Key empty_key; @@ -290,6 +381,7 @@ class DenseHashSet public: typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::iterator iterator; DenseHashSet(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) @@ -336,6 +428,16 @@ public: { return impl.end(); } + + iterator begin() + { + return impl.begin(); + } + + iterator end() + { + return impl.end(); + } }; // This is a faster alternative of unordered_map, but it does not implement the same interface (i.e. it does not support erasing and has @@ -348,6 +450,7 @@ class DenseHashMap public: typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::iterator iterator; DenseHashMap(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) @@ -401,10 +504,21 @@ public: { return impl.begin(); } + const_iterator end() const { return impl.end(); } + + iterator begin() + { + return impl.begin(); + } + + iterator end() + { + return impl.end(); + } }; } // namespace Luau diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index d7d867f4..4f3dbbd5 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -173,7 +173,7 @@ public: } const Lexeme& next(); - const Lexeme& next(bool skipComments); + const Lexeme& next(bool skipComments, bool updatePrevLocation); void nextline(); Lexeme lookahead(); diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 70c6c78d..5dd4f04e 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -349,13 +349,11 @@ void Lexer::setReadNames(bool read) const Lexeme& Lexer::next() { - return next(this->skipComments); + return next(this->skipComments, true); } -const Lexeme& Lexer::next(bool skipComments) +const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) { - bool first = true; - // in skipComments mode we reject valid comments do { @@ -363,11 +361,11 @@ const Lexeme& Lexer::next(bool skipComments) while (isSpace(peekch())) consume(); - if (!FFlag::LuauParseLocationIgnoreCommentSkip || first) + if (!FFlag::LuauParseLocationIgnoreCommentSkip || updatePrevLocation) prevLocation = lexeme.location; lexeme = readNext(); - first = false; + updatePrevLocation = false; } while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment)); return lexeme; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index f9d32178..badd3fd3 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -11,6 +11,7 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseRecoverUnexpectedPack, false) +LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false) namespace Luau { @@ -2789,7 +2790,7 @@ void Parser::nextLexeme() { if (options.captureComments) { - Lexeme::Type type = lexer.next(/* skipComments= */ false).type; + Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type; while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) { @@ -2813,7 +2814,7 @@ void Parser::nextLexeme() hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } - type = lexer.next(/* skipComments= */ false).type; + type = lexer.next(/* skipComments= */ false, !FFlag::LuauParseLocationIgnoreCommentSkipInCapture).type; } } else diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 6330bf1f..8ef69e75 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -1386,8 +1386,8 @@ struct Compiler const Constant* cv = constants.find(expr->index); - if (cv && cv->type == Constant::Type_Number && double(int(cv->valueNumber)) == cv->valueNumber && cv->valueNumber >= 1 && - cv->valueNumber <= 256) + if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && + double(int(cv->valueNumber)) == cv->valueNumber) { uint8_t rt = compileExprAuto(expr->expr, rs); uint8_t i = uint8_t(int(cv->valueNumber) - 1); diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp new file mode 100644 index 00000000..d8511bdb --- /dev/null +++ b/Compiler/src/CostModel.cpp @@ -0,0 +1,258 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "CostModel.h" + +#include "Luau/Common.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +namespace Compile +{ + +inline uint64_t parallelAddSat(uint64_t x, uint64_t y) +{ + uint64_t s = x + y; + uint64_t m = s & 0x8080808080808080ull; // saturation mask + + return (s ^ m) | (m - (m >> 7)); +} + +struct Cost +{ + static const uint64_t kLiteral = ~0ull; + + // cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant + uint64_t model; + // constant mask: 8-byte 0xff mask; equal to all ff's for literals, for variables only byte #i (1+) is set to align with model + uint64_t constant; + + Cost(int cost = 0, uint64_t constant = 0) + : model(cost < 0x7f ? cost : 0x7f) + , constant(constant) + { + } + + Cost operator+(const Cost& other) const + { + Cost result; + result.model = parallelAddSat(model, other.model); + return result; + } + + Cost& operator+=(const Cost& other) + { + model = parallelAddSat(model, other.model); + constant = 0; + return *this; + } + + static Cost fold(const Cost& x, const Cost& y) + { + uint64_t newmodel = parallelAddSat(x.model, y.model); + uint64_t newconstant = x.constant & y.constant; + + // the extra cost for folding is 1; the discount is 1 for the variable that is shared by x&y (or whichever one is used in x/y if the other is + // literal) + uint64_t extra = (newconstant == kLiteral) ? 0 : (1 | (0x0101010101010101ull & newconstant)); + + Cost result; + result.model = parallelAddSat(newmodel, extra); + result.constant = newconstant; + + return result; + } +}; + +struct CostVisitor : AstVisitor +{ + DenseHashMap vars; + Cost result; + + CostVisitor() + : vars(nullptr) + { + } + + Cost model(AstExpr* node) + { + if (AstExprGroup* expr = node->as()) + { + return model(expr->expr); + } + else if (node->is() || node->is() || node->is() || + node->is()) + { + return Cost(0, Cost::kLiteral); + } + else if (AstExprLocal* expr = node->as()) + { + const uint64_t* i = vars.find(expr->local); + + return Cost(0, i ? *i : 0); // locals typically don't require extra instructions to compute + } + else if (node->is()) + { + return 1; + } + else if (node->is()) + { + return 3; + } + else if (AstExprCall* expr = node->as()) + { + Cost cost = 3; + cost += model(expr->func); + + for (size_t i = 0; i < expr->args.size; ++i) + { + Cost ac = model(expr->args.data[i]); + // for constants/locals we still need to copy them to the argument list + cost += ac.model == 0 ? Cost(1) : ac; + } + + return cost; + } + else if (AstExprIndexName* expr = node->as()) + { + return model(expr->expr) + 1; + } + else if (AstExprIndexExpr* expr = node->as()) + { + return model(expr->expr) + model(expr->index) + 1; + } + else if (AstExprFunction* expr = node->as()) + { + return 10; // high baseline cost due to allocation + } + else if (AstExprTable* expr = node->as()) + { + Cost cost = 10; // high baseline cost due to allocation + + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + + if (item.key) + cost += model(item.key); + + cost += model(item.value); + cost += 1; + } + + return cost; + } + else if (AstExprUnary* expr = node->as()) + { + return Cost::fold(model(expr->expr), Cost(0, Cost::kLiteral)); + } + else if (AstExprBinary* expr = node->as()) + { + return Cost::fold(model(expr->left), model(expr->right)); + } + else if (AstExprTypeAssertion* expr = node->as()) + { + return model(expr->expr); + } + else if (AstExprIfElse* expr = node->as()) + { + return model(expr->condition) + model(expr->trueExpr) + model(expr->falseExpr) + 2; + } + else + { + LUAU_ASSERT(!"Unknown expression type"); + return {}; + } + } + + void assign(AstExpr* expr) + { + // variable assignments reset variable mask, so that further uses of this variable aren't discounted + // this doesn't work perfectly with backwards control flow like loops, but is good enough for a single pass + if (AstExprLocal* lv = expr->as()) + if (uint64_t* i = vars.find(lv->local)) + *i = 0; + } + + bool visit(AstExpr* node) override + { + // note: we short-circuit the visitor traversal through any expression trees by returning false + // recursive traversal is happening inside model() which makes it easier to get the resulting value of the subexpression + result += model(node); + + return false; + } + + bool visit(AstStat* node) override + { + if (node->is()) + result += 2; + else if (node->is() || node->is() || node->is() || node->is()) + result += 2; + else if (node->is() || node->is()) + result += 1; + + return true; + } + + bool visit(AstStatLocal* node) override + { + for (size_t i = 0; i < node->values.size; ++i) + { + Cost arg = model(node->values.data[i]); + + // propagate constant mask from expression through variables + if (arg.constant && i < node->vars.size) + vars[node->vars.data[i]] = arg.constant; + + result += arg; + } + + return false; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + return true; + } + + bool visit(AstStatCompoundAssign* node) override + { + assign(node->var); + + // if lhs is not a local, setting it requires an extra table operation + result += node->var->is() ? 1 : 2; + + return true; + } +}; + +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount) +{ + CostVisitor visitor; + for (size_t i = 0; i < varCount && i < 7; ++i) + visitor.vars[vars[i]] = 0xffull << (i * 8 + 8); + + root->visit(&visitor); + + return visitor.result.model; +} + +int computeCost(uint64_t model, const bool* varsConst, size_t varCount) +{ + int cost = int(model & 0x7f); + + // don't apply discounts to what is likely a saturated sum + if (cost == 0x7f) + return cost; + + for (size_t i = 0; i < varCount && i < 7; ++i) + cost -= int((model >> (8 * i + 8)) & 0x7f) * varsConst[i]; + + return cost; +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/CostModel.h b/Compiler/src/CostModel.h new file mode 100644 index 00000000..c27861ec --- /dev/null +++ b/Compiler/src/CostModel.h @@ -0,0 +1,18 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" + +namespace Luau +{ +namespace Compile +{ + +// cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); + +// cost is computed as B - sum(Di * Ci), where B is baseline cost, Di is the discount for each variable and Ci is 1 when variable #i is constant +int computeCost(uint64_t model, const bool* varsConst, size_t varCount); + +} // namespace Compile +} // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index 6f110f1f..60e5dfda 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -32,11 +32,13 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/Compiler.cpp Compiler/src/Builtins.cpp Compiler/src/ConstantFolding.cpp + Compiler/src/CostModel.cpp Compiler/src/TableShape.cpp Compiler/src/ValueTracking.cpp Compiler/src/lcode.cpp Compiler/src/Builtins.h Compiler/src/ConstantFolding.h + Compiler/src/CostModel.h Compiler/src/TableShape.h Compiler/src/ValueTracking.h ) @@ -58,6 +60,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/LValue.h Analysis/include/Luau/Module.h Analysis/include/Luau/ModuleResolver.h + Analysis/include/Luau/Normalize.h Analysis/include/Luau/Predicate.h Analysis/include/Luau/Quantify.h Analysis/include/Luau/RecursionCounter.h @@ -94,6 +97,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Linter.cpp Analysis/src/LValue.cpp Analysis/src/Module.cpp + Analysis/src/Normalize.cpp Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp @@ -216,6 +220,7 @@ if(TARGET Luau.UnitTest) tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp + tests/CostModel.test.cpp tests/Config.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp @@ -224,6 +229,7 @@ if(TARGET Luau.UnitTest) tests/LValue.test.cpp tests/Module.test.cpp tests/NonstrictMode.test.cpp + tests/Normalize.test.cpp tests/Parser.test.cpp tests/RequireTracer.test.cpp tests/StringUtils.test.cpp diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 1c75c0b0..dc40b6ef 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -34,7 +34,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false) -LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary, false) +LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary2, false) // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 @@ -390,6 +390,8 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) setarrayvector(L, t, nasize); /* create new hash part with appropriate size */ setnodevector(L, t, nhsize); + /* used for the migration check at the end */ + LuaNode* nnew = t->node; if (nasize < oldasize) { /* array part must shrink? */ t->sizearray = nasize; @@ -413,6 +415,8 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) /* shrink array */ luaM_reallocarray(L, t->array, oldasize, nasize, TValue, t->memcat); } + /* used for the migration check at the end */ + TValue* anew = t->array; /* re-insert elements from hash part */ if (FFlag::LuauTableRehashRework) { @@ -441,14 +445,30 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) } } + /* make sure we haven't recursively rehashed during element migration */ + LUAU_ASSERT(nnew == t->node); + LUAU_ASSERT(anew == t->array); + if (nold != dummynode) luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); /* free old array */ } +static int adjustasize(Table* t, int size, const TValue* ek) +{ + LUAU_ASSERT(FFlag::LuauTableNewBoundary2); + bool tbound = t->node != dummynode || size < t->sizearray; + int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; + /* move the array size up until the boundary is guaranteed to be inside the array part */ + while (size + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, size + 1)))) + size++; + return size; +} + void luaH_resizearray(lua_State* L, Table* t, int nasize) { int nsize = (t->node == dummynode) ? 0 : sizenode(t); - resize(L, t, nasize, nsize); + int asize = FFlag::LuauTableNewBoundary2 ? adjustasize(t, nasize, NULL) : nasize; + resize(L, t, asize, nsize); } void luaH_resizehash(lua_State* L, Table* t, int nhsize) @@ -470,21 +490,12 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) totaluse++; /* compute new size for array part */ int na = computesizes(nums, &nasize); + int nh = totaluse - na; /* enforce the boundary invariant; for performance, only do hash lookups if we must */ - if (FFlag::LuauTableNewBoundary) - { - bool tbound = t->node != dummynode || nasize < t->sizearray; - int ekindex = ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; - /* move the array size up until the boundary is guaranteed to be inside the array part */ - while (nasize + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, nasize + 1)))) - { - nasize++; - na++; - } - } + if (FFlag::LuauTableNewBoundary2) + nasize = adjustasize(t, nasize, ek); /* resize the table to new computed sizes */ - LUAU_ASSERT(na <= totaluse); - resize(L, t, nasize, totaluse - na); + resize(L, t, nasize, nh); } /* @@ -544,7 +555,7 @@ static LuaNode* getfreepos(Table* t) static TValue* newkey(lua_State* L, Table* t, const TValue* key) { /* enforce boundary invariant */ - if (FFlag::LuauTableNewBoundary && ttisnumber(key) && nvalue(key) == t->sizearray + 1) + if (FFlag::LuauTableNewBoundary2 && ttisnumber(key) && nvalue(key) == t->sizearray + 1) { rehash(L, t, key); /* grow table */ @@ -735,7 +746,7 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key) static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j) { - LUAU_ASSERT(!FFlag::LuauTableNewBoundary); + LUAU_ASSERT(!FFlag::LuauTableNewBoundary2); unsigned int i = j; /* i is zero or a present index */ j++; /* find `i' and `j' such that i is present and j is not */ @@ -820,7 +831,7 @@ int luaH_getn(Table* t) maybesetaboundary(t, boundary); return boundary; } - else if (FFlag::LuauTableNewBoundary) + else if (FFlag::LuauTableNewBoundary2) { /* validate boundary invariant */ LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1))); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 41887f4b..9c1f387e 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -199,7 +199,7 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); - void (*telemetrycb)(lua_State* L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; + void (*telemetrycb)(lua_State * L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb && e >= f) { diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 34949efb..39c60eac 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,7 +16,7 @@ #include -LUAU_FASTFLAG(LuauTableNewBoundary) +LUAU_FASTFLAG(LuauTableNewBoundary2) // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -2268,7 +2268,7 @@ static void luau_execute(lua_State* L) VM_NEXT(); } } - else if (FFlag::LuauTableNewBoundary || (h->lsizenode == 0 && ttisnil(gval(h->node)))) + else if (FFlag::LuauTableNewBoundary2 || (h->lsizenode == 0 && ttisnil(gval(h->node)))) { // fallthrough to exit VM_NEXT(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp new file mode 100644 index 00000000..ec04932f --- /dev/null +++ b/tests/CostModel.test.cpp @@ -0,0 +1,101 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" + +#include "doctest.h" + +using namespace Luau; + +namespace Luau +{ +namespace Compile +{ + +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); +int computeCost(uint64_t model, const bool* varsConst, size_t varCount); + +} // namespace Compile +} // namespace Luau + +TEST_SUITE_BEGIN("CostModel"); + +static uint64_t modelFunction(const char* source) +{ + Allocator allocator; + AstNameTable names(allocator); + + ParseResult result = Parser::parse(source, strlen(source), names, allocator); + REQUIRE(result.root != nullptr); + + AstStatFunction* func = result.root->body.data[0]->as(); + REQUIRE(func); + + return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size); +} + +TEST_CASE("Expression") +{ + uint64_t model = modelFunction(R"( +function test(a, b, c) + return a + (b + 1) * (b + 1) - c +end +)"); + + const bool args1[] = {false, false, false}; + const bool args2[] = {false, true, false}; + + CHECK_EQ(5, Luau::Compile::computeCost(model, args1, 3)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 3)); +} + +TEST_CASE("PropagateVariable") +{ + uint64_t model = modelFunction(R"( +function test(a) + local b = a * a * a + return b * b +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(0, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("LoopAssign") +{ + uint64_t model = modelFunction(R"( +function test(a) + for i=1,3 do + a[i] = i + end +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + // loop baseline cost is 2 + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("MutableVariable") +{ + uint64_t model = modelFunction(R"( +function test(a, b) + local x = a * a + x += b + return x * x +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 9dc9feee..d8b37a65 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -231,7 +231,7 @@ ModulePtr Fixture::getMainModule() SourceModule* Fixture::getMainSourceModule() { - return frontend.getSourceModule(fromString("MainModule")); + return frontend.getSourceModule(fromString(mainModuleName)); } std::optional Fixture::getPrimitiveType(TypeId ty) @@ -259,7 +259,7 @@ std::optional Fixture::getType(const std::string& name) TypeId Fixture::requireType(const std::string& name) { std::optional ty = getType(name); - REQUIRE(bool(ty)); + REQUIRE_MESSAGE(bool(ty), "Unable to requireType \"" << name << "\""); return follow(*ty); } diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp index 6711d979..1d2ad645 100644 --- a/tests/JsonEncoder.test.cpp +++ b/tests/JsonEncoder.test.cpp @@ -68,7 +68,9 @@ TEST_CASE("encode_tables") REQUIRE(parseResult.errors.size() == 0); std::string json = toJson(parseResult.root); - CHECK(json == R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); + CHECK( + json == + R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); } TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 9ce9a4c2..05ee9a7b 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -597,8 +597,6 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { - ScopedFastFlag sff("LuauLintNoRobloxBits", true); - unfreeze(typeChecker.globalTypes); TableTypeVar::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, @@ -1439,6 +1437,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") { unfreeze(typeChecker.globalTypes); TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}}); + persist(instanceType); typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; getMutable(instanceType)->props = { diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index de063121..738893db 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -2,6 +2,7 @@ #include "Luau/Clone.h" #include "Luau/Module.h" #include "Luau/Scope.h" +#include "Luau/RecursionCounter.h" #include "Fixture.h" @@ -9,6 +10,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("ModuleTests"); TEST_CASE_FIXTURE(Fixture, "is_within_comment") @@ -42,29 +45,23 @@ TEST_CASE_FIXTURE(Fixture, "is_within_comment") TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks, cloneState); + TypeId newNumber = clone(typeChecker.numberType, dest, cloneState); CHECK_EQ(newNumber, typeChecker.numberType); } TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; // Create a new number type that isn't persistent unfreeze(typeChecker.globalTypes); TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number}); freeze(typeChecker.globalTypes); - TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks, cloneState); + TypeId newNumber = clone(oldNumber, dest, cloneState); CHECK_NE(newNumber, oldNumber); CHECK_EQ(*oldNumber, *newNumber); @@ -90,12 +87,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TypeId counterType = requireType("Cyclic"); - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - TypeArena dest; - TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks, cloneState); + CloneState cloneState; + TypeId counterCopy = clone(counterType, dest, cloneState); TableTypeVar* ttv = getMutable(counterCopy); REQUIRE(ttv != nullptr); @@ -112,8 +106,11 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") REQUIRE(methodReturnType); CHECK_EQ(methodReturnType, counterCopy); - CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type - CHECK_EQ(2, dest.typeVars.size()); // One table and one function + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(3, dest.typePacks.size()); // function args, its return type, and the hidden any... pack + else + CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type + CHECK_EQ(2, dest.typeVars.size()); // One table and one function } TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") @@ -143,15 +140,12 @@ TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") TEST_CASE_FIXTURE(Fixture, "deepClone_union") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks, cloneState); + TypeId newUnion = clone(oldUnion, dest, cloneState); CHECK_NE(newUnion, oldUnion); CHECK_EQ("number | string", toString(newUnion)); @@ -161,15 +155,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks, cloneState); + TypeId newIntersection = clone(oldIntersection, dest, cloneState); CHECK_NE(newIntersection, oldIntersection); CHECK_EQ("number & string", toString(newIntersection)); @@ -191,12 +182,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") std::nullopt, &exampleMetaClass, {}, {}}}; TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks, cloneState); + TypeId cloned = clone(&exampleClass, dest, cloneState); const ClassTypeVar* ctv = get(cloned); REQUIRE(ctv != nullptr); @@ -216,16 +204,14 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") TypePackVar freeTp(FreeTypePack{TypeLevel{}}); TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, cloneState); + TypeId clonedTy = clone(&freeTy, dest, cloneState); CHECK_EQ("any", toString(clonedTy)); CHECK(cloneState.encounteredFreeType); cloneState = {}; - TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, cloneState); + TypePackId clonedTp = clone(&freeTp, dest, cloneState); CHECK_EQ("...any", toString(clonedTp)); CHECK(cloneState.encounteredFreeType); } @@ -237,16 +223,32 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") ttv->state = TableState::Free; TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, cloneState); + TypeId cloned = clone(&tableTy, dest, cloneState); const TableTypeVar* clonedTtv = get(cloned); CHECK_EQ(clonedTtv->state, TableState::Sealed); CHECK(cloneState.encounteredFreeType); } +TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") +{ + TypeArena src; + + TypeId constrained = src.addType(ConstrainedTypeVar{TypeLevel{}, {getSingletonTypes().numberType, getSingletonTypes().stringType}}); + + TypeArena dest; + CloneState cloneState; + + TypeId cloned = clone(constrained, dest, cloneState); + CHECK_NE(constrained, cloned); + + const ConstrainedTypeVar* ctv = get(cloned); + REQUIRE_EQ(2, ctv->parts.size()); + CHECK_EQ(getSingletonTypes().numberType, ctv->parts[0]); + CHECK_EQ(getSingletonTypes().stringType, ctv->parts[1]); +} + TEST_CASE_FIXTURE(Fixture, "clone_self_property") { ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; @@ -284,6 +286,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; + ScopedFastFlag sff{"LuauRecursionLimitException", true}; TypeArena src; @@ -299,11 +302,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") } TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - CHECK_THROWS_AS(clone(table, dest, seenTypes, seenTypePacks, cloneState), std::runtime_error); + CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); } TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index d3faea2a..a8a12b69 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -275,4 +275,38 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType)); } +TEST_CASE_FIXTURE(Fixture, "returning_insufficient_return_values") +{ + CheckResult result = check(R"( + --!nonstrict + + function foo(): (boolean, string?) + if true then + return true, "hello" + else + return false + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "returning_too_many_values") +{ + CheckResult result = check(R"( + --!nonstrict + + function foo(): boolean + if true then + return true, "hello" + else + return false + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp new file mode 100644 index 00000000..5a84201a --- /dev/null +++ b/tests/Normalize.test.cpp @@ -0,0 +1,967 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +#include "Luau/Normalize.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +struct NormalizeFixture : Fixture +{ + ScopedFastFlag sff1{"LuauLowerBoundsCalculation", true}; + ScopedFastFlag sff2{"LuauTableSubtypingVariance2", true}; +}; + +void createSomeClasses(TypeChecker& typeChecker) +{ + auto& arena = typeChecker.globalTypes; + + unfreeze(arena); + + TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr}); + + ClassTypeVar* parentClass = getMutable(parentType); + parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; + + parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; + + addGlobalBinding(typeChecker, "Parent", {parentType}); + typeChecker.globalScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; + + TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr}); + + ClassTypeVar* childClass = getMutable(childType); + childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; + + addGlobalBinding(typeChecker, "Child", {childType}); + typeChecker.globalScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; + + TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr}); + + addGlobalBinding(typeChecker, "Unrelated", {unrelatedType}); + typeChecker.globalScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; + + freeze(arena); +} + +static bool isSubtype(TypeId a, TypeId b) +{ + InternalErrorReporter ice; + return isSubtype(a, b, ice); +} + +TEST_SUITE_BEGIN("isSubtype"); + +TEST_CASE_FIXTURE(NormalizeFixture, "primitives") +{ + check(R"( + local a = 41 + local b = 32 + + local c = "hello" + local d = "world" + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(d, c)); + CHECK(!isSubtype(d, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions") +{ + check(R"( + function a(x: number): number return x end + function b(x: number): number return x end + + function c(x: number?): number return x end + function d(x: number): number? return x end + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(c, a)); + CHECK(!isSubtype(d, a)); + CHECK(isSubtype(a, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_and_any") +{ + check(R"( + function a(n: number) return "string" end + function b(q: any) return 5 :: any end + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + // Intuition: + // We cannot use b where a is required because we cannot rely on b to return a string. + // We cannot use a where b is required because we cannot rely on a to accept non-number arguments. + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(a, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arities") +{ + check(R"( + type A = (any) -> () + type B = (any, any) -> () + type T = A & B + + local a: A + local b: B + local t: T + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(a, b)); // !! + CHECK(!isSubtype(b, a)); + + CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") +{ + check(R"( + local a: (number) -> () + local b: () -> () + + local c: () -> number + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(c, a)); + + CHECK(!isSubtype(a, b)); + CHECK(!isSubtype(c, b)); + + CHECK(!isSubtype(a, c)); + CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_optional_parameters") +{ + /* + * (T0..TN) <: (T0..TN, A?) + * (T0..TN) <: (T0..TN, any) + * (T0..TN, A?) R <: U -> S if U <: T and R <: S + * A | B <: T if A <: T and B <: T + * T <: A | B if T <: A or T <: B + */ + check(R"( + local a: (number?) -> () + local b: (number) -> () + local c: (number, number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, number?) -> () <: (number) -> (number) + * The packs have inequal lengths, but (number) <: (number, number?) + * and number <: number + */ + CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * because (number, number?) () () + * because (number, number?) () + local b: (number) -> () + local c: (number, any) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, any) -> () (number) + * The packs have inequal lengths + */ + CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * The packs have inequal lengths + */ + CHECK(!isSubtype(a, c)); + + /* + * (number) -> () () + * The packs have inequal lengths + */ + CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "variadic_functions_with_no_head") +{ + check(R"( + local a: (...number) -> () + local b: (...number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "variadic_function_with_head") +{ + check(R"( + local a: (...number) -> () + local b: (number, number) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "union") +{ + check(R"( + local a: number | string + local b: number + local c: string + local d: number? + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(!isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(b, d)); + CHECK(!isSubtype(d, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_union_prop") +{ + check(R"( + local a: {x: number} + local b: {x: number?} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop") +{ + check(R"( + local a: {x: number} + local b: {x: any} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection") +{ + check(R"( + local a: number & string + local b: number + local c: string + local d: number & nil + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); + + CHECK(!isSubtype(c, a)); + CHECK(isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection") +{ + check(R"( + local a: number & string + local b: number | nil + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_table_prop") +{ + check(R"( + type T = {x: {y: number}} & {x: {y: string}} + local a: T + )"); + + CHECK_EQ("{| x: {| y: number & string |} |}", toString(requireType("a"))); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "tables") +{ + check(R"( + local a: {x: number} + local b: {x: any} + local c: {y: number} + local d: {x: number, y: number} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); + + CHECK(!isSubtype(c, a)); + CHECK(!isSubtype(a, c)); + + CHECK(isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(d, b)); + CHECK(!isSubtype(b, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant") +{ + check(R"( + local a: {[string]: number} + local b: {[string]: any} + local c: {[string]: number} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(isSubtype(a, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "mismatched_indexers") +{ + check(R"( + local a: {x: number} + local b: {[string]: number} + local c: {} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(!isSubtype(c, b)); + CHECK(isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_table") +{ + check(R"( + type A = {method: (A) -> ()} + local a: A + + type B = {method: (any) -> ()} + local b: B + + type C = {method: (C) -> ()} + local c: C + + type D = {method: (D) -> (), another: (D) -> ()} + local d: D + + type E = {method: (A) -> (), another: (E) -> ()} + local e: E + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + TypeId e = requireType("e"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(e, a)); + CHECK(!isSubtype(a, e)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "classes") +{ + createSomeClasses(typeChecker); + + TypeId p = typeChecker.globalScope->lookupType("Parent")->type; + TypeId c = typeChecker.globalScope->lookupType("Child")->type; + TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type; + + CHECK(isSubtype(c, p)); + CHECK(!isSubtype(p, c)); + CHECK(!isSubtype(u, p)); + CHECK(!isSubtype(p, u)); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "metatable" * doctest::expected_failures{1}) +{ + check(R"( + local T = {} + T.__index = T + function T.new() + return setmetatable({}, T) + end + + function T:method() end + + local a: typeof(T.new) + local b: {method: (any) -> ()} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_tables") +{ + check(R"( + type T = {x: number} & ({x: number} & {y: string?}) + local t: T + )"); + + CHECK("{| x: number, y: string? |}" == toString(requireType("t"))); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("Normalize"); + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_disjoint_tables") +{ + check(R"( + type T = {a: number} & {b: number} + local t: T + )"); + + CHECK_EQ("{| a: number, b: number |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_overlapping_tables") +{ + check(R"( + type T = {a: number, b: string} & {b: number, c: string} + local t: T + )"); + + CHECK_EQ("{| a: number, b: number & string, c: string |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_confluent_overlapping_tables") +{ + check(R"( + type T = {a: number, b: string} & {b: string, c: string} + local t: T + )"); + + CHECK_EQ("{| a: number, b: string, c: string |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_subtype_relationship") +{ + check(R"( + local t: {x: number} | {x: number?} + )"); + + ModulePtr tempModule{new Module}; + + // HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze + // the arena that the type lives in. + ModulePtr mainModule = getMainModule(); + unfreeze(mainModule->internalTypes); + + TypeId tType = requireType("t"); + normalize(tType, tempModule, *typeChecker.iceHandler); + + CHECK_EQ("{| x: number? |}", toString(tType, {true})); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions") +{ + check(R"( + type T = ((any) -> string) & ((number) -> string) + local t: T + )"); + + CHECK_EQ("(any) -> string", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + check(R"( + --!nonstrict + + if Math.random() then + return function(initialState, handlers) + return function(state, action) + return state + end + end + else + return function(initialState, handlers) + return function(state, action) + return state + end + end + end + )"); + + CHECK_EQ("(any, any) -> (...any)", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_is_not_a_constrained_intersection") +{ + check(R"( + function foo(x:number, y:number) + return x + y + end + )"); + + CHECK_EQ("(number, number) -> number", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function") +{ + check(R"( + function apply(f, x) + return f(x) + end + + local a = apply(function(x: number) return x + x end, 5) + )"); + + TypeId aType = requireType("a"); + CHECK_MESSAGE(isNumber(follow(aType)), "Expected a number but got ", toString(aType)); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_with_annotation") +{ + check(R"( + function apply(f: (a) -> b, x) + return f(x) + end + )"); + + CHECK_EQ("((a) -> b, a) -> b", toString(requireType("apply"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_marked_normal") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + check(R"( + type Fiber = { + return_: Fiber? + } + + local f: Fiber + )"); + + TypeId t = requireType("f"); + CHECK(t->normal); +} + +TEST_CASE_FIXTURE(Fixture, "variadic_tail_is_marked_normal") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + type Weirdo = (...{x: number}) -> () + + local w: Weirdo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("w"); + auto ftv = get(t); + REQUIRE(ftv); + + auto [argHead, argTail] = flatten(ftv->argTypes); + CHECK(argHead.empty()); + REQUIRE(argTail.has_value()); + + auto vtp = get(*argTail); + REQUIRE(vtp); + CHECK(vtp->ty->normal); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly") +{ + CheckResult result = check(R"( + local Cyclic = {} + function Cyclic.get() + return Cyclic + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = requireType("Cyclic"); + CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true})); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_distinct_free_types") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function fussy(a, b) + if math.random() > 0.5 then + return a + else + return b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(a, b) -> a | b" == toString(requireType("fussy"))); +} + +TEST_CASE_FIXTURE(Fixture, "constrained_intersection_of_intersections") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + local f : (() -> number) | ((number) -> number) + local g : (() -> number) | ((string) -> number) + + function h() + if math.random() then + return f + else + return g + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId h = requireType("h"); + + CHECK("() -> (() -> number) | ((number) -> number) | ((string) -> number)" == toString(h)); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + type X = {} + type Y = {y: number} + type Z = {z: string} + type W = {w: boolean} + type T = {x: Y & X} & {x:Z & W} + + local x: X + local y: Y + local z: Z + local w: W + local t: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("{| |}" == toString(requireType("x"), {true})); + CHECK("{| y: number |}" == toString(requireType("y"), {true})); + CHECK("{| z: string |}" == toString(requireType("z"), {true})); + CHECK("{| w: boolean |}" == toString(requireType("w"), {true})); + CHECK("{| x: {| w: boolean, y: number, z: string |} |}" == toString(requireType("t"), {true})); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_2") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(w, x, y, z) + y.y = 5 + z.z = "five" + w.w = true + + type Z = {x: typeof(x) & typeof(y)} & {x: typeof(w) & typeof(z)} + + return ((nil :: any) :: Z) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(4 == args.size()); + CHECK("{+ w: boolean +}" == toString(args[0])); + CHECK("a" == toString(args[1])); + CHECK("{+ y: number +}" == toString(args[2])); + CHECK("{+ z: string +}" == toString(args[3])); + + std::vector ret = flatten(ftv->retType).first; + + REQUIRE(1 == ret.size()); + CHECK("{| x: a & {- w: boolean, y: number, z: string -} |}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_3") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y, z) + x.x = true + y.y = y + z.z = "five" + + type Z = {x: typeof(y)} & {x: typeof(x) & typeof(z)} + + return ((nil :: any) :: Z) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(3 == args.size()); + CHECK("{+ x: boolean +}" == toString(args[0])); + CHECK("t1 where t1 = {+ y: t1 +}" == toString(args[1])); + CHECK("{+ z: string +}" == toString(args[2])); + + std::vector ret = flatten(ftv->retType).first; + + REQUIRE(1 == ret.size()); + CHECK("{| x: {- x: boolean, y: t1, z: string -} |} where t1 = {+ y: t1 +}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_4") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y, z) + x.x = true + z.z = "five" + + type R = {x: typeof(y)} & {x: typeof(x) & typeof(z)} + local r: R + + y.y = r + + return r + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(3 == args.size()); + CHECK("{+ x: boolean +}" == toString(args[0])); + CHECK("{+ y: t1 +} where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(args[1])); + CHECK("{+ z: string +}" == toString(args[2])); + + std::vector ret = flatten(ftv->retType).first; + + REQUIRE(1 == ret.size()); + CHECK("t1 where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeCombineTableFix", true}, + }; + // CLI-52787 + // ends up combining {_:any} with any, recursively + // which used to ICE because this combines a table with a non-table. + CheckResult result = check(R"( + export type t0 = any & { _: {_:any} } & { _:any } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeCombineIntersectionFix", true}, + }; + + CheckResult result = check(R"( + export type t0 = {_:{_:any} & {_:any|string}} & {_:{_:{}}} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 79f9ecab..b941103d 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1618,6 +1618,26 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); } +TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture") +{ + ScopedFastFlag luauParseLocationIgnoreCommentSkip{"LuauParseLocationIgnoreCommentSkip", true}; + ScopedFastFlag luauParseLocationIgnoreCommentSkipInCapture{"LuauParseLocationIgnoreCommentSkipInCapture", true}; + + // Same should hold when comments are captured + ParseOptions opts; + opts.captureComments = true; + + AstStatBlock* block = parse(R"( + type F = number + --comment + print('hello') + )", + opts); + + REQUIRE_EQ(2, block->body.size); + CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); +} + TEST_CASE_FIXTURE(Fixture, "parse_error_loop_control") { matchParseError("break", "break statement must be inside a loop"); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 29bdd866..f3fda54e 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauLowerBoundsCalculation) + using namespace Luau; struct ToDotClassFixture : Fixture @@ -101,9 +103,34 @@ local function f(a, ...: string) return a end )"); LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(a, ...string) -> a", toString(requireType("f"))); + ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + + if (FFlag::LuauLowerBoundsCalculation) + { + CHECK_EQ(R"(digraph graphname { +n1 [label="FunctionTypeVar 1"]; +n1 -> n2 [label="arg"]; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="GenericTypeVar 3"]; +n2 -> n4 [label="tail"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n1 -> n6 [label="ret"]; +n6 [label="TypePack 6"]; +n6 -> n7; +n7 [label="BoundTypeVar 7"]; +n7 -> n3; +})", + toDot(requireType("f"), opts)); + } + else + { + CHECK_EQ(R"(digraph graphname { n1 [label="FunctionTypeVar 1"]; n1 -> n2 [label="arg"]; n2 [label="TypePack 2"]; @@ -119,7 +146,8 @@ n6 -> n7; n7 [label="TypePack 7"]; n7 -> n3; })", - toDot(requireType("f"), opts)); + toDot(requireType("f"), opts)); + } } TEST_CASE_FIXTURE(Fixture, "union") @@ -361,4 +389,49 @@ n3 [label="number"]; toDot(*ty, opts)); } +TEST_CASE_FIXTURE(Fixture, "constrained") +{ + // ConstrainedTypeVars never appear in the final type graph, so we have to create one directly + // to dotify it. + TypeVar t{ConstrainedTypeVar{TypeLevel{}, {typeChecker.numberType, typeChecker.stringType, typeChecker.nilType}}}; + + ToDotOptions opts; + opts.showPointers = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="ConstrainedTypeVar 1"]; +n1 -> n2; +n2 [label="number"]; +n1 -> n3; +n3 [label="string"]; +n1 -> n4; +n4 [label="nil"]; +})", + toDot(&t, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "singletontypes") +{ + CheckResult result = check(R"( + local x: "hi" | "\"hello\"" | true | false + )"); + + ToDotOptions opts; + opts.showPointers = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="UnionTypeVar 1"]; +n1 -> n2; +n2 [label="SingletonTypeVar string: hi"]; +n1 -> n3; +)" +"n3 [label=\"SingletonTypeVar string: \\\"hello\\\"\"];" +R"( +n1 -> n4; +n4 [label="SingletonTypeVar boolean: true"]; +n1 -> n5; +n5 [label="SingletonTypeVar boolean: false"]; +})", toDot(requireType("x"), opts)); +} + TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 3051e209..ccf5c583 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); + TEST_SUITE_BEGIN("ToString"); TEST_CASE_FIXTURE(Fixture, "primitive") diff --git a/tests/TopoSort.test.cpp b/tests/TopoSort.test.cpp index 9b990866..1f14ae88 100644 --- a/tests/TopoSort.test.cpp +++ b/tests/TopoSort.test.cpp @@ -340,26 +340,28 @@ TEST_CASE_FIXTURE(Fixture, "nested_type_annotations_depends_on_later_typealiases TEST_CASE_FIXTURE(Fixture, "return_comes_last") { - CheckResult result = check(R"( -export type Module = { bar: (number) -> boolean, foo: () -> string } + AstStatBlock* program = parse(R"( + local module = {} -return function() : Module - local module = {} + local function confuseCompiler() return module.foo() end - local function confuseCompiler() return module.foo() end - - module.foo = function() return "" end + module.foo = function() return "" end - function module.bar(x:number) - confuseCompiler() - return true - end - - return module -end + function module.bar(x:number) + confuseCompiler() + return true + end + + return module )"); - LUAU_REQUIRE_NO_ERRORS(result); + auto sorted = toposort(*program); + + CHECK_EQ(sorted[0], program->body.data[0]); + CHECK_EQ(sorted[2], program->body.data[1]); + CHECK_EQ(sorted[1], program->body.data[2]); + CHECK_EQ(sorted[3], program->body.data[3]); + CHECK_EQ(sorted[4], program->body.data[4]); } TEST_CASE_FIXTURE(Fixture, "break_comes_last") diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 5ac45ff2..0c324cd0 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -388,7 +388,7 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") std::string actual = decorateWithTypes(code); - CHECK_EQ(expected, decorateWithTypes(code)); + CHECK_EQ(expected, actual); } TEST_CASE_FIXTURE(Fixture, "function_type_location") diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2ad11d01..e2971ad5 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -753,4 +753,14 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_intersection_typevar") REQUIRE(ocf); } +TEST_CASE_FIXTURE(Fixture, "instantiation_clone_has_to_follow") +{ + CheckResult result = check(R"( + export type t8 = (t0)&(((true)|(any))->"") + export type t0 = ({})&({_:{[any]:number},}) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index c6fbebed..1ae65947 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("BuiltinTests"); TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") @@ -557,9 +559,9 @@ TEST_CASE_FIXTURE(Fixture, "xpcall") )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("boolean", toString(requireType("a"))); - REQUIRE_EQ("number", toString(requireType("b"))); - REQUIRE_EQ("boolean", toString(requireType("c"))); + CHECK_EQ("boolean", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("boolean", toString(requireType("c"))); } TEST_CASE_FIXTURE(Fixture, "see_thru_select") @@ -881,7 +883,10 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); + else + CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 98fa66eb..8e3629e7 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -91,6 +91,9 @@ struct ClassFixture : Fixture typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; addGlobalBinding(typeChecker, "Vector2", vector2Type, "@test"); + for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) + persist(tf.type); + freeze(arena); } }; diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 1713216a..65993681 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TypeInferFunctions"); TEST_CASE_FIXTURE(Fixture, "tc_function") @@ -98,7 +100,7 @@ TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") end return result - end + end return T )"); @@ -274,6 +276,10 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( function f(g) return f(f) @@ -281,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); + CHECK_EQ("t1 where t1 = (t1) -> (a...)", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") @@ -481,10 +487,10 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") std::vector fArgs = flatten(fType->argTypes).first; - TypeId xType = argVec[1]; + TypeId xType = follow(argVec[1]); CHECK_EQ(1, fArgs.size()); - CHECK_EQ(xType, fArgs[0]); + CHECK_EQ(xType, follow(fArgs[0])); } TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") @@ -1043,13 +1049,16 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); + if (!FFlag::LuauLowerBoundsCalculation) + { + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") @@ -1142,13 +1151,16 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); + if (!FFlag::LuauLowerBoundsCalculation) + { + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") @@ -1338,6 +1350,126 @@ end CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); } +TEST_CASE_FIXTURE(Fixture, "inconsistent_return_types") +{ + const ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(a: boolean, b: number) + if a then + return nil + else + return b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(boolean, number) -> number?", toString(requireType("foo"))); + + // TODO: Test multiple returns + // Think of various cases where typepacks need to grow. maybe consult other tests + // Basic normalization of ConstrainedTypeVars during quantification +} + +TEST_CASE_FIXTURE(Fixture, "inconsistent_higher_order_function") +{ + const ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(f) + f(5) + f("six") + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("((number | string) -> (a...)) -> ()", toString(requireType("foo"))); +} + + +/* The bug here is that we are using the same level 2.0 for both the body of resolveDispatcher and the + * lambda useCallback. + * + * I think what we want to do is, at each scope level, never reuse the same sublevel. + * + * We also adjust checkBlock to consider the syntax `local x = function() ... end` to be sortable + * in the same way as `local function x() ... end`. This causes the function `resolveDispatcher` to be + * checked before the lambda. + */ +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + --!strict + + local function resolveDispatcher() + return (nil :: any) :: {useCallback: (any) -> any} + end + + local useCallback = function(deps: any) + return resolveDispatcher().useCallback(deps) + end + )"); + + // LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken. + // You get a TypeMismatch error where both types stringify the same. + + CHECK(result.errors.empty()); + if (!result.errors.empty()) + { + for (const auto& e : result.errors) + printf("%s: %s\n", toString(e.location).c_str(), toString(e).c_str()); + } +} + +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time2") +{ + CheckResult result = check(R"( + --!strict + + local function resolveDispatcher() + return (nil :: any) :: {useContext: (number?) -> any} + end + + local useContext + useContext = function(unstable_observedBits: number?) + resolveDispatcher().useContext(unstable_observedBits) + end + )"); + + // LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken. + // You get a TypeMismatch error where both types stringify the same. + + CHECK(result.errors.empty()); + if (!result.errors.empty()) + { + for (const auto& e : result.errors) + printf("%s %s: %s\n", e.moduleName.c_str(), toString(e.location).c_str(), toString(e).c_str()); + } +} + +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time3") +{ + CheckResult result = check(R"( + local foo + + foo():bar(function() + return foo() + end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_unsealed_overwrite") { ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; @@ -1471,4 +1603,17 @@ pcall(wrapper, test) CHECK(acm->isVariadic); } +TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type") +{ + CheckResult result = check(R"( + function f() + return 5, f() + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(nullptr != get(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index f360a77c..49d31fc6 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -230,8 +230,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); - CHECK_EQ(args[0], idFun->generics[0]); - CHECK_EQ(rets[0], idFun->generics[0]); + CHECK_EQ(follow(args[0]), follow(idFun->generics[0])); + CHECK_EQ(follow(rets[0]), follow(idFun->generics[0])); } TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") @@ -253,8 +253,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); - CHECK_EQ(args[0], idFun->generics[0]); - CHECK_EQ(rets[0], idFun->generics[0]); + CHECK_EQ(follow(args[0]), follow(idFun->generics[0])); + CHECK_EQ(follow(rets[0]), follow(idFun->generics[0])); } TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") @@ -705,10 +705,10 @@ end TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") { ScopedFastFlag sffs[] = { - { "LuauTableSubtypingVariance2", true }, - { "LuauUnsealedTableLiteral", true }, - { "LuauPropertiesGetExpectedType", true }, - { "LuauRecursiveTypeParameterRestriction", true }, + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + {"LuauPropertiesGetExpectedType", true}, + {"LuauRecursiveTypeParameterRestriction", true}, }; CheckResult result = check(R"( @@ -843,6 +843,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_function") LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(a) -> a", toString(requireType("id"))); CHECK_EQ(*typeChecker.numberType, *requireType("a")); CHECK_EQ(*typeChecker.nilType, *requireType("b")); } @@ -1037,25 +1038,39 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end -return sum(2, 3, function(a, b) return a + b end) + local function sum(x: a, y: a, f: (a, a) -> a) + return f(x, y) + end + return sum(2, 3, function(a, b) return a + b end) )"); LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end -local a = {1, 2, 3} -local r = map(a, function(a) return a + a > 100 end) + local function map(arr: {a}, f: (a) -> b) + local r = {} + for i,v in ipairs(arr) do + table.insert(r, f(v)) + end + return r + end + local a = {1, 2, 3} + local r = map(a, function(a) return a + a > 100 end) )"); LUAU_REQUIRE_NO_ERRORS(result); REQUIRE_EQ("{boolean}", toString(requireType("r"))); check(R"( -local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end -local a = {1, 2, 3} -local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) + local function foldl(arr: {a}, init: b, f: (b, a) -> b) + local r = init + for i,v in ipairs(arr) do + r = f(r, v) + end + return r + end + local a = {1, 2, 3} + local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1065,25 +1080,19 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") { CheckResult result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + local g12: ((T, (T) -> T) -> T) & ((T, T, (T, T) -> T) -> T) -local g12: typeof(g1) & typeof(g2) - -g12(1, function(x) return x + x end) -g12(1, 2, function(x, y) return x + y end) + g12(1, function(x) return x + x end) + g12(1, 2, function(x, y) return x + y end) )"); LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + local g12: ((T, (T) -> T) -> T) & ((T, T, (T, T) -> T) -> T) -local g12: typeof(g1) & typeof(g2) - -g12({x=1}, function(x) return {x=-x.x} end) -g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) + g12({x=1}, function(x) return {x=-x.x} end) + g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1121,12 +1130,12 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { CheckResult result = check(R"( -type A = { x: number } -local a: A = { x = 1 } -local b = a -type B = typeof(b) -type X = T -local c: X + type A = { x: number } + local a: A = { x = 1 } + local b = a + type B = typeof(b) + type X = T + local c: X )"); LUAU_REQUIRE_NO_ERRORS(result); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index ac7a6532..3675919f 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("IntersectionTypes"); TEST_CASE_FIXTURE(Fixture, "select_correct_union_fn") @@ -306,7 +308,10 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table '{| x: number, y: number |}'"); + else + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); } TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") @@ -314,27 +319,34 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; CheckResult result = check(R"( - type X = { x: (number) -> number } - type Y = { y: (string) -> string } + type X = { x: (number) -> number } + type Y = { y: (string) -> string } - type XY = X & Y + type XY = X & Y - local xy : XY = { - x = function(a: number) return -a end, - y = function(a: string) return a .. "b" end - } - function xy.z(a:number) return a * 10 end - function xy:y(a:number) return a * 10 end - function xy:w(a:number) return a * 10 end + local xy : XY = { + x = function(a: number) return -a end, + y = function(a: string) return a .. "b" end + } + function xy.z(a:number) return a * 10 end + function xy:y(a:number) return a * 10 end + function xy:w(a:number) return a * 10 end )"); LUAU_REQUIRE_ERROR_COUNT(4, result); CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' caused by: Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table '{| x: (number) -> number, y: (string) -> string |}'"); + else + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); - CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); + + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table '{| x: (number) -> number, y: (string) -> string |}'"); + else + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); } TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") @@ -375,6 +387,8 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") { + ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}}; + CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -393,6 +407,8 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") { + ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}}; + CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -427,8 +443,8 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") repeat type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) function _(l0):(t0)&(t0) - while nil do - end + while nil do + end end until _(_)(_)._ )"); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 40831bf6..5cd3f3ba 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -199,16 +199,16 @@ end TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") { CheckResult result = check(R"( ---!nonstrict -local f = {} -function f:foo(a: number, b: number) end + --!nonstrict + local f = {} + function f:foo(a: number, b: number) end -function bar(...) - f.foo(f, 1, ...) -end + function bar(...) + f.foo(f, 1, ...) + end -bar(2) -)"); + bar(2) + )"); LUAU_REQUIRE_NO_ERRORS(result); } diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 6a8a9d93..5f2e2404 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -91,7 +91,8 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") const FunctionTypeVar* functionType = get(requireType("add")); std::optional retType = first(functionType->retType); - CHECK_EQ(std::optional(typeChecker.numberType), retType); + REQUIRE(retType.has_value()); + CHECK_EQ(typeChecker.numberType, follow(*retType)); CHECK_EQ(requireType("n"), typeChecker.numberType); CHECK_EQ(requireType("s"), typeChecker.stringType); } diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 2e16b21e..6b3741fa 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; @@ -527,6 +528,7 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_NO_ERRORS(result); } +// FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { ScopedFastFlag sff[]{ @@ -556,10 +558,19 @@ TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type LUAU_REQUIRE_NO_ERRORS(result); - // f and g should have the type () -> () - CHECK_EQ("() -> (a...)", toString(requireType("f"))); - CHECK_EQ("() -> (a...)", toString(requireType("g"))); - CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now + if (FFlag::LuauLowerBoundsCalculation) + { + CHECK_EQ("() -> ()", toString(requireType("f"))); + CHECK_EQ("() -> ()", toString(requireType("g"))); + CHECK_EQ("nil", toString(requireType("x"))); + } + else + { + // f and g should have the type () -> () + CHECK_EQ("() -> (a...)", toString(requireType("f"))); + CHECK_EQ("() -> (a...)", toString(requireType("g"))); + CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now + } } TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") @@ -575,6 +586,10 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", false}, + }; + CheckResult result = check(R"( local function f() return end local g = function() return f() end @@ -585,6 +600,10 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", false}, + }; + CheckResult result = check(R"( --!strict local function f(...) return ... end @@ -594,4 +613,112 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } +TEST_CASE_FIXTURE(Fixture, "lower_bounds_calculation_is_too_permissive_with_overloaded_higher_order_functions") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(f) + f(5, 'a') + f('b', 6) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // We incorrectly infer that the argument to foo could be called with (number, number) or (string, string) + // even though that is strictly more permissive than the actual source text shows. + CHECK("((number | string, number | string) -> (a...)) -> ()" == toString(requireType("foo"))); +} + +// Once fixed, move this to Normalize.test.cpp +TEST_CASE_FIXTURE(Fixture, "normalization_fails_on_certain_kinds_of_cyclic_tables") +{ +#if defined(_DEBUG) || defined(_NOOPT) + ScopedFastInt sfi("LuauNormalizeIterationLimit", 500); +#endif + + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y) + x.x = y + y.x = x + + type R = {x: typeof(x)} & {x: typeof(y)} + local r: R + + return r + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(nullptr != get(result.errors[0])); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(Fixture, "pcall_returns_at_least_two_value_but_function_returns_nothing") +{ + CheckResult result = check(R"( + local function f(): () end + local ok, res = pcall(f) + )"); + + LUAU_REQUIRE_ERRORS(result); + // LUAU_REQUIRE_NO_ERRORS(result); + // CHECK_EQ("boolean", toString(requireType("ok"))); + // CHECK_EQ("any", toString(requireType("res"))); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(Fixture, "choose_the_right_overload_for_pcall") +{ + CheckResult result = check(R"( + local function f(): number + if math.random() > 0.5 then + return 5 + else + error("something") + end + end + + local ok, res = pcall(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireType("ok"))); + CHECK_EQ("number", toString(requireType("res"))); + // CHECK_EQ("any", toString(requireType("res"))); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(Fixture, "function_returns_many_things_but_first_of_it_is_forgotten") +{ + CheckResult result = check(R"( + local function f(): (number, string, boolean) + if math.random() > 0.5 then + return 5, "hello", true + else + error("something") + end + end + + local ok, res, s, b = pcall(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireType("ok"))); + CHECK_EQ("number", toString(requireType("res"))); + // CHECK_EQ("any", toString(requireType("res"))); + CHECK_EQ("string", toString(requireType("s"))); + CHECK_EQ("boolean", toString(requireType("b"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index cddeab6e..ce22bcb1 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Normalize.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -8,6 +9,7 @@ LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAG(LuauWeakEqConstraint) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; @@ -48,6 +50,7 @@ struct RefinementClassFixture : Fixture {"Y", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}}, }; + normalize(vec3, arena, *typeChecker.iceHandler); TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); @@ -55,17 +58,21 @@ struct RefinementClassFixture : Fixture TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + normalize(isA, arena, *typeChecker.iceHandler); getMutable(inst)->props = { {"Name", Property{typeChecker.stringType}}, {"IsA", Property{isA}}, }; + normalize(inst, arena, *typeChecker.iceHandler); TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); + normalize(folder, arena, *typeChecker.iceHandler); TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); getMutable(part)->props = { {"Position", Property{vec3}}, }; + normalize(part, arena, *typeChecker.iceHandler); typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; @@ -697,7 +704,10 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("{| x: number, y: number |}", toString(requireTypeAtPosition({4, 28}))); + else + CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index d39341ea..2b01c29e 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -5,8 +5,6 @@ #include "doctest.h" #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(BetterDiagnosticCodesInStudio) - using namespace Luau; TEST_SUITE_BEGIN("TypeSingletons"); @@ -261,14 +259,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::BetterDiagnosticCodesInStudio) - { - CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); - } - else - { - CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); - } + CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 0484351d..ca1b8de7 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -11,6 +11,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -1211,7 +1213,10 @@ TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_c )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(get(result.errors[0])); + if (FFlag::LuauLowerBoundsCalculation) + CHECK(get(result.errors[0])); + else + CHECK(get(result.errors[0])); } // This unit test could be flaky if the fix has regressed. @@ -2922,6 +2927,60 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the LUAU_REQUIRE_NO_ERRORS(result); } +// The real bug here was that we weren't always uncondionally typechecking a trailing return statement last. +TEST_CASE_FIXTURE(Fixture, "dont_leak_free_table_props") +{ + CheckResult result = check(R"( + local function a(state) + print(state.blah) + end + + local function b(state) -- The bug was that we inferred state: {blah: any, gwar: any} + print(state.gwar) + end + + return function() + return function(state) + a(state) + b(state) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({+ blah: a +}) -> ()", toString(requireType("a"))); + CHECK_EQ("({+ gwar: a +}) -> ()", toString(requireType("b"))); + CHECK_EQ("() -> ({+ blah: a, gwar: b +}) -> ()", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + )"); + + CHECK_EQ("(t1) -> {| Byte: (b) -> (a...), PeekByte: (c) -> (a...) |} where t1 = {+ byte: (t1, number) -> (a...) +}", + toString(requireType("Base64FileReader"))); +} + TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") { ScopedFastFlag sff{"LuauCheckImplicitNumbericKeys", true}; diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 660ddcfc..6abd96b9 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -13,6 +13,7 @@ #include +LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) LUAU_FASTFLAG(LuauEqConstraint) @@ -177,7 +178,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_case") )"); LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); } TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") @@ -293,7 +293,7 @@ TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") // In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type // checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. -TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") +TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count") { #if defined(LUAU_ENABLE_ASAN) int limit = 250; @@ -302,12 +302,14 @@ TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") #else int limit = 600; #endif - ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; - ScopedFastInt luauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", limit - 100}; - ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", 0}; - CHECK_NOTHROW(check("print('Hello!')")); - CHECK_THROWS_AS(check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"), std::runtime_error); + ScopedFastFlag sff{"LuauTableUseCounterInstead", true}; + ScopedFastInt sfi{"LuauCheckRecursionLimit", limit}; + + CheckResult result = check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(nullptr != get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "check_block_recursion_limit") @@ -721,9 +723,9 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") local l0 do end while _ do - function _:_() - _ += _(_._(_:n0(xpcall,_))) - end + function _:_() + _ += _(_._(_:n0(xpcall,_))) + end end )"); @@ -978,4 +980,48 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") +{ + ScopedFastInt sfi("LuauTypeInferRecursionLimit", 2); + ScopedFastFlag sff{"LuauRecursionLimitException", true}; + + CheckResult result = check(R"( + function complex() + function _(l0:t0): (any, ()->()) + return 0,_ + end + type t0 = t0 | {} + _(nil) + end + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") +{ + ScopedFastFlag substituteFollowNewTypes{"LuauSubstituteFollowNewTypes", true}; + + CheckResult result = check(R"( + local obj = {} + + function obj:Method() + self.fieldA = function(object) + if object.a then + self.arr[object] = true + elseif object.b then + self.fieldB[object] = object:Connect(function(arg) + self.arr[arg] = nil + end) + end + end + end + + return obj + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 130f33d7..f141622f 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") @@ -27,8 +29,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const auto& [returns, tail] = flatten(takeTwoType->retType); CHECK_EQ(2, returns.size()); - CHECK_EQ(typeChecker.numberType, returns[0]); - CHECK_EQ(typeChecker.numberType, returns[1]); + CHECK_EQ(typeChecker.numberType, follow(returns[0])); + CHECK_EQ(typeChecker.numberType, follow(returns[1])); CHECK(!tail); } @@ -74,9 +76,9 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const auto& [rets, tail] = flatten(takeOneMoreType->retType); REQUIRE_EQ(3, rets.size()); - CHECK_EQ(typeChecker.numberType, rets[0]); - CHECK_EQ(typeChecker.numberType, rets[1]); - CHECK_EQ(typeChecker.numberType, rets[2]); + CHECK_EQ(typeChecker.numberType, follow(rets[0])); + CHECK_EQ(typeChecker.numberType, follow(rets[1])); + CHECK_EQ(typeChecker.numberType, follow(rets[2])); CHECK(!tail); } @@ -91,26 +93,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* applyType = get(requireType("apply")); - REQUIRE(applyType != nullptr); - - std::vector applyArgs = flatten(applyType->argTypes).first; - REQUIRE_EQ(3, applyArgs.size()); - - const FunctionTypeVar* fType = get(follow(applyArgs[0])); - REQUIRE(fType != nullptr); - - const FunctionTypeVar* gType = get(follow(applyArgs[1])); - REQUIRE(gType != nullptr); - - std::vector gArgs = flatten(gType->argTypes).first; - REQUIRE_EQ(1, gArgs.size()); - - // function(function(t1, T2...): (t3, T4...), function(t5): (t1, T2...), t5): (t3, T4...) - - REQUIRE_EQ(*gArgs[0], *applyArgs[2]); - REQUIRE_EQ(toString(fType->argTypes), toString(gType->retType)); - REQUIRE_EQ(toString(fType->retType), toString(applyType->retType)); + CHECK_EQ("((b...) -> (c...), (a) -> (b...), a) -> (c...)", toString(requireType("apply"))); } TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") @@ -328,7 +311,10 @@ local c: Packed auto ttvA = get(requireType("a")); REQUIRE(ttvA); CHECK_EQ(toString(requireType("a")), "Packed"); - CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); + else + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); REQUIRE(ttvA->instantiatedTypeParams.size() == 1); REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ff207a18..96bdd534 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -6,6 +6,7 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauEqConstraint) using namespace Luau; @@ -254,11 +255,11 @@ local c = bf.a.y TEST_CASE_FIXTURE(Fixture, "optional_union_functions") { CheckResult result = check(R"( -local a = {} -function a.foo(x:number, y:number) return x + y end -type A = typeof(a) -local b: A? = a -local c = b.foo(1, 2) + local a = {} + function a.foo(x:number, y:number) return x + y end + type A = typeof(a) + local b: A? = a + local c = b.foo(1, 2) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -356,7 +357,10 @@ a.x = 2 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("Value of type '{| x: number, y: number |}?' could be nil", toString(result.errors[0])); + else + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "optional_length_error") @@ -533,8 +537,13 @@ TEST_CASE_FIXTURE(Fixture, "table_union_write_indirect") LUAU_REQUIRE_ERROR_COUNT(1, result); // NOTE: union normalization will improve this message - CHECK_EQ(toString(result.errors[0]), - R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[0]), "Type '(string) -> number' could not be converted into '(number) -> string'\n" + "caused by:\n" + " Argument #1 type is not compatible. Type 'number' could not be converted into 'string'"); + else + CHECK_EQ(toString(result.errors[0]), + R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); } diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index ab9be42c..c8176456 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -581,4 +581,19 @@ do assert(#arr == 5) end +-- test boundary invariant maintenance when table is filled using SETLIST opcode +do + local arr = {[2]=2,1} + assert(#arr == 2) +end + +-- test boundary invariant maintenance when table is filled using table.move +do + local t1 = {1, 2, 3, 4, 5} + local t2 = {[6] = 6} + + table.move(t1, 1, 5, 1, t2) + assert(#t2 == 6) +end + return"OK"