Sync to upstream/release/649 (#1489)
Some checks failed
benchmark / callgrind (map[branch:main name:luau-lang/benchmark-data], ubuntu-22.04) (push) Has been cancelled
build / ${{matrix.os.name}} (map[name:macos version:macos-latest]) (push) Has been cancelled
build / ${{matrix.os.name}} (map[name:macos-arm version:macos-14]) (push) Has been cancelled
build / ${{matrix.os.name}} (map[name:ubuntu version:ubuntu-latest]) (push) Has been cancelled
build / windows (Win32) (push) Has been cancelled
build / windows (x64) (push) Has been cancelled
build / coverage (push) Has been cancelled
build / web (push) Has been cancelled
release / ${{matrix.os.name}} (map[name:macos version:macos-latest]) (push) Has been cancelled
release / ${{matrix.os.name}} (map[name:ubuntu version:ubuntu-20.04]) (push) Has been cancelled
release / ${{matrix.os.name}} (map[name:windows version:windows-latest]) (push) Has been cancelled
release / web (push) Has been cancelled

This commit is contained in:
aaron 2024-10-25 16:15:01 -04:00 committed by GitHub
parent e491128f95
commit db809395bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 1374 additions and 448 deletions

View File

@ -82,4 +82,16 @@ std::optional<Binding> tryGetGlobalBinding(GlobalTypes& globals, const std::stri
Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name); Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name);
TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name); TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name);
/** A number of built-in functions are magical enough that we need to match on them specifically by
* name when they are called. These are listed here to be used whenever necessary, instead of duplicating this logic repeatedly.
*/
bool matchSetMetatable(const AstExprCall& call);
bool matchTableFreeze(const AstExprCall& call);
bool matchAssert(const AstExprCall& call);
// Returns `true` if the function should introduce typestate for its first argument.
bool shouldTypestateForFirstArgument(const AstExprCall& call);
} // namespace Luau } // namespace Luau

View File

@ -22,6 +22,15 @@ struct CloneState
SeenTypePacks seenTypePacks; SeenTypePacks seenTypePacks;
}; };
/** `shallowClone` will make a copy of only the _top level_ constructor of the type,
* while `clone` will make a deep copy of the entire type and its every component.
*
* Be mindful about which behavior you actually _want_.
*/
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState);
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);

View File

@ -146,6 +146,8 @@ struct ConstraintGenerator
*/ */
void visitModuleRoot(AstStatBlock* block); void visitModuleRoot(AstStatBlock* block);
void visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block);
private: private:
std::vector<std::vector<TypeId>> interiorTypes; std::vector<std::vector<TypeId>> interiorTypes;

View File

@ -3,6 +3,7 @@
#pragma once #pragma once
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Location.h" #include "Luau/Location.h"
@ -69,6 +70,9 @@ struct ConstraintSolver
NotNull<Scope> rootScope; NotNull<Scope> rootScope;
ModuleName currentModuleName; ModuleName currentModuleName;
// The dataflow graph of the program, used in constraint generation and for magic functions.
NotNull<const DataFlowGraph> dfg;
// Constraints that the solver has generated, rather than sourcing from the // Constraints that the solver has generated, rather than sourcing from the
// scope tree. // scope tree.
std::vector<std::unique_ptr<Constraint>> solverConstraints; std::vector<std::unique_ptr<Constraint>> solverConstraints;
@ -120,6 +124,7 @@ struct ConstraintSolver
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
std::vector<RequireCycle> requireCycles, std::vector<RequireCycle> requireCycles,
DcrLogger* logger, DcrLogger* logger,
NotNull<const DataFlowGraph> dfg,
TypeCheckLimits limits TypeCheckLimits limits
); );
@ -167,9 +172,9 @@ public:
*/ */
bool tryDispatch(NotNull<const Constraint> c, bool force); bool tryDispatch(NotNull<const Constraint> c, bool force);
bool tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint);
@ -194,14 +199,14 @@ public:
bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do // for a, ... in some_table do
// also handles __iter metamethod // also handles __iter metamethod
bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
// for a, ... in next_function, t, ... do // for a, ... in next_function, t, ... do
bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp( std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
NotNull<const Constraint> constraint, NotNull<const Constraint> constraint,

View File

@ -35,6 +35,8 @@ struct DataFlowGraph
DataFlowGraph& operator=(DataFlowGraph&&) = default; DataFlowGraph& operator=(DataFlowGraph&&) = default;
DefId getDef(const AstExpr* expr) const; DefId getDef(const AstExpr* expr) const;
// Look up the definition optionally, knowing it may not be present.
std::optional<DefId> getDefOptional(const AstExpr* expr) const;
// Look up for the rvalue def for a compound assignment. // Look up for the rvalue def for a compound assignment.
std::optional<DefId> getRValueDefForCompoundAssign(const AstExpr* expr) const; std::optional<DefId> getRValueDefForCompoundAssign(const AstExpr* expr) const;

View File

@ -12,6 +12,7 @@
namespace Luau namespace Luau
{ {
struct FrontendOptions;
struct FragmentAutocompleteAncestryResult struct FragmentAutocompleteAncestryResult
{ {
@ -29,15 +30,30 @@ struct FragmentParseResult
std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>(); std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>();
}; };
struct FragmentTypeCheckResult
{
ModulePtr incrementalModule = nullptr;
Scope* freshScope = nullptr;
};
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos); FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos);
FragmentTypeCheckResult typecheckFragment(
Frontend& frontend,
const ModuleName& moduleName,
const Position& cursorPos,
std::optional<FrontendOptions> opts,
std::string_view src
);
AutocompleteResult fragmentAutocomplete( AutocompleteResult fragmentAutocomplete(
Frontend& frontend, Frontend& frontend,
std::string_view src, std::string_view src,
const ModuleName& moduleName, const ModuleName& moduleName,
Position& cursorPosition, Position& cursorPosition,
std::optional<FrontendOptions> opts,
StringCompletionCallback callback StringCompletionCallback callback
); );

View File

@ -60,7 +60,7 @@ struct ReplaceGenerics : Substitution
}; };
// A substitution which replaces generic functions by monomorphic functions // A substitution which replaces generic functions by monomorphic functions
struct Instantiation : Substitution struct Instantiation final : Substitution
{ {
Instantiation(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope) Instantiation(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope)
: Substitution(log, arena) : Substitution(log, arena)

View File

@ -23,8 +23,6 @@ using ModulePtr = std::shared_ptr<Module>;
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice); bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice); bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isConsistentSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
class TypeIds class TypeIds
{ {
@ -336,6 +334,7 @@ struct NormalizedType
}; };
using SeenTablePropPairs = Set<std::pair<TypeId, TypeId>, TypeIdPairHash>;
class Normalizer class Normalizer
{ {
@ -390,7 +389,13 @@ public:
void unionTablesWithTable(TypeIds& heres, TypeId there); void unionTablesWithTable(TypeIds& heres, TypeId there);
void unionTables(TypeIds& heres, const TypeIds& theres); void unionTables(TypeIds& heres, const TypeIds& theres);
NormalizationResult unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); NormalizationResult unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
NormalizationResult unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars = -1); NormalizationResult unionNormalWithTy(
NormalizedType& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes,
int ignoreSmallerTyvars = -1
);
// ------- Negations // ------- Negations
std::optional<NormalizedType> negateNormal(const NormalizedType& here); std::optional<NormalizedType> negateNormal(const NormalizedType& here);
@ -407,16 +412,26 @@ public:
void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectClassesWithClass(NormalizedClassType& heres, TypeId there);
void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there);
std::optional<TypePackId> intersectionOfTypePacks(TypePackId here, TypePackId there); std::optional<TypePackId> intersectionOfTypePacks(TypePackId here, TypePackId there);
std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet); std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet);
void intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes); void intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSetTypes);
void intersectTables(TypeIds& heres, const TypeIds& theres); void intersectTables(TypeIds& heres, const TypeIds& theres);
std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there); std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there);
void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress);
NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes); NormalizationResult intersectTyvarsWithTy(
NormalizedTyvars& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes
);
NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes); NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSetTypes);
NormalizationResult normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet); NormalizationResult normalizeIntersections(
const std::vector<TypeId>& intersections,
NormalizedType& outType,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSet
);
// Check for inhabitance // Check for inhabitance
NormalizationResult isInhabited(TypeId ty); NormalizationResult isInhabited(TypeId ty);
@ -426,7 +441,7 @@ public:
// Check for intersections being inhabited // Check for intersections being inhabited
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right); NormalizationResult isIntersectionInhabited(TypeId left, TypeId right);
NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet); NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet);
// -------- Convert back from a normalized type to a type // -------- Convert back from a normalized type to a type
TypeId typeFromNormal(const NormalizedType& norm); TypeId typeFromNormal(const NormalizedType& norm);

View File

@ -806,6 +806,13 @@ struct Type final
Type& operator=(const TypeVariant& rhs); Type& operator=(const TypeVariant& rhs);
Type& operator=(TypeVariant&& rhs); Type& operator=(TypeVariant&& rhs);
Type(Type&&) = default;
Type& operator=(Type&&) = default;
Type clone() const;
private:
Type(const Type&) = default;
Type& operator=(const Type& rhs); Type& operator=(const Type& rhs);
}; };

View File

@ -179,7 +179,7 @@ public:
bool occursCheck(TypePackId needle, TypePackId haystack, bool reversed); bool occursCheck(TypePackId needle, TypePackId haystack, bool reversed);
bool occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack); bool occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);
Unifier makeChildUnifier(); std::unique_ptr<Unifier> makeChildUnifier();
void reportError(TypeError err); void reportError(TypeError err);
LUAU_NOINLINE void reportError(Location location, TypeErrorData data); LUAU_NOINLINE void reportError(Location location, TypeErrorData data);

View File

@ -16,7 +16,6 @@
#include <utility> #include <utility>
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAutocompleteNewSolverLimit)
LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions, false) LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions, false)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
@ -157,11 +156,8 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
NotNull{&iceReporter}, NotNull{&limits} NotNull{&iceReporter}, NotNull{&limits}
}; // TODO: maybe subtyping checks should not invoke user-defined type function runtime }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
if (FFlag::LuauAutocompleteNewSolverLimit) unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
{ unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
}
Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}};

View File

@ -2,6 +2,8 @@
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Clone.h"
#include "Luau/Error.h"
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/Symbol.h" #include "Luau/Symbol.h"
#include "Luau/Common.h" #include "Luau/Common.h"
@ -25,9 +27,12 @@
* about a function that takes any number of values, but where each value must have some specific type. * about a function that takes any number of values, but where each value must have some specific type.
*/ */
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins, false)
LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix, false)
LUAU_FASTFLAG(AutocompleteRequirePathSuggestions); LUAU_FASTFLAG(AutocompleteRequirePathSuggestions)
namespace Luau namespace Luau
{ {
@ -67,6 +72,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionSelect(MagicFunctionCallContext context);
static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context);
static bool dcrMagicFunctionPack(MagicFunctionCallContext context); static bool dcrMagicFunctionPack(MagicFunctionCallContext context);
static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context);
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types) TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types)
{ {
@ -395,8 +401,10 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
// but it'll be ok for now. // but it'll be ok for now.
TypeId genericTy = arena.addType(GenericType{"T"}); TypeId genericTy = arena.addType(GenericType{"T"});
TypePackId thePack = arena.addTypePack({genericTy}); TypePackId thePack = arena.addTypePack({genericTy});
TypeId idTyWithMagic = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack});
ttv->props["freeze"] = makeProperty(idTyWithMagic, "@luau/global/table.freeze");
TypeId idTy = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack}); TypeId idTy = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack});
ttv->props["freeze"] = makeProperty(idTy, "@luau/global/table.freeze");
ttv->props["clone"] = makeProperty(idTy, "@luau/global/table.clone"); ttv->props["clone"] = makeProperty(idTy, "@luau/global/table.clone");
} }
else else
@ -413,6 +421,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack); attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack);
attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack);
if (FFlag::LuauTypestateBuiltins)
attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze);
} }
if (FFlag::AutocompleteRequirePathSuggestions) if (FFlag::AutocompleteRequirePathSuggestions)
@ -574,7 +584,11 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex
fmt = context.callSite->args.data[0]->as<AstExprConstantString>(); fmt = context.callSite->args.data[0]->as<AstExprConstantString>();
if (!fmt) if (!fmt)
{
if (FFlag::LuauStringFormatArityFix)
context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location);
return; return;
}
std::vector<TypeId> expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size); std::vector<TypeId> expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(context.arguments); const auto& [params, tail] = flatten(context.arguments);
@ -1324,6 +1338,58 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
return true; return true;
} }
static bool dcrMagicFunctionFreeze(MagicFunctionCallContext context)
{
LUAU_ASSERT(FFlag::LuauTypestateBuiltins);
TypeArena* arena = context.solver->arena;
const DataFlowGraph* dfg = context.solver->dfg.get();
Scope* scope = context.constraint->scope.get();
const auto& [paramTypes, paramTail] = extendTypePack(*arena, context.solver->builtinTypes, context.arguments, 1);
LUAU_ASSERT(paramTypes.size() >= 1);
TypeId inputType = follow(paramTypes.at(0));
// we'll check if it's a table first since this magic function also produces the error if it's not until we have bounded generics
if (!get<TableType>(inputType))
{
context.solver->reportError(TypeMismatch{context.solver->builtinTypes->tableType, inputType}, context.callSite->argLocation);
return false;
}
AstExpr* targetExpr = context.callSite->args.data[0];
std::optional<DefId> resultDef = dfg->getDefOptional(targetExpr);
std::optional<TypeId> resultTy = resultDef ? scope->lookup(*resultDef) : std::nullopt;
// Clone the input type, this will become our final result type after we mutate it.
CloneState cloneState{context.solver->builtinTypes};
TypeId clonedType = shallowClone(inputType, *arena, cloneState);
auto tableTy = getMutable<TableType>(clonedType);
// `clone` should not break this.
LUAU_ASSERT(tableTy);
tableTy->state = TableState::Sealed;
tableTy->syntheticName = std::nullopt;
// We'll mutate the table to make every property type read-only.
for (auto iter = tableTy->props.begin(); iter != tableTy->props.end();)
{
if (iter->second.isWriteOnly())
iter = tableTy->props.erase(iter);
else
{
iter->second.writeTy = std::nullopt;
iter++;
}
}
if (resultTy)
asMutable(*resultTy)->ty.emplace<BoundType>(clonedType);
asMutable(context.result)->ty.emplace<BoundTypePack>(arena->addTypePack({clonedType}));
return true;
}
static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
{ {
// require(foo.parent.bar) will technically work, but it depends on legacy goop that // require(foo.parent.bar) will technically work, but it depends on legacy goop that
@ -1415,4 +1481,52 @@ static bool dcrMagicFunctionRequire(MagicFunctionCallContext context)
return false; return false;
} }
bool matchSetMetatable(const AstExprCall& call)
{
const char* smt = "setmetatable";
if (call.args.size != 2)
return false;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != smt)
return false;
return true;
}
bool matchTableFreeze(const AstExprCall& call)
{
if (call.args.size < 1)
return false;
const AstExprIndexName* index = call.func->as<AstExprIndexName>();
if (!index || index->index != "freeze")
return false;
const AstExprGlobal* global = index->expr->as<AstExprGlobal>();
if (!global || global->name != "table")
return false;
return true;
}
bool matchAssert(const AstExprCall& call)
{
if (call.args.size < 1)
return false;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != "assert")
return false;
return true;
}
bool shouldTypestateForFirstArgument(const AstExprCall& call)
{
// TODO: magic function for setmetatable and assert and then add them
return matchTableFreeze(call);
}
} // namespace Luau } // namespace Luau

View File

@ -140,7 +140,7 @@ private:
} }
} }
private: public:
TypeId shallowClone(TypeId ty) TypeId shallowClone(TypeId ty)
{ {
// We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s. // We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s.
@ -189,6 +189,7 @@ private:
return target; return target;
} }
private:
Property shallowClone(const Property& p) Property shallowClone(const Property& p)
{ {
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
@ -453,6 +454,24 @@ private:
} // namespace } // namespace
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
{
if (tp->persistent)
return tp;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
return cloner.shallowClone(tp);
}
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
{
if (typeId->persistent)
return typeId;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
return cloner.shallowClone(typeId);
}
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
{ {
if (tp->persistent) if (tp->persistent)

View File

@ -2,11 +2,12 @@
#include "Luau/ConstraintGenerator.h" #include "Luau/ConstraintGenerator.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Def.h" #include "Luau/BuiltinDefinitions.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/ControlFlow.h" #include "Luau/ControlFlow.h"
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/Def.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
@ -30,6 +31,9 @@ LUAU_FASTINT(LuauCheckRecursionLimit)
LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauLogSolverToJson)
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAG(LuauTypestateBuiltins)
LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues, false)
namespace Luau namespace Luau
{ {
@ -54,20 +58,6 @@ static std::optional<AstExpr*> matchRequire(const AstExprCall& call)
return call.args.data[0]; return call.args.data[0];
} }
static bool matchSetmetatable(const AstExprCall& call)
{
const char* smt = "setmetatable";
if (call.args.size != 2)
return false;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != smt)
return false;
return true;
}
struct TypeGuard struct TypeGuard
{ {
bool isTypeof; bool isTypeof;
@ -110,18 +100,6 @@ static std::optional<TypeGuard> matchTypeGuard(const AstExprBinary* binary)
}; };
} }
static bool matchAssert(const AstExprCall& call)
{
if (call.args.size < 1)
return false;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != "assert")
return false;
return true;
}
namespace namespace
{ {
@ -285,6 +263,31 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
} }
} }
void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block)
{
visitBlockWithoutChildScope(resumeScope, block);
fillInInferredBindings(resumeScope, block);
if (logger)
logger->captureGenerationModule(module);
for (const auto& [ty, domain] : localTypes)
{
// FIXME: This isn't the most efficient thing.
TypeId domainTy = builtinTypes->neverType;
for (TypeId d : domain)
{
d = follow(d);
if (d == ty)
continue;
domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result;
}
LUAU_ASSERT(get<BlockedType>(ty));
asMutable(ty)->ty.emplace<BoundType>(domainTy);
}
}
TypeId ConstraintGenerator::freshType(const ScopePtr& scope) TypeId ConstraintGenerator::freshType(const ScopePtr& scope)
{ {
return Luau::freshType(arena, builtinTypes, scope.get()); return Luau::freshType(arena, builtinTypes, scope.get());
@ -1075,9 +1078,17 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat
addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true});
else if (const AstExprCall* call = value->as<AstExprCall>()) else if (const AstExprCall* call = value->as<AstExprCall>())
{ {
if (const AstExprGlobal* global = call->func->as<AstExprGlobal>(); global && global->name == "setmetatable") if (FFlag::LuauTypestateBuiltins)
{ {
addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); if (matchSetMetatable(*call))
addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true});
}
else
{
if (const AstExprGlobal* global = call->func->as<AstExprGlobal>(); global && global->name == "setmetatable")
{
addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true});
}
} }
} }
} }
@ -1975,7 +1986,7 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
Checkpoint argEndCheckpoint = checkpoint(this); Checkpoint argEndCheckpoint = checkpoint(this);
if (matchSetmetatable(*call)) if (matchSetMetatable(*call))
{ {
TypePack argTailPack; TypePack argTailPack;
if (argTail && args.size() < 2) if (argTail && args.size() < 2)
@ -2050,72 +2061,80 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}};
} }
else
if (FFlag::LuauTypestateBuiltins && shouldTypestateForFirstArgument(*call) && call->args.size > 0 && isLValue(call->args.data[0]))
{ {
if (matchAssert(*call) && !argumentRefinements.empty()) AstExpr* targetExpr = call->args.data[0];
applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); auto resultTy = arena->addType(BlockedType{});
// TODO: How do expectedTypes play into this? Do they? if (auto def = dfg->getDefOptional(targetExpr))
TypePackId rets = arena->addTypePack(BlockedTypePack{}); {
TypePackId argPack = addTypePack(std::move(args), argTail); scope->lvalueTypes[*def] = resultTy;
FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); scope->rvalueRefinements[*def] = resultTy;
}
/*
* To make bidirectional type checking work, we need to solve these constraints in a particular order:
*
* 1. Solve the function type
* 2. Propagate type information from the function type to the argument types
* 3. Solve the argument types
* 4. Solve the call
*/
NotNull<Constraint> checkConstraint = addConstraint(
scope,
call->func->location,
FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}}
);
forEachConstraint(
funcBeginCheckpoint,
funcEndCheckpoint,
this,
[checkConstraint](const ConstraintPtr& constraint)
{
checkConstraint->dependencies.emplace_back(constraint.get());
}
);
NotNull<Constraint> callConstraint = addConstraint(
scope,
call->func->location,
FunctionCallConstraint{
fnType,
argPack,
rets,
call,
std::move(discriminantTypes),
&module->astOverloadResolvedTypes,
}
);
getMutable<BlockedTypePack>(rets)->owner = callConstraint.get();
callConstraint->dependencies.push_back(checkConstraint);
forEachConstraint(
argBeginCheckpoint,
argEndCheckpoint,
this,
[checkConstraint, callConstraint](const ConstraintPtr& constraint)
{
constraint->dependencies.emplace_back(checkConstraint);
callConstraint->dependencies.emplace_back(constraint.get());
}
);
return InferencePack{rets, {refinementArena.variadic(returnRefinements)}};
} }
if (matchAssert(*call) && !argumentRefinements.empty())
applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]);
// TODO: How do expectedTypes play into this? Do they?
TypePackId rets = arena->addTypePack(BlockedTypePack{});
TypePackId argPack = addTypePack(std::move(args), argTail);
FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self);
/*
* To make bidirectional type checking work, we need to solve these constraints in a particular order:
*
* 1. Solve the function type
* 2. Propagate type information from the function type to the argument types
* 3. Solve the argument types
* 4. Solve the call
*/
NotNull<Constraint> checkConstraint = addConstraint(
scope, call->func->location, FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}}
);
forEachConstraint(
funcBeginCheckpoint,
funcEndCheckpoint,
this,
[checkConstraint](const ConstraintPtr& constraint)
{
checkConstraint->dependencies.emplace_back(constraint.get());
}
);
NotNull<Constraint> callConstraint = addConstraint(
scope,
call->func->location,
FunctionCallConstraint{
fnType,
argPack,
rets,
call,
std::move(discriminantTypes),
&module->astOverloadResolvedTypes,
}
);
getMutable<BlockedTypePack>(rets)->owner = callConstraint.get();
callConstraint->dependencies.push_back(checkConstraint);
forEachConstraint(
argBeginCheckpoint,
argEndCheckpoint,
this,
[checkConstraint, callConstraint](const ConstraintPtr& constraint)
{
constraint->dependencies.emplace_back(checkConstraint);
callConstraint->dependencies.emplace_back(constraint.get());
}
);
return InferencePack{rets, {refinementArena.variadic(returnRefinements)}};
} }
Inference ConstraintGenerator::check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType, bool forceSingleton, bool generalize) Inference ConstraintGenerator::check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType, bool forceSingleton, bool generalize)
@ -2703,7 +2722,16 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExpr* expr, Type
visitLValue(scope, e, rhsType); visitLValue(scope, e, rhsType);
else if (auto e = expr->as<AstExprError>()) else if (auto e = expr->as<AstExprError>())
{ {
// Nothing? if (FFlag::LuauNewSolverVisitErrorExprLvalues)
{
// If we end up with some sort of error expression in an lvalue
// position, at least go and check the expressions so that when
// we visit them later, there aren't any invalid assumptions.
for (auto subExpr : e->expressions)
{
check(scope, subExpr);
}
}
} }
else else
ice->ice("Unexpected lvalue expression", expr->location); ice->ice("Unexpected lvalue expression", expr->location);

View File

@ -326,6 +326,7 @@ ConstraintSolver::ConstraintSolver(
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
std::vector<RequireCycle> requireCycles, std::vector<RequireCycle> requireCycles,
DcrLogger* logger, DcrLogger* logger,
NotNull<const DataFlowGraph> dfg,
TypeCheckLimits limits TypeCheckLimits limits
) )
: arena(normalizer->arena) : arena(normalizer->arena)
@ -335,6 +336,7 @@ ConstraintSolver::ConstraintSolver(
, constraints(std::move(constraints)) , constraints(std::move(constraints))
, rootScope(rootScope) , rootScope(rootScope)
, currentModuleName(std::move(moduleName)) , currentModuleName(std::move(moduleName))
, dfg(dfg)
, moduleResolver(moduleResolver) , moduleResolver(moduleResolver)
, requireCycles(std::move(requireCycles)) , requireCycles(std::move(requireCycles))
, logger(logger) , logger(logger)
@ -618,11 +620,11 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
bool success = false; bool success = false;
if (auto sc = get<SubtypeConstraint>(*constraint)) if (auto sc = get<SubtypeConstraint>(*constraint))
success = tryDispatch(*sc, constraint, force); success = tryDispatch(*sc, constraint);
else if (auto psc = get<PackSubtypeConstraint>(*constraint)) else if (auto psc = get<PackSubtypeConstraint>(*constraint))
success = tryDispatch(*psc, constraint, force); success = tryDispatch(*psc, constraint);
else if (auto gc = get<GeneralizationConstraint>(*constraint)) else if (auto gc = get<GeneralizationConstraint>(*constraint))
success = tryDispatch(*gc, constraint, force); success = tryDispatch(*gc, constraint);
else if (auto ic = get<IterableConstraint>(*constraint)) else if (auto ic = get<IterableConstraint>(*constraint))
success = tryDispatch(*ic, constraint, force); success = tryDispatch(*ic, constraint, force);
else if (auto nc = get<NameConstraint>(*constraint)) else if (auto nc = get<NameConstraint>(*constraint))
@ -650,14 +652,14 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
else if (auto rpc = get<ReducePackConstraint>(*constraint)) else if (auto rpc = get<ReducePackConstraint>(*constraint))
success = tryDispatch(*rpc, constraint, force); success = tryDispatch(*rpc, constraint, force);
else if (auto eqc = get<EqualityConstraint>(*constraint)) else if (auto eqc = get<EqualityConstraint>(*constraint))
success = tryDispatch(*eqc, constraint, force); success = tryDispatch(*eqc, constraint);
else else
LUAU_ASSERT(false); LUAU_ASSERT(false);
return success; return success;
} }
bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint, bool force) bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint)
{ {
if (isBlocked(c.subType)) if (isBlocked(c.subType))
return block(c.subType, constraint); return block(c.subType, constraint);
@ -669,7 +671,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull<const Con
return true; return true;
} }
bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force) bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint)
{ {
if (isBlocked(c.subPack)) if (isBlocked(c.subPack))
return block(c.subPack, constraint); return block(c.subPack, constraint);
@ -681,7 +683,7 @@ bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull<const
return true; return true;
} }
bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force) bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint)
{ {
TypeId generalizedType = follow(c.generalizedType); TypeId generalizedType = follow(c.generalizedType);
@ -828,7 +830,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
if (iterator.head.size() >= 2) if (iterator.head.size() >= 2)
tableTy = iterator.head[1]; tableTy = iterator.head[1];
return tryDispatchIterableFunction(nextTy, tableTy, c, constraint, force); return tryDispatchIterableFunction(nextTy, tableTy, c, constraint);
} }
else else
@ -2165,7 +2167,7 @@ bool ConstraintSolver::tryDispatch(const ReducePackConstraint& c, NotNull<const
return reductionFinished; return reductionFinished;
} }
bool ConstraintSolver::tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint, bool force) bool ConstraintSolver::tryDispatch(const EqualityConstraint& c, NotNull<const Constraint> constraint)
{ {
unify(constraint, c.resultType, c.assignmentType); unify(constraint, c.resultType, c.assignmentType);
unify(constraint, c.assignmentType, c.resultType); unify(constraint, c.assignmentType, c.resultType);
@ -2328,8 +2330,7 @@ bool ConstraintSolver::tryDispatchIterableFunction(
TypeId nextTy, TypeId nextTy,
TypeId tableTy, TypeId tableTy,
const IterableConstraint& c, const IterableConstraint& c,
NotNull<const Constraint> constraint, NotNull<const Constraint> constraint
bool force
) )
{ {
const FunctionType* nextFn = get<FunctionType>(nextTy); const FunctionType* nextFn = get<FunctionType>(nextTy);

View File

@ -2,6 +2,7 @@
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Def.h" #include "Luau/Def.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Error.h" #include "Luau/Error.h"
@ -12,6 +13,7 @@
LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauTypestateBuiltins)
namespace Luau namespace Luau
{ {
@ -67,6 +69,14 @@ DefId DataFlowGraph::getDef(const AstExpr* expr) const
return NotNull{*def}; return NotNull{*def};
} }
std::optional<DefId> DataFlowGraph::getDefOptional(const AstExpr* expr) const
{
auto def = astDefs.find(expr);
if (!def)
return std::nullopt;
return NotNull{*def};
}
std::optional<DefId> DataFlowGraph::getRValueDefForCompoundAssign(const AstExpr* expr) const std::optional<DefId> DataFlowGraph::getRValueDefForCompoundAssign(const AstExpr* expr) const
{ {
auto def = compoundAssignDefs.find(expr); auto def = compoundAssignDefs.find(expr);
@ -929,6 +939,39 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c)
{ {
visitExpr(c->func); visitExpr(c->func);
if (FFlag::LuauTypestateBuiltins && shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin()))
{
AstExpr* firstArg = *c->args.begin();
// this logic has to handle the name-like subset of expressions.
std::optional<DataFlowResult> result;
if (auto l = firstArg->as<AstExprLocal>())
result = visitExpr(l);
else if (auto g = firstArg->as<AstExprGlobal>())
result = visitExpr(g);
else if (auto i = firstArg->as<AstExprIndexName>())
result = visitExpr(i);
else if (auto i = firstArg->as<AstExprIndexExpr>())
result = visitExpr(i);
else
LUAU_UNREACHABLE(); // This is unreachable because the whole thing is guarded by `isLValue`.
LUAU_ASSERT(result);
Location location = currentScope()->location;
// This scope starts at the end of the call site and continues to the end of the original scope.
location.begin = c->location.end;
DfgScope* child = makeChildScope(location);
scopeStack.push_back(child);
auto [def, key] = *result;
graph.astDefs[firstArg] = def;
if (key)
graph.astRefinementKeys[firstArg] = key;
visitLValue(firstArg, def);
}
for (AstExpr* arg : c->args) for (AstExpr* arg : c->args)
visitExpr(arg); visitExpr(arg);

View File

@ -4,11 +4,44 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Parser.h"
#include "Luau/ParseOptions.h"
#include "Luau/Module.h"
#include "Luau/TimeTrace.h"
#include "Luau/UnifierSharedState.h"
#include "Luau/TypeFunction.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/ConstraintGenerator.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
#include "Luau/ParseOptions.h" #include "Luau/ParseOptions.h"
#include "Luau/Module.h" #include "Luau/Module.h"
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauAllowFragmentParsing);
LUAU_FASTFLAG(LuauStoreDFGOnModule2);
namespace
{
template<typename T>
void copyModuleVec(std::vector<T>& result, const std::vector<T>& input)
{
result.insert(result.end(), input.begin(), input.end());
}
template<typename K, typename V>
void copyModuleMap(Luau::DenseHashMap<K, V>& result, const Luau::DenseHashMap<K, V>& input)
{
for (auto [k, v] : input)
result[k] = v;
}
} // namespace
namespace Luau namespace Luau
{ {
@ -147,17 +180,173 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie
return fragmentResult; return fragmentResult;
} }
ModulePtr copyModule(const ModulePtr& result, std::unique_ptr<Allocator> alloc)
{
freeze(result->internalTypes);
freeze(result->interfaceTypes);
ModulePtr incrementalModule = std::make_shared<Module>();
incrementalModule->name = result->name;
incrementalModule->humanReadableName = result->humanReadableName;
incrementalModule->allocator = std::move(alloc);
// Don't need to keep this alive (it's already on the source module)
copyModuleVec(incrementalModule->scopes, result->scopes);
copyModuleMap(incrementalModule->astTypes, result->astTypes);
copyModuleMap(incrementalModule->astTypePacks, result->astTypePacks);
copyModuleMap(incrementalModule->astExpectedTypes, result->astExpectedTypes);
// Don't need to clone astOriginalCallTypes
copyModuleMap(incrementalModule->astOverloadResolvedTypes, result->astOverloadResolvedTypes);
// Don't need to clone astForInNextTypes
copyModuleMap(incrementalModule->astForInNextTypes, result->astForInNextTypes);
// Don't need to clone astResolvedTypes
// Don't need to clone astResolvedTypePacks
// Don't need to clone upperBoundContributors
copyModuleMap(incrementalModule->astScopes, result->astScopes);
// Don't need to clone declared Globals;
return incrementalModule;
}
FragmentTypeCheckResult typeCheckFragmentHelper(
Frontend& frontend,
AstStatBlock* root,
const ModulePtr& stale,
const ScopePtr& closestScope,
const Position& cursorPos,
std::unique_ptr<Allocator> astAllocator,
const FrontendOptions& opts
)
{
freeze(stale->internalTypes);
freeze(stale->interfaceTypes);
ModulePtr incrementalModule = copyModule(stale, std::move(astAllocator));
unfreeze(incrementalModule->internalTypes);
unfreeze(incrementalModule->interfaceTypes);
/// Setup typecheck limits
TypeCheckLimits limits;
if (opts.moduleTimeLimitSec)
limits.finishTime = TimeTrace::getClock() + *opts.moduleTimeLimitSec;
else
limits.finishTime = std::nullopt;
limits.cancellationToken = opts.cancellationToken;
/// Icehandler
NotNull<InternalErrorReporter> iceHandler{&frontend.iceHandler};
/// Make the shared state for the unifier (recursion + iteration limits)
UnifierSharedState unifierState{iceHandler};
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit);
/// Initialize the normalizer
Normalizer normalizer{&incrementalModule->internalTypes, frontend.builtinTypes, NotNull{&unifierState}};
/// User defined type functions runtime
TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits});
/// Create a DataFlowGraph just for the surrounding context
auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler);
/// Contraint Generator
ConstraintGenerator cg{
incrementalModule,
NotNull{&normalizer},
NotNull{&typeFunctionRuntime},
NotNull{&frontend.moduleResolver},
frontend.builtinTypes,
iceHandler,
frontend.globals.globalScope,
nullptr,
nullptr,
NotNull{&updatedDfg},
{}
};
cg.rootScope = stale->getModuleScope().get();
// Any additions to the scope must occur in a fresh scope
auto freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.push_back({root->location, freshChildOfNearestScope});
// closest Scope -> children = { ...., freshChildOfNearestScope}
// We need to trim nearestChild from the scope hierarcy
closestScope->children.push_back(NotNull{freshChildOfNearestScope.get()});
// Visit just the root - we know the scope it should be in
cg.visitFragmentRoot(freshChildOfNearestScope, root);
// Trim nearestChild from the closestScope
Scope* back = closestScope->children.back().get();
LUAU_ASSERT(back == freshChildOfNearestScope.get());
closestScope->children.pop_back();
/// Initialize the constraint solver and run it
ConstraintSolver cs{
NotNull{&normalizer},
NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope),
borrowConstraints(cg.constraints),
incrementalModule->name,
NotNull{&frontend.moduleResolver},
{},
nullptr,
NotNull{&updatedDfg},
limits
};
try
{
cs.run();
}
catch (const TimeLimitError&)
{
stale->timeout = true;
}
catch (const UserCancelError&)
{
stale->cancelled = true;
}
// In frontend we would forbid internal types
// because this is just for autocomplete, we don't actually care
// We also don't even need to typecheck - just synthesize types as best as we can
freeze(incrementalModule->internalTypes);
freeze(incrementalModule->interfaceTypes);
return {std::move(incrementalModule), freshChildOfNearestScope.get()};
}
FragmentTypeCheckResult typecheckFragment(
Frontend& frontend,
const ModuleName& moduleName,
const Position& cursorPos,
std::optional<FrontendOptions> opts,
std::string_view src
)
{
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
{
LUAU_ASSERT(!"Expected Source Module for fragment typecheck");
return {};
}
ModulePtr module = frontend.moduleResolver.getModule(moduleName);
const ScopePtr& closestScope = findClosestScope(module, cursorPos);
FragmentParseResult r = parseFragment(*sourceModule, src, cursorPos);
FrontendOptions frontendOptions = opts.value_or(frontend.options);
return typeCheckFragmentHelper(frontend, r.root, module, closestScope, cursorPos, std::move(r.alloc), frontendOptions);
}
AutocompleteResult fragmentAutocomplete( AutocompleteResult fragmentAutocomplete(
Frontend& frontend, Frontend& frontend,
std::string_view src, std::string_view src,
const ModuleName& moduleName, const ModuleName& moduleName,
Position& cursorPosition, Position& cursorPosition,
const FrontendOptions& opts,
StringCompletionCallback callback StringCompletionCallback callback
) )
{ {
LUAU_ASSERT(FFlag::LuauSolverV2); LUAU_ASSERT(FFlag::LuauSolverV2);
// TODO LUAU_ASSERT(FFlag::LuauAllowFragmentParsing);
LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2);
return {}; return {};
} }

View File

@ -49,7 +49,7 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)
LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection, false) LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection, false)
LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAG(StudioReportLuauAny2)
LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule, false) LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2, false)
namespace Luau namespace Luau
{ {
@ -1315,9 +1315,9 @@ ModulePtr check(
} }
} }
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); DataFlowGraph oldDfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler);
DataFlowGraph* dfgForConstraintGeneration = nullptr; DataFlowGraph* dfgForConstraintGeneration = nullptr;
if (FFlag::LuauStoreDFGOnModule) if (FFlag::LuauStoreDFGOnModule2)
{ {
auto [dfg, scopes] = DataFlowGraphBuilder::buildShared(sourceModule.root, iceHandler); auto [dfg, scopes] = DataFlowGraphBuilder::buildShared(sourceModule.root, iceHandler);
result->dataFlowGraph = std::move(dfg); result->dataFlowGraph = std::move(dfg);
@ -1326,7 +1326,7 @@ ModulePtr check(
} }
else else
{ {
dfgForConstraintGeneration = &dfg; dfgForConstraintGeneration = &oldDfg;
} }
UnifierSharedState unifierState{iceHandler}; UnifierSharedState unifierState{iceHandler};
@ -1365,6 +1365,7 @@ ModulePtr check(
moduleResolver, moduleResolver,
requireCycles, requireCycles,
logger.get(), logger.get(),
NotNull{dfgForConstraintGeneration},
limits limits
}; };
@ -1418,16 +1419,32 @@ ModulePtr check(
switch (mode) switch (mode)
{ {
case Mode::Nonstrict: case Mode::Nonstrict:
Luau::checkNonStrict( if (FFlag::LuauStoreDFGOnModule2)
builtinTypes, {
NotNull{&typeFunctionRuntime}, Luau::checkNonStrict(
iceHandler, builtinTypes,
NotNull{&unifierState}, NotNull{&typeFunctionRuntime},
NotNull{&dfg}, iceHandler,
NotNull{&limits}, NotNull{&unifierState},
sourceModule, NotNull{dfgForConstraintGeneration},
result.get() NotNull{&limits},
); sourceModule,
result.get()
);
}
else
{
Luau::checkNonStrict(
builtinTypes,
NotNull{&typeFunctionRuntime},
iceHandler,
NotNull{&unifierState},
NotNull{&oldDfg},
NotNull{&limits},
sourceModule,
result.get()
);
}
break; break;
case Mode::Definition: case Mode::Definition:
// fallthrough intentional // fallthrough intentional

View File

@ -20,8 +20,9 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAGVARIABLE(LuauUseNormalizeIntersectionLimit, false)
LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200) LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200)
LUAU_FASTFLAGVARIABLE(LuauNormalizationTracksCyclicPairsThroughInhabitance, false);
LUAU_FASTFLAGVARIABLE(LuauIntersectNormalsNeedsToTrackResourceLimits, false);
namespace Luau namespace Luau
{ {
@ -570,10 +571,11 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set<TypeId>& seen)
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right) NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right)
{ {
Set<TypeId> seen{nullptr}; Set<TypeId> seen{nullptr};
return isIntersectionInhabited(left, right, seen); SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}};
return isIntersectionInhabited(left, right, seenTablePropPairs, seen);
} }
NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set<TypeId>& seenSet) NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet)
{ {
left = follow(left); left = follow(left);
right = follow(right); right = follow(right);
@ -586,7 +588,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ
} }
NormalizedType norm{builtinTypes}; NormalizedType norm{builtinTypes};
NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet); NormalizationResult res = normalizeIntersections({left, right}, norm, seenTablePropPairs, seenSet);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
{ {
if (cacheInhabitance && res == NormalizationResult::False) if (cacheInhabitance && res == NormalizationResult::False)
@ -937,7 +939,8 @@ std::shared_ptr<const NormalizedType> Normalizer::normalize(TypeId ty)
NormalizedType norm{builtinTypes}; NormalizedType norm{builtinTypes};
Set<TypeId> seenSetTypes{nullptr}; Set<TypeId> seenSetTypes{nullptr};
NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes); SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}};
NormalizationResult res = unionNormalWithTy(norm, ty, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return nullptr; return nullptr;
@ -955,7 +958,12 @@ std::shared_ptr<const NormalizedType> Normalizer::normalize(TypeId ty)
return shared; return shared;
} }
NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType, Set<TypeId>& seenSet) NormalizationResult Normalizer::normalizeIntersections(
const std::vector<TypeId>& intersections,
NormalizedType& outType,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSet
)
{ {
if (!arena) if (!arena)
sharedState->iceHandler->ice("Normalizing types outside a module"); sharedState->iceHandler->ice("Normalizing types outside a module");
@ -964,7 +972,7 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>
// Now we need to intersect the two types // Now we need to intersect the two types
for (auto ty : intersections) for (auto ty : intersections)
{ {
NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet); NormalizationResult res = intersectNormalWithTy(norm, ty, seenTablePropPairs, seenSet);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
@ -1728,7 +1736,13 @@ NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, N
} }
// See above for an explaination of `ignoreSmallerTyvars`. // See above for an explaination of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars) NormalizationResult Normalizer::unionNormalWithTy(
NormalizedType& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes,
int ignoreSmallerTyvars
)
{ {
RecursionCounter _rc(&sharedState->counters.recursionCount); RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits()) if (!withinResourceLimits())
@ -1760,7 +1774,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) for (UnionTypeIterator it = begin(utv); it != end(utv); ++it)
{ {
NormalizationResult res = unionNormalWithTy(here, *it, seenSetTypes); NormalizationResult res = unionNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
{ {
seenSetTypes.erase(there); seenSetTypes.erase(there);
@ -1781,7 +1795,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
norm.tops = builtinTypes->anyType; norm.tops = builtinTypes->anyType;
for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it)
{ {
NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes); NormalizationResult res = intersectNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
{ {
seenSetTypes.erase(there); seenSetTypes.erase(there);
@ -1881,7 +1895,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
for (auto& [tyvar, intersect] : here.tyvars) for (auto& [tyvar, intersect] : here.tyvars)
{ {
NormalizationResult res = unionNormalWithTy(*intersect, there, seenSetTypes, tyvarIndex(tyvar)); NormalizationResult res = unionNormalWithTy(*intersect, there, seenTablePropPairs, seenSetTypes, tyvarIndex(tyvar));
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
@ -2491,7 +2505,7 @@ std::optional<TypePackId> Normalizer::intersectionOfTypePacks(TypePackId here, T
return arena->addTypePack({}); return arena->addTypePack({});
} }
std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there, Set<TypeId>& seenSet) std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSet)
{ {
if (here == there) if (here == there)
return here; return here;
@ -2573,31 +2587,63 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
// if the intersection of the read types of a property is uninhabited, the whole table is `never`. // if the intersection of the read types of a property is uninhabited, the whole table is `never`.
// We've seen these table prop elements before and we're about to ask if their intersection // We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited // is inhabited
if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) if (FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance)
{ {
seenSet.erase(*hprop.readTy); auto pair1 = std::pair{*hprop.readTy, *tprop.readTy};
seenSet.erase(*tprop.readTy); auto pair2 = std::pair{*tprop.readTy, *hprop.readTy};
return {builtinTypes->neverType}; if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2))
{
seenTablePropPairs.erase(pair1);
seenTablePropPairs.erase(pair2);
return {builtinTypes->neverType};
}
else
{
seenTablePropPairs.insert(pair1);
seenTablePropPairs.insert(pair2);
}
Set<TypeId> seenSet{nullptr};
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet);
seenTablePropPairs.erase(pair1);
seenTablePropPairs.erase(pair2);
if (NormalizationResult::True != res)
return {builtinTypes->neverType};
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
} }
else else
{ {
seenSet.insert(*hprop.readTy);
seenSet.insert(*tprop.readTy); if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy))
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
return {builtinTypes->neverType};
}
else
{
seenSet.insert(*hprop.readTy);
seenSet.insert(*tprop.readTy);
}
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy);
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
if (NormalizationResult::True != res)
return {builtinTypes->neverType};
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
} }
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy);
// Cleanup
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
if (NormalizationResult::True != res)
return {builtinTypes->neverType};
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
} }
else else
{ {
@ -2703,7 +2749,7 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
if (tmtable && hmtable) if (tmtable && hmtable)
{ {
// NOTE: this assumes metatables are ivariant // NOTE: this assumes metatables are ivariant
if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable, seenSet)) if (std::optional<TypeId> mtable = intersectionOfTables(hmtable, tmtable, seenTablePropPairs, seenSet))
{ {
if (table == htable && *mtable == hmtable) if (table == htable && *mtable == hmtable)
return here; return here;
@ -2733,12 +2779,12 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
return table; return table;
} }
void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set<TypeId>& seenSetTypes) void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set<TypeId>& seenSetTypes)
{ {
TypeIds tmp; TypeIds tmp;
for (TypeId here : heres) for (TypeId here : heres)
{ {
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes)) if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes))
tmp.insert(*inter); tmp.insert(*inter);
} }
heres.retain(tmp); heres.retain(tmp);
@ -2753,7 +2799,8 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres)
for (TypeId there : theres) for (TypeId there : theres)
{ {
Set<TypeId> seenSetTypes{nullptr}; Set<TypeId> seenSetTypes{nullptr};
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenSetTypes)) SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}};
if (std::optional<TypeId> inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes))
tmp.insert(*inter); tmp.insert(*inter);
} }
} }
@ -2971,12 +3018,17 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali
} }
} }
NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes) NormalizationResult Normalizer::intersectTyvarsWithTy(
NormalizedTyvars& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes
)
{ {
for (auto it = here.begin(); it != here.end();) for (auto it = here.begin(); it != here.end();)
{ {
NormalizedType& inter = *it->second; NormalizedType& inter = *it->second;
NormalizationResult res = intersectNormalWithTy(inter, there, seenSetTypes); NormalizationResult res = intersectNormalWithTy(inter, there, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
if (isShallowInhabited(inter)) if (isShallowInhabited(inter))
@ -2990,6 +3042,13 @@ NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, Ty
// See above for an explaination of `ignoreSmallerTyvars`. // See above for an explaination of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars)
{ {
if (FFlag::LuauIntersectNormalsNeedsToTrackResourceLimits)
{
RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits())
return NormalizationResult::HitLimits;
}
if (!get<NeverType>(there.tops)) if (!get<NeverType>(there.tops))
{ {
here.tops = intersectionOfTops(here.tops, there.tops); here.tops = intersectionOfTops(here.tops, there.tops);
@ -3001,13 +3060,10 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor
return unionNormals(here, there, ignoreSmallerTyvars); return unionNormals(here, there, ignoreSmallerTyvars);
} }
if (FFlag::LuauUseNormalizeIntersectionLimit) // Limit based on worst-case expansion of the table intersection
{ // This restriction can be relaxed when table intersection simplification is improved
// Limit based on worst-case expansion of the table intersection if (here.tables.size() * there.tables.size() >= size_t(FInt::LuauNormalizeIntersectionLimit))
// This restriction can be relaxed when table intersection simplification is improved return NormalizationResult::HitLimits;
if (here.tables.size() * there.tables.size() >= size_t(FInt::LuauNormalizeIntersectionLimit))
return NormalizationResult::HitLimits;
}
here.booleans = intersectionOfBools(here.booleans, there.booleans); here.booleans = intersectionOfBools(here.booleans, there.booleans);
@ -3062,7 +3118,12 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor
return NormalizationResult::True; return NormalizationResult::True;
} }
NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes) NormalizationResult Normalizer::intersectNormalWithTy(
NormalizedType& here,
TypeId there,
SeenTablePropPairs& seenTablePropPairs,
Set<TypeId>& seenSetTypes
)
{ {
RecursionCounter _rc(&sharedState->counters.recursionCount); RecursionCounter _rc(&sharedState->counters.recursionCount);
if (!withinResourceLimits()) if (!withinResourceLimits())
@ -3078,14 +3139,14 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
else if (!get<NeverType>(here.tops)) else if (!get<NeverType>(here.tops))
{ {
clearNormal(here); clearNormal(here);
return unionNormalWithTy(here, there, seenSetTypes); return unionNormalWithTy(here, there, seenTablePropPairs, seenSetTypes);
} }
else if (const UnionType* utv = get<UnionType>(there)) else if (const UnionType* utv = get<UnionType>(there))
{ {
NormalizedType norm{builtinTypes}; NormalizedType norm{builtinTypes};
for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) for (UnionTypeIterator it = begin(utv); it != end(utv); ++it)
{ {
NormalizationResult res = unionNormalWithTy(norm, *it, seenSetTypes); NormalizationResult res = unionNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
@ -3095,7 +3156,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{ {
for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it)
{ {
NormalizationResult res = intersectNormalWithTy(here, *it, seenSetTypes); NormalizationResult res = intersectNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
} }
@ -3124,7 +3185,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{ {
TypeIds tables = std::move(here.tables); TypeIds tables = std::move(here.tables);
clearNormal(here); clearNormal(here);
intersectTablesWithTable(tables, there, seenSetTypes); intersectTablesWithTable(tables, there, seenTablePropPairs, seenSetTypes);
here.tables = std::move(tables); here.tables = std::move(tables);
} }
else if (get<ClassType>(there)) else if (get<ClassType>(there))
@ -3236,7 +3297,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
return NormalizationResult::True; return NormalizationResult::True;
} }
else if (auto nt = get<NegationType>(t)) else if (auto nt = get<NegationType>(t))
return intersectNormalWithTy(here, nt->ty, seenSetTypes); return intersectNormalWithTy(here, nt->ty, seenTablePropPairs, seenSetTypes);
else else
{ {
// TODO negated unions, intersections, table, and function. // TODO negated unions, intersections, table, and function.
@ -3256,7 +3317,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenSetTypes); NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenTablePropPairs, seenSetTypes);
if (res != NormalizationResult::True) if (res != NormalizationResult::True)
return res; return res;
here.tyvars = std::move(tyvars); here.tyvars = std::move(tyvars);
@ -3456,38 +3517,4 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, N
} }
} }
bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice)
{
LUAU_ASSERT(!FFlag::LuauSolverV2);
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant};
u.tryUnify(subTy, superTy);
const bool ok = u.errors.empty() && u.log.empty();
return ok;
}
bool isConsistentSubtype(
TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter& ice
)
{
LUAU_ASSERT(!FFlag::LuauSolverV2);
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant};
u.tryUnify(subPack, superPack);
const bool ok = u.errors.empty() && u.log.empty();
return ok;
}
} // namespace Luau } // namespace Luau

View File

@ -2,6 +2,7 @@
#include "Luau/Simplify.h" #include "Luau/Simplify.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/Set.h" #include "Luau/Set.h"
@ -14,6 +15,7 @@
LUAU_FASTINT(LuauTypeReductionRecursionLimit) LUAU_FASTINT(LuauTypeReductionRecursionLimit)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_DYNAMIC_FASTINTVARIABLE(LuauSimplificationComplexityLimit, 8); LUAU_DYNAMIC_FASTINTVARIABLE(LuauSimplificationComplexityLimit, 8);
LUAU_FASTFLAGVARIABLE(LuauFlagBasicIntersectFollows, false);
namespace Luau namespace Luau
{ {
@ -1064,6 +1066,12 @@ TypeId TypeSimplifier::intersectIntersectionWithType(TypeId left, TypeId right)
std::optional<TypeId> TypeSimplifier::basicIntersect(TypeId left, TypeId right) std::optional<TypeId> TypeSimplifier::basicIntersect(TypeId left, TypeId right)
{ {
if (FFlag::LuauFlagBasicIntersectFollows)
{
left = follow(left);
right = follow(right);
}
if (get<AnyType>(left) && get<ErrorType>(right)) if (get<AnyType>(left) && get<ErrorType>(right))
return right; return right;
if (get<AnyType>(right) && get<ErrorType>(left)) if (get<AnyType>(right) && get<ErrorType>(left))

View File

@ -22,7 +22,6 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity, false); LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity, false);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteNewSolverLimit, false);
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
namespace Luau namespace Luau
@ -512,19 +511,14 @@ struct SeenSetPopper
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy, NotNull<Scope> scope) SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy, NotNull<Scope> scope)
{ {
std::optional<RecursionCounter> rc; UnifierCounters& counters = normalizer->sharedState->counters;
RecursionCounter rc(&counters.recursionCount);
if (FFlag::LuauAutocompleteNewSolverLimit) if (counters.recursionLimit > 0 && counters.recursionLimit < counters.recursionCount)
{ {
UnifierCounters& counters = normalizer->sharedState->counters; SubtypingResult result;
rc.emplace(&counters.recursionCount); result.normalizationTooComplex = true;
return result;
if (counters.recursionLimit > 0 && counters.recursionLimit < counters.recursionCount)
{
SubtypingResult result;
result.normalizationTooComplex = true;
return result;
}
} }
subTy = follow(subTy); subTy = follow(subTy);

View File

@ -93,8 +93,8 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull<TypeArena> arena)
if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead)
{ {
TypeId leftTy = arena->addType((*leftRep)->pending); TypeId leftTy = arena->addType((*leftRep)->pending.clone());
TypeId rightTy = arena->addType(rightRep->pending); TypeId rightTy = arena->addType(rightRep->pending.clone());
typeVarChanges[ty]->pending.ty = IntersectionType{{leftTy, rightTy}}; typeVarChanges[ty]->pending.ty = IntersectionType{{leftTy, rightTy}};
} }
else else
@ -170,8 +170,8 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull<TypeArena> arena)
if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead)
{ {
TypeId leftTy = arena->addType((*leftRep)->pending); TypeId leftTy = arena->addType((*leftRep)->pending.clone());
TypeId rightTy = arena->addType(rightRep->pending); TypeId rightTy = arena->addType(rightRep->pending.clone());
if (follow(leftTy) == follow(rightTy)) if (follow(leftTy) == follow(rightTy))
typeVarChanges[ty] = std::move(rightRep); typeVarChanges[ty] = std::move(rightRep);
@ -217,7 +217,7 @@ TxnLog TxnLog::inverse()
for (auto& [ty, _rep] : typeVarChanges) for (auto& [ty, _rep] : typeVarChanges)
{ {
if (!_rep->dead) if (!_rep->dead)
inversed.typeVarChanges[ty] = std::make_unique<PendingType>(*ty); inversed.typeVarChanges[ty] = std::make_unique<PendingType>(ty->clone());
} }
for (auto& [tp, _rep] : typePackChanges) for (auto& [tp, _rep] : typePackChanges)
@ -292,7 +292,7 @@ PendingType* TxnLog::queue(TypeId ty)
auto& pending = typeVarChanges[ty]; auto& pending = typeVarChanges[ty];
if (!pending || (*pending).dead) if (!pending || (*pending).dead)
{ {
pending = std::make_unique<PendingType>(*ty); pending = std::make_unique<PendingType>(ty->clone());
pending->pending.owningArena = nullptr; pending->pending.owningArena = nullptr;
} }

View File

@ -999,6 +999,11 @@ Type& Type::operator=(const Type& rhs)
return *this; return *this;
} }
Type Type::clone() const
{
return *this;
}
TypeId makeFunction( TypeId makeFunction(
TypeArena& arena, TypeArena& arena,
std::optional<TypeId> selfType, std::optional<TypeId> selfType,

View File

@ -5037,17 +5037,17 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c
{ {
// First try unifying with the original uninstantiated type // First try unifying with the original uninstantiated type
// but if that fails, try the instantiated one. // but if that fails, try the instantiated one.
Unifier child = state.makeChildUnifier(); std::unique_ptr<Unifier> child = state.makeChildUnifier();
child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); child->tryUnify(subTy, superTy, /*isFunctionCall*/ false);
if (!child.errors.empty()) if (!child->errors.empty())
{ {
TypeId instantiated = instantiate(scope, subTy, state.location, &child.log); TypeId instantiated = instantiate(scope, subTy, state.location, &child->log);
if (subTy == instantiated) if (subTy == instantiated)
{ {
// Instantiating the argument made no difference, so just report any child errors // Instantiating the argument made no difference, so just report any child errors
state.log.concat(std::move(child.log)); state.log.concat(std::move(child->log));
state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); state.errors.insert(state.errors.end(), child->errors.begin(), child->errors.end());
} }
else else
{ {
@ -5056,7 +5056,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c
} }
else else
{ {
state.log.concat(std::move(child.log)); state.log.concat(std::move(child->log));
} }
} }
} }

View File

@ -749,25 +749,25 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ
for (TypeId type : subUnion->options) for (TypeId type : subUnion->options)
{ {
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.tryUnify_(type, superTy); innerState->tryUnify_(type, superTy);
if (useNewSolver) if (useNewSolver)
logs.push_back(std::move(innerState.log)); logs.push_back(std::move(innerState->log));
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState->errors))
unificationTooComplex = e; unificationTooComplex = e;
else if (innerState.failure) else if (innerState->failure)
{ {
// If errors were suppressed, we store the log up, so we can commit it if no other option succeeds. // If errors were suppressed, we store the log up, so we can commit it if no other option succeeds.
if (innerState.errors.empty()) if (innerState->errors.empty())
logs.push_back(std::move(innerState.log)); logs.push_back(std::move(innerState->log));
// 'nil' option is skipped from extended report because we present the type in a special way - 'T?' // 'nil' option is skipped from extended report because we present the type in a special way - 'T?'
else if (!firstFailedOption && !isNil(type)) else if (!firstFailedOption && !isNil(type))
firstFailedOption = {innerState.errors.front()}; firstFailedOption = {innerState->errors.front()};
failed = true; failed = true;
errorsSuppressed &= innerState.errors.empty(); errorsSuppressed &= innerState->errors.empty();
} }
} }
@ -862,26 +862,26 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
for (size_t i = 0; i < uv->options.size(); ++i) for (size_t i = 0; i < uv->options.size(); ++i)
{ {
TypeId type = uv->options[(i + startIndex) % uv->options.size()]; TypeId type = uv->options[(i + startIndex) % uv->options.size()];
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.normalize = false; innerState->normalize = false;
innerState.tryUnify_(subTy, type, isFunctionCall); innerState->tryUnify_(subTy, type, isFunctionCall);
if (!innerState.failure) if (!innerState->failure)
{ {
found = true; found = true;
if (useNewSolver) if (useNewSolver)
logs.push_back(std::move(innerState.log)); logs.push_back(std::move(innerState->log));
else else
{ {
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
break; break;
} }
} }
else if (innerState.errors.empty()) else if (innerState->errors.empty())
{ {
errorsSuppressed = true; errorsSuppressed = true;
} }
else if (auto e = hasUnificationTooComplex(innerState.errors)) else if (auto e = hasUnificationTooComplex(innerState->errors))
{ {
unificationTooComplex = e; unificationTooComplex = e;
} }
@ -890,7 +890,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
failedOptionCount++; failedOptionCount++;
if (!failedOption) if (!failedOption)
failedOption = {innerState.errors.front()}; failedOption = {innerState->errors.front()};
} }
} }
@ -906,25 +906,25 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
// It is possible that T <: A | B even though T </: A and T </:B // It is possible that T <: A | B even though T </: A and T </:B
// for example boolean <: true | false. // for example boolean <: true | false.
// We deal with this by type normalization. // We deal with this by type normalization.
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy); std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy); std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm) if (!subNorm || !superNorm)
return reportError(location, NormalizationTooComplex{}); return reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
innerState.tryUnifyNormalizedTypes( innerState->tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption
); );
else else
innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); innerState->tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
if (!innerState.failure) if (!innerState->failure)
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
else if (errorsSuppressed || innerState.errors.empty()) else if (errorsSuppressed || innerState->errors.empty())
failure = true; failure = true;
else else
reportError(std::move(innerState.errors.front())); reportError(std::move(innerState->errors.front()));
} }
else if (!found && normalize) else if (!found && normalize)
{ {
@ -963,22 +963,22 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I
// T <: A & B if and only if T <: A and T <: B // T <: A & B if and only if T <: A and T <: B
for (TypeId type : uv->parts) for (TypeId type : uv->parts)
{ {
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); innerState->tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true);
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState->errors))
unificationTooComplex = e; unificationTooComplex = e;
else if (!innerState.errors.empty()) else if (!innerState->errors.empty())
{ {
if (!firstFailedOption) if (!firstFailedOption)
firstFailedOption = {innerState.errors.front()}; firstFailedOption = {innerState->errors.front()};
} }
if (useNewSolver) if (useNewSolver)
logs.push_back(std::move(innerState.log)); logs.push_back(std::move(innerState->log));
else else
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
failure |= innerState.failure; failure |= innerState->failure;
} }
if (useNewSolver) if (useNewSolver)
@ -1058,27 +1058,27 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
for (size_t i = 0; i < uv->parts.size(); ++i) for (size_t i = 0; i < uv->parts.size(); ++i)
{ {
TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; TypeId type = uv->parts[(i + startIndex) % uv->parts.size()];
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.normalize = false; innerState->normalize = false;
innerState.tryUnify_(type, superTy, isFunctionCall); innerState->tryUnify_(type, superTy, isFunctionCall);
// TODO: This sets errorSuppressed to true if any of the parts is error-suppressing, // TODO: This sets errorSuppressed to true if any of the parts is error-suppressing,
// in paricular any & T is error-suppressing. Really, errorSuppressed should be true if // in paricular any & T is error-suppressing. Really, errorSuppressed should be true if
// all of the parts are error-suppressing, but that fails to typecheck lua-apps. // all of the parts are error-suppressing, but that fails to typecheck lua-apps.
if (innerState.errors.empty()) if (innerState->errors.empty())
{ {
found = true; found = true;
errorsSuppressed = innerState.failure; errorsSuppressed = innerState->failure;
if (useNewSolver || innerState.failure) if (useNewSolver || innerState->failure)
logs.push_back(std::move(innerState.log)); logs.push_back(std::move(innerState->log));
else else
{ {
errorsSuppressed = false; errorsSuppressed = false;
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
break; break;
} }
} }
else if (auto e = hasUnificationTooComplex(innerState.errors)) else if (auto e = hasUnificationTooComplex(innerState->errors))
{ {
unificationTooComplex = e; unificationTooComplex = e;
} }
@ -1204,16 +1204,16 @@ void Unifier::tryUnifyNormalizedTypes(
{ {
for (TypeId superTable : superNorm.tables) for (TypeId superTable : superNorm.tables)
{ {
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.tryUnify(subClass, superTable); innerState->tryUnify(subClass, superTable);
if (innerState.errors.empty()) if (innerState->errors.empty())
{ {
found = true; found = true;
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
break; break;
} }
else if (auto e = hasUnificationTooComplex(innerState.errors)) else if (auto e = hasUnificationTooComplex(innerState->errors))
return reportError(*e); return reportError(*e);
} }
} }
@ -1235,17 +1235,17 @@ void Unifier::tryUnifyNormalizedTypes(
break; break;
} }
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.tryUnify(subTable, superTable); innerState->tryUnify(subTable, superTable);
if (innerState.errors.empty()) if (innerState->errors.empty())
{ {
found = true; found = true;
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
break; break;
} }
else if (auto e = hasUnificationTooComplex(innerState.errors)) else if (auto e = hasUnificationTooComplex(innerState->errors))
return reportError(*e); return reportError(*e);
} }
if (!found) if (!found)
@ -1258,15 +1258,15 @@ void Unifier::tryUnifyNormalizedTypes(
return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()});
for (TypeId superFun : superNorm.functions.parts) for (TypeId superFun : superNorm.functions.parts)
{ {
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
const FunctionType* superFtv = get<FunctionType>(superFun); const FunctionType* superFtv = get<FunctionType>(superFun);
if (!superFtv) if (!superFtv)
return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()});
TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); TypePackId tgt = innerState->tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes);
innerState.tryUnify_(tgt, superFtv->retTypes); innerState->tryUnify_(tgt, superFtv->retTypes);
if (innerState.errors.empty()) if (innerState->errors.empty())
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
else if (auto e = hasUnificationTooComplex(innerState.errors)) else if (auto e = hasUnificationTooComplex(innerState->errors))
return reportError(*e); return reportError(*e);
else else
return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()});
@ -1304,17 +1304,17 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized
{ {
if (!firstFun) if (!firstFun)
firstFun = ftv; firstFun = ftv;
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.tryUnify_(args, ftv->argTypes); innerState->tryUnify_(args, ftv->argTypes);
if (innerState.errors.empty()) if (innerState->errors.empty())
{ {
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
if (result) if (result)
{ {
innerState.log.clear(); innerState->log.clear();
innerState.tryUnify_(*result, ftv->retTypes); innerState->tryUnify_(*result, ftv->retTypes);
if (innerState.errors.empty()) if (innerState->errors.empty())
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
// Annoyingly, since we don't support intersection of generic type packs, // Annoyingly, since we don't support intersection of generic type packs,
// the intersection may fail. We rather arbitrarily use the first matching overload // the intersection may fail. We rather arbitrarily use the first matching overload
// in that case. // in that case.
@ -1324,7 +1324,7 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized
else else
result = ftv->retTypes; result = ftv->retTypes;
} }
else if (auto e = hasUnificationTooComplex(innerState.errors)) else if (auto e = hasUnificationTooComplex(innerState->errors))
{ {
reportError(*e); reportError(*e);
return builtinTypes->errorRecoveryTypePack(args); return builtinTypes->errorRecoveryTypePack(args);
@ -1510,18 +1510,18 @@ void Unifier::enableNewSolver()
ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy)
{ {
Unifier s = makeChildUnifier(); std::unique_ptr<Unifier> s = makeChildUnifier();
s.tryUnify_(subTy, superTy); s->tryUnify_(subTy, superTy);
return s.errors; return s->errors;
} }
ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall) ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall)
{ {
Unifier s = makeChildUnifier(); std::unique_ptr<Unifier> s = makeChildUnifier();
s.tryUnify_(subTy, superTy, isFunctionCall); s->tryUnify_(subTy, superTy, isFunctionCall);
return s.errors; return s->errors;
} }
void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall) void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall)
@ -1884,9 +1884,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
// generic methods in tables to be marked read-only. // generic methods in tables to be marked read-only.
if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate) if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate)
{ {
Instantiation instantiation{&log, types, builtinTypes, scope->level, scope}; std::unique_ptr<Instantiation> instantiation = std::make_unique<Instantiation>(&log, types, builtinTypes, scope->level, scope);
std::optional<TypeId> instantiated = instantiation.substitute(subTy); std::optional<TypeId> instantiated = instantiation->substitute(subTy);
if (instantiated.has_value()) if (instantiated.has_value())
{ {
subFunction = log.getMutable<FunctionType>(*instantiated); subFunction = log.getMutable<FunctionType>(*instantiated);
@ -1930,54 +1930,54 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (!isFunctionCall) if (!isFunctionCall)
{ {
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.ctx = CountMismatch::Arg; innerState->ctx = CountMismatch::Arg;
innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); innerState->tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall);
bool reported = !innerState.errors.empty(); bool reported = !innerState->errors.empty();
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState->errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty() && innerState.firstPackErrorPos) else if (!innerState->errors.empty() && innerState->firstPackErrorPos)
reportError( reportError(
location, location,
TypeMismatch{ TypeMismatch{
superTy, superTy,
subTy, subTy,
format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), format("Argument #%d type is not compatible.", *innerState->firstPackErrorPos),
innerState.errors.front(), innerState->errors.front(),
mismatchContext() mismatchContext()
} }
); );
else if (!innerState.errors.empty()) else if (!innerState->errors.empty())
reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); reportError(location, TypeMismatch{superTy, subTy, "", innerState->errors.front(), mismatchContext()});
innerState.ctx = CountMismatch::FunctionResult; innerState->ctx = CountMismatch::FunctionResult;
innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); innerState->tryUnify_(subFunction->retTypes, superFunction->retTypes);
if (!reported) if (!reported)
{ {
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState->errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) else if (!innerState->errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes))
reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front(), mismatchContext()}); reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState->errors.front(), mismatchContext()});
else if (!innerState.errors.empty() && innerState.firstPackErrorPos) else if (!innerState->errors.empty() && innerState->firstPackErrorPos)
reportError( reportError(
location, location,
TypeMismatch{ TypeMismatch{
superTy, superTy,
subTy, subTy,
format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), format("Return #%d type is not compatible.", *innerState->firstPackErrorPos),
innerState.errors.front(), innerState->errors.front(),
mismatchContext() mismatchContext()
} }
); );
else if (!innerState.errors.empty()) else if (!innerState->errors.empty())
reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); reportError(location, TypeMismatch{superTy, subTy, "", innerState->errors.front(), mismatchContext()});
} }
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
} }
else else
{ {
@ -2115,14 +2115,14 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
if (!literalProperties || !literalProperties->contains(name)) if (!literalProperties || !literalProperties->contains(name))
variance = Invariant; variance = Invariant;
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.tryUnify_(r->second.type(), prop.type()); innerState->tryUnify_(r->second.type(), prop.type());
checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy);
if (innerState.errors.empty()) if (innerState->errors.empty())
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
failure |= innerState.failure; failure |= innerState->failure;
} }
else if (subTable->indexer && maybeString(subTable->indexer->indexType)) else if (subTable->indexer && maybeString(subTable->indexer->indexType))
{ {
@ -2132,14 +2132,14 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
if (!literalProperties || !literalProperties->contains(name)) if (!literalProperties || !literalProperties->contains(name))
variance = Invariant; variance = Invariant;
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.tryUnify_(subTable->indexer->indexResultType, prop.type()); innerState->tryUnify_(subTable->indexer->indexResultType, prop.type());
checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy);
if (innerState.errors.empty()) if (innerState->errors.empty())
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
failure |= innerState.failure; failure |= innerState->failure;
} }
else if (subTable->state == TableState::Unsealed && isOptional(prop.type())) else if (subTable->state == TableState::Unsealed && isOptional(prop.type()))
// This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }`
@ -2210,20 +2210,20 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
if (!literalProperties || !literalProperties->contains(name)) if (!literalProperties || !literalProperties->contains(name))
variance = Invariant; variance = Invariant;
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
if (useNewSolver || FFlag::LuauFixIndexerSubtypingOrdering) if (useNewSolver || FFlag::LuauFixIndexerSubtypingOrdering)
innerState.tryUnify_(prop.type(), superTable->indexer->indexResultType); innerState->tryUnify_(prop.type(), superTable->indexer->indexResultType);
else else
{ {
// Incredibly, the old solver depends on this bug somehow. // Incredibly, the old solver depends on this bug somehow.
innerState.tryUnify_(superTable->indexer->indexResultType, prop.type()); innerState->tryUnify_(superTable->indexer->indexResultType, prop.type());
} }
checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy);
if (innerState.errors.empty()) if (innerState->errors.empty())
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
failure |= innerState.failure; failure |= innerState->failure;
} }
else if (superTable->state == TableState::Unsealed) else if (superTable->state == TableState::Unsealed)
{ {
@ -2294,22 +2294,22 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
Resetter resetter{&variance}; Resetter resetter{&variance};
variance = Invariant; variance = Invariant;
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); innerState->tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType);
bool reported = !innerState.errors.empty(); bool reported = !innerState->errors.empty();
checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); checkChildUnifierTypeMismatch(innerState->errors, "[indexer key]", superTy, subTy);
innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); innerState->tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType);
if (!reported) if (!reported)
checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); checkChildUnifierTypeMismatch(innerState->errors, "[indexer value]", superTy, subTy);
if (innerState.errors.empty()) if (innerState->errors.empty())
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
failure |= innerState.failure; failure |= innerState->failure;
} }
else if (superTable->indexer) else if (superTable->indexer)
{ {
@ -2408,13 +2408,13 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
if (auto it = mttv->props.find("__index"); it != mttv->props.end()) if (auto it = mttv->props.find("__index"); it != mttv->props.end())
{ {
TypeId ty = it->second.type(); TypeId ty = it->second.type();
Unifier child = makeChildUnifier(); std::unique_ptr<Unifier> child = makeChildUnifier();
child.tryUnify_(ty, superTy); child->tryUnify_(ty, superTy);
// To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table
// There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed
// Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check
TypeId newSuperTy = child.log.follow(superTy); TypeId newSuperTy = child->log.follow(superTy);
if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty())
{ {
@ -2422,16 +2422,16 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
return; return;
} }
if (auto e = hasUnificationTooComplex(child.errors)) if (auto e = hasUnificationTooComplex(child->errors))
reportError(*e); reportError(*e);
else if (!child.errors.empty()) else if (!child->errors.empty())
fail(child.errors.front()); fail(child->errors.front());
log.concat(std::move(child.log)); log.concat(std::move(child->log));
// To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table
// We return success because subtype <: free table which means that correct unification is to replace free table with the subtype // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype
if (child.errors.empty()) if (child->errors.empty())
log.replace(superTy, BoundType{subTy}); log.replace(superTy, BoundType{subTy});
return; return;
@ -2476,19 +2476,19 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
if (const MetatableType* subMetatable = log.getMutable<MetatableType>(subTy)) if (const MetatableType* subMetatable = log.getMutable<MetatableType>(subTy))
{ {
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.tryUnify_(subMetatable->table, superMetatable->table); innerState->tryUnify_(subMetatable->table, superMetatable->table);
innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); innerState->tryUnify_(subMetatable->metatable, superMetatable->metatable);
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState->errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty()) else if (!innerState->errors.empty())
reportError( reportError(
location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()} location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState->errors.front(), mismatchContext()}
); );
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
failure |= innerState.failure; failure |= innerState->failure;
} }
else if (TableType* subTable = log.getMutable<TableType>(subTy)) else if (TableType* subTable = log.getMutable<TableType>(subTy))
{ {
@ -2498,14 +2498,14 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
{ {
if (useNewSolver) if (useNewSolver)
{ {
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
bool missingProperty = false; bool missingProperty = false;
for (const auto& [propName, prop] : subTable->props) for (const auto& [propName, prop] : subTable->props)
{ {
if (std::optional<TypeId> mtPropTy = findTablePropertyRespectingMeta(superTy, propName)) if (std::optional<TypeId> mtPropTy = findTablePropertyRespectingMeta(superTy, propName))
{ {
innerState.tryUnify(prop.type(), *mtPropTy); innerState->tryUnify(prop.type(), *mtPropTy);
} }
else else
{ {
@ -2520,18 +2520,18 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
// TODO: Unify indexers. // TODO: Unify indexers.
} }
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState->errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty()) else if (!innerState->errors.empty())
reportError(TypeError{ reportError(TypeError{
location, location,
TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()} TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState->errors.front(), mismatchContext()}
}); });
else if (!missingProperty) else if (!missingProperty)
{ {
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
log.bindTable(subTy, superTy); log.bindTable(subTy, superTy);
failure |= innerState.failure; failure |= innerState->failure;
} }
} }
else else
@ -2618,15 +2618,15 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
} }
else else
{ {
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState.tryUnify_(classProp->type(), prop.type()); innerState->tryUnify_(classProp->type(), prop.type());
checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); checkChildUnifierTypeMismatch(innerState->errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy);
if (innerState.errors.empty()) if (innerState->errors.empty())
{ {
log.concat(std::move(innerState.log)); log.concat(std::move(innerState->log));
failure |= innerState.failure; failure |= innerState->failure;
} }
else else
{ {
@ -2662,9 +2662,9 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy)
return reportError(location, NormalizationTooComplex{}); return reportError(location, NormalizationTooComplex{});
// T </: ~U iff T <: U // T </: ~U iff T <: U
Unifier state = makeChildUnifier(); std::unique_ptr<Unifier> state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, ""); state->tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty()) if (state->errors.empty())
reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); reportError(location, TypeMismatch{superTy, subTy, mismatchContext()});
} }
@ -2889,27 +2889,27 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack, bool reversed)
if (occurs) if (occurs)
{ {
Unifier innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
if (const UnionType* ut = get<UnionType>(haystack)) if (const UnionType* ut = get<UnionType>(haystack))
{ {
if (reversed) if (reversed)
innerState.tryUnifyUnionWithType(haystack, ut, needle); innerState->tryUnifyUnionWithType(haystack, ut, needle);
else else
innerState.tryUnifyTypeWithUnion(needle, haystack, ut, /* cacheEnabled = */ false, /* isFunction = */ false); innerState->tryUnifyTypeWithUnion(needle, haystack, ut, /* cacheEnabled = */ false, /* isFunction = */ false);
} }
else if (const IntersectionType* it = get<IntersectionType>(haystack)) else if (const IntersectionType* it = get<IntersectionType>(haystack))
{ {
if (reversed) if (reversed)
innerState.tryUnifyIntersectionWithType(haystack, it, needle, /* cacheEnabled = */ false, /* isFunction = */ false); innerState->tryUnifyIntersectionWithType(haystack, it, needle, /* cacheEnabled = */ false, /* isFunction = */ false);
else else
innerState.tryUnifyTypeWithIntersection(needle, haystack, it); innerState->tryUnifyTypeWithIntersection(needle, haystack, it);
} }
else else
{ {
innerState.failure = true; innerState->failure = true;
} }
if (innerState.failure) if (innerState->failure)
{ {
reportError(location, OccursCheckFailed{}); reportError(location, OccursCheckFailed{});
log.replace(needle, BoundType{builtinTypes->errorRecoveryType()}); log.replace(needle, BoundType{builtinTypes->errorRecoveryType()});
@ -3014,14 +3014,14 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
return false; return false;
} }
Unifier Unifier::makeChildUnifier() std::unique_ptr<Unifier> Unifier::makeChildUnifier()
{ {
Unifier u = Unifier{normalizer, scope, location, variance, &log}; std::unique_ptr<Unifier> u = std::make_unique<Unifier>(normalizer, scope, location, variance, &log);
u.normalize = normalize; u->normalize = normalize;
u.checkInhabited = checkInhabited; u->checkInhabited = checkInhabited;
if (useNewSolver) if (useNewSolver)
u.enableNewSolver(); u->enableNewSolver();
return u; return u;
} }

View File

@ -1490,6 +1490,7 @@ public:
} }
}; };
bool isLValue(const AstExpr*);
AstName getIdentifier(AstExpr*); AstName getIdentifier(AstExpr*);
Location getLocation(const AstTypeList& typeList); Location getLocation(const AstTypeList& typeList);
@ -1520,4 +1521,4 @@ struct hash<Luau::AstName>
} }
}; };
} // namespace std } // namespace std

View File

@ -1146,6 +1146,14 @@ void AstTypePackGeneric::visit(AstVisitor* visitor)
visitor->visit(this); visitor->visit(this);
} }
bool isLValue(const AstExpr* expr)
{
return expr->is<AstExprLocal>()
|| expr->is<AstExprGlobal>()
|| expr->is<AstExprIndexName>()
|| expr->is<AstExprIndexExpr>();
}
AstName getIdentifier(AstExpr* node) AstName getIdentifier(AstExpr* node)
{ {
if (AstExprGlobal* expr = node->as<AstExprGlobal>()) if (AstExprGlobal* expr = node->as<AstExprGlobal>())
@ -1170,4 +1178,4 @@ Location getLocation(const AstTypeList& typeList)
return result; return result;
} }
} // namespace Luau } // namespace Luau

View File

@ -2,9 +2,13 @@
// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details
#include "lualib.h" #include "lualib.h"
#include "ldebug.h"
#include "lstate.h" #include "lstate.h"
#include "lvm.h" #include "lvm.h"
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCoroCheckStack, false)
LUAU_DYNAMIC_FASTFLAG(LuauStackLimit)
#define CO_STATUS_ERROR -1 #define CO_STATUS_ERROR -1
#define CO_STATUS_BREAK -2 #define CO_STATUS_BREAK -2
@ -37,6 +41,12 @@ static int auxresume(lua_State* L, lua_State* co, int narg)
luaL_error(L, "too many arguments to resume"); luaL_error(L, "too many arguments to resume");
lua_xmove(L, co, narg); lua_xmove(L, co, narg);
} }
else if (DFFlag::LuauCoroCheckStack)
{
// coroutine might be completely full already
if ((co->top - co->base) > LUAI_MAXCSTACK)
luaL_error(L, "too many arguments to resume");
}
co->singlestep = L->singlestep; co->singlestep = L->singlestep;
@ -227,8 +237,22 @@ static int coclose(lua_State* L)
else else
{ {
lua_pushboolean(L, false); lua_pushboolean(L, false);
if (lua_gettop(co))
lua_xmove(co, L, 1); // move error message if (DFFlag::LuauStackLimit)
{
if (co->status == LUA_ERRMEM)
lua_pushstring(L, LUA_MEMERRMSG);
else if (co->status == LUA_ERRERR)
lua_pushstring(L, LUA_ERRERRMSG);
else if (lua_gettop(co))
lua_xmove(co, L, 1); // move error message
}
else
{
if (lua_gettop(co))
lua_xmove(co, L, 1); // move error message
}
lua_resetthread(co); lua_resetthread(co);
return 2; return 2;
} }

View File

@ -17,6 +17,11 @@
#include <string.h> #include <string.h>
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauStackLimit, false)
// keep max stack allocation request under 1GB
#define MAX_STACK_SIZE (int(1024 / sizeof(TValue)) * 1024 * 1024)
/* /*
** {====================================================== ** {======================================================
** Error-recovery functions ** Error-recovery functions
@ -176,6 +181,10 @@ static void correctstack(lua_State* L, TValue* oldstack)
void luaD_reallocstack(lua_State* L, int newsize) void luaD_reallocstack(lua_State* L, int newsize)
{ {
// throw 'out of memory' error because space for a custom error message cannot be guaranteed here
if (DFFlag::LuauStackLimit && newsize > MAX_STACK_SIZE)
luaD_throw(L, LUA_ERRMEM);
TValue* oldstack = L->stack; TValue* oldstack = L->stack;
int realsize = newsize + EXTRA_STACK; int realsize = newsize + EXTRA_STACK;
LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK); LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK);

View File

@ -14,6 +14,8 @@
#include <string.h> #include <string.h>
LUAU_DYNAMIC_FASTFLAG(LuauCoroCheckStack)
/* /*
* Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * Luau uses an incremental non-generational non-moving mark&sweep garbage collector.
* *
@ -436,12 +438,27 @@ static void shrinkstack(lua_State* L)
int s_used = cast_int(lim - L->stack); // part of stack in use int s_used = cast_int(lim - L->stack); // part of stack in use
if (L->size_ci > LUAI_MAXCALLS) // handling overflow? if (L->size_ci > LUAI_MAXCALLS) // handling overflow?
return; // do not touch the stacks return; // do not touch the stacks
if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci)
luaD_reallocCI(L, L->size_ci / 2); // still big enough... if (DFFlag::LuauCoroCheckStack)
condhardstacktests(luaD_reallocCI(L, ci_used + 1)); {
if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) if (3 * size_t(ci_used) < size_t(L->size_ci) && 2 * BASIC_CI_SIZE < L->size_ci)
luaD_reallocstack(L, L->stacksize / 2); // still big enough... luaD_reallocCI(L, L->size_ci / 2); // still big enough...
condhardstacktests(luaD_reallocstack(L, s_used)); condhardstacktests(luaD_reallocCI(L, ci_used + 1));
if (3 * size_t(s_used) < size_t(L->stacksize) && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize)
luaD_reallocstack(L, L->stacksize / 2); // still big enough...
condhardstacktests(luaD_reallocstack(L, s_used));
}
else
{
if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci)
luaD_reallocCI(L, L->size_ci / 2); // still big enough...
condhardstacktests(luaD_reallocCI(L, ci_used + 1));
if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize)
luaD_reallocstack(L, L->stacksize / 2); // still big enough...
condhardstacktests(luaD_reallocstack(L, s_used));
}
} }
/* /*

View File

@ -15,9 +15,7 @@
LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2)
LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel)
LUAU_FASTFLAG(LuauAutocompleteNewSolverLimit)
LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauUseNormalizeIntersectionLimit)
using namespace Luau; using namespace Luau;
@ -3824,7 +3822,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_subtyping_recursion_limit")
if (!FFlag::LuauSolverV2) if (!FFlag::LuauSolverV2)
return; return;
ScopedFastFlag luauAutocompleteNewSolverLimit{FFlag::LuauAutocompleteNewSolverLimit, true};
ScopedFastInt luauTypeInferRecursionLimit{FInt::LuauTypeInferRecursionLimit, 10}; ScopedFastInt luauTypeInferRecursionLimit{FInt::LuauTypeInferRecursionLimit, 10};
const int parts = 100; const int parts = 100;

View File

@ -35,6 +35,7 @@ LUAU_FASTFLAG(LuauMathMap)
LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTFLAG(DebugLuauAbortingChecks)
LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsInstructionLimit)
LUAU_FASTFLAG(LuauNativeAttribute) LUAU_FASTFLAG(LuauNativeAttribute)
LUAU_DYNAMIC_FASTFLAG(LuauStackLimit)
static lua_CompileOptions defaultOptions() static lua_CompileOptions defaultOptions()
{ {
@ -755,6 +756,8 @@ TEST_CASE("Closure")
TEST_CASE("Calls") TEST_CASE("Calls")
{ {
ScopedFastFlag LuauStackLimit{DFFlag::LuauStackLimit, true};
runConformance("calls.lua"); runConformance("calls.lua");
} }
@ -794,6 +797,8 @@ static int cxxthrow(lua_State* L)
TEST_CASE("PCall") TEST_CASE("PCall")
{ {
ScopedFastFlag LuauStackLimit{DFFlag::LuauStackLimit, true};
runConformance( runConformance(
"pcall.lua", "pcall.lua",
[](lua_State* L) [](lua_State* L)

View File

@ -44,7 +44,7 @@ void ConstraintGeneratorFixture::solve(const std::string& code)
{ {
generateConstraints(code); generateConstraints(code);
ConstraintSolver cs{ ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {} NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, NotNull{dfg.get()}, {}
}; };
cs.run(); cs.run();
} }

View File

@ -5,14 +5,17 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Frontend.h"
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauAllowFragmentParsing);
LUAU_FASTFLAG(LuauStoreDFGOnModule2);
struct FragmentAutocompleteFixture : Fixture struct FragmentAutocompleteFixture : Fixture
{ {
ScopedFastFlag sffs[3] = {{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}, {FFlag::LuauStoreDFGOnModule2, true}};
FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos) FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos)
{ {
@ -31,11 +34,20 @@ struct FragmentAutocompleteFixture : Fixture
FragmentParseResult parseFragment(const std::string& document, const Position& cursorPos) FragmentParseResult parseFragment(const std::string& document, const Position& cursorPos)
{ {
ScopedFastFlag sffs[]{{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}};
SourceModule* srcModule = this->getMainSourceModule(); SourceModule* srcModule = this->getMainSourceModule();
std::string_view srcString = document; std::string_view srcString = document;
return Luau::parseFragment(*srcModule, srcString, cursorPos); return Luau::parseFragment(*srcModule, srcString, cursorPos);
} }
FragmentTypeCheckResult checkFragment(const std::string& document, const Position& cursorPos)
{
FrontendOptions options;
options.retainFullTypeGraphs = true;
// Don't strictly need this in the new solver
options.forAutocomplete = true;
options.runLintChecks = false;
return Luau::typecheckFragment(frontend, "MainModule", cursorPos, options, document);
}
}; };
TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests");
@ -267,3 +279,56 @@ local y = 5
} }
TEST_SUITE_END(); TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_typecheck_simple_fragment")
{
auto res = check(
R"(
local x = 4
local y = 5
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = checkFragment(
R"(
local x = 4
local y = 5
local z = x + y
)",
Position{3, 15}
);
auto opt = linearSearchForBinding(fragment.freshScope, "z");
REQUIRE(opt);
CHECK_EQ("number", toString(*opt));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_typecheck_fragment_inserted_inline")
{
auto res = check(
R"(
local x = 4
local y = 5
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = checkFragment(
R"(
local x = 4
local z = x
local y = 5
)",
Position{2, 11}
);
auto correct = linearSearchForBinding(fragment.freshScope, "z");
REQUIRE(correct);
CHECK_EQ("number", toString(*correct));
}
TEST_SUITE_END();

View File

@ -12,6 +12,7 @@
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauNormalizationTracksCyclicPairsThroughInhabitance)
using namespace Luau; using namespace Luau;
namespace namespace
@ -1026,4 +1027,109 @@ TEST_CASE_FIXTURE(NormalizeFixture, "truthy_table_property_and_optional_table_wi
CHECK("{ x: number }" == toString(ty)); CHECK("{ x: number }" == toString(ty));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "normalizer_should_be_able_to_detect_cyclic_tables_and_not_stack_overflow")
{
if (!FFlag::LuauSolverV2)
return;
ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 0};
ScopedFastFlag sff{FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance, true};
CheckResult result = check(R"(
--!strict
type Array<T> = { [number] : T}
type Object = { [number] : any}
type Set<T> = typeof(setmetatable(
{} :: {
size: number,
-- method definitions
add: (self: Set<T>, T) -> Set<T>,
clear: (self: Set<T>) -> (),
delete: (self: Set<T>, T) -> boolean,
has: (self: Set<T>, T) -> boolean,
ipairs: (self: Set<T>) -> any,
},
{} :: {
__index: Set<T>,
__iter: (self: Set<T>) -> (<K, V>({ [K]: V }, K?) -> (K, V), T),
}
))
type Map<K, V> = typeof(setmetatable(
{} :: {
size: number,
-- method definitions
set: (self: Map<K, V>, K, V) -> Map<K, V>,
get: (self: Map<K, V>, K) -> V | nil,
clear: (self: Map<K, V>) -> (),
delete: (self: Map<K, V>, K) -> boolean,
[K]: V,
has: (self: Map<K, V>, K) -> boolean,
keys: (self: Map<K, V>) -> Array<K>,
values: (self: Map<K, V>) -> Array<V>,
entries: (self: Map<K, V>) -> Array<Tuple<K, V>>,
ipairs: (self: Map<K, V>) -> any,
_map: { [K]: V },
_array: { [number]: K },
__index: (self: Map<K, V>, key: K) -> V,
__iter: (self: Map<K, V>) -> (<K, V>({ [K]: V }, K?) -> (K?, V), V),
__newindex: (self: Map<K, V>, key: K, value: V) -> (),
},
{} :: {
__index: Map<K, V>,
__iter: (self: Map<K, V>) -> (<K, V>({ [K]: V }, K?) -> (K, V), V),
__newindex: (self: Map<K, V>, key: K, value: V) -> (),
}
))
type mapFn<T, U> = (element: T, index: number) -> U
type mapFnWithThisArg<T, U> = (thisArg: any, element: T, index: number) -> U
function fromSet<T, U>(
value: Set<T>,
mapFn: (mapFn<T, U> | mapFnWithThisArg<T, U>)?,
thisArg: Object?
-- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts
): Array<U> | Array<T> | Array<string>
local array : { [number] : string} = {"foo"}
return array
end
function instanceof(tbl: any, class: any): boolean
return true
end
function fromArray<T, U>(
value: Array<T>,
mapFn: (mapFn<T, U> | mapFnWithThisArg<T, U>)?,
thisArg: Object?
-- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts
): Array<U> | Array<T> | Array<string>
local array : {[number] : string} = {}
return array
end
return function<T, U>(
value: string | Array<T> | Set<T> | Map<any, any>,
mapFn: (mapFn<T, U> | mapFnWithThisArg<T, U>)?,
thisArg: Object?
-- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts
): Array<U> | Array<T> | Array<string>
if value == nil then
error("cannot create array from a nil value")
end
local array: Array<U> | Array<T> | Array<string>
if instanceof(value, Set) then
array = fromSet(value :: Set<T>, mapFn, thisArg)
else
array = {}
end
return array
end
)");
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
#include "Luau/Common.h"
#include "Fixture.h" #include "Fixture.h"
@ -8,7 +9,10 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAG(LuauTypestateBuiltins)
LUAU_FASTFLAG(LuauStringFormatArityFix)
TEST_SUITE_BEGIN("BuiltinTests"); TEST_SUITE_BEGIN("BuiltinTests");
@ -802,6 +806,19 @@ TEST_CASE_FIXTURE(Fixture, "string_format_as_method")
CHECK_EQ(tm->givenType, builtinTypes->numberType); CHECK_EQ(tm->givenType, builtinTypes->numberType);
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_trivial_arity")
{
ScopedFastFlag sff{FFlag::LuauStringFormatArityFix, true};
CheckResult result = check(R"(
string.format()
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Argument count mismatch. Function 'string.format' expects at least 1 argument, but none are specified", toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument") TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
@ -1109,15 +1126,28 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic")
local c = tf3[2] local c = tf3[2]
local d = tf1.b local d = tf1.b
local a2 = t1.a
local b2 = t2.b
local c2 = t3[2]
)"); )");
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins)
CHECK("Key 'b' not found in table '{ read a: number }'" == toString(result.errors[0]));
else if (FFlag::LuauSolverV2)
CHECK("Key 'b' not found in table '{ a: number }'" == toString(result.errors[0])); CHECK("Key 'b' not found in table '{ a: number }'" == toString(result.errors[0]));
else else
CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0]));
CHECK(Location({13, 18}, {13, 23}) == result.errors[0].location); CHECK(Location({13, 18}, {13, 23}) == result.errors[0].location);
if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins)
{
CHECK_EQ("{ read a: number }", toString(requireTypeAtPosition({15, 19})));
CHECK_EQ("{ read b: string }", toString(requireTypeAtPosition({16, 19})));
CHECK_EQ("{boolean}", toString(requireTypeAtPosition({17, 19})));
}
CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("number", toString(requireType("a")));
CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("string", toString(requireType("b")));
CHECK_EQ("boolean", toString(requireType("c"))); CHECK_EQ("boolean", toString(requireType("c")));
@ -1126,6 +1156,86 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic")
CHECK_EQ("any", toString(requireType("d"))); CHECK_EQ("any", toString(requireType("d")));
else else
CHECK_EQ("*error-type*", toString(requireType("d"))); CHECK_EQ("*error-type*", toString(requireType("d")));
CHECK_EQ("number", toString(requireType("a2")));
CHECK_EQ("string", toString(requireType("b2")));
CHECK_EQ("boolean", toString(requireType("c2")));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_does_not_retroactively_block_mutation")
{
CheckResult result = check(R"(
local t1 = {a = 42}
t1.q = ":3"
local tf1 = table.freeze(t1)
local a = tf1.a
local b = t1.a
)");
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::LuauTypestateBuiltins)
{
CHECK_EQ("t1 | { read a: number, read q: string }", toString(requireType("t1")));
// before the assignment, it's `t1`
CHECK_EQ("t1", toString(requireTypeAtPosition({3, 8})));
// after the assignment, it's read-only.
CHECK_EQ("{ read a: number, read q: string }", toString(requireTypeAtPosition({8, 18})));
}
CHECK_EQ("number", toString(requireType("a")));
CHECK_EQ("number", toString(requireType("b")));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_no_generic_table")
{
ScopedFastFlag sff{FFlag::LuauSolverV2, true};
CheckResult result = check(R"(
--!strict
type k = {
read k: string,
}
function _(): k
return table.freeze({
k = "",
})
end
)");
if (FFlag::LuauTypestateBuiltins)
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_errors_on_non_tables")
{
CheckResult result = check(R"(
--!strict
table.freeze(42)
)");
// this does not error in the new solver without the typestate builtins functionality.
if (FFlag::LuauSolverV2 && !FFlag::LuauTypestateBuiltins)
{
LUAU_REQUIRE_NO_ERRORS(result);
return;
}
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins)
CHECK_EQ(toString(tm->wantedType), "table");
else
CHECK_EQ(toString(tm->wantedType), "{- -}");
CHECK_EQ(toString(tm->givenType), "number");
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments")

View File

@ -12,6 +12,7 @@
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauTypestateBuiltins)
using namespace Luau; using namespace Luau;
@ -152,6 +153,45 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_a_variadic_function")
CHECK(get<VariadicTypePack>(*iter.tail())); CHECK(get<VariadicTypePack>(*iter.tail()));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "cross_module_table_freeze")
{
fileResolver.source["game/A"] = R"(
--!strict
return {
a = 1,
}
)";
fileResolver.source["game/B"] = R"(
--!strict
return table.freeze(require(game.A))
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CheckResult bResult = frontend.check("game/B");
LUAU_REQUIRE_NO_ERRORS(bResult);
ModulePtr a = frontend.moduleResolver.getModule("game/A");
REQUIRE(a != nullptr);
// confirm that no cross-module mutation happened here!
if (FFlag::LuauSolverV2)
CHECK(toString(a->returnType) == "{ a: number }");
else
CHECK(toString(a->returnType) == "{| a: number |}");
ModulePtr b = frontend.moduleResolver.getModule("game/B");
REQUIRE(b != nullptr);
// confirm that no cross-module mutation happened here!
if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins)
CHECK(toString(b->returnType) == "{ read a: number }");
else if (FFlag::LuauSolverV2)
CHECK(toString(b->returnType) == "{ a: number }");
else
CHECK(toString(b->returnType) == "{| a: number |}");
}
TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(

View File

@ -8,7 +8,6 @@
#include "doctest.h" #include "doctest.h"
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauUseNormalizeIntersectionLimit)
using namespace Luau; using namespace Luau;
@ -2327,8 +2326,6 @@ end)
TEST_CASE_FIXTURE(Fixture, "refinements_table_intersection_limits" * doctest::timeout(0.5)) TEST_CASE_FIXTURE(Fixture, "refinements_table_intersection_limits" * doctest::timeout(0.5))
{ {
ScopedFastFlag LuauUseNormalizeIntersectionLimit{FFlag::LuauUseNormalizeIntersectionLimit, true};
CheckResult result = check(R"( CheckResult result = check(R"(
--!strict --!strict
type Dir = { type Dir = {

View File

@ -23,6 +23,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit);
LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauNormalizeCacheLimit);
LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauRecursionLimit);
LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues)
using namespace Luau; using namespace Luau;
@ -1706,4 +1707,29 @@ TEST_CASE_FIXTURE(Fixture, "react_lua_follow_free_type_ub")
)")); )"));
} }
TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue")
{
ScopedFastFlag sffs[] = {
{FFlag::LuauSolverV2, true},
{FFlag::LuauNewSolverVisitErrorExprLvalues, true}
};
// This should always fail to parse, but shouldn't assert. Previously this
// would assert as we end up _roughly_ parsing this (with a lot of error
// nodes) as:
//
// do
// x :: T, y = z
// end
//
// We assume that `T` has some resolved type that is set up during
// constraint generation and resolved during constraint solving to
// be used during typechecking. We didn't descend into error nodes
// in lvalue positions.
LUAU_REQUIRE_ERRORS(check(R"(
--!strict
(::,
)"));
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -32,7 +32,7 @@ TEST_SUITE_BEGIN("TryUnifyTests");
TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify")
{ {
Type numberOne{TypeVariant{PrimitiveType{PrimitiveType::Number}}}; Type numberOne{TypeVariant{PrimitiveType{PrimitiveType::Number}}};
Type numberTwo = numberOne; Type numberTwo = numberOne.clone();
state.tryUnify(&numberTwo, &numberOne); state.tryUnify(&numberTwo, &numberOne);
@ -64,13 +64,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved")
Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))
}}; }};
Type functionOneSaved = functionOne; Type functionOneSaved = functionOne.clone();
TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}};
Type functionTwo{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->stringType})) Type functionTwo{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->stringType}))
}}; }};
Type functionTwoSaved = functionTwo; Type functionTwoSaved = functionTwo.clone();
state.tryUnify(&functionTwo, &functionOne); state.tryUnify(&functionTwo, &functionOne);
CHECK(state.failure); CHECK(state.failure);

View File

@ -236,4 +236,12 @@ if not limitedstack then
assert(not err and string.find(msg, "error")) assert(not err and string.find(msg, "error"))
end end
-- testing deep nested calls with a large thread stack
do
function recurse(n, ...) return n <= 1 and (1 + #{...}) or recurse(n-1, table.unpack(table.create(4000, 1))) + 1 end
local ok, msg = pcall(recurse, 19000)
assert(not ok and string.find(msg, "not enough memory"))
end
return('OK') return('OK')

View File

@ -168,6 +168,10 @@ checkresults({ false, "oops" }, xpcall(function() table.create(1e6) end, functio
checkresults({ false, "error in error handling" }, xpcall(function() error("oops") end, function(e) table.create(1e6) end)) checkresults({ false, "error in error handling" }, xpcall(function() error("oops") end, function(e) table.create(1e6) end))
checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) table.create(1e6) end)) checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) table.create(1e6) end))
co = coroutine.create(function() table.create(1e6) end)
coroutine.resume(co)
checkresults({ false, "not enough memory" }, coroutine.close(co))
-- ensure that pcall and xpcall close upvalues when handling error -- ensure that pcall and xpcall close upvalues when handling error
local upclo local upclo
local function uptest(y) local function uptest(y)