diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index d4457638..6154f3d1 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -14,8 +14,6 @@ struct GlobalTypes; struct TypeChecker; struct TypeArena; -void registerBuiltinTypes(GlobalTypes& globals); - void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false); TypeId makeUnion(TypeArena& arena, std::vector&& types); TypeId makeIntersection(TypeArena& arena, std::vector&& types); diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 3404c6a2..afe47322 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -2,12 +2,12 @@ #pragma once #include "Luau/Config.h" +#include "Luau/GlobalTypes.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" #include "Luau/Scope.h" #include "Luau/TypeCheckLimits.h" -#include "Luau/TypeInfer.h" #include "Luau/Variant.h" #include diff --git a/Analysis/include/Luau/GlobalTypes.h b/Analysis/include/Luau/GlobalTypes.h new file mode 100644 index 00000000..86bfd943 --- /dev/null +++ b/Analysis/include/Luau/GlobalTypes.h @@ -0,0 +1,26 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Module.h" +#include "Luau/NotNull.h" +#include "Luau/Scope.h" +#include "Luau/TypeArena.h" + +namespace Luau +{ + +struct BuiltinTypes; + +struct GlobalTypes +{ + explicit GlobalTypes(NotNull builtinTypes); + + NotNull builtinTypes; // Global types are based on builtin types + + TypeArena globalTypes; + SourceModule globalNames; // names for symbols entered into globalScope + ScopePtr globalScope; // shared by all modules +}; + +} diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h index a69952f6..70cd8bae 100644 --- a/Analysis/include/Luau/Subtyping.h +++ b/Analysis/include/Luau/Subtyping.h @@ -19,6 +19,7 @@ class TypeIds; class Normalizer; struct NormalizedType; struct NormalizedClassType; +struct NormalizedFunctionType; struct SubtypingResult { @@ -103,6 +104,7 @@ private: SubtypingResult isSubtype_(const NormalizedType* subNorm, const NormalizedType* superNorm); SubtypingResult isSubtype_(const NormalizedClassType& subClass, const NormalizedClassType& superClass, const TypeIds& superTables); + SubtypingResult isSubtype_(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction); SubtypingResult isSubtype_(const TypeIds& subTypes, const TypeIds& superTypes); SubtypingResult isSubtype_(const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic); diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index cc88d54b..ffbe3fa0 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -798,12 +798,13 @@ struct BuiltinTypes TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; + friend TypeId makeStringMetatable(NotNull builtinTypes); + friend struct GlobalTypes; + private: std::unique_ptr arena; bool debugFreezeArena = false; - TypeId makeStringMetatable(); - public: const TypeId nilType; const TypeId numberType; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 9a44af49..abae8b92 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -57,17 +57,6 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; -struct GlobalTypes -{ - GlobalTypes(NotNull builtinTypes); - - NotNull builtinTypes; // Global types are based on builtin types - - TypeArena globalTypes; - SourceModule globalNames; // names for symbols entered into globalScope - ScopePtr globalScope; // shared by all modules -}; - // All Types are retained via Environment::types. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index baeac469..4a5638b1 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,8 +13,6 @@ #include LUAU_FASTFLAG(DebugLuauReadWriteProperties) -LUAU_FASTFLAGVARIABLE(LuauAnonymousAutofilled1, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteLastTypecheck, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDoEnd, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringLiteralBounds, false); @@ -611,7 +609,6 @@ std::optional getLocalTypeInScopeAt(const Module& module, Position posit template static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) { - LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1); ToStringOptions opts; opts.useLineBreaks = false; opts.hideTableKind = true; @@ -630,23 +627,7 @@ static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool if (!canSuggestInferredType(scope, ty)) return std::nullopt; - if (FFlag::LuauAnonymousAutofilled1) - { - return tryToStringDetailed(scope, ty, functionTypeArguments); - } - else - { - ToStringOptions opts; - opts.useLineBreaks = false; - opts.hideTableKind = true; - opts.scope = scope; - ToStringResult name = toStringDetailed(ty, opts); - - if (name.error || name.invalid || name.cycle || name.truncated) - return std::nullopt; - - return name.name; - } + return tryToStringDetailed(scope, ty, functionTypeArguments); } static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) @@ -1417,7 +1398,6 @@ static AutocompleteResult autocompleteWhileLoopKeywords(std::vector an static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) { - LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1); std::string result = "function("; auto [args, tail] = Luau::flatten(funcTy.argTypes); @@ -1483,7 +1463,6 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func static std::optional makeAnonymousAutofilled(const ModulePtr& module, Position position, const AstNode* node, const std::vector& ancestry) { - LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1); const AstExprCall* call = node->as(); if (!call && ancestry.size() > 1) call = ancestry[ancestry.size() - 2]->as(); @@ -1801,17 +1780,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (node->asExpr()) { - if (FFlag::LuauAnonymousAutofilled1) - { - AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) - ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); - return ret; - } - else - { - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - } + AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) + ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); + return ret; } else if (node->asStat()) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1821,15 +1793,6 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { - if (!FFlag::LuauAutocompleteLastTypecheck) - { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - FrontendOptions opts; - opts.forAutocomplete = true; - frontend.check(moduleName, opts); - } - const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index c55a88eb..0200ee3e 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -201,18 +201,6 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string& } } -void registerBuiltinTypes(GlobalTypes& globals) -{ - globals.globalScope->addBuiltinTypeBinding("any", TypeFun{{}, globals.builtinTypes->anyType}); - globals.globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, globals.builtinTypes->nilType}); - globals.globalScope->addBuiltinTypeBinding("number", TypeFun{{}, globals.builtinTypes->numberType}); - globals.globalScope->addBuiltinTypeBinding("string", TypeFun{{}, globals.builtinTypes->stringType}); - globals.globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, globals.builtinTypes->booleanType}); - globals.globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, globals.builtinTypes->threadType}); - globals.globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, globals.builtinTypes->unknownType}); - globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType}); -} - void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete) { LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); @@ -310,6 +298,520 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); } +static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) +{ + const char* options = "cdiouxXeEfgGqs*"; + + std::vector result; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + i++; + + if (i < size && data[i] == '%') + continue; + + // we just ignore all characters (including flags/precision) up until first alphabetic character + while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*'))) + i++; + + if (i == size) + break; + + if (data[i] == 'q' || data[i] == 's') + result.push_back(builtinTypes->stringType); + else if (data[i] == '*') + result.push_back(builtinTypes->unknownType); + else if (strchr(options, data[i])) + result.push_back(builtinTypes->numberType); + else + result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType)); + } + } + + return result; +} + +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* fmt = nullptr; + if (auto index = expr.func->as(); index && expr.self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!expr.self && expr.args.size > 0) + fmt = expr.args.data[0]->as(); + + if (!fmt) + return std::nullopt; + + std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(paramPack); + + size_t paramOffset = 1; + size_t dataOffset = expr.self ? 0 : 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; + + typechecker.unify(params[i + paramOffset], expected[i], scope, location); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + return WithPredicate{arena.addTypePack({typechecker.stringType})}; +} + +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +{ + TypeArena* arena = context.solver->arena; + + AstExprConstantString* fmt = nullptr; + if (auto index = context.callSite->func->as(); index && context.callSite->self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!context.callSite->self && context.callSite->args.size > 0) + fmt = context.callSite->args.data[0]->as(); + + if (!fmt) + return false; + + std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(context.arguments); + + size_t paramOffset = 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + context.solver->unify(context.solver->rootScope, context.callSite->location, params[i + paramOffset], expected[i]); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType}); + asMutable(context.result)->ty.emplace(resultPack); + + return true; +} + +static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) +{ + std::vector result; + int depth = 0; + bool parsingSet = false; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + ++i; + if (!parsingSet && i < size && data[i] == 'b') + i += 2; + } + else if (!parsingSet && data[i] == '[') + { + parsingSet = true; + if (i + 1 < size && data[i + 1] == ']') + i += 1; + } + else if (parsingSet && data[i] == ']') + { + parsingSet = false; + } + else if (data[i] == '(') + { + if (parsingSet) + continue; + + if (i + 1 < size && data[i + 1] == ')') + { + i++; + result.push_back(builtinTypes->optionalNumberType); + continue; + } + + ++depth; + result.push_back(builtinTypes->optionalStringType); + } + else if (data[i] == ')') + { + if (parsingSet) + continue; + + --depth; + + if (depth < 0) + break; + } + } + + if (depth != 0 || parsingSet) + return std::vector(); + + if (result.empty()) + result.push_back(builtinTypes->optionalStringType); + + return result; +} + +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() != 2) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t index = expr.self ? 0 : 1; + if (expr.args.size > index) + pattern = expr.args.data[index]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypePackId emptyPack = arena.addTypePack({}); + const TypePackId returnList = arena.addTypePack(returnTypes); + const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList}); + return WithPredicate{arena.addTypePack({iteratorType})}; +} + +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() != 2) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t index = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > index) + pattern = context.callSite->args.data[index]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId returnList = arena->addTypePack(returnTypes); + const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList}); + const TypePackId resTypePack = arena->addTypePack({iteratorType}); + asMutable(context.result)->ty.emplace(resTypePack); + + return true; +} + +static std::optional> magicFunctionMatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 3) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() == 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 3) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType); + + const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() == 3 && context.callSite->args.size > initIndex) + context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + + return true; +} + +static std::optional> magicFunctionFind( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 4) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + bool plain = false; + size_t plainIndex = expr.self ? 2 : 3; + if (expr.args.size > plainIndex) + { + AstExprConstantBool* p = expr.args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + } + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); + const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() >= 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); + + if (params.size() == 4 && expr.args.size > plainIndex) + typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 4) + return false; + + TypeArena* arena = context.solver->arena; + NotNull builtinTypes = context.solver->builtinTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + bool plain = false; + size_t plainIndex = context.callSite->self ? 2 : 3; + if (context.callSite->args.size > plainIndex) + { + AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + } + + context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], builtinTypes->stringType); + + const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}}); + const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() >= 3 && context.callSite->args.size > initIndex) + context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber); + + if (params.size() == 4 && context.callSite->args.size > plainIndex) + context.solver->unify(context.solver->rootScope, context.callSite->location, params[3], optionalBoolean); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + return true; +} + +TypeId makeStringMetatable(NotNull builtinTypes) +{ + NotNull arena{builtinTypes->arena.get()}; + + const TypeId nilType = builtinTypes->nilType; + const TypeId numberType = builtinTypes->numberType; + const TypeId booleanType = builtinTypes->booleanType; + const TypeId stringType = builtinTypes->stringType; + const TypeId anyType = builtinTypes->anyType; + + const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}}); + const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}}); + const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}}); + + const TypePackId oneStringPack = arena->addTypePack({stringType}); + const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); + + FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; + formatFTV.magicFunction = &magicFunctionFormat; + const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); + const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); + + const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); + + const TypeId replArgType = + arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), + makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); + const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); + const TypeId gmatchFunc = + makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + + const TypeId matchFunc = arena->addType( + FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); + attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + + const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); + attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + + TableType::Props stringLib = { + {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, + {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, + {"find", {findFunc}}, + {"format", {formatFn}}, // FIXME + {"gmatch", {gmatchFunc}}, + {"gsub", {gsubFunc}}, + {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, + {"lower", {stringToStringType}}, + {"match", {matchFunc}}, + {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, + {"reverse", {stringToStringType}}, + {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, + {"upper", {stringToStringType}}, + {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, + {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, + {"pack", {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType}, anyTypePack}), + oneStringPack, + })}}, + {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, + {"unpack", {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), + anyTypePack, + })}}, + }; + + assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + + TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + + if (TableType* ttv = getMutable(tableType)) + ttv->name = "typeof(string)"; + + return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); +} + static std::optional> magicFunctionSelect( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index c6c360b8..b71a9354 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -36,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckLimitControls, false) +LUAU_FASTFLAGVARIABLE(CorrectEarlyReturnInMarkDirty, false) namespace Luau { @@ -928,7 +929,6 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) { // The autocomplete typecheck is always in strict mode with DM awareness // to provide better type information for IDE features - TypeCheckLimits typeCheckLimits; if (autocompleteTimeLimit != 0.0) typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; @@ -1149,8 +1149,16 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { - if (!moduleResolver.getModule(name) && !moduleResolverForAutocomplete.getModule(name)) - return; + if (FFlag::CorrectEarlyReturnInMarkDirty) + { + if (sourceNodes.count(name) == 0) + return; + } + else + { + if (!moduleResolver.getModule(name) && !moduleResolverForAutocomplete.getModule(name)) + return; + } std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) diff --git a/Analysis/src/GlobalTypes.cpp b/Analysis/src/GlobalTypes.cpp new file mode 100644 index 00000000..9e26a2e3 --- /dev/null +++ b/Analysis/src/GlobalTypes.cpp @@ -0,0 +1,34 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/GlobalTypes.h" + +LUAU_FASTFLAG(LuauInitializeStringMetatableInGlobalTypes) + +namespace Luau +{ + +GlobalTypes::GlobalTypes(NotNull builtinTypes) + : builtinTypes(builtinTypes) +{ + globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); + + globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType}); + globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType}); + globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType}); + globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType}); + globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType}); + globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType}); + globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType}); + globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); + + if (FFlag::LuauInitializeStringMetatableInGlobalTypes) + { + unfreeze(*builtinTypes->arena); + TypeId stringMetatableTy = makeStringMetatable(builtinTypes); + asMutable(builtinTypes->stringType)->ty.emplace(PrimitiveType::String, stringMetatableTy); + persist(stringMetatableTy); + freeze(*builtinTypes->arena); + } +} + +} diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index 3f3c9319..8012bac7 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -664,9 +664,8 @@ SubtypingResult Subtyping::isSubtype_(const NormalizedType* subNorm, const Norma result.andAlso(isSubtype_(subNorm->tables, superNorm->tables)); // isSubtype_(subNorm->tables, superNorm->strings); // isSubtype_(subNorm->tables, superNorm->classes); - // isSubtype_(subNorm->functions, superNorm->functions); + result.andAlso(isSubtype_(subNorm->functions, superNorm->functions)); // isSubtype_(subNorm->tyvars, superNorm->tyvars); - return result; } @@ -703,6 +702,16 @@ SubtypingResult Subtyping::isSubtype_(const NormalizedClassType& subClass, const return {true}; } +SubtypingResult Subtyping::isSubtype_(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction) +{ + if (subFunction.isNever()) + return {true}; + else if (superFunction.isTop) + return {true}; + else + return isSubtype_(subFunction.parts, superFunction.parts); +} + SubtypingResult Subtyping::isSubtype_(const TypeIds& subTypes, const TypeIds& superTypes) { std::vector results; diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index c3a1db4c..04d04470 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -9,6 +9,8 @@ #include #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -52,7 +54,7 @@ bool StateDot::canDuplicatePrimitive(TypeId ty) if (get(ty)) return false; - return get(ty) || get(ty); + return get(ty) || get(ty) || get(ty) || get(ty); } void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) @@ -76,6 +78,10 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str()); else if (get(ty)) formatAppend(result, "n%d [label=\"any\"];\n", index); + else if (get(ty)) + formatAppend(result, "n%d [label=\"unknown\"];\n", index); + else if (get(ty)) + formatAppend(result, "n%d [label=\"never\"];\n", index); } else { @@ -139,159 +145,215 @@ void StateDot::visitChildren(TypeId ty, int index) startNode(index); startNodeLabel(); - if (const BoundType* btv = get(ty)) + auto go = [&](auto&& t) { - formatAppend(result, "BoundType %d", index); - finishNodeLabel(ty); - finishNode(); + using T = std::decay_t; - visitChild(btv->boundTo, index); - } - else if (const FunctionType* ftv = get(ty)) - { - formatAppend(result, "FunctionType %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retTypes, index, "ret"); - } - else if (const TableType* ttv = get(ty)) - { - if (ttv->name) - formatAppend(result, "TableType %s", ttv->name->c_str()); - else if (ttv->syntheticName) - formatAppend(result, "TableType %s", ttv->syntheticName->c_str()); - else - formatAppend(result, "TableType %d", index); - finishNodeLabel(ty); - finishNode(); - - if (ttv->boundTo) - return visitChild(*ttv->boundTo, index, "boundTo"); - - for (const auto& [name, prop] : ttv->props) - visitChild(prop.type(), index, name.c_str()); - if (ttv->indexer) + if constexpr (std::is_same_v) { - visitChild(ttv->indexer->indexType, index, "[index]"); - visitChild(ttv->indexer->indexResultType, index, "[value]"); + formatAppend(result, "BoundType %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(t.boundTo, index); } - for (TypeId itp : ttv->instantiatedTypeParams) - visitChild(itp, index, "typeParam"); - - for (TypePackId itp : ttv->instantiatedTypePackParams) - visitChild(itp, index, "typePackParam"); - } - else if (const MetatableType* mtv = get(ty)) - { - formatAppend(result, "MetatableType %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(mtv->table, index, "table"); - visitChild(mtv->metatable, index, "metatable"); - } - else if (const UnionType* utv = get(ty)) - { - formatAppend(result, "UnionType %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId opt : utv->options) - visitChild(opt, index); - } - else if (const IntersectionType* itv = get(ty)) - { - formatAppend(result, "IntersectionType %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId part : itv->parts) - visitChild(part, index); - } - else if (const GenericType* gtv = get(ty)) - { - if (gtv->explicitName) - formatAppend(result, "GenericType %s", gtv->name.c_str()); - else - formatAppend(result, "GenericType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const FreeType* ftv = get(ty)) - { - formatAppend(result, "FreeType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "AnyType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "ErrorType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const ClassType* ctv = get(ty)) - { - formatAppend(result, "ClassType %s", ctv->name.c_str()); - finishNodeLabel(ty); - finishNode(); - - for (const auto& [name, prop] : ctv->props) - visitChild(prop.type(), index, name.c_str()); - - if (ctv->parent) - visitChild(*ctv->parent, index, "[parent]"); - - if (ctv->metatable) - visitChild(*ctv->metatable, index, "[metatable]"); - - if (ctv->indexer) + else if constexpr (std::is_same_v) { - visitChild(ctv->indexer->indexType, index, "[index]"); - visitChild(ctv->indexer->indexResultType, index, "[value]"); + formatAppend(result, "BlockedType %d", index); + finishNodeLabel(ty); + finishNode(); } - } - else if (const SingletonType* stv = get(ty)) - { - std::string res; + else if constexpr (std::is_same_v) + { + formatAppend(result, "FunctionType %d", index); + finishNodeLabel(ty); + finishNode(); - if (const StringSingleton* ss = get(stv)) - { - // Don't put in quotes anywhere. If it's outside of the call to escape, - // then it's invalid syntax. If it's inside, then escaping is super noisy. - res = "string: " + escape(ss->value); + visitChild(t.argTypes, index, "arg"); + visitChild(t.retTypes, index, "ret"); } - else if (const BooleanSingleton* bs = get(stv)) + else if constexpr (std::is_same_v) { - res = "boolean: "; - res += bs->value ? "true" : "false"; + if (t.name) + formatAppend(result, "TableType %s", t.name->c_str()); + else if (t.syntheticName) + formatAppend(result, "TableType %s", t.syntheticName->c_str()); + else + formatAppend(result, "TableType %d", index); + finishNodeLabel(ty); + finishNode(); + + if (t.boundTo) + return visitChild(*t.boundTo, index, "boundTo"); + + for (const auto& [name, prop] : t.props) + visitChild(prop.type(), index, name.c_str()); + if (t.indexer) + { + visitChild(t.indexer->indexType, index, "[index]"); + visitChild(t.indexer->indexResultType, index, "[value]"); + } + for (TypeId itp : t.instantiatedTypeParams) + visitChild(itp, index, "typeParam"); + + for (TypePackId itp : t.instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "MetatableType %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(t.table, index, "table"); + visitChild(t.metatable, index, "metatable"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "UnionType %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId opt : t.options) + visitChild(opt, index); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "IntersectionType %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : t.parts) + visitChild(part, index); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "LazyType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "PendingExpansionType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + if (t.explicitName) + formatAppend(result, "GenericType %s", t.name.c_str()); + else + formatAppend(result, "GenericType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "FreeType %d", index); + finishNodeLabel(ty); + finishNode(); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (!get(t.lowerBound)) + visitChild(t.lowerBound, index, "[lowerBound]"); + + if (!get(t.upperBound)) + visitChild(t.upperBound, index, "[upperBound]"); + } + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "AnyType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "UnknownType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "NeverType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "ErrorType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "ClassType %s", t.name.c_str()); + finishNodeLabel(ty); + finishNode(); + + for (const auto& [name, prop] : t.props) + visitChild(prop.type(), index, name.c_str()); + + if (t.parent) + visitChild(*t.parent, index, "[parent]"); + + if (t.metatable) + visitChild(*t.metatable, index, "[metatable]"); + + if (t.indexer) + { + visitChild(t.indexer->indexType, index, "[index]"); + visitChild(t.indexer->indexResultType, index, "[value]"); + } + } + else if constexpr (std::is_same_v) + { + std::string res; + + if (const StringSingleton* ss = get(&t)) + { + // Don't put in quotes anywhere. If it's outside of the call to escape, + // then it's invalid syntax. If it's inside, then escaping is super noisy. + res = "string: " + escape(ss->value); + } + else if (const BooleanSingleton* bs = get(&t)) + { + res = "boolean: "; + res += bs->value ? "true" : "false"; + } + else + LUAU_ASSERT(!"unknown singleton type"); + + formatAppend(result, "SingletonType %s", res.c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "NegationType %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(t.ty, index, "[negated]"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "TypeFamilyInstanceType %d", index); + finishNodeLabel(ty); + finishNode(); } else - LUAU_ASSERT(!"unknown singleton type"); + static_assert(always_false_v, "unknown type kind"); + }; - formatAppend(result, "SingletonType %s", res.c_str()); - finishNodeLabel(ty); - finishNode(); - } - else - { - LUAU_ASSERT(!"unknown type kind"); - finishNodeLabel(ty); - finishNode(); - } + visit(go, ty->ty); } void StateDot::visitChildren(TypePackId tp, int index) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 2590e4dc..86564cf5 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -27,26 +27,11 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(DebugLuauReadWriteProperties) +LUAU_FASTFLAGVARIABLE(LuauInitializeStringMetatableInGlobalTypes, false) namespace Luau { -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context); - -static std::optional> magicFunctionGmatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context); - -static std::optional> magicFunctionMatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context); - -static std::optional> magicFunctionFind( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionFind(MagicFunctionCallContext context); - // LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv) { @@ -933,6 +918,8 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); +TypeId makeStringMetatable(NotNull builtinTypes); // BuiltinDefinitions.cpp + BuiltinTypes::BuiltinTypes() : arena(new TypeArena) , debugFreezeArena(FFlag::DebugLuauFreezeArena) @@ -961,9 +948,12 @@ BuiltinTypes::BuiltinTypes() , uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true})) , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) { - TypeId stringMetatable = makeStringMetatable(); - asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable}; - persist(stringMetatable); + if (!FFlag::LuauInitializeStringMetatableInGlobalTypes) + { + TypeId stringMetatable = makeStringMetatable(NotNull{this}); + asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable}; + persist(stringMetatable); + } freeze(*arena); } @@ -980,82 +970,6 @@ BuiltinTypes::~BuiltinTypes() FFlag::DebugLuauFreezeArena.value = prevFlag; } -TypeId BuiltinTypes::makeStringMetatable() -{ - const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}}); - const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}}); - const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}}); - - const TypePackId oneStringPack = arena->addTypePack({stringType}); - const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); - - FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; - const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - - const TypePackId emptyPack = arena->addTypePack({}); - const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); - const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); - - const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); - - const TypeId replArgType = - arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), - makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); - const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); - const TypeId gmatchFunc = - makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); - - const TypeId matchFunc = arena->addType( - FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - - const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), - arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); - - TableType::Props stringLib = { - {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, - {"find", {findFunc}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {matchFunc}}, - {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, - {"pack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType}, anyTypePack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"unpack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - anyTypePack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - - TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); - - if (TableType* ttv = getMutable(tableType)) - ttv->name = "typeof(string)"; - - return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); -} - TypeId BuiltinTypes::errorRecoveryType() const { return errorType; @@ -1261,436 +1175,6 @@ IntersectionTypeIterator end(const IntersectionType* itv) return IntersectionTypeIterator{}; } -static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) -{ - const char* options = "cdiouxXeEfgGqs*"; - - std::vector result; - - for (size_t i = 0; i < size; ++i) - { - if (data[i] == '%') - { - i++; - - if (i < size && data[i] == '%') - continue; - - // we just ignore all characters (including flags/precision) up until first alphabetic character - while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*'))) - i++; - - if (i == size) - break; - - if (data[i] == 'q' || data[i] == 's') - result.push_back(builtinTypes->stringType); - else if (data[i] == '*') - result.push_back(builtinTypes->unknownType); - else if (strchr(options, data[i])) - result.push_back(builtinTypes->numberType); - else - result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType)); - } - } - - return result; -} - -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* fmt = nullptr; - if (auto index = expr.func->as(); index && expr.self) - { - if (auto group = index->expr->as()) - fmt = group->expr->as(); - else - fmt = index->expr->as(); - } - - if (!expr.self && expr.args.size > 0) - fmt = expr.args.data[0]->as(); - - if (!fmt) - return std::nullopt; - - std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); - const auto& [params, tail] = flatten(paramPack); - - size_t paramOffset = 1; - size_t dataOffset = expr.self ? 0 : 1; - - // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) - { - Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - - typechecker.unify(params[i + paramOffset], expected[i], scope, location); - } - - // if we know the argument count or if we have too many arguments for sure, we can issue an error - size_t numActualParams = params.size(); - size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - - if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - - return WithPredicate{arena.addTypePack({typechecker.stringType})}; -} - -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) -{ - TypeArena* arena = context.solver->arena; - - AstExprConstantString* fmt = nullptr; - if (auto index = context.callSite->func->as(); index && context.callSite->self) - { - if (auto group = index->expr->as()) - fmt = group->expr->as(); - else - fmt = index->expr->as(); - } - - if (!context.callSite->self && context.callSite->args.size > 0) - fmt = context.callSite->args.data[0]->as(); - - if (!fmt) - return false; - - std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); - const auto& [params, tail] = flatten(context.arguments); - - size_t paramOffset = 1; - - // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) - { - context.solver->unify(context.solver->rootScope, context.callSite->location, params[i + paramOffset], expected[i]); - } - - // if we know the argument count or if we have too many arguments for sure, we can issue an error - size_t numActualParams = params.size(); - size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - - if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - - TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType}); - asMutable(context.result)->ty.emplace(resultPack); - - return true; -} - -static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) -{ - std::vector result; - int depth = 0; - bool parsingSet = false; - - for (size_t i = 0; i < size; ++i) - { - if (data[i] == '%') - { - ++i; - if (!parsingSet && i < size && data[i] == 'b') - i += 2; - } - else if (!parsingSet && data[i] == '[') - { - parsingSet = true; - if (i + 1 < size && data[i + 1] == ']') - i += 1; - } - else if (parsingSet && data[i] == ']') - { - parsingSet = false; - } - else if (data[i] == '(') - { - if (parsingSet) - continue; - - if (i + 1 < size && data[i + 1] == ')') - { - i++; - result.push_back(builtinTypes->optionalNumberType); - continue; - } - - ++depth; - result.push_back(builtinTypes->optionalStringType); - } - else if (data[i] == ')') - { - if (parsingSet) - continue; - - --depth; - - if (depth < 0) - break; - } - } - - if (depth != 0 || parsingSet) - return std::vector(); - - if (result.empty()) - result.push_back(builtinTypes->optionalStringType); - - return result; -} - -static std::optional> magicFunctionGmatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() != 2) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t index = expr.self ? 0 : 1; - if (expr.args.size > index) - pattern = expr.args.data[index]->as(); - - if (!pattern) - return std::nullopt; - - std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypePackId emptyPack = arena.addTypePack({}); - const TypePackId returnList = arena.addTypePack(returnTypes); - const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList}); - return WithPredicate{arena.addTypePack({iteratorType})}; -} - -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() != 2) - return false; - - TypeArena* arena = context.solver->arena; - - AstExprConstantString* pattern = nullptr; - size_t index = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > index) - pattern = context.callSite->args.data[index]->as(); - - if (!pattern) - return false; - - std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - - context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType); - - const TypePackId emptyPack = arena->addTypePack({}); - const TypePackId returnList = arena->addTypePack(returnTypes); - const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList}); - const TypePackId resTypePack = arena->addTypePack({iteratorType}); - asMutable(context.result)->ty.emplace(resTypePack); - - return true; -} - -static std::optional> magicFunctionMatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() < 2 || params.size() > 3) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = expr.self ? 0 : 1; - if (expr.args.size > patternIndex) - pattern = expr.args.data[patternIndex]->as(); - - if (!pattern) - return std::nullopt; - - std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); - - size_t initIndex = expr.self ? 1 : 2; - if (params.size() == 3 && expr.args.size > initIndex) - typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); - - const TypePackId returnList = arena.addTypePack(returnTypes); - return WithPredicate{returnList}; -} - -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() < 2 || params.size() > 3) - return false; - - TypeArena* arena = context.solver->arena; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > patternIndex) - pattern = context.callSite->args.data[patternIndex]->as(); - - if (!pattern) - return false; - - std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - - context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType); - - const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}}); - - size_t initIndex = context.callSite->self ? 1 : 2; - if (params.size() == 3 && context.callSite->args.size > initIndex) - context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber); - - const TypePackId returnList = arena->addTypePack(returnTypes); - asMutable(context.result)->ty.emplace(returnList); - - return true; -} - -static std::optional> magicFunctionFind( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() < 2 || params.size() > 4) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = expr.self ? 0 : 1; - if (expr.args.size > patternIndex) - pattern = expr.args.data[patternIndex]->as(); - - if (!pattern) - return std::nullopt; - - bool plain = false; - size_t plainIndex = expr.self ? 2 : 3; - if (expr.args.size > plainIndex) - { - AstExprConstantBool* p = expr.args.data[plainIndex]->as(); - plain = p && p->value; - } - - std::vector returnTypes; - if (!plain) - { - returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - } - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); - const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}}); - - size_t initIndex = expr.self ? 1 : 2; - if (params.size() >= 3 && expr.args.size > initIndex) - typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); - - if (params.size() == 4 && expr.args.size > plainIndex) - typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location); - - returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); - - const TypePackId returnList = arena.addTypePack(returnTypes); - return WithPredicate{returnList}; -} - -static bool dcrMagicFunctionFind(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() < 2 || params.size() > 4) - return false; - - TypeArena* arena = context.solver->arena; - NotNull builtinTypes = context.solver->builtinTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > patternIndex) - pattern = context.callSite->args.data[patternIndex]->as(); - - if (!pattern) - return false; - - bool plain = false; - size_t plainIndex = context.callSite->self ? 2 : 3; - if (context.callSite->args.size > plainIndex) - { - AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); - plain = p && p->value; - } - - std::vector returnTypes; - if (!plain) - { - returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - } - - context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], builtinTypes->stringType); - - const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}}); - const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}}); - - size_t initIndex = context.callSite->self ? 1 : 2; - if (params.size() >= 3 && context.callSite->args.size > initIndex) - context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber); - - if (params.size() == 4 && context.callSite->args.size > plainIndex) - context.solver->unify(context.solver->rootScope, context.callSite->location, params[3], optionalBoolean); - - returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); - - const TypePackId returnList = arena->addTypePack(returnTypes); - asMutable(context.result)->ty.emplace(returnList); - return true; -} - TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope) { return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType}); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 5349f16a..61c90ba8 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -38,6 +38,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) +LUAU_FASTFLAGVARIABLE(LuauVariadicOverloadFix, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAG(LuauParseDeclareClassIndexer) LUAU_FASTFLAG(LuauFloorDivision); @@ -210,21 +211,6 @@ size_t HashBoolNamePair::operator()(const std::pair& pair) const return std::hash()(pair.first) ^ std::hash()(pair.second); } -GlobalTypes::GlobalTypes(NotNull builtinTypes) - : builtinTypes(builtinTypes) -{ - globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); - - globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType}); - globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType}); - globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType}); - globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType}); - globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType}); - globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType}); - globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType}); - globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); -} - TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) : globalScope(globalScope) , resolver(resolver) @@ -4038,7 +4024,13 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam if (argIndex < argLocations.size()) location = argLocations[argIndex]; - unify(*argIter, vtp->ty, scope, location); + if (FFlag::LuauVariadicOverloadFix) + { + state.location = location; + state.tryUnify(*argIter, vtp->ty); + } + else + unify(*argIter, vtp->ty, scope, location); ++argIter; ++argIndex; } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index bc8ef018..7cf05cda 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -25,7 +25,6 @@ LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAGVARIABLE(LuauTableUnifyRecursionLimit, false) namespace Luau { @@ -2260,23 +2259,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (superTable != newSuperTable || subTable != newSubTable) { - if (FFlag::LuauTableUnifyRecursionLimit) + if (errors.empty()) { - if (errors.empty()) - { - RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); - tryUnifyTables(subTy, superTy, isIntersection); - } + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnifyTables(subTy, superTy, isIntersection); + } - return; - } - else - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } + return; } } @@ -2351,23 +2340,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (superTable != newSuperTable || subTable != newSubTable) { - if (FFlag::LuauTableUnifyRecursionLimit) + if (errors.empty()) { - if (errors.empty()) - { - RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); - tryUnifyTables(subTy, superTy, isIntersection); - } + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnifyTables(subTy, superTy, isIntersection); + } - return; - } - else - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } + return; } } diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 894d2dd7..1795243c 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -7,7 +7,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauFloorDivision, false) -LUAU_FASTFLAGVARIABLE(LuauLexerConsumeFast, false) LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) namespace Luau @@ -460,19 +459,8 @@ Position Lexer::position() const LUAU_FORCEINLINE void Lexer::consume() { - if (isNewline(buffer[offset])) - { - // TODO: When the flag is removed, remove the outer condition - if (FFlag::LuauLexerConsumeFast) - { - LUAU_ASSERT(!isNewline(buffer[offset])); - } - else - { - line++; - lineOffset = offset + 1; - } - } + // consume() assumes current character is known to not be a newline; use consumeAny if this is not guaranteed + LUAU_ASSERT(!isNewline(buffer[offset])); offset++; } diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index d854b400..b953c888 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -66,6 +66,8 @@ struct IrBuilder bool inTerminatedBlock = false; + bool interruptRequested = false; + bool activeFastcallFallback = false; IrOp fastcallFallbackReturn; int fastcallSkipTarget = -1; @@ -76,6 +78,8 @@ struct IrBuilder std::vector instIndexToBlock; // Block index at the bytecode instruction + std::vector loopStepStack; + // Similar to BytecodeBuilder, duplicate constants are removed used the same method struct ConstantKey { diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 12465906..298258c1 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -199,24 +199,12 @@ enum class IrCmd : uint8_t // D: block (if false) JUMP_EQ_TAG, - // Jump if two int numbers are equal - // A, B: int - // C: block (if true) - // D: block (if false) - JUMP_EQ_INT, - - // Jump if A < B - // A, B: int - // C: block (if true) - // D: block (if false) - JUMP_LT_INT, - - // Jump if unsigned(A) >= unsigned(B) + // Perform a conditional jump based on the result of integer comparison // A, B: int // C: condition // D: block (if true) // E: block (if false) - JUMP_GE_UINT, + JUMP_CMP_INT, // Jump if pointers are equal // A, B: pointer (*) diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 3def51a8..5db5f6f1 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -94,9 +94,7 @@ inline bool isBlockTerminator(IrCmd cmd) case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_EQ_TAG: - case IrCmd::JUMP_EQ_INT: - case IrCmd::JUMP_LT_INT: - case IrCmd::JUMP_GE_UINT: + case IrCmd::JUMP_CMP_INT: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_SLOT_MATCH: diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index c831c0bc..97749fbe 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -12,6 +12,8 @@ #include "lgc.h" #include "lstate.h" +#include + namespace Luau { namespace CodeGen @@ -22,10 +24,15 @@ namespace X64 void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label) { // Refresher on comi/ucomi EFLAGS: + // all zero: greater // CF only: less // ZF only: equal // PF+CF+ZF: unordered (NaN) + // To avoid the lack of conditional jumps that check for "greater" conditions in IEEE 754 compliant way, we use "less" forms to emulate these + if (cond == IrCondition::Greater || cond == IrCondition::GreaterEqual || cond == IrCondition::NotGreater || cond == IrCondition::NotGreaterEqual) + std::swap(lhs, rhs); + if (rhs.cat == CategoryX64::reg) { build.vucomisd(rhs, lhs); @@ -41,18 +48,22 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, switch (cond) { case IrCondition::NotLessEqual: + case IrCondition::NotGreaterEqual: // (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN build.jcc(ConditionX64::NotAboveEqual, label); break; case IrCondition::LessEqual: + case IrCondition::GreaterEqual: // (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN build.jcc(ConditionX64::AboveEqual, label); break; case IrCondition::NotLess: + case IrCondition::NotGreater: // (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN build.jcc(ConditionX64::NotAbove, label); break; case IrCondition::Less: + case IrCondition::Greater: // (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN build.jcc(ConditionX64::Above, label); break; @@ -66,6 +77,44 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, } } +ConditionX64 getConditionInt(IrCondition cond) +{ + switch (cond) + { + case IrCondition::Equal: + return ConditionX64::Equal; + case IrCondition::NotEqual: + return ConditionX64::NotEqual; + case IrCondition::Less: + return ConditionX64::Less; + case IrCondition::NotLess: + return ConditionX64::NotLess; + case IrCondition::LessEqual: + return ConditionX64::LessEqual; + case IrCondition::NotLessEqual: + return ConditionX64::NotLessEqual; + case IrCondition::Greater: + return ConditionX64::Greater; + case IrCondition::NotGreater: + return ConditionX64::NotGreater; + case IrCondition::GreaterEqual: + return ConditionX64::GreaterEqual; + case IrCondition::NotGreaterEqual: + return ConditionX64::NotGreaterEqual; + case IrCondition::UnsignedLess: + return ConditionX64::Below; + case IrCondition::UnsignedLessEqual: + return ConditionX64::BelowEqual; + case IrCondition::UnsignedGreater: + return ConditionX64::Above; + case IrCondition::UnsignedGreaterEqual: + return ConditionX64::AboveEqual; + default: + LUAU_ASSERT(!"Unsupported condition"); + return ConditionX64::Zero; + } +} + void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos) { LUAU_ASSERT(tmp != node); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 3288a164..3d9a59ff 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -195,6 +195,8 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label); +ConditionX64 getConditionInt(IrCondition cond); + void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 3ee82c76..e467ca68 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -149,6 +149,12 @@ void IrBuilder::buildFunctionIr(Proto* proto) // We skip dead bytecode instructions when they appear after block was already terminated if (!inTerminatedBlock) { + if (interruptRequested) + { + interruptRequested = false; + inst(IrCmd::INTERRUPT, constUint(i)); + } + translateInst(op, pc, i); if (fastcallSkipTarget != -1) diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 7ed1a295..21b2f702 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -157,12 +157,8 @@ const char* getCmdName(IrCmd cmd) return "JUMP_IF_FALSY"; case IrCmd::JUMP_EQ_TAG: return "JUMP_EQ_TAG"; - case IrCmd::JUMP_EQ_INT: - return "JUMP_EQ_INT"; - case IrCmd::JUMP_LT_INT: - return "JUMP_LT_INT"; - case IrCmd::JUMP_GE_UINT: - return "JUMP_GE_UINT"; + case IrCmd::JUMP_CMP_INT: + return "JUMP_CMP_INT"; case IrCmd::JUMP_EQ_POINTER: return "JUMP_EQ_POINTER"; case IrCmd::JUMP_CMP_NUM: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 4369d120..a030f955 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -58,6 +58,58 @@ inline ConditionA64 getConditionFP(IrCondition cond) } } +inline ConditionA64 getConditionInt(IrCondition cond) +{ + switch (cond) + { + case IrCondition::Equal: + return ConditionA64::Equal; + + case IrCondition::NotEqual: + return ConditionA64::NotEqual; + + case IrCondition::Less: + return ConditionA64::Minus; + + case IrCondition::NotLess: + return ConditionA64::Plus; + + case IrCondition::LessEqual: + return ConditionA64::LessEqual; + + case IrCondition::NotLessEqual: + return ConditionA64::Greater; + + case IrCondition::Greater: + return ConditionA64::Greater; + + case IrCondition::NotGreater: + return ConditionA64::LessEqual; + + case IrCondition::GreaterEqual: + return ConditionA64::GreaterEqual; + + case IrCondition::NotGreaterEqual: + return ConditionA64::Less; + + case IrCondition::UnsignedLess: + return ConditionA64::CarryClear; + + case IrCondition::UnsignedLessEqual: + return ConditionA64::UnsignedLessEqual; + + case IrCondition::UnsignedGreater: + return ConditionA64::UnsignedGreater; + + case IrCondition::UnsignedGreaterEqual: + return ConditionA64::CarrySet; + + default: + LUAU_ASSERT(!"Unexpected condition code"); + return ConditionA64::Always; + } +} + static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA64 src, size_t offset) { LUAU_ASSERT(dst != src); @@ -714,31 +766,25 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } break; } - case IrCmd::JUMP_EQ_INT: - if (intOp(inst.b) == 0) + case IrCmd::JUMP_CMP_INT: + { + IrCondition cond = conditionOp(inst.c); + + if (cond == IrCondition::Equal && intOp(inst.b) == 0) { - build.cbz(regOp(inst.a), labelOp(inst.c)); + build.cbz(regOp(inst.a), labelOp(inst.d)); + } + else if (cond == IrCondition::NotEqual && intOp(inst.b) == 0) + { + build.cbnz(regOp(inst.a), labelOp(inst.d)); } else { LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); - build.b(ConditionA64::Equal, labelOp(inst.c)); + build.b(getConditionInt(cond), labelOp(inst.d)); } - jumpOrFallthrough(blockOp(inst.d), next); - break; - case IrCmd::JUMP_LT_INT: - LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); - build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); - build.b(ConditionA64::Less, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); - break; - case IrCmd::JUMP_GE_UINT: - { - LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); - build.cmp(regOp(inst.a), uint16_t(unsigned(intOp(inst.b)))); - build.b(ConditionA64::CarrySet, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); + jumpOrFallthrough(blockOp(inst.e), next); break; } case IrCmd::JUMP_EQ_POINTER: diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 261f5717..fe5127ac 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -655,42 +655,36 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } break; } - case IrCmd::JUMP_EQ_INT: - if (intOp(inst.b) == 0) + case IrCmd::JUMP_CMP_INT: + { + IrCondition cond = conditionOp(inst.c); + + if ((cond == IrCondition::Equal || cond == IrCondition::NotEqual) && intOp(inst.b) == 0) { + bool invert = cond == IrCondition::NotEqual; + build.test(regOp(inst.a), regOp(inst.a)); - if (isFallthroughBlock(blockOp(inst.c), next)) + if (isFallthroughBlock(blockOp(inst.d), next)) { - build.jcc(ConditionX64::NotZero, labelOp(inst.d)); - jumpOrFallthrough(blockOp(inst.c), next); + build.jcc(invert ? ConditionX64::Zero : ConditionX64::NotZero, labelOp(inst.e)); + jumpOrFallthrough(blockOp(inst.d), next); } else { - build.jcc(ConditionX64::Zero, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); + build.jcc(invert ? ConditionX64::NotZero : ConditionX64::Zero, labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.e), next); } } else { build.cmp(regOp(inst.a), intOp(inst.b)); - build.jcc(ConditionX64::Equal, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); + build.jcc(getConditionInt(cond), labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.e), next); } break; - case IrCmd::JUMP_LT_INT: - build.cmp(regOp(inst.a), intOp(inst.b)); - - build.jcc(ConditionX64::Less, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); - break; - case IrCmd::JUMP_GE_UINT: - build.cmp(regOp(inst.a), unsigned(intOp(inst.b))); - - build.jcc(ConditionX64::AboveEqual, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); - break; + } case IrCmd::JUMP_EQ_POINTER: build.cmp(regOp(inst.a), regOp(inst.b)); @@ -703,7 +697,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) ScopedRegX64 tmp{regs, SizeX64::xmmword}; - // TODO: jumpOnNumberCmp should work on IrCondition directly jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), cond, labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index 8513f786..279eabea 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -411,7 +411,7 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp( IrOp falsey = build.block(IrBlockKind::Internal); IrOp truthy = build.block(IrBlockKind::Internal); IrOp exit = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_EQ_INT, res, build.constInt(0), falsey, truthy); + build.inst(IrCmd::JUMP_CMP_INT, res, build.constInt(0), build.cond(IrCondition::Equal), falsey, truthy); build.beginBlock(falsey); build.inst(IrCmd::STORE_INT, build.vmReg(ra), build.constInt(0)); @@ -484,7 +484,7 @@ static BuiltinImplResult translateBuiltinBit32Shift( if (!knownGoodShift) { IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_GE_UINT, vbi, build.constInt(32), fallback, block); + build.inst(IrCmd::JUMP_CMP_INT, vbi, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block); build.beginBlock(block); } @@ -549,36 +549,56 @@ static BuiltinImplResult translateBuiltinBit32Extract( IrOp vb = builtinLoadDouble(build, args); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); - IrOp f = build.inst(IrCmd::NUM_TO_INT, vb); IrOp value; if (nparams == 2) { - IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); - build.beginBlock(block); + if (vb.kind == IrOpKind::Constant) + { + int f = int(build.function.doubleOp(vb)); - // TODO: this can be optimized using a bit-select instruction (bt on x86) - IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f); - value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1)); + if (unsigned(f) >= 32) + build.inst(IrCmd::JUMP, fallback); + + // TODO: this pair can be optimized using a bit-select instruction (bt on x86) + if (f) + value = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f)); + + if ((f + 1) < 32) + value = build.inst(IrCmd::BITAND_UINT, value, build.constInt(1)); + } + else + { + IrOp f = build.inst(IrCmd::NUM_TO_INT, vb); + + IrOp block = build.block(IrBlockKind::Internal); + build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block); + build.beginBlock(block); + + // TODO: this pair can be optimized using a bit-select instruction (bt on x86) + IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f); + value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1)); + } } else { + IrOp f = build.inst(IrCmd::NUM_TO_INT, vb); + builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos); IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vc); IrOp block1 = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_LT_INT, f, build.constInt(0), fallback, block1); + build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(0), build.cond(IrCondition::Less), fallback, block1); build.beginBlock(block1); IrOp block2 = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_LT_INT, w, build.constInt(1), fallback, block2); + build.inst(IrCmd::JUMP_CMP_INT, w, build.constInt(1), build.cond(IrCondition::Less), fallback, block2); build.beginBlock(block2); IrOp block3 = build.block(IrBlockKind::Internal); IrOp fw = build.inst(IrCmd::ADD_INT, f, w); - build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); + build.inst(IrCmd::JUMP_CMP_INT, fw, build.constInt(33), build.cond(IrCondition::Less), block3, fallback); build.beginBlock(block3); IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); @@ -615,10 +635,15 @@ static BuiltinImplResult translateBuiltinBit32ExtractK( uint32_t m = ~(0xfffffffeu << w1); - IrOp nf = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f)); - IrOp and_ = build.inst(IrCmd::BITAND_UINT, nf, build.constInt(m)); + IrOp result = n; - IrOp value = build.inst(IrCmd::UINT_TO_NUM, and_); + if (f) + result = build.inst(IrCmd::BITRSHIFT_UINT, result, build.constInt(f)); + + if ((f + w1 + 1) < 32) + result = build.inst(IrCmd::BITAND_UINT, result, build.constInt(m)); + + IrOp value = build.inst(IrCmd::UINT_TO_NUM, result); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); if (ra != arg) @@ -673,7 +698,7 @@ static BuiltinImplResult translateBuiltinBit32Replace( if (nparams == 3) { IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); + build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block); build.beginBlock(block); // TODO: this can be optimized using a bit-select instruction (btr on x86) @@ -694,16 +719,16 @@ static BuiltinImplResult translateBuiltinBit32Replace( IrOp w = build.inst(IrCmd::NUM_TO_INT, vd); IrOp block1 = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_LT_INT, f, build.constInt(0), fallback, block1); + build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(0), build.cond(IrCondition::Less), fallback, block1); build.beginBlock(block1); IrOp block2 = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_LT_INT, w, build.constInt(1), fallback, block2); + build.inst(IrCmd::JUMP_CMP_INT, w, build.constInt(1), build.cond(IrCondition::Less), fallback, block2); build.beginBlock(block2); IrOp block3 = build.block(IrBlockKind::Internal); IrOp fw = build.inst(IrCmd::ADD_INT, f, w); - build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); + build.inst(IrCmd::JUMP_CMP_INT, fw, build.constInt(33), build.cond(IrCondition::Less), block3, fallback); build.beginBlock(block3); IrOp shift1 = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 26ad727a..f1eea645 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -12,6 +12,8 @@ #include "lstate.h" #include "ltm.h" +LUAU_FASTFLAGVARIABLE(LuauImproveForN, false) + namespace Luau { namespace CodeGen @@ -170,7 +172,7 @@ void translateInstJumpIfEq(IrBuilder& build, const Instruction* pc, int pcpos, b build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); IrOp result = build.inst(IrCmd::CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(IrCondition::Equal)); - build.inst(IrCmd::JUMP_EQ_INT, result, build.constInt(0), not_ ? target : next, not_ ? next : target); + build.inst(IrCmd::JUMP_CMP_INT, result, build.constInt(0), build.cond(IrCondition::Equal), not_ ? target : next, not_ ? next : target); build.beginBlock(next); } @@ -218,7 +220,7 @@ void translateInstJumpIfCond(IrBuilder& build, const Instruction* pc, int pcpos, } IrOp result = build.inst(IrCmd::CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(cond)); - build.inst(IrCmd::JUMP_EQ_INT, result, build.constInt(0), reverse ? target : next, reverse ? next : target); + build.inst(IrCmd::JUMP_CMP_INT, result, build.constInt(0), build.cond(IrCondition::Equal), reverse ? target : next, reverse ? next : target); build.beginBlock(next); } @@ -262,7 +264,7 @@ void translateInstJumpxEqB(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(checkValue); IrOp va = build.inst(IrCmd::LOAD_INT, build.vmReg(ra)); - build.inst(IrCmd::JUMP_EQ_INT, va, build.constInt(aux & 0x1), not_ ? next : target, not_ ? target : next); + build.inst(IrCmd::JUMP_CMP_INT, va, build.constInt(aux & 0x1), build.cond(IrCondition::Equal), not_ ? next : target, not_ ? target : next); // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(next)) @@ -607,6 +609,27 @@ IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool return fallback; } +// numeric for loop always ends with the computation of step that targets ra+1 +// any conditionals would result in a split basic block, so we can recover the step constants by pattern matching the IR we generated for LOADN/K +static IrOp getLoopStepK(IrBuilder& build, int ra) +{ + IrBlock& active = build.function.blocks[build.activeBlockIdx]; + + if (active.start + 2 < build.function.instructions.size()) + { + IrInst& sv = build.function.instructions[build.function.instructions.size() - 2]; + IrInst& st = build.function.instructions[build.function.instructions.size() - 1]; + + // We currently expect to match IR generated from LOADN/LOADK so we match a particular sequence of opcodes + // In the future this can be extended to cover opposite STORE order as well as STORE_SPLIT_TVALUE + if (sv.cmd == IrCmd::STORE_DOUBLE && sv.a.kind == IrOpKind::VmReg && sv.a.index == ra + 1 && sv.b.kind == IrOpKind::Constant && + st.cmd == IrCmd::STORE_TAG && st.a.kind == IrOpKind::VmReg && st.a.index == ra + 1 && build.function.tagOp(st.b) == LUA_TNUMBER) + return sv.b; + } + + return build.undef(); +} + void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) { int ra = LUAU_INSN_A(*pc); @@ -614,40 +637,103 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) IrOp loopStart = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); IrOp loopExit = build.blockAtInst(getJumpTarget(*pc, pcpos)); - IrOp direct = build.block(IrBlockKind::Internal); - IrOp reverse = build.block(IrBlockKind::Internal); + if (FFlag::LuauImproveForN) + { + IrOp stepK = getLoopStepK(build, ra); + build.loopStepStack.push_back(stepK); - // When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails - // Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant - // To avoid that overhead for an extreemely rare case (that doesn't even typecheck), we exit to VM to handle it - IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); - build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); - build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); - build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + // When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails + // Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant + // To avoid that overhead for an extremely rare case (that doesn't even typecheck), we exit to VM to handle it + IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); + build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - IrOp zero = build.constDouble(0.0); - IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); - IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); - // step <= 0 - build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + if (stepK.kind == IrOpKind::Undef) + { + IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - // TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); - // step <= 0 is false, check idx <= limit - build.beginBlock(direct); - build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopStart, loopExit); + IrOp zero = build.constDouble(0.0); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - // step <= 0 is true, check limit <= idx - build.beginBlock(reverse); - build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopStart, loopExit); + // step > 0 + // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse); + + // Condition to start the loop: step > 0 ? idx <= limit : limit <= idx + // We invert the condition so that loopStart is the fallthrough (false) label + + // step > 0 is false, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + + // step > 0 is true, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + } + else + { + double stepN = build.function.doubleOp(stepK); + + // Condition to start the loop: step > 0 ? idx <= limit : limit <= idx + // We invert the condition so that loopStart is the fallthrough (false) label + if (stepN > 0) + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + else + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + } + } + else + { + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); + + // When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails + // Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant + // To avoid that overhead for an extreemely rare case (that doesn't even typecheck), we exit to VM to handle it + IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); + build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + + IrOp zero = build.constDouble(0.0); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + + // step <= 0 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + + // TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation + + // step <= 0 is false, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopStart, loopExit); + + // step <= 0 is true, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopStart, loopExit); + } // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(loopStart)) build.beginBlock(loopStart); + + // VM places interrupt in FORNLOOP, but that creates a likely spill point for short loops that use loop index as INTERRUPT always spills + // We place the interrupt at the beginning of the loop body instead; VM uses FORNLOOP because it doesn't want to waste an extra instruction. + // Because loop block may not have been started yet (as it's started when lowering the first instruction!), we need to defer INTERRUPT placement. + if (FFlag::LuauImproveForN) + build.interruptRequested = true; } void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) @@ -657,29 +743,76 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) IrOp loopRepeat = build.blockAtInst(getJumpTarget(*pc, pcpos)); IrOp loopExit = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); - build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + if (FFlag::LuauImproveForN) + { + LUAU_ASSERT(!build.loopStepStack.empty()); + IrOp stepK = build.loopStepStack.back(); + build.loopStepStack.pop_back(); - IrOp zero = build.constDouble(0.0); - IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); - IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + IrOp zero = build.constDouble(0.0); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = stepK.kind == IrOpKind::Undef ? build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)) : stepK; - IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); - idx = build.inst(IrCmd::ADD_NUM, idx, step); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + idx = build.inst(IrCmd::ADD_NUM, idx, step); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); - IrOp direct = build.block(IrBlockKind::Internal); - IrOp reverse = build.block(IrBlockKind::Internal); + if (stepK.kind == IrOpKind::Undef) + { + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); - // step <= 0 - build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + // step > 0 + // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse); - // step <= 0 is false, check idx <= limit - build.beginBlock(direct); - build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + // Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx - // step <= 0 is true, check limit <= idx - build.beginBlock(reverse); - build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + // step > 0 is false, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + + // step > 0 is true, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + } + else + { + double stepN = build.function.doubleOp(stepK); + + // Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx + if (stepN > 0) + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + else + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + } + } + else + { + build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + + IrOp zero = build.constDouble(0.0); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + idx = build.inst(IrCmd::ADD_NUM, idx, step); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); + + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); + + // step <= 0 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + + // step <= 0 is false, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + + // step <= 0 is true, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + } // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(loopExit)) diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 07704388..d263d3aa 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -72,9 +72,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_EQ_TAG: - case IrCmd::JUMP_EQ_INT: - case IrCmd::JUMP_LT_INT: - case IrCmd::JUMP_GE_UINT: + case IrCmd::JUMP_CMP_INT: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_SLOT_MATCH: @@ -422,6 +420,45 @@ bool compare(double a, double b, IrCondition cond) return false; } +bool compare(int a, int b, IrCondition cond) +{ + switch (cond) + { + case IrCondition::Equal: + return a == b; + case IrCondition::NotEqual: + return a != b; + case IrCondition::Less: + return a < b; + case IrCondition::NotLess: + return !(a < b); + case IrCondition::LessEqual: + return a <= b; + case IrCondition::NotLessEqual: + return !(a <= b); + case IrCondition::Greater: + return a > b; + case IrCondition::NotGreater: + return !(a > b); + case IrCondition::GreaterEqual: + return a >= b; + case IrCondition::NotGreaterEqual: + return !(a >= b); + case IrCondition::UnsignedLess: + return unsigned(a) < unsigned(b); + case IrCondition::UnsignedLessEqual: + return unsigned(a) <= unsigned(b); + case IrCondition::UnsignedGreater: + return unsigned(a) > unsigned(b); + case IrCondition::UnsignedGreaterEqual: + return unsigned(a) >= unsigned(b); + default: + LUAU_ASSERT(!"Unsupported condition"); + } + + return false; +} + void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t index) { IrInst& inst = function.instructions[index]; @@ -540,31 +577,13 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 replace(function, block, index, {IrCmd::JUMP, inst.d}); } break; - case IrCmd::JUMP_EQ_INT: + case IrCmd::JUMP_CMP_INT: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { - if (function.intOp(inst.a) == function.intOp(inst.b)) - replace(function, block, index, {IrCmd::JUMP, inst.c}); - else + if (compare(function.intOp(inst.a), function.intOp(inst.b), conditionOp(inst.c))) replace(function, block, index, {IrCmd::JUMP, inst.d}); - } - break; - case IrCmd::JUMP_LT_INT: - if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) - { - if (function.intOp(inst.a) < function.intOp(inst.b)) - replace(function, block, index, {IrCmd::JUMP, inst.c}); else - replace(function, block, index, {IrCmd::JUMP, inst.d}); - } - break; - case IrCmd::JUMP_GE_UINT: - if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) - { - if (unsigned(function.intOp(inst.a)) >= unsigned(function.intOp(inst.b))) - replace(function, block, index, {IrCmd::JUMP, inst.c}); - else - replace(function, block, index, {IrCmd::JUMP, inst.d}); + replace(function, block, index, {IrCmd::JUMP, inst.e}); } break; case IrCmd::JUMP_CMP_NUM: diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 03c26cdd..de839061 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -17,6 +17,7 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) LUAU_FASTFLAGVARIABLE(LuauReuseHashSlots2, false) LUAU_FASTFLAGVARIABLE(LuauKeepVmapLinear, false) +LUAU_FASTFLAGVARIABLE(LuauMergeTagLoads, false) namespace Luau { @@ -502,9 +503,16 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& { case IrCmd::LOAD_TAG: if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff) + { substitute(function, inst, build.constTag(tag)); + } else if (inst.a.kind == IrOpKind::VmReg) - state.createRegLink(index, inst.a); + { + if (FFlag::LuauMergeTagLoads) + state.substituteOrRecordVmRegLoad(inst); + else + state.createRegLink(index, inst.a); + } break; case IrCmd::LOAD_POINTER: if (inst.a.kind == IrOpKind::VmReg) @@ -716,44 +724,20 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& else replace(function, block, index, {IrCmd::JUMP, inst.d}); } + else if (FFlag::LuauMergeTagLoads && inst.a == inst.b) + { + replace(function, block, index, {IrCmd::JUMP, inst.c}); + } break; } - case IrCmd::JUMP_EQ_INT: + case IrCmd::JUMP_CMP_INT: { std::optional valueA = function.asIntOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a)); std::optional valueB = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)); if (valueA && valueB) { - if (*valueA == *valueB) - replace(function, block, index, {IrCmd::JUMP, inst.c}); - else - replace(function, block, index, {IrCmd::JUMP, inst.d}); - } - break; - } - case IrCmd::JUMP_LT_INT: - { - std::optional valueA = function.asIntOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a)); - std::optional valueB = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)); - - if (valueA && valueB) - { - if (*valueA < *valueB) - replace(function, block, index, {IrCmd::JUMP, inst.c}); - else - replace(function, block, index, {IrCmd::JUMP, inst.d}); - } - break; - } - case IrCmd::JUMP_GE_UINT: - { - std::optional valueA = function.asUintOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a)); - std::optional valueB = function.asUintOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)); - - if (valueA && valueB) - { - if (*valueA >= *valueB) + if (compare(*valueA, *valueB, conditionOp(inst.c))) replace(function, block, index, {IrCmd::JUMP, inst.c}); else replace(function, block, index, {IrCmd::JUMP, inst.d}); diff --git a/Sources.cmake b/Sources.cmake index 4776b6b8..ddd514f5 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -167,6 +167,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h Analysis/include/Luau/Frontend.h + Analysis/include/Luau/GlobalTypes.h Analysis/include/Luau/InsertionOrderedMap.h Analysis/include/Luau/Instantiation.h Analysis/include/Luau/IostreamHelpers.h @@ -226,6 +227,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp + Analysis/src/GlobalTypes.cpp Analysis/src/Instantiation.cpp Analysis/src/IostreamHelpers.cpp Analysis/src/JsonEmitter.cpp @@ -365,6 +367,8 @@ if(TARGET Luau.UnitTest) tests/AstQueryDsl.cpp tests/AstQueryDsl.h tests/AstVisitor.test.cpp + tests/RegisterCallbacks.h + tests/RegisterCallbacks.cpp tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp tests/ClassFixture.cpp @@ -447,6 +451,8 @@ endif() if(TARGET Luau.Conformance) # Luau.Conformance Sources target_sources(Luau.Conformance PRIVATE + tests/RegisterCallbacks.h + tests/RegisterCallbacks.cpp tests/Conformance.test.cpp tests/main.cpp) endif() @@ -464,6 +470,8 @@ if(TARGET Luau.CLI.Test) CLI/Profiler.cpp CLI/Repl.cpp + tests/RegisterCallbacks.h + tests/RegisterCallbacks.cpp tests/Repl.test.cpp tests/main.cpp) endif() diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 44bba9c0..68efc0e4 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -135,6 +135,8 @@ // Does VM support native execution via ExecutionCallbacks? We mostly assume it does but keep the define to make it easy to quantify the cost. #define VM_HAS_NATIVE 1 +void (*lua_iter_call_telemetry)(lua_State* L, int gtt, int stt, int itt) = NULL; + LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) { ptrdiff_t base = savestack(L, L->base); @@ -2289,6 +2291,10 @@ reentry: { // table or userdata with __call, will be called during FORGLOOP // TODO: we might be able to stop supporting this depending on whether it's used in practice + void (*telemetrycb)(lua_State* L, int gtt, int stt, int itt) = lua_iter_call_telemetry; + + if (telemetrycb) + telemetrycb(L, ttype(ra), ttype(ra + 1), ttype(ra + 2)); } else if (ttistable(ra)) { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 6bfdc0f4..dd90dd8e 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,7 +15,6 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauAutocompleteLastTypecheck) using namespace Luau; @@ -34,36 +33,27 @@ struct ACFixtureImpl : BaseType AutocompleteResult autocomplete(unsigned row, unsigned column) { - if (FFlag::LuauAutocompleteLastTypecheck) - { - FrontendOptions opts; - opts.forAutocomplete = true; - this->frontend.check("MainModule", opts); - } + FrontendOptions opts; + opts.forAutocomplete = true; + this->frontend.check("MainModule", opts); return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); } AutocompleteResult autocomplete(char marker, StringCompletionCallback callback = nullCallback) { - if (FFlag::LuauAutocompleteLastTypecheck) - { - FrontendOptions opts; - opts.forAutocomplete = true; - this->frontend.check("MainModule", opts); - } + FrontendOptions opts; + opts.forAutocomplete = true; + this->frontend.check("MainModule", opts); return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), callback); } AutocompleteResult autocomplete(const ModuleName& name, Position pos, StringCompletionCallback callback = nullCallback) { - if (FFlag::LuauAutocompleteLastTypecheck) - { - FrontendOptions opts; - opts.forAutocomplete = true; - this->frontend.check(name, opts); - } + FrontendOptions opts; + opts.forAutocomplete = true; + this->frontend.check(name, opts); return Luau::autocomplete(this->frontend, name, pos, callback); } @@ -3699,8 +3689,6 @@ TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_empty") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: () -> ()) a() @@ -3722,8 +3710,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (number, string) -> ()) a() @@ -3745,8 +3731,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_single_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (number, string) -> (string)) a() @@ -3768,8 +3752,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (number, string) -> (string, number)) a() @@ -3791,8 +3773,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__noargs_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: () -> (string, number)) a() @@ -3814,8 +3794,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__varargs_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (...number) -> (string, number)) a() @@ -3837,8 +3815,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (string, ...number) -> (string, number)) a() @@ -3860,8 +3836,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_varargs_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (string, ...number) -> ...number) a() @@ -3883,8 +3857,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_varargs_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (string, ...number) -> (boolean, ...number)) a() @@ -3906,8 +3878,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_named_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (foo: number, bar: string) -> (string, number)) a() @@ -3929,8 +3899,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (number, bar: string) -> (string, number)) a() @@ -3952,8 +3920,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args_last") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (foo: number, string) -> (string, number)) a() @@ -3975,8 +3941,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local t = { a = 1, b = 2 } @@ -4000,8 +3964,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (tbl: { x: number, y: number }) -> number) return a({x=2, y = 3}) end foo(@1) @@ -4020,8 +3982,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_returns") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local t = { a = 1, b = 2 } @@ -4045,8 +4005,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: () -> { x: number, y: number }) return {x=2, y = 3} end foo(@1) @@ -4065,8 +4023,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_vararg") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local t = { a = 1, b = 2 } @@ -4090,8 +4046,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_type_pack_vararg") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (...A) -> number, ...: A) return a(...) @@ -4113,8 +4067,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_on_argument_type_pack_vararg") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (...: T...) -> number) return a(4, 5, 6) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 75c38762..c8918954 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -282,6 +282,8 @@ TEST_CASE("Assert") TEST_CASE("Basic") { ScopedFastFlag sffs{"LuauFloorDivision", true}; + ScopedFastFlag sfff{"LuauImproveForN", true}; + runConformance("basic.lua"); } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index d4fa7178..f31727e0 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -9,6 +9,7 @@ #include "Luau/Parser.h" #include "Luau/Type.h" #include "Luau/TypeAttach.h" +#include "Luau/TypeInfer.h" #include "Luau/Transpiler.h" #include "doctest.h" @@ -144,8 +145,6 @@ Fixture::Fixture(bool freeze, bool prepareAutocomplete) configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.parseOptions.captureComments = true; - registerBuiltinTypes(frontend.globals); - Luau::freeze(frontend.globals.globalTypes); Luau::freeze(frontend.globalsForAutocomplete.globalTypes); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index ec0c213a..96032fd1 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1222,4 +1222,28 @@ TEST_CASE_FIXTURE(FrontendFixture, "parse_only") CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(FrontendFixture, "markdirty_early_return") +{ + ScopedFastFlag fflag("CorrectEarlyReturnInMarkDirty", true); + + constexpr char moduleName[] = "game/Gui/Modules/A"; + fileResolver.source[moduleName] = R"( + return 1 + )"; + + { + std::vector markedDirty; + frontend.markDirty(moduleName, &markedDirty); + CHECK(markedDirty.empty()); + } + + frontend.parse(moduleName); + + { + std::vector markedDirty; + frontend.markDirty(moduleName, &markedDirty); + CHECK(!markedDirty.empty()); + } +} + TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index a736f5cf..b511156d 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -621,11 +621,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowEq") }); withTwoBlocks([this](IrOp a, IrOp b) { - build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(0), a, b); + build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(0), build.cond(IrCondition::Equal), a, b); }); withTwoBlocks([this](IrOp a, IrOp b) { - build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), a, b); + build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(1), build.cond(IrCondition::Equal), a, b); }); updateUseCounts(build.function); @@ -1359,7 +1359,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") build.beginBlock(block); build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(5)); IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); - build.inst(IrCmd::JUMP_EQ_INT, value, build.constInt(5), trueBlock, falseBlock); + build.inst(IrCmd::JUMP_CMP_INT, value, build.constInt(5), build.cond(IrCondition::Equal), trueBlock, falseBlock); build.beginBlock(trueBlock); build.inst(IrCmd::RETURN, build.constUint(1)); @@ -1556,7 +1556,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") IrOp repeat = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), block, exit1); + build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(1), build.cond(IrCondition::Equal), block, exit1); build.beginBlock(exit1); build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); @@ -2785,4 +2785,37 @@ bb_0: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "TagSelfEqualityCheckRemoval") +{ + ScopedFastFlag luauMergeTagLoads{"LuauMergeTagLoads", true}; + + IrOp entry = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + + IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(0)); + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(0)); + build.inst(IrCmd::JUMP_EQ_TAG, tag1, tag2, trueBlock, falseBlock); + + build.beginBlock(trueBlock); + build.inst(IrCmd::RETURN, build.constUint(1)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::RETURN, build.constUint(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + JUMP bb_1 + +bb_1: + RETURN 1u + +)"); +} + TEST_SUITE_END(); diff --git a/tests/RegisterCallbacks.cpp b/tests/RegisterCallbacks.cpp new file mode 100644 index 00000000..9f471933 --- /dev/null +++ b/tests/RegisterCallbacks.cpp @@ -0,0 +1,20 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "RegisterCallbacks.h" + +namespace Luau +{ + +std::unordered_set& getRegisterCallbacks() +{ + static std::unordered_set cbs; + return cbs; +} + +int addTestCallback(RegisterCallback cb) +{ + getRegisterCallbacks().insert(cb); + return 0; +} + +} // namespace Luau diff --git a/tests/RegisterCallbacks.h b/tests/RegisterCallbacks.h new file mode 100644 index 00000000..f62ac0e7 --- /dev/null +++ b/tests/RegisterCallbacks.h @@ -0,0 +1,22 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +namespace Luau +{ + +using RegisterCallback = void (*)(); + +/// Gets a set of callbacks to run immediately before running tests, intended +/// for registering new tests at runtime. +std::unordered_set& getRegisterCallbacks(); + +/// Adds a new callback to be ran immediately before running tests. +/// +/// @param cb the callback to add. +/// @returns a dummy integer to satisfy a doctest internal contract. +int addTestCallback(RegisterCallback cb); + +} // namespace Luau diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index 23d05f91..6089f036 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -2,7 +2,9 @@ #include "doctest.h" #include "Fixture.h" +#include "RegisterCallbacks.h" +#include "Luau/Normalize.h" #include "Luau/Subtyping.h" #include "Luau/TypePack.h" @@ -344,14 +346,72 @@ struct SubtypeFixture : Fixture CHECK_MESSAGE(!result.isErrorSuppressing, "Expected " << leftTy << " to error-suppress " << rightTy); \ } while (0) +/// Internal macro for registering a generated test case. +/// +/// @param der the name of the derived fixture struct +/// @param reg the name of the registration callback, invoked immediately before +/// tests are ran to register the test +/// @param run the name of the run callback, invoked to actually run the test case +#define TEST_REGISTER(der, reg, run) \ + static inline DOCTEST_NOINLINE void run() \ + { \ + der fix; \ + fix.test(); \ + } \ + static inline DOCTEST_NOINLINE void reg() \ + { \ + /* we have to mark this as `static` to ensure the memory remains alive \ + for the entirety of the test process */ \ + static std::string name = der().testName; \ + doctest::detail::regTest(doctest::detail::TestCase(run, __FILE__, __LINE__, \ + doctest_detail_test_suite_ns::getCurrentTestSuite()) /* the test case's name, determined at runtime */ \ + * name.c_str() /* getCurrentTestSuite() only works at static initialization \ + time due to implementation details. To ensure that test cases \ + are grouped where they should be, manually override the suite \ + with the test_suite decorator. */ \ + * doctest::test_suite("Subtyping")); \ + } \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), addTestCallback(reg)); + +/// Internal macro for deriving a test case fixture. Roughly analogous to +/// DOCTEST_IMPLEMENT_FIXTURE. +/// +/// @param op a function (or macro) to call that compares the subtype to +/// the supertype. +/// @param symbol the symbol to use in stringification +/// @param der the name of the derived fixture struct +/// @param left the subtype expression +/// @param right the supertype expression +#define TEST_DERIVE(op, symbol, der, left, right) \ + namespace \ + { \ + struct der : SubtypeFixture \ + { \ + const TypeId subTy = (left); \ + const TypeId superTy = (right); \ + const std::string testName = toString(subTy) + " " symbol " " + toString(superTy); \ + inline DOCTEST_NOINLINE void test() \ + { \ + op(subTy, superTy); \ + } \ + }; \ + TEST_REGISTER(der, DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_)); \ + } + +/// Generates a test that checks if a type is a subtype of another. +#define TEST_IS_SUBTYPE(left, right) TEST_DERIVE(CHECK_IS_SUBTYPE, "<:", DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), left, right) + +/// Generates a test that checks if a type is _not_ a subtype of another. +/// Uses numberType, builtinTypes->anyType); -} +TEST_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->anyType); +TEST_IS_NOT_SUBTYPE(builtinTypes->numberType, builtinTypes->stringType); TEST_CASE_FIXTURE(SubtypeFixture, "any numberType, builtinTypes->numberType); } -TEST_CASE_FIXTURE(SubtypeFixture, "number numberType, builtinTypes->stringType); -} - TEST_CASE_FIXTURE(SubtypeFixture, "number <: number?") { CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->optionalNumberType); @@ -895,6 +950,16 @@ TEST_CASE_FIXTURE(SubtypeFixture, "string ( CHECK_IS_NOT_SUBTYPE(builtinTypes->stringType, tableWithoutScalarProp); } +TEST_CASE_FIXTURE(SubtypeFixture, "~fun & (string) -> number <: (string) -> number") +{ + CHECK_IS_SUBTYPE(meet(negate(builtinTypes->functionType), numberToStringType), numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(string) -> number <: ~fun & (string) -> number") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, meet(negate(builtinTypes->functionType), numberToStringType)); +} + /* * (A) -> A <: (X) -> X * A can be bound to X. diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 9293bfb2..ac01b5f3 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -44,25 +44,34 @@ TEST_SUITE_BEGIN("ToDot"); TEST_CASE_FIXTURE(Fixture, "primitive") { - CheckResult result = check(R"( -local a: nil -local b: number -local c: any -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_NE("nil", toDot(requireType("a"))); + CHECK_EQ(R"(digraph graphname { +n1 [label="nil"]; +})", + toDot(builtinTypes->nilType)); CHECK_EQ(R"(digraph graphname { n1 [label="number"]; })", - toDot(requireType("b"))); + toDot(builtinTypes->numberType)); CHECK_EQ(R"(digraph graphname { n1 [label="any"]; })", - toDot(requireType("c"))); + toDot(builtinTypes->anyType)); + CHECK_EQ(R"(digraph graphname { +n1 [label="unknown"]; +})", + toDot(builtinTypes->unknownType)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="never"]; +})", + toDot(builtinTypes->neverType)); +} + +TEST_CASE_FIXTURE(Fixture, "no_duplicatePrimitives") +{ ToDotOptions opts; opts.showPointers = false; opts.duplicatePrimitives = false; @@ -70,12 +79,22 @@ n1 [label="any"]; CHECK_EQ(R"(digraph graphname { n1 [label="PrimitiveType number"]; })", - toDot(requireType("b"), opts)); + toDot(builtinTypes->numberType, opts)); CHECK_EQ(R"(digraph graphname { n1 [label="AnyType 1"]; })", - toDot(requireType("c"), opts)); + toDot(builtinTypes->anyType, opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="UnknownType 1"]; +})", + toDot(builtinTypes->unknownType, opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="NeverType 1"]; +})", + toDot(builtinTypes->neverType, opts)); } TEST_CASE_FIXTURE(Fixture, "bound") @@ -283,6 +302,30 @@ n1 [label="FreeType 1"]; toDot(&type, opts)); } +TEST_CASE_FIXTURE(Fixture, "free_with_constraints") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + }; + + Type type{TypeVariant{FreeType{nullptr, builtinTypes->numberType, builtinTypes->optionalNumberType}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FreeType 1"]; +n1 -> n2 [label="[lowerBound]"]; +n2 [label="number"]; +n1 -> n3 [label="[upperBound]"]; +n3 [label="UnionType 3"]; +n3 -> n4; +n4 [label="number"]; +n3 -> n5; +n5 [label="nil"]; +})", + toDot(&type, opts)); +} + TEST_CASE_FIXTURE(Fixture, "error") { Type type{TypeVariant{ErrorType{}}}; @@ -440,4 +483,19 @@ n5 [label="SingletonType boolean: false"]; toDot(requireType("x"), opts)); } +TEST_CASE_FIXTURE(Fixture, "negation") +{ + TypeArena arena; + TypeId t = arena.addType(NegationType{builtinTypes->stringType}); + + ToDotOptions opts; + opts.showPointers = false; + + CHECK(R"(digraph graphname { +n1 [label="NegationType 1"]; +n1 -> n2 [label="[negated]"]; +n2 [label="string"]; +})" == toDot(t, opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index 6e6dba09..c7ab66dd 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -1,5 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeFamily.h" + +#include "Luau/TxnLog.h" #include "Luau/Type.h" #include "Fixture.h" diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 794171fb..d4dbfa91 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -1006,8 +1006,6 @@ end // We would prefer this unification to be able to complete, but at least it should not crash TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_infinite_recursion") { - ScopedFastFlag luauTableUnifyRecursionLimit{"LuauTableUnifyRecursionLimit", true}; - #if defined(_NOOPT) || defined(_DEBUG) ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; #endif diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 8af5c684..aeabf0ac 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1404,4 +1404,32 @@ TEST_CASE_FIXTURE(Fixture, "promote_tail_type_packs") LUAU_REQUIRE_NO_ERRORS(result); } +/* + * CLI-49876 + * + * We had a bug where we would not use the correct TxnLog when evaluating a + * variadic overload. We could therefore get into a state where the TxnLog has + * logged that a generic matches to one type, but the variadic tail has already + * been bound to another type outside of that TxnLog. + * + * This caused type checking to succeed when it should have failed. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload") +{ + ScopedFastFlag sff{"LuauVariadicOverloadFix", true}; + + CheckResult result = check(R"( + local function concat(target: {T}, ...: {T} | T): {T} + return (nil :: any) :: {T} + end + + local res = concat({"alic"}, 1, 2) + )"); + + LUAU_REQUIRE_ERRORS(result); + + for (const auto& e: result.errors) + CHECK(5 == e.location.begin.line); +} + TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 42030c55..17f4497a 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -177,6 +177,33 @@ assert((function() local a = 1 for b=1,9 do a = a * 2 if a == 128 then break els -- make sure internal index is protected against modification assert((function() local a = 1 for b=9,1,-2 do a = a * 2 b = nil end return a end)() == 32) +-- make sure that when step is 0, we treat it as backward iteration (and as such, iterate zero times or indefinitely) +-- this is consistent with Lua 5.1; future Lua versions emit an error when step is 0; LuaJIT instead treats 0 as forward iteration +-- we repeat tests twice, with and without constant folding +local zero = tonumber("0") +assert((function() local c = 0 for i=1,10,0 do c += 1 if c > 10 then break end end return c end)() == 0) +assert((function() local c = 0 for i=10,1,0 do c += 1 if c > 10 then break end end return c end)() == 11) +assert((function() local c = 0 for i=1,10,zero do c += 1 if c > 10 then break end end return c end)() == 0) +assert((function() local c = 0 for i=10,1,zero do c += 1 if c > 10 then break end end return c end)() == 11) + +-- make sure that when limit is nan, we iterate zero times (this is consistent with Lua 5.1; future Lua versions break this) +-- we repeat tests twice, with and without constant folding +local nan = tonumber("nan") +assert((function() local c = 0 for i=1,0/0 do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=1,0/0,-1 do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=1,nan do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=1,nan,-1 do c += 1 end return c end)() == 0) + +-- make sure that when step is nan, we treat it as backward iteration and as such iterate once iff start<=limit +assert((function() local c = 0 for i=1,10,0/0 do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=10,1,0/0 do c += 1 end return c end)() == 1) +assert((function() local c = 0 for i=1,10,nan do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=10,1,nan do c += 1 end return c end)() == 1) + +-- make sure that when index becomes nan mid-iteration, we correctly exit the loop (this is broken in Lua 5.1; future Lua versions fix this) +assert((function() local c = 0 for i=-math.huge,0,math.huge do c += 1 end return c end)() == 1) +assert((function() local c = 0 for i=math.huge,math.huge,-math.huge do c += 1 end return c end)() == 1) + -- generic for -- ipairs assert((function() local a = '' for k in ipairs({5, 6, 7}) do a = a .. k end return a end)() == "123") @@ -286,6 +313,10 @@ assert((function() return result end)() == "ArcticDunesCanyonsWaterMountainsHillsLavaflowPlainsMarsh") +-- table literals may contain duplicate fields; the language doesn't specify assignment order but we currently assign left to right +assert((function() local t = {data = 4, data = nil, data = 42} return t.data end)() == 42) +assert((function() local t = {data = 4, data = nil, data = 42, data = nil} return t.data end)() == nil) + -- multiple returns -- local= assert((function() function foo() return 2, 3, 4 end local a, b, c = foo() return ''..a..b..c end)() == "234") diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 6b5bfc5f..4027bba0 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -189,6 +189,26 @@ do -- testing NaN assert(a[NaN] == nil) end +-- extra NaN tests, hidden in a function +do + function neq(a) return a ~= a end + function eq(a) return a == a end + function lt(a) return a < a end + function le(a) return a <= a end + function gt(a) return a > a end + function ge(a) return a >= a end + + local NaN -- to avoid constant folding + NaN = 10e500 - 10e400 + + assert(neq(NaN)) + assert(not eq(NaN)) + assert(not lt(NaN)) + assert(not le(NaN)) + assert(not gt(NaN)) + assert(not ge(NaN)) +end + -- require "checktable" -- stat(a) diff --git a/tests/main.cpp b/tests/main.cpp index 9435c61a..26872196 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -6,6 +6,8 @@ #define DOCTEST_CONFIG_OPTIONS_PREFIX "" #include "doctest.h" +#include "RegisterCallbacks.h" + #ifdef _WIN32 #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN @@ -327,6 +329,14 @@ int main(int argc, char** argv) } } + // These callbacks register unit tests that need runtime support to be + // correctly set up. Running them here means that all command line flags + // have been parsed, fast flags have been set, and we've potentially already + // exited. Once doctest::Context::run is invoked, the test list will be + // picked up from global state. + for (Luau::RegisterCallback cb : Luau::getRegisterCallbacks()) + cb(); + int result = context.run(); if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h")) {