diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 2223c29e..c7bc58b5 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -107,6 +107,11 @@ struct FunctionCallConstraint TypePackId result; class AstExprCall* callSite; std::vector> discriminantTypes; + + // When we dispatch this constraint, we update the key at this map to record + // the overload that we selected. + DenseHashMap* astOriginalCallTypes; + DenseHashMap* astOverloadResolvedTypes; }; // result ~ prim ExpectedType SomeSingletonType MultitonType diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 67e840ee..14bf2e2e 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -28,6 +28,7 @@ struct FileResolver; struct ModuleResolver; struct ParseResult; struct HotComment; +struct BuildQueueItem; struct LoadDefinitionFileResult { @@ -171,7 +172,18 @@ struct Frontend LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete = false); + // Batch module checking. Queue modules and check them together, retrieve results with 'getCheckResult' + // If provided, 'executeTask' function is allowed to call the 'task' function on any thread and return without waiting for 'task' to complete + void queueModuleCheck(const std::vector& names); + void queueModuleCheck(const ModuleName& name); + std::vector checkQueuedModules(std::optional optionOverride = {}, + std::function task)> executeTask = {}, std::function progress = {}); + + std::optional getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); + private: + CheckResult check_DEPRECATED(const ModuleName& name, std::optional optionOverride = {}); + struct TypeCheckLimits { std::optional finishTime; @@ -185,7 +197,14 @@ private: std::pair getSourceNode(const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); - bool parseGraph(std::vector& buildQueue, const ModuleName& root, bool forAutocomplete); + bool parseGraph( + std::vector& buildQueue, const ModuleName& root, bool forAutocomplete, std::function canSkip = {}); + + void addBuildQueueItems(std::vector& items, std::vector& buildQueue, bool cycleDetected, + std::unordered_set& seen, const FrontendOptions& frontendOptions); + void checkBuildQueueItem(BuildQueueItem& item); + void checkBuildQueueItems(std::vector& items); + void recordItemResult(const BuildQueueItem& item); static LintResult classifyLints(const std::vector& warnings, const Config& config); @@ -212,11 +231,13 @@ public: InternalErrorReporter iceHandler; std::function prepareModuleScope; - std::unordered_map sourceNodes; - std::unordered_map sourceModules; + std::unordered_map> sourceNodes; + std::unordered_map> sourceModules; std::unordered_map requireTrace; Stats stats = {}; + + std::vector moduleQueue; }; ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 6c808286..2ec5406f 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -226,10 +226,6 @@ struct NormalizedType NormalizedClassType classes; - // The class part of the type. - // Each element of this set is a class, and none of the classes are subclasses of each other. - TypeIds DEPRECATED_classes; - // The error part of the type. // This type is either never or the error type. TypeId errors; @@ -333,8 +329,6 @@ public: // ------- Normalizing intersections TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there); - void DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres); - void DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there); void intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres); void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 5d92cbd0..c615b8f5 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -694,7 +694,7 @@ bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs); // Follow BoundTypes until we get to something real TypeId follow(TypeId t); -TypeId follow(TypeId t, std::function mapper); +TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)); std::vector flattenIntersection(TypeId ty); diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 2ae56e5f..e78a66b8 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -169,7 +169,7 @@ using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); -TypePackId follow(TypePackId tp, std::function mapper); +TypePackId follow(TypePackId t, const void* context, TypePackId (*mapper)(const void*, TypePackId)); size_t size(TypePackId tp, TxnLog* log = nullptr); bool finite(TypePackId tp, TxnLog* log = nullptr); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index e3b0a878..742f029c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -163,5 +163,6 @@ private: void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); std::optional hasUnificationTooComplex(const ErrorVec& errors); +std::optional hasCountMismatch(const ErrorVec& errors); } // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 611f420a..e07fe701 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -18,7 +18,6 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauMagicTypes); -LUAU_FASTFLAG(LuauNegatedClassTypes); namespace Luau { @@ -1016,7 +1015,7 @@ static bool isMetamethod(const Name& name) ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) { - std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; + std::optional superTy = std::make_optional(builtinTypes->classType); if (declaredClass->superName) { Name superName = Name(declaredClass->superName->value); @@ -1420,6 +1419,8 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa rets, call, std::move(discriminantTypes), + &module->astOriginalCallTypes, + &module->astOverloadResolvedTypes, }); // We force constraints produced by checking function arguments to wait @@ -1772,7 +1773,7 @@ std::tuple ConstraintGraphBuilder::checkBinary( TypeId ty = follow(typeFun->type); // We're only interested in the root class of any classes. - if (auto ctv = get(ty); !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent == builtinTypes->classType) : !ctv->parent)) + if (auto ctv = get(ty); !ctv || ctv->parent == builtinTypes->classType) discriminantTy = ty; } @@ -1786,8 +1787,10 @@ std::tuple ConstraintGraphBuilder::checkBinary( } else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) { - TypeId leftType = check(scope, binary->left, ValueContext::RValue, expectedType, true).ty; - TypeId rightType = check(scope, binary->right, ValueContext::RValue, expectedType, true).ty; + // We are checking a binary expression of the form a op b + // Just because a op b is epxected to return a bool, doesn't mean a, b are expected to be bools too + TypeId leftType = check(scope, binary->left, ValueContext::RValue, {}, true).ty; + TypeId rightType = check(scope, binary->right, ValueContext::RValue, {}, true).ty; RefinementId leftRefinement = nullptr; if (auto bc = dfg->getBreadcrumb(binary->left)) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index ec63b25e..f1f868ad 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1172,6 +1172,9 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn)) fn = collapse(it).value_or(fn); + if (c.callSite) + (*c.astOriginalCallTypes)[c.callSite] = fn; + // We don't support magic __call metamethods. if (std::optional callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location)) { @@ -1219,10 +1222,22 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); - std::vector overloads = flattenIntersection(fn); + const NormalizedType* normFn = normalizer->normalize(fn); + if (!normFn) + { + reportError(UnificationTooComplex{}, constraint->location); + return true; + } + + // TODO: It would be nice to not need to convert the normalized type back to + // an intersection and flatten it. + TypeId normFnTy = normalizer->typeFromNormal(*normFn); + std::vector overloads = flattenIntersection(normFnTy); Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); + std::vector arityMatchingOverloads; + for (TypeId overload : overloads) { overload = follow(overload); @@ -1247,8 +1262,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(*e)->context != CountMismatch::Context::Arg) && get(*instantiated)) + { + arityMatchingOverloads.push_back(*instantiated); + } + if (u.errors.empty()) { + if (c.callSite) + (*c.astOverloadResolvedTypes)[c.callSite] = *instantiated; + // We found a matching overload. const auto [changedTypes, changedPacks] = u.log.getChanges(); u.log.commit(); @@ -1260,6 +1284,15 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullscope, Location{}, Covariant}; u.useScopes = true; @@ -1267,8 +1300,6 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullanyType); u.tryUnify(fn, builtinTypes->anyType); - LUAU_ASSERT(u.errors.empty()); // unifying with any should never fail - const auto [changedTypes, changedPacks] = u.log.getChanges(); u.log.commit(); @@ -2166,13 +2197,24 @@ void ConstraintSolver::unblock(NotNull progressed) void ConstraintSolver::unblock(TypeId progressed) { - if (logger) - logger->popBlock(progressed); + DenseHashSet seen{nullptr}; - unblock_(progressed); + while (true) + { + if (seen.find(progressed)) + iceReporter.ice("ConstraintSolver::unblock encountered a self-bound type!"); + seen.insert(progressed); - if (auto bt = get(progressed)) - unblock(bt->boundTo); + if (logger) + logger->popBlock(progressed); + + unblock_(progressed); + + if (auto bt = get(progressed)) + progressed = bt->boundTo; + else + break; + } } void ConstraintSolver::unblock(TypePackId progressed) diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 486ef696..b6b315cf 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -21,6 +21,9 @@ #include #include +#include +#include +#include #include #include @@ -34,10 +37,36 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAG(LuauRequirePathTrueModuleName) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) +LUAU_FASTFLAGVARIABLE(LuauSplitFrontendProcessing, false) namespace Luau { +struct BuildQueueItem +{ + ModuleName name; + ModuleName humanReadableName; + + // Parameters + std::shared_ptr sourceNode; + std::shared_ptr sourceModule; + Config config; + ScopePtr environmentScope; + std::vector requireCycles; + FrontendOptions options; + bool recordJsonLog = false; + + // Queue state + std::vector reverseDeps; + int dirtyDependencies = 0; + bool processing = false; + + // Result + std::exception_ptr exception; + ModulePtr module; + Frontend::Stats stats; +}; + std::optional parseMode(const std::vector& hotcomments) { for (const HotComment& hc : hotcomments) @@ -220,7 +249,7 @@ namespace { static ErrorVec accumulateErrors( - const std::unordered_map& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name) + const std::unordered_map>& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name) { std::unordered_set seen; std::vector queue{name}; @@ -240,7 +269,7 @@ static ErrorVec accumulateErrors( if (it == sourceNodes.end()) continue; - const SourceNode& sourceNode = it->second; + const SourceNode& sourceNode = *it->second; queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end()); // FIXME: If a module has a syntax error, we won't be able to re-report it here. @@ -285,8 +314,8 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector getRequireCycles( - const FileResolver* resolver, const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) +std::vector getRequireCycles(const FileResolver* resolver, + const std::unordered_map>& sourceNodes, const SourceNode* start, bool stopAtFirst = false) { std::vector result; @@ -302,7 +331,7 @@ std::vector getRequireCycles( if (dit == sourceNodes.end()) continue; - stack.push_back(&dit->second); + stack.push_back(dit->second.get()); while (!stack.empty()) { @@ -343,7 +372,7 @@ std::vector getRequireCycles( auto rit = sourceNodes.find(reqName); if (rit != sourceNodes.end()) - stack.push_back(&rit->second); + stack.push_back(rit->second.get()); } } } @@ -389,6 +418,52 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c } CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) +{ + if (!FFlag::LuauSplitFrontendProcessing) + return check_DEPRECATED(name, optionOverride); + + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + + FrontendOptions frontendOptions = optionOverride.value_or(options); + + if (std::optional result = getCheckResult(name, true, frontendOptions.forAutocomplete)) + return std::move(*result); + + std::vector buildQueue; + bool cycleDetected = parseGraph(buildQueue, name, frontendOptions.forAutocomplete); + + std::unordered_set seen; + std::vector buildQueueItems; + addBuildQueueItems(buildQueueItems, buildQueue, cycleDetected, seen, frontendOptions); + LUAU_ASSERT(!buildQueueItems.empty()); + + if (FFlag::DebugLuauLogSolverToJson) + { + LUAU_ASSERT(buildQueueItems.back().name == name); + buildQueueItems.back().recordJsonLog = true; + } + + checkBuildQueueItems(buildQueueItems); + + // Collect results only for checked modules, 'getCheckResult' produces a different result + CheckResult checkResult; + + for (const BuildQueueItem& item : buildQueueItems) + { + if (item.module->timeout) + checkResult.timeoutHits.push_back(item.name); + + checkResult.errors.insert(checkResult.errors.end(), item.module->errors.begin(), item.module->errors.end()); + + if (item.name == name) + checkResult.lintResult = item.module->lintResult; + } + + return checkResult; +} + +CheckResult Frontend::check_DEPRECATED(const ModuleName& name, std::optional optionOverride) { LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); @@ -399,7 +474,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond.hasDirtyModule(frontendOptions.forAutocomplete)) + if (it != sourceNodes.end() && !it->second->hasDirtyModule(frontendOptions.forAutocomplete)) { // No recheck required. ModulePtr module = resolver.getModule(name); @@ -421,13 +496,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalgetConfig(moduleName); @@ -583,7 +658,241 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& buildQueue, const ModuleName& root, bool forAutocomplete) +void Frontend::queueModuleCheck(const std::vector& names) +{ + moduleQueue.insert(moduleQueue.end(), names.begin(), names.end()); +} + +void Frontend::queueModuleCheck(const ModuleName& name) +{ + moduleQueue.push_back(name); +} + +std::vector Frontend::checkQueuedModules(std::optional optionOverride, + std::function task)> executeTask, std::function progress) +{ + FrontendOptions frontendOptions = optionOverride.value_or(options); + + // By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown + std::vector currModuleQueue; + std::swap(currModuleQueue, moduleQueue); + + std::unordered_set seen; + std::vector buildQueueItems; + + for (const ModuleName& name : currModuleQueue) + { + if (seen.count(name)) + continue; + + if (!isDirty(name, frontendOptions.forAutocomplete)) + { + seen.insert(name); + continue; + } + + std::vector queue; + bool cycleDetected = parseGraph(queue, name, frontendOptions.forAutocomplete, [&seen](const ModuleName& name) { + return seen.count(name); + }); + + addBuildQueueItems(buildQueueItems, queue, cycleDetected, seen, frontendOptions); + } + + if (buildQueueItems.empty()) + return {}; + + // We need a mapping from modules to build queue slots + std::unordered_map moduleNameToQueue; + + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; + moduleNameToQueue[item.name] = i; + } + + // Default task execution is single-threaded and immediate + if (!executeTask) + { + executeTask = [](std::function task) { + task(); + }; + } + + std::mutex mtx; + std::condition_variable cv; + std::vector readyQueueItems; + + size_t processing = 0; + size_t remaining = buildQueueItems.size(); + + auto itemTask = [&](size_t i) { + BuildQueueItem& item = buildQueueItems[i]; + + try + { + checkBuildQueueItem(item); + } + catch (...) + { + item.exception = std::current_exception(); + } + + { + std::unique_lock guard(mtx); + readyQueueItems.push_back(i); + } + + cv.notify_one(); + }; + + auto sendItemTask = [&](size_t i) { + BuildQueueItem& item = buildQueueItems[i]; + + item.processing = true; + processing++; + + executeTask([&itemTask, i]() { + itemTask(i); + }); + }; + + auto sendCycleItemTask = [&] { + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; + + if (!item.processing) + { + sendItemTask(i); + break; + } + } + }; + + // In a first pass, check modules that have no dependencies and record info of those modules that wait + for (size_t i = 0; i < buildQueueItems.size(); i++) + { + BuildQueueItem& item = buildQueueItems[i]; + + for (const ModuleName& dep : item.sourceNode->requireSet) + { + if (auto it = sourceNodes.find(dep); it != sourceNodes.end()) + { + if (it->second->hasDirtyModule(frontendOptions.forAutocomplete)) + { + item.dirtyDependencies++; + + buildQueueItems[moduleNameToQueue[dep]].reverseDeps.push_back(i); + } + } + } + + if (item.dirtyDependencies == 0) + sendItemTask(i); + } + + // Not a single item was found, a cycle in the graph was hit + if (processing == 0) + sendCycleItemTask(); + + std::vector nextItems; + + while (remaining != 0) + { + { + std::unique_lock guard(mtx); + + // If nothing is ready yet, wait + if (readyQueueItems.empty()) + { + cv.wait(guard, [&readyQueueItems] { + return !readyQueueItems.empty(); + }); + } + + // Handle checked items + for (size_t i : readyQueueItems) + { + const BuildQueueItem& item = buildQueueItems[i]; + recordItemResult(item); + + // Notify items that were waiting for this dependency + for (size_t reverseDep : item.reverseDeps) + { + BuildQueueItem& reverseDepItem = buildQueueItems[reverseDep]; + + LUAU_ASSERT(reverseDepItem.dirtyDependencies != 0); + reverseDepItem.dirtyDependencies--; + + // In case of a module cycle earlier, check if unlocked an item that was already processed + if (!reverseDepItem.processing && reverseDepItem.dirtyDependencies == 0) + nextItems.push_back(reverseDep); + } + } + + LUAU_ASSERT(processing >= readyQueueItems.size()); + processing -= readyQueueItems.size(); + + LUAU_ASSERT(remaining >= readyQueueItems.size()); + remaining -= readyQueueItems.size(); + readyQueueItems.clear(); + } + + if (progress) + progress(buildQueueItems.size() - remaining, buildQueueItems.size()); + + // Items cannot be submitted while holding the lock + for (size_t i : nextItems) + sendItemTask(i); + nextItems.clear(); + + // If we aren't done, but don't have anything processing, we hit a cycle + if (remaining != 0 && processing == 0) + sendCycleItemTask(); + } + + std::vector checkedModules; + checkedModules.reserve(buildQueueItems.size()); + + for (size_t i = 0; i < buildQueueItems.size(); i++) + checkedModules.push_back(std::move(buildQueueItems[i].name)); + + return checkedModules; +} + +std::optional Frontend::getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete) +{ + auto it = sourceNodes.find(name); + + if (it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete)) + return std::nullopt; + + auto& resolver = forAutocomplete ? moduleResolverForAutocomplete : moduleResolver; + + ModulePtr module = resolver.getModule(name); + + if (module == nullptr) + throw InternalCompilerError("Frontend does not have module: " + name, name); + + CheckResult checkResult; + + if (module->timeout) + checkResult.timeoutHits.push_back(name); + + if (accumulateNested) + checkResult.errors = accumulateErrors(sourceNodes, resolver, name); + else + checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end()); + + // Get lint result only for top checked module + checkResult.lintResult = module->lintResult; + + return checkResult; +} + +bool Frontend::parseGraph( + std::vector& buildQueue, const ModuleName& root, bool forAutocomplete, std::function canSkip) { LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); @@ -654,14 +963,18 @@ bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& // this relies on the fact that markDirty marks reverse-dependencies dirty as well // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // to be built, *and* can't form a cycle with any nodes we did process. - if (!it->second.hasDirtyModule(forAutocomplete)) + if (!it->second->hasDirtyModule(forAutocomplete)) + continue; + + // This module might already be in the outside build queue + if (canSkip && canSkip(dep)) continue; // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization // calling getSourceNode twice in succession will reparse the file, since getSourceNode leaves dirty flag set - if (seen.contains(&it->second)) + if (seen.contains(it->second.get())) { - stack.push_back(&it->second); + stack.push_back(it->second.get()); continue; } } @@ -681,6 +994,210 @@ bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& return cyclic; } +void Frontend::addBuildQueueItems(std::vector& items, std::vector& buildQueue, bool cycleDetected, + std::unordered_set& seen, const FrontendOptions& frontendOptions) +{ + LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing); + + for (const ModuleName& moduleName : buildQueue) + { + if (seen.count(moduleName)) + continue; + seen.insert(moduleName); + + LUAU_ASSERT(sourceNodes.count(moduleName)); + std::shared_ptr& sourceNode = sourceNodes[moduleName]; + + if (!sourceNode->hasDirtyModule(frontendOptions.forAutocomplete)) + continue; + + LUAU_ASSERT(sourceModules.count(moduleName)); + std::shared_ptr& sourceModule = sourceModules[moduleName]; + + BuildQueueItem data{moduleName, fileResolver->getHumanReadableModuleName(moduleName), sourceNode, sourceModule}; + + data.config = configResolver->getConfig(moduleName); + data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete); + + Mode mode = sourceModule->mode.value_or(data.config.mode); + + // in NoCheck mode we only need to compute the value of .cyclic for typeck + // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself + // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term + // all correct programs must be acyclic so this code triggers rarely + if (cycleDetected) + data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck); + + data.options = frontendOptions; + + // This is used by the type checker to replace the resulting type of cyclic modules with any + sourceModule->cyclic = !data.requireCycles.empty(); + + items.push_back(std::move(data)); + } +} + +void Frontend::checkBuildQueueItem(BuildQueueItem& item) +{ + LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing); + + SourceNode& sourceNode = *item.sourceNode; + const SourceModule& sourceModule = *item.sourceModule; + const Config& config = item.config; + Mode mode = sourceModule.mode.value_or(config.mode); + ScopePtr environmentScope = item.environmentScope; + double timestamp = getTimestamp(); + const std::vector& requireCycles = item.requireCycles; + + if (item.options.forAutocomplete) + { + double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; + + // The autocomplete typecheck is always in strict mode with DM awareness + // to provide better type information for IDE features + TypeCheckLimits typeCheckLimits; + + if (autocompleteTimeLimit != 0.0) + typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; + else + typeCheckLimits.finishTime = std::nullopt; + + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.unifierIterationLimit = std::nullopt; + + ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, + /*recordJsonLog*/ false, typeCheckLimits); + + double duration = getTimestamp() - timestamp; + + if (moduleForAutocomplete->timeout) + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + else if (duration < autocompleteTimeLimit / 2.0) + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); + + item.stats.timeCheck += duration; + item.stats.filesStrict += 1; + + item.module = moduleForAutocomplete; + return; + } + + ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, item.recordJsonLog, {}); + + item.stats.timeCheck += getTimestamp() - timestamp; + item.stats.filesStrict += mode == Mode::Strict; + item.stats.filesNonstrict += mode == Mode::Nonstrict; + + if (module == nullptr) + throw InternalCompilerError("Frontend::check produced a nullptr module for " + item.name, item.name); + + if (FFlag::DebugLuauDeferredConstraintResolution && mode == Mode::NoCheck) + module->errors.clear(); + + if (item.options.runLintChecks) + { + LUAU_TIMETRACE_SCOPE("lint", "Frontend"); + + LintOptions lintOptions = item.options.enabledLintWarnings.value_or(config.enabledLint); + filterLintOptions(lintOptions, sourceModule.hotcomments, mode); + + double timestamp = getTimestamp(); + + std::vector warnings = + Luau::lint(sourceModule.root, *sourceModule.names, environmentScope, module.get(), sourceModule.hotcomments, lintOptions); + + item.stats.timeLint += getTimestamp() - timestamp; + + module->lintResult = classifyLints(warnings, config); + } + + if (!item.options.retainFullTypeGraphs) + { + // copyErrors needs to allocate into interfaceTypes as it copies + // types out of internalTypes, so we unfreeze it here. + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes); + freeze(module->interfaceTypes); + + module->internalTypes.clear(); + + module->astTypes.clear(); + module->astTypePacks.clear(); + module->astExpectedTypes.clear(); + module->astOriginalCallTypes.clear(); + module->astOverloadResolvedTypes.clear(); + module->astResolvedTypes.clear(); + module->astOriginalResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); + module->astScopes.clear(); + + module->scopes.clear(); + } + + if (mode != Mode::NoCheck) + { + for (const RequireCycle& cyc : requireCycles) + { + TypeError te{cyc.location, item.name, ModuleHasCyclicDependency{cyc.path}}; + + module->errors.push_back(te); + } + } + + ErrorVec parseErrors; + + for (const ParseError& pe : sourceModule.parseErrors) + parseErrors.push_back(TypeError{pe.getLocation(), item.name, SyntaxError{pe.what()}}); + + module->errors.insert(module->errors.begin(), parseErrors.begin(), parseErrors.end()); + + item.module = module; +} + +void Frontend::checkBuildQueueItems(std::vector& items) +{ + LUAU_ASSERT(FFlag::LuauSplitFrontendProcessing); + + for (BuildQueueItem& item : items) + { + checkBuildQueueItem(item); + recordItemResult(item); + } +} + +void Frontend::recordItemResult(const BuildQueueItem& item) +{ + if (item.exception) + std::rethrow_exception(item.exception); + + if (item.options.forAutocomplete) + { + moduleResolverForAutocomplete.setModule(item.name, item.module); + item.sourceNode->dirtyModuleForAutocomplete = false; + } + else + { + moduleResolver.setModule(item.name, item.module); + item.sourceNode->dirtyModule = false; + } + + stats.timeCheck += item.stats.timeCheck; + stats.timeLint += item.stats.timeLint; + + stats.filesStrict += item.stats.filesStrict; + stats.filesNonstrict += item.stats.filesNonstrict; +} + ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const { ScopePtr result; @@ -711,7 +1228,7 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); - return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete); + return it == sourceNodes.end() || it->second->hasDirtyModule(forAutocomplete); } /* @@ -728,7 +1245,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) { - for (const auto& dep : module.second.requireSet) + for (const auto& dep : module.second->requireSet) reverseDeps[dep].push_back(module.first); } @@ -740,7 +1257,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked queue.pop_back(); LUAU_ASSERT(sourceNodes.count(next) > 0); - SourceNode& sourceNode = sourceNodes[next]; + SourceNode& sourceNode = *sourceNodes[next]; if (markedDirty) markedDirty->push_back(next); @@ -766,7 +1283,7 @@ SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) { auto it = sourceModules.find(moduleName); if (it != sourceModules.end()) - return &it->second; + return it->second.get(); else return nullptr; } @@ -901,22 +1418,22 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(const ModuleName& name) { - LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); - auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.hasDirtySourceModule()) + if (it != sourceNodes.end() && !it->second->hasDirtySourceModule()) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) - return {&it->second, &moduleIt->second}; + return {it->second.get(), moduleIt->second.get()}; else { LUAU_ASSERT(!"Everything in sourceNodes should also be in sourceModules"); - return {&it->second, nullptr}; + return {it->second.get(), nullptr}; } } + LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + double timestamp = getTimestamp(); std::optional source = fileResolver->readSource(name); @@ -939,30 +1456,37 @@ std::pair Frontend::getSourceNode(const ModuleName& RequireTraceResult& require = requireTrace[name]; require = traceRequires(fileResolver, result.root, name); - SourceNode& sourceNode = sourceNodes[name]; - SourceModule& sourceModule = sourceModules[name]; + std::shared_ptr& sourceNode = sourceNodes[name]; - sourceModule = std::move(result); - sourceModule.environmentName = environmentName; + if (!sourceNode) + sourceNode = std::make_shared(); - sourceNode.name = sourceModule.name; - sourceNode.humanReadableName = sourceModule.humanReadableName; - sourceNode.requireSet.clear(); - sourceNode.requireLocations.clear(); - sourceNode.dirtySourceModule = false; + std::shared_ptr& sourceModule = sourceModules[name]; + + if (!sourceModule) + sourceModule = std::make_shared(); + + *sourceModule = std::move(result); + sourceModule->environmentName = environmentName; + + sourceNode->name = sourceModule->name; + sourceNode->humanReadableName = sourceModule->humanReadableName; + sourceNode->requireSet.clear(); + sourceNode->requireLocations.clear(); + sourceNode->dirtySourceModule = false; if (it == sourceNodes.end()) { - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; + sourceNode->dirtyModule = true; + sourceNode->dirtyModuleForAutocomplete = true; } for (const auto& [moduleName, location] : require.requireList) - sourceNode.requireSet.insert(moduleName); + sourceNode->requireSet.insert(moduleName); - sourceNode.requireLocations = require.requireList; + sourceNode->requireLocations = require.requireList; - return {&sourceNode, &sourceModule}; + return {sourceNode.get(), sourceModule.get()}; } /** Try to parse a source file into a SourceModule. diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 29f8b2e6..cfc0ae13 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -17,8 +17,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeMetatableFixes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -232,15 +230,8 @@ NormalizedType::NormalizedType(NotNull builtinTypes) static bool isShallowInhabited(const NormalizedType& norm) { - bool inhabitedClasses; - - if (FFlag::LuauNegatedClassTypes) - inhabitedClasses = !norm.classes.isNever(); - else - inhabitedClasses = !norm.DEPRECATED_classes.empty(); - // This test is just a shallow check, for example it returns `true` for `{ p : never }` - return !get(norm.tops) || !get(norm.booleans) || inhabitedClasses || !get(norm.errors) || + return !get(norm.tops) || !get(norm.booleans) || !norm.classes.isNever() || !get(norm.errors) || !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } @@ -257,14 +248,8 @@ bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_setclasses.isNever(); - else - inhabitedClasses = !norm->DEPRECATED_classes.empty(); - if (!get(norm->tops) || !get(norm->booleans) || !get(norm->errors) || !get(norm->nils) || - !get(norm->numbers) || !get(norm->threads) || inhabitedClasses || !norm->strings.isNever() || + !get(norm->numbers) || !get(norm->threads) || !norm->classes.isNever() || !norm->strings.isNever() || !norm->functions.isNever()) return true; @@ -466,7 +451,7 @@ static bool areNormalizedTables(const TypeIds& tys) if (!pt) return false; - if (pt->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) + if (pt->type == PrimitiveType::Table) continue; return false; @@ -475,14 +460,6 @@ static bool areNormalizedTables(const TypeIds& tys) return true; } -static bool areNormalizedClasses(const TypeIds& tys) -{ - for (TypeId ty : tys) - if (!get(ty)) - return false; - return true; -} - static bool areNormalizedClasses(const NormalizedClassType& tys) { for (const auto& [ty, negations] : tys.classes) @@ -567,7 +544,6 @@ static void assertInvariant(const NormalizedType& norm) LUAU_ASSERT(isNormalizedTop(norm.tops)); LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); - LUAU_ASSERT(areNormalizedClasses(norm.DEPRECATED_classes)); LUAU_ASSERT(areNormalizedClasses(norm.classes)); LUAU_ASSERT(isNormalizedError(norm.errors)); LUAU_ASSERT(isNormalizedNil(norm.nils)); @@ -629,7 +605,6 @@ void Normalizer::clearNormal(NormalizedType& norm) norm.tops = builtinTypes->neverType; norm.booleans = builtinTypes->neverType; norm.classes.resetToNever(); - norm.DEPRECATED_classes.clear(); norm.errors = builtinTypes->neverType; norm.nils = builtinTypes->neverType; norm.numbers = builtinTypes->neverType; @@ -1253,18 +1228,11 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) { for (TypeId there : theres) { - if (FFlag::LuauNegatedTableTypes) + if (there == builtinTypes->tableType) { - if (there == builtinTypes->tableType) - { - heres.clear(); - heres.insert(there); - return; - } - else - { - unionTablesWithTable(heres, there); - } + heres.clear(); + heres.insert(there); + return; } else { @@ -1320,10 +1288,7 @@ bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, } here.booleans = unionOfBools(here.booleans, there.booleans); - if (FFlag::LuauNegatedClassTypes) - unionClasses(here.classes, there.classes); - else - unionClasses(here.DEPRECATED_classes, there.DEPRECATED_classes); + unionClasses(here.classes, there.classes); here.errors = (get(there.errors) ? here.errors : there.errors); here.nils = (get(there.nils) ? here.nils : there.nils); @@ -1414,16 +1379,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else if (get(there) || get(there)) unionTablesWithTable(here.tables, there); else if (get(there)) - { - if (FFlag::LuauNegatedClassTypes) - { - unionClassesWithClass(here.classes, there); - } - else - { - unionClassesWithClass(here.DEPRECATED_classes, there); - } - } + unionClassesWithClass(here.classes, there); else if (get(there)) here.errors = there; else if (const PrimitiveType* ptv = get(there)) @@ -1442,7 +1398,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor { here.functions.resetToTop(); } - else if (ptv->type == PrimitiveType::Table && FFlag::LuauNegatedTableTypes) + else if (ptv->type == PrimitiveType::Table) { here.tables.clear(); here.tables.insert(there); @@ -1527,36 +1483,29 @@ std::optional Normalizer::negateNormal(const NormalizedType& her result.booleans = builtinTypes->trueType; } - if (FFlag::LuauNegatedClassTypes) + if (here.classes.isNever()) { - if (here.classes.isNever()) - { - resetToTop(builtinTypes, result.classes); - } - else if (isTop(builtinTypes, result.classes)) - { - result.classes.resetToNever(); - } - else - { - TypeIds rootNegations{}; - - for (const auto& [hereParent, hereNegations] : here.classes.classes) - { - if (hereParent != builtinTypes->classType) - rootNegations.insert(hereParent); - - for (TypeId hereNegation : hereNegations) - unionClassesWithClass(result.classes, hereNegation); - } - - if (!rootNegations.empty()) - result.classes.pushPair(builtinTypes->classType, rootNegations); - } + resetToTop(builtinTypes, result.classes); + } + else if (isTop(builtinTypes, result.classes)) + { + result.classes.resetToNever(); } else { - result.DEPRECATED_classes = negateAll(here.DEPRECATED_classes); + TypeIds rootNegations{}; + + for (const auto& [hereParent, hereNegations] : here.classes.classes) + { + if (hereParent != builtinTypes->classType) + rootNegations.insert(hereParent); + + for (TypeId hereNegation : hereNegations) + unionClassesWithClass(result.classes, hereNegation); + } + + if (!rootNegations.empty()) + result.classes.pushPair(builtinTypes->classType, rootNegations); } result.nils = get(here.nils) ? builtinTypes->nilType : builtinTypes->neverType; @@ -1584,15 +1533,12 @@ std::optional Normalizer::negateNormal(const NormalizedType& her * types are not runtime-testable. Thus, we prohibit negation of anything * other than `table` and `never`. */ - if (FFlag::LuauNegatedTableTypes) - { - if (here.tables.empty()) - result.tables.insert(builtinTypes->tableType); - else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType) - result.tables.clear(); - else - return std::nullopt; - } + if (here.tables.empty()) + result.tables.insert(builtinTypes->tableType); + else if (here.tables.size() == 1 && here.tables.front() == builtinTypes->tableType) + result.tables.clear(); + else + return std::nullopt; // TODO: negating tables // TODO: negating tyvars? @@ -1662,7 +1608,6 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) here.functions.resetToNever(); break; case PrimitiveType::Table: - LUAU_ASSERT(FFlag::LuauNegatedTableTypes); here.tables.clear(); break; } @@ -1734,64 +1679,6 @@ TypeId Normalizer::intersectionOfBools(TypeId here, TypeId there) return there; } -void Normalizer::DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres) -{ - TypeIds tmp; - for (auto it = heres.begin(); it != heres.end();) - { - const ClassType* hctv = get(*it); - LUAU_ASSERT(hctv); - bool keep = false; - for (TypeId there : theres) - { - const ClassType* tctv = get(there); - LUAU_ASSERT(tctv); - if (isSubclass(hctv, tctv)) - { - keep = true; - break; - } - else if (isSubclass(tctv, hctv)) - { - keep = false; - tmp.insert(there); - break; - } - } - if (keep) - it++; - else - it = heres.erase(it); - } - heres.insert(tmp.begin(), tmp.end()); -} - -void Normalizer::DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there) -{ - bool foundSuper = false; - const ClassType* tctv = get(there); - LUAU_ASSERT(tctv); - for (auto it = heres.begin(); it != heres.end();) - { - const ClassType* hctv = get(*it); - LUAU_ASSERT(hctv); - if (isSubclass(hctv, tctv)) - it++; - else if (isSubclass(tctv, hctv)) - { - foundSuper = true; - break; - } - else - it = heres.erase(it); - } - if (foundSuper) - { - heres.clear(); - heres.insert(there); - } -} - void Normalizer::intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres) { if (theres.isNever()) @@ -2504,15 +2391,7 @@ bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& th here.booleans = intersectionOfBools(here.booleans, there.booleans); - if (FFlag::LuauNegatedClassTypes) - { - intersectClasses(here.classes, there.classes); - } - else - { - DEPRECATED_intersectClasses(here.DEPRECATED_classes, there.DEPRECATED_classes); - } - + intersectClasses(here.classes, there.classes); here.errors = (get(there.errors) ? there.errors : here.errors); here.nils = (get(there.nils) ? there.nils : here.nils); here.numbers = (get(there.numbers) ? there.numbers : here.numbers); @@ -2619,20 +2498,10 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) } else if (get(there)) { - if (FFlag::LuauNegatedClassTypes) - { - NormalizedClassType nct = std::move(here.classes); - clearNormal(here); - intersectClassesWithClass(nct, there); - here.classes = std::move(nct); - } - else - { - TypeIds classes = std::move(here.DEPRECATED_classes); - clearNormal(here); - DEPRECATED_intersectClassesWithClass(classes, there); - here.DEPRECATED_classes = std::move(classes); - } + NormalizedClassType nct = std::move(here.classes); + clearNormal(here); + intersectClassesWithClass(nct, there); + here.classes = std::move(nct); } else if (get(there)) { @@ -2665,10 +2534,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) else if (ptv->type == PrimitiveType::Function) here.functions = std::move(functions); else if (ptv->type == PrimitiveType::Table) - { - LUAU_ASSERT(FFlag::LuauNegatedTableTypes); here.tables = std::move(tables); - } else LUAU_ASSERT(!"Unreachable"); } @@ -2696,7 +2562,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) subtractPrimitive(here, ntv->ty); else if (const SingletonType* stv = get(t)) subtractSingleton(here, follow(ntv->ty)); - else if (get(t) && FFlag::LuauNegatedClassTypes) + else if (get(t)) { const NormalizedType* normal = normalize(t); std::optional negated = negateNormal(*normal); @@ -2730,7 +2596,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) LUAU_ASSERT(!"Unimplemented"); } } - else if (get(there) && FFlag::LuauNegatedClassTypes) + else if (get(there)) { here.classes.resetToNever(); } @@ -2756,53 +2622,46 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) if (!get(norm.booleans)) result.push_back(norm.booleans); - if (FFlag::LuauNegatedClassTypes) + if (isTop(builtinTypes, norm.classes)) { - if (isTop(builtinTypes, norm.classes)) - { - result.push_back(builtinTypes->classType); - } - else if (!norm.classes.isNever()) - { - std::vector parts; - parts.reserve(norm.classes.classes.size()); - - for (const TypeId normTy : norm.classes.ordering) - { - const TypeIds& normNegations = norm.classes.classes.at(normTy); - - if (normNegations.empty()) - { - parts.push_back(normTy); - } - else - { - std::vector intersection; - intersection.reserve(normNegations.size() + 1); - - intersection.push_back(normTy); - for (TypeId negation : normNegations) - { - intersection.push_back(arena->addType(NegationType{negation})); - } - - parts.push_back(arena->addType(IntersectionType{std::move(intersection)})); - } - } - - if (parts.size() == 1) - { - result.push_back(parts.at(0)); - } - else if (parts.size() > 1) - { - result.push_back(arena->addType(UnionType{std::move(parts)})); - } - } + result.push_back(builtinTypes->classType); } - else + else if (!norm.classes.isNever()) { - result.insert(result.end(), norm.DEPRECATED_classes.begin(), norm.DEPRECATED_classes.end()); + std::vector parts; + parts.reserve(norm.classes.classes.size()); + + for (const TypeId normTy : norm.classes.ordering) + { + const TypeIds& normNegations = norm.classes.classes.at(normTy); + + if (normNegations.empty()) + { + parts.push_back(normTy); + } + else + { + std::vector intersection; + intersection.reserve(normNegations.size() + 1); + + intersection.push_back(normTy); + for (TypeId negation : normNegations) + { + intersection.push_back(arena->addType(NegationType{negation})); + } + + parts.push_back(arena->addType(IntersectionType{std::move(intersection)})); + } + } + + if (parts.size() == 1) + { + result.push_back(parts.at(0)); + } + else if (parts.size() > 1) + { + result.push_back(arena->addType(UnionType{std::move(parts)})); + } } if (!get(norm.errors)) diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 26618313..33554ce9 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -382,8 +382,9 @@ std::optional TxnLog::getLevel(TypeId ty) const TypeId TxnLog::follow(TypeId ty) const { - return Luau::follow(ty, [this](TypeId ty) { - PendingType* state = this->pending(ty); + return Luau::follow(ty, this, [](const void* ctx, TypeId ty) -> TypeId { + const TxnLog* self = static_cast(ctx); + PendingType* state = self->pending(ty); if (state == nullptr) return ty; @@ -397,8 +398,9 @@ TypeId TxnLog::follow(TypeId ty) const TypePackId TxnLog::follow(TypePackId tp) const { - return Luau::follow(tp, [this](TypePackId tp) { - PendingTypePack* state = this->pending(tp); + return Luau::follow(tp, this, [](const void* ctx, TypePackId tp) -> TypePackId { + const TxnLog* self = static_cast(ctx); + PendingTypePack* state = self->pending(tp); if (state == nullptr) return tp; diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 2ca39b41..e8a2bc5d 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -48,19 +48,39 @@ static std::optional> magicFunctionFind( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); static bool dcrMagicFunctionFind(MagicFunctionCallContext context); +// LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable +static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv) +{ + TypeId unwrapped = ltv->unwrapped.load(); + + if (unwrapped) + return unwrapped; + + ltv->unwrap(*ltv); + unwrapped = ltv->unwrapped.load(); + + if (!unwrapped) + throw InternalCompilerError("Lazy Type didn't fill in unwrapped type field"); + + if (get(unwrapped)) + throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); + + return unwrapped; +} + TypeId follow(TypeId t) { - return follow(t, [](TypeId t) { + return follow(t, nullptr, [](const void*, TypeId t) -> TypeId { return t; }); } -TypeId follow(TypeId t, std::function mapper) +TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)) { - auto advance = [&mapper](TypeId ty) -> std::optional { + auto advance = [context, mapper](TypeId ty) -> std::optional { if (FFlag::LuauBoundLazyTypes2) { - TypeId mapped = mapper(ty); + TypeId mapped = mapper(context, ty); if (auto btv = get>(mapped)) return btv->boundTo; @@ -69,39 +89,25 @@ TypeId follow(TypeId t, std::function mapper) return ttv->boundTo; if (auto ltv = getMutable(mapped)) - { - TypeId unwrapped = ltv->unwrapped.load(); - - if (unwrapped) - return unwrapped; - - ltv->unwrap(*ltv); - unwrapped = ltv->unwrapped.load(); - - if (!unwrapped) - throw InternalCompilerError("Lazy Type didn't fill in unwrapped type field"); - - if (get(unwrapped)) - throw InternalCompilerError("Lazy Type cannot resolve to another Lazy Type"); - - return unwrapped; - } + return unwrapLazy(ltv); return std::nullopt; } else { - if (auto btv = get>(mapper(ty))) + if (auto btv = get>(mapper(context, ty))) return btv->boundTo; - else if (auto ttv = get(mapper(ty))) + else if (auto ttv = get(mapper(context, ty))) return ttv->boundTo; else return std::nullopt; } }; - auto force = [&mapper](TypeId ty) { - if (auto ltv = get_if(&mapper(ty)->ty)) + auto force = [context, mapper](TypeId ty) { + TypeId mapped = mapper(context, ty); + + if (auto ltv = get_if(&mapped->ty)) { TypeId res = ltv->thunk_DEPRECATED(); if (get(res)) @@ -120,6 +126,12 @@ TypeId follow(TypeId t, std::function mapper) else return t; + if (FFlag::LuauBoundLazyTypes2) + { + if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null + return cycleTester; + } + while (true) { if (!FFlag::LuauBoundLazyTypes2) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a103df14..2a2fe69c 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -22,8 +22,6 @@ LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauDontReduceTypes) -LUAU_FASTFLAG(LuauNegatedClassTypes) - namespace Luau { @@ -519,18 +517,39 @@ struct TypeChecker2 auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), iterFtv->argTypes, /*includeHiddenVariadics*/ true); if (minCount > 2) - reportError(CountMismatch{2, std::nullopt, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values)); + } if (maxCount && *maxCount < 2) - reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values)); + } TypePack flattenedArgTypes = extendTypePack(arena, builtinTypes, iterFtv->argTypes, 2); size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; size_t actualArgCount = expectedVariableTypes.head.size(); - if (firstIterationArgCount < minCount) - reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + } + else if (actualArgCount < minCount) - reportError(CountMismatch{2, std::nullopt, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + } + if (iterTys.size() >= 2 && flattenedArgTypes.head.size() > 0) { @@ -841,125 +860,31 @@ struct TypeChecker2 // TODO! } - ErrorVec visitOverload(AstExprCall* call, NotNull overloadFunctionType, const std::vector& argLocs, - TypePackId expectedArgTypes, TypePackId expectedRetType) - { - ErrorVec overloadErrors = - tryUnify(stack.back(), call->location, overloadFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult); - - size_t argIndex = 0; - auto inferredArgIt = begin(overloadFunctionType->argTypes); - auto expectedArgIt = begin(expectedArgTypes); - while (inferredArgIt != end(overloadFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) - { - Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; - ErrorVec argErrors = tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt); - for (TypeError e : argErrors) - overloadErrors.emplace_back(e); - - ++argIndex; - ++inferredArgIt; - ++expectedArgIt; - } - - // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad - ErrorVec argumentErrors = tryUnify(stack.back(), call->location, expectedArgTypes, overloadFunctionType->argTypes); - for (TypeError e : argumentErrors) - if (get(e) != nullptr) - overloadErrors.emplace_back(std::move(e)); - - return overloadErrors; - } - - void reportOverloadResolutionErrors(AstExprCall* call, std::vector overloads, TypePackId expectedArgTypes, - const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) - { - if (overloads.size() == 1) - { - reportErrors(std::get<0>(overloadsErrors.front())); - return; - } - - std::vector overloadTypes = overloadsThatMatchArgCount; - if (overloadsThatMatchArgCount.size() == 0) - { - reportError(GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location); - // If no overloads match argument count, just list all overloads. - overloadTypes = overloads; - } - else - { - // Report errors of the first argument-count-matching, but failing overload - TypeId overload = overloadsThatMatchArgCount[0]; - - // Remove the overload we are reporting errors about from the list of alternatives - overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end()); - - const FunctionType* ftv = get(overload); - LUAU_ASSERT(ftv); // overload must be a function type here - - auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [overload](const std::pair& e) { - return overload == e.second; - }); - - LUAU_ASSERT(error != overloadsErrors.end()); - reportErrors(std::get<0>(*error)); - - // If only one overload matched, we don't need this error because we provided the previous errors. - if (overloadsThatMatchArgCount.size() == 1) - return; - } - - std::string s; - for (size_t i = 0; i < overloadTypes.size(); ++i) - { - TypeId overload = follow(overloadTypes[i]); - - if (i > 0) - s += "; "; - - if (i > 0 && i == overloadTypes.size() - 1) - s += "and "; - - s += toString(overload); - } - - if (overloadsThatMatchArgCount.size() == 0) - reportError(ExtraInformation{"Available overloads: " + s}, call->func->location); - else - reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location); - } - // Note: this is intentionally separated from `visit(AstExprCall*)` for stack allocation purposes. void visitCall(AstExprCall* call) { - TypeArena* arena = &testArena; - Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; - - TypePackId expectedRetType = lookupExpectedPack(call, *arena); - TypeId functionType = lookupType(call->func); - TypeId testFunctionType = functionType; + TypePackId expectedRetType = lookupExpectedPack(call, testArena); TypePack args; std::vector argLocs; argLocs.reserve(call->args.size + 1); - if (get(functionType) || get(functionType) || get(functionType)) + TypeId* maybeOriginalCallTy = module->astOriginalCallTypes.find(call); + TypeId* maybeSelectedOverload = module->astOverloadResolvedTypes.find(call); + + if (!maybeOriginalCallTy) return; - else if (std::optional callMm = findMetatableEntry(builtinTypes, module->errors, functionType, "__call", call->func->location)) + + TypeId originalCallTy = follow(*maybeOriginalCallTy); + std::vector overloads = flattenIntersection(originalCallTy); + + if (get(originalCallTy) || get(originalCallTy) || get(originalCallTy)) + return; + else if (std::optional callMm = findMetatableEntry(builtinTypes, module->errors, originalCallTy, "__call", call->func->location)) { if (get(follow(*callMm))) { - if (std::optional instantiatedCallMm = instantiation.substitute(*callMm)) - { - args.head.push_back(functionType); - argLocs.push_back(call->func->location); - testFunctionType = follow(*instantiatedCallMm); - } - else - { - reportError(UnificationTooComplex{}, call->func->location); - return; - } + args.head.push_back(originalCallTy); + argLocs.push_back(call->func->location); } else { @@ -969,29 +894,16 @@ struct TypeChecker2 return; } } - else if (get(functionType)) + else if (get(originalCallTy) || get(originalCallTy)) { - if (std::optional instantiatedFunctionType = instantiation.substitute(functionType)) - { - testFunctionType = *instantiatedFunctionType; - } - else - { - reportError(UnificationTooComplex{}, call->func->location); - return; - } } - else if (auto itv = get(functionType)) - { - // We do nothing here because we'll flatten the intersection later, but we don't want to report it as a non-function. - } - else if (auto utv = get(functionType)) + else if (auto utv = get(originalCallTy)) { // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. // Another scenario we might run into it is if the union has a nil member. In this case, we want to throw an error - if (isOptional(functionType)) + if (isOptional(originalCallTy)) { - reportError(OptionalValueAccess{functionType}, call->location); + reportError(OptionalValueAccess{originalCallTy}, call->location); return; } std::optional fst; @@ -1001,7 +913,7 @@ struct TypeChecker2 fst = follow(ty); else if (fst != follow(ty)) { - reportError(CannotCallNonFunction{functionType}, call->func->location); + reportError(CannotCallNonFunction{originalCallTy}, call->func->location); return; } } @@ -1009,19 +921,16 @@ struct TypeChecker2 if (!fst) ice->ice("UnionType had no elements, so fst is nullopt?"); - if (std::optional instantiatedFunctionType = instantiation.substitute(*fst)) + originalCallTy = follow(*fst); + if (!get(originalCallTy)) { - testFunctionType = *instantiatedFunctionType; - } - else - { - reportError(UnificationTooComplex{}, call->func->location); + reportError(CannotCallNonFunction{originalCallTy}, call->func->location); return; } } else { - reportError(CannotCallNonFunction{functionType}, call->func->location); + reportError(CannotCallNonFunction{originalCallTy}, call->func->location); return; } @@ -1054,63 +963,134 @@ struct TypeChecker2 args.head.push_back(builtinTypes->anyType); } - TypePackId expectedArgTypes = arena->addTypePack(args); + TypePackId expectedArgTypes = testArena.addTypePack(args); - std::vector overloads = flattenIntersection(testFunctionType); - std::vector> overloadsErrors; - overloadsErrors.reserve(overloads.size()); - - std::vector overloadsThatMatchArgCount; - - for (TypeId overload : overloads) + if (maybeSelectedOverload) { - overload = follow(overload); + // This overload might not work still: the constraint solver will + // pass the type checker an instantiated function type that matches + // in arity, but not in subtyping, in order to allow the type + // checker to report better error messages. - const FunctionType* overloadFn = get(overload); - if (!overloadFn) + TypeId selectedOverload = follow(*maybeSelectedOverload); + const FunctionType* ftv; + + if (get(selectedOverload) || get(selectedOverload) || get(selectedOverload)) { - reportError(CannotCallNonFunction{overload}, call->func->location); return; } + else if (const FunctionType* overloadFtv = get(selectedOverload)) + { + ftv = overloadFtv; + } else { - // We may have to instantiate the overload in order for it to typecheck. - if (std::optional instantiatedFunctionType = instantiation.substitute(overload)) - { - overloadFn = get(*instantiatedFunctionType); - } - else - { - overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overload); - return; - } - } - - ErrorVec overloadErrors = visitOverload(call, NotNull{overloadFn}, argLocs, expectedArgTypes, expectedRetType); - if (overloadErrors.empty()) + reportError(CannotCallNonFunction{selectedOverload}, call->func->location); return; + } - bool argMismatch = false; - for (auto error : overloadErrors) + LUAU_ASSERT(ftv); + reportErrors(tryUnify(stack.back(), call->location, ftv->retTypes, expectedRetType, CountMismatch::Context::Return)); + + auto it = begin(expectedArgTypes); + size_t i = 0; + std::vector slice; + for (TypeId arg : ftv->argTypes) { - CountMismatch* cm = get(error); - if (!cm) - continue; - - if (cm->context == CountMismatch::Arg) + if (it == end(expectedArgTypes)) { - argMismatch = true; - break; + slice.push_back(arg); + continue; + } + + TypeId expectedArg = *it; + + Location argLoc = argLocs.at(i >= argLocs.size() ? argLocs.size() - 1 : i); + + reportErrors(tryUnify(stack.back(), argLoc, expectedArg, arg)); + + ++it; + ++i; + } + + if (slice.size() > 0 && it == end(expectedArgTypes)) + { + if (auto tail = it.tail()) + { + TypePackId remainingArgs = testArena.addTypePack(TypePack{std::move(slice), std::nullopt}); + reportErrors(tryUnify(stack.back(), argLocs.back(), *tail, remainingArgs)); } } - if (!argMismatch) - overloadsThatMatchArgCount.push_back(overload); - - overloadsErrors.emplace_back(std::move(overloadErrors), overload); + // We do not need to do an arity test because this overload was + // selected based on its arity already matching. } + else + { + // No overload worked, even when instantiated. We need to filter the + // set of overloads to those that match the arity of the incoming + // argument set, and then report only those as not matching. - reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); + std::vector arityMatchingOverloads; + ErrorVec empty; + for (TypeId overload : overloads) + { + overload = follow(overload); + if (const FunctionType* ftv = get(overload)) + { + if (size(ftv->argTypes) == size(expectedArgTypes)) + { + arityMatchingOverloads.push_back(overload); + } + } + else if (const std::optional callMm = findMetatableEntry(builtinTypes, empty, overload, "__call", call->location)) + { + if (const FunctionType* ftv = get(follow(*callMm))) + { + if (size(ftv->argTypes) == size(expectedArgTypes)) + { + arityMatchingOverloads.push_back(overload); + } + } + else + { + reportError(CannotCallNonFunction{}, call->location); + } + } + } + + if (arityMatchingOverloads.size() == 0) + { + reportError( + GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location); + } + else + { + // We have handled the case of a singular arity-matching + // overload above, in the case where an overload was selected. + // LUAU_ASSERT(arityMatchingOverloads.size() > 1); + reportError(GenericError{"None of the overloads for function that accept " + std::to_string(size(expectedArgTypes)) + + " arguments are compatible."}, + call->location); + } + + std::string s; + std::vector& stringifyOverloads = arityMatchingOverloads.size() == 0 ? overloads : arityMatchingOverloads; + for (size_t i = 0; i < stringifyOverloads.size(); ++i) + { + TypeId overload = follow(stringifyOverloads[i]); + + if (i > 0) + s += "; "; + + if (i > 0 && i == stringifyOverloads.size() - 1) + s += "and "; + + s += toString(overload); + } + + reportError(ExtraInformation{"Available overloads: " + s}, call->func->location); + } } void visit(AstExprCall* call) @@ -2077,17 +2057,9 @@ struct TypeChecker2 fetch(norm->tops); fetch(norm->booleans); - if (FFlag::LuauNegatedClassTypes) + for (const auto& [ty, _negations] : norm->classes.classes) { - for (const auto& [ty, _negations] : norm->classes.classes) - { - fetch(ty); - } - } - else - { - for (TypeId ty : norm->DEPRECATED_classes) - fetch(ty); + fetch(ty); } fetch(norm->errors); fetch(norm->nils); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8f9e1851..1ccba91e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,7 +35,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) @@ -1701,7 +1700,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { - std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; + std::optional superTy = std::make_optional(builtinTypes->classType); if (declaredClass.superName) { Name superName = Name(declaredClass.superName->value); @@ -5968,17 +5967,13 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r TypeId type = follow(typeFun->type); // You cannot refine to the top class type. - if (FFlag::LuauNegatedClassTypes) + if (type == builtinTypes->classType) { - if (type == builtinTypes->classType) - { - return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - } + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); } // We're only interested in the root class of any classes. - if (auto ctv = get(type); - !ctv || (FFlag::LuauNegatedClassTypes ? (ctv->parent != builtinTypes->classType) : (ctv->parent != std::nullopt))) + if (auto ctv = get(type); !ctv || ctv->parent != builtinTypes->classType) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 6873820a..0db0e5a1 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -255,15 +255,17 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId follow(TypePackId tp) { - return follow(tp, [](TypePackId t) { + return follow(tp, nullptr, [](const void*, TypePackId t) { return t; }); } -TypePackId follow(TypePackId tp, std::function mapper) +TypePackId follow(TypePackId tp, const void* context, TypePackId (*mapper)(const void*, TypePackId)) { - auto advance = [&mapper](TypePackId ty) -> std::optional { - if (const Unifiable::Bound* btv = get>(mapper(ty))) + auto advance = [context, mapper](TypePackId ty) -> std::optional { + TypePackId mapped = mapper(context, ty); + + if (const Unifiable::Bound* btv = get>(mapped)) return btv->boundTo; else return std::nullopt; @@ -275,6 +277,9 @@ TypePackId follow(TypePackId tp, std::function mapper) else return tp; + if (!advance(cycleTester)) // Short circuit traversal for the rather common case when advance(advance(t)) == null + return cycleTester; + while (true) { auto a1 = advance(tp); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 3ca93591..6047a49b 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -26,8 +26,6 @@ LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) -LUAU_FASTFLAG(LuauNegatedClassTypes) -LUAU_FASTFLAG(LuauNegatedTableTypes) namespace Luau { @@ -344,6 +342,19 @@ std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } +std::optional hasCountMismatch(const ErrorVec& errors) +{ + auto isCountMismatch = [](const TypeError& te) { + return nullptr != get(te); + }; + + auto it = std::find_if(errors.begin(), errors.end(), isCountMismatch); + if (it == errors.end()) + return std::nullopt; + else + return *it; +} + // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { @@ -620,7 +631,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Ok. Do nothing. forall functions F, F <: function } - else if (FFlag::LuauNegatedTableTypes && isPrim(superTy, PrimitiveType::Table) && (get(subTy) || get(subTy))) + else if (isPrim(superTy, PrimitiveType::Table) && (get(subTy) || get(subTy))) { // Ok, do nothing: forall tables T, T <: table } @@ -1183,81 +1194,59 @@ void Unifier::tryUnifyNormalizedTypes( if (!get(superNorm.errors)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - if (FFlag::LuauNegatedClassTypes) + for (const auto& [subClass, _] : subNorm.classes.classes) { - for (const auto& [subClass, _] : subNorm.classes.classes) + bool found = false; + const ClassType* subCtv = get(subClass); + LUAU_ASSERT(subCtv); + + for (const auto& [superClass, superNegations] : superNorm.classes.classes) { - bool found = false; - const ClassType* subCtv = get(subClass); - LUAU_ASSERT(subCtv); + const ClassType* superCtv = get(superClass); + LUAU_ASSERT(superCtv); - for (const auto& [superClass, superNegations] : superNorm.classes.classes) + if (isSubclass(subCtv, superCtv)) { - const ClassType* superCtv = get(superClass); - LUAU_ASSERT(superCtv); + found = true; - if (isSubclass(subCtv, superCtv)) + for (TypeId negation : superNegations) { - found = true; + const ClassType* negationCtv = get(negation); + LUAU_ASSERT(negationCtv); - for (TypeId negation : superNegations) + if (isSubclass(subCtv, negationCtv)) { - const ClassType* negationCtv = get(negation); - LUAU_ASSERT(negationCtv); - - if (isSubclass(subCtv, negationCtv)) - { - found = false; - break; - } - } - - if (found) - break; - } - } - - if (FFlag::DebugLuauDeferredConstraintResolution) - { - for (TypeId superTable : superNorm.tables) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify(subClass, superTable); - - if (innerState.errors.empty()) - { - found = true; - log.concat(std::move(innerState.log)); + found = false; break; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - return reportError(*e); } - } - if (!found) - { - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + if (found) + break; } } - } - else - { - for (TypeId subClass : subNorm.DEPRECATED_classes) + + if (FFlag::DebugLuauDeferredConstraintResolution) { - bool found = false; - const ClassType* subCtv = get(subClass); - for (TypeId superClass : superNorm.DEPRECATED_classes) + for (TypeId superTable : superNorm.tables) { - const ClassType* superCtv = get(superClass); - if (isSubclass(subCtv, superCtv)) + Unifier innerState = makeChildUnifier(); + innerState.tryUnify(subClass, superTable); + + if (innerState.errors.empty()) { found = true; + log.concat(std::move(innerState.log)); break; } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); } - if (!found) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } + + if (!found) + { + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); } } @@ -1266,7 +1255,7 @@ void Unifier::tryUnifyNormalizedTypes( bool found = false; for (TypeId superTable : superNorm.tables) { - if (FFlag::LuauNegatedTableTypes && isPrim(superTable, PrimitiveType::Table)) + if (isPrim(superTable, PrimitiveType::Table)) { found = true; break; diff --git a/Ast/include/Luau/ParseOptions.h b/Ast/include/Luau/ParseOptions.h index 89e79528..01f2a74f 100644 --- a/Ast/include/Luau/ParseOptions.h +++ b/Ast/include/Luau/ParseOptions.h @@ -14,8 +14,6 @@ enum class Mode struct ParseOptions { - bool allowTypeAnnotations = true; - bool supportContinueStatement = true; bool allowDeclarationSyntax = false; bool captureComments = false; }; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 6a76eda2..7cae609d 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -14,8 +14,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false) - #define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" namespace Luau @@ -327,22 +325,19 @@ AstStat* Parser::parseStat() // we know this isn't a call or an assignment; therefore it must be a context-sensitive keyword such as `type` or `continue` AstName ident = getIdentifier(expr); - if (options.allowTypeAnnotations) - { - if (ident == "type") - return parseTypeAlias(expr->location, /* exported= */ false); + if (ident == "type") + return parseTypeAlias(expr->location, /* exported= */ false); - if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") - { - nextLexeme(); - return parseTypeAlias(expr->location, /* exported= */ true); - } + if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") + { + nextLexeme(); + return parseTypeAlias(expr->location, /* exported= */ true); } - if (options.supportContinueStatement && ident == "continue") + if (ident == "continue") return parseContinue(expr->location); - if (options.allowTypeAnnotations && options.allowDeclarationSyntax) + if (options.allowDeclarationSyntax) { if (ident == "declare") return parseDeclaration(expr->location); @@ -1123,7 +1118,7 @@ std::tuple Parser::parseBindingList(TempVector& result, TempVector Parser::parseOptionalReturnType() { - if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) + if (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow) { if (lexer.current().type == Lexeme::SkinnyArrow) report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); @@ -2056,7 +2051,7 @@ AstExpr* Parser::parseAssertionExpr() Location start = lexer.current().location; AstExpr* expr = parseSimpleExpr(); - if (options.allowTypeAnnotations && lexer.current().type == Lexeme::DoubleColon) + if (lexer.current().type == Lexeme::DoubleColon) { nextLexeme(); AstType* annotation = parseType(); @@ -2449,24 +2444,13 @@ std::pair, AstArray> Parser::parseG seenDefault = true; nextLexeme(); - Lexeme packBegin = lexer.current(); - if (shouldParseTypePack(lexer)) { AstTypePack* typePack = parseTypePack(); namePacks.push_back({name, nameLocation, typePack}); } - else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') - { - auto [type, typePack] = parseTypeOrPack(); - - if (type) - report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); - - namePacks.push_back({name, nameLocation, typePack}); - } - else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) + else { auto [type, typePack] = parseTypeOrPack(); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 6d1f5451..50fef7fc 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -9,6 +9,13 @@ #include "FileUtils.h" #include "Flags.h" +#include +#include +#include +#include +#include +#include + #ifdef CALLGRIND #include #endif @@ -64,26 +71,29 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin report(format, name, warning.location, Luau::LintWarning::getName(warning.code), warning.text.c_str()); } -static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate) +static bool reportModuleResult(Luau::Frontend& frontend, const Luau::ModuleName& name, ReportFormat format, bool annotate) { - Luau::CheckResult cr; + std::optional cr = frontend.getCheckResult(name, false); - if (frontend.isDirty(name)) - cr = frontend.check(name); - - if (!frontend.getSourceModule(name)) + if (!cr) { - fprintf(stderr, "Error opening %s\n", name); + fprintf(stderr, "Failed to find result for %s\n", name.c_str()); return false; } - for (auto& error : cr.errors) + if (!frontend.getSourceModule(name)) + { + fprintf(stderr, "Error opening %s\n", name.c_str()); + return false; + } + + for (auto& error : cr->errors) reportError(frontend, format, error); std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); - for (auto& error : cr.lintResult.errors) + for (auto& error : cr->lintResult.errors) reportWarning(format, humanReadableName.c_str(), error); - for (auto& warning : cr.lintResult.warnings) + for (auto& warning : cr->lintResult.warnings) reportWarning(format, humanReadableName.c_str(), warning); if (annotate) @@ -98,7 +108,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat printf("%s", annotated.c_str()); } - return cr.errors.empty() && cr.lintResult.errors.empty(); + return cr->errors.empty() && cr->lintResult.errors.empty(); } static void displayHelp(const char* argv0) @@ -216,6 +226,70 @@ struct CliConfigResolver : Luau::ConfigResolver } }; +struct TaskScheduler +{ + TaskScheduler(unsigned threadCount) + : threadCount(threadCount) + { + for (unsigned i = 0; i < threadCount; i++) + { + workers.emplace_back([this] { + workerFunction(); + }); + } + } + + ~TaskScheduler() + { + for (unsigned i = 0; i < threadCount; i++) + push({}); + + for (std::thread& worker : workers) + worker.join(); + } + + std::function pop() + { + std::unique_lock guard(mtx); + + cv.wait(guard, [this] { + return !tasks.empty(); + }); + + std::function task = tasks.front(); + tasks.pop(); + return task; + } + + void push(std::function task) + { + { + std::unique_lock guard(mtx); + tasks.push(std::move(task)); + } + + cv.notify_one(); + } + + static unsigned getThreadCount() + { + return std::max(std::thread::hardware_concurrency(), 1u); + } + +private: + void workerFunction() + { + while (std::function task = pop()) + task(); + } + + unsigned threadCount = 1; + std::mutex mtx; + std::condition_variable cv; + std::vector workers; + std::queue> tasks; +}; + int main(int argc, char** argv) { Luau::assertHandler() = assertionHandler; @@ -231,6 +305,7 @@ int main(int argc, char** argv) ReportFormat format = ReportFormat::Default; Luau::Mode mode = Luau::Mode::Nonstrict; bool annotate = false; + int threadCount = 0; for (int i = 1; i < argc; ++i) { @@ -249,6 +324,8 @@ int main(int argc, char** argv) FFlag::DebugLuauTimeTracing.value = true; else if (strncmp(argv[i], "--fflags=", 9) == 0) setLuauFlags(argv[i] + 9); + else if (strncmp(argv[i], "-j", 2) == 0) + threadCount = strtol(argv[i] + 2, nullptr, 10); } #if !defined(LUAU_ENABLE_TIME_TRACE) @@ -276,10 +353,28 @@ int main(int argc, char** argv) std::vector files = getSourceFiles(argc, argv); + for (const std::string& path : files) + frontend.queueModuleCheck(path); + + std::vector checkedModules; + + // If thread count is not set, try to use HW thread count, but with an upper limit + // When we improve scalability of typechecking, upper limit can be adjusted/removed + if (threadCount <= 0) + threadCount = std::min(TaskScheduler::getThreadCount(), 8u); + + { + TaskScheduler scheduler(threadCount); + + checkedModules = frontend.checkQueuedModules(std::nullopt, [&](std::function f) { + scheduler.push(std::move(f)); + }); + } + int failed = 0; - for (const std::string& path : files) - failed += !analyzeFile(frontend, path.c_str(), format, annotate); + for (const Luau::ModuleName& name : checkedModules) + failed += !reportModuleResult(frontend, name, format, annotate); if (!configResolver.configErrors.empty()) { diff --git a/CLI/Ast.cpp b/CLI/Ast.cpp index 99c58393..b5a922aa 100644 --- a/CLI/Ast.cpp +++ b/CLI/Ast.cpp @@ -64,8 +64,6 @@ int main(int argc, char** argv) Luau::ParseOptions options; options.captureComments = true; - options.supportContinueStatement = true; - options.allowTypeAnnotations = true; options.allowDeclarationSyntax = true; Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 75b4940a..5418009a 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -35,6 +35,8 @@ struct RegisterSet uint8_t varargStart = 0; }; +void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart); + struct CfgInfo { std::vector predecessors; @@ -43,10 +45,22 @@ struct CfgInfo std::vector successors; std::vector successorsOffsets; + // VM registers that are live when the block is entered + // Additionally, an active variadic sequence can exist at the entry of the block std::vector in; + + // VM registers that are defined inside the block + // It can also contain a variadic sequence definition if that hasn't been consumed inside the block + // Note that this means that checking 'def' set might not be enough to say that register has not been written to std::vector def; + + // VM registers that are coming out from the block + // These might be registers that are defined inside the block or have been defined at the entry of the block + // Additionally, an active variadic sequence can exist at the exit of the block std::vector out; + // VM registers captured by nested closures + // This set can never have an active variadic sequence RegisterSet captured; }; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index addd18f6..4bc9c823 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -575,7 +575,7 @@ enum class IrCmd : uint8_t // Calls native libm function with 1 or 2 arguments // A: builtin function ID // B: double - // C: double (optional, 2nd argument) + // C: double/int (optional, 2nd argument) INVOKE_LIBM, }; diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index 1bc31d9d..179edd0d 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -30,7 +30,7 @@ void toString(IrToStringContext& ctx, IrOp op); void toString(std::string& result, IrConst constant); -void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo); +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, bool includeUseInfo); void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo); // Block title std::string toString(const IrFunction& function, bool includeUseInfo); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 3cf18cd4..a1211d46 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -114,6 +114,28 @@ inline bool isBlockTerminator(IrCmd cmd) return false; } +inline bool isNonTerminatingJump(IrCmd cmd) +{ + switch (cmd) + { + case IrCmd::TRY_NUM_TO_INDEX: + case IrCmd::TRY_CALL_FASTGETTM: + case IrCmd::CHECK_FASTCALL_RES: + case IrCmd::CHECK_TAG: + case IrCmd::CHECK_READONLY: + case IrCmd::CHECK_NO_METATABLE: + case IrCmd::CHECK_SAFE_ENV: + case IrCmd::CHECK_ARRAY_SIZE: + case IrCmd::CHECK_SLOT_MATCH: + case IrCmd::CHECK_NODE_NO_NEXT: + return true; + default: + break; + } + + return false; +} + inline bool hasResult(IrCmd cmd) { switch (cmd) diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h index 8fe55ba6..8a44629f 100644 --- a/CodeGen/include/Luau/UnwindBuilder.h +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -1,8 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/RegisterA64.h" #include "Luau/RegisterX64.h" +#include + #include #include @@ -17,22 +20,36 @@ static uint32_t kFullBlockFuncton = ~0u; class UnwindBuilder { public: + enum Arch + { + X64, + A64 + }; + virtual ~UnwindBuilder() = default; virtual void setBeginOffset(size_t beginOffset) = 0; virtual size_t getBeginOffset() const = 0; - virtual void startInfo() = 0; - + virtual void startInfo(Arch arch) = 0; virtual void startFunction() = 0; - virtual void spill(int espOffset, X64::RegisterX64 reg) = 0; - virtual void save(X64::RegisterX64 reg) = 0; - virtual void allocStack(int size) = 0; - virtual void setupFrameReg(X64::RegisterX64 reg, int espOffset) = 0; virtual void finishFunction(uint32_t beginOffset, uint32_t endOffset) = 0; - virtual void finishInfo() = 0; + // A64-specific; prologue must look like this: + // sub sp, sp, stackSize + // store sequence that saves regs to [sp..sp+regs.size*8) in the order specified in regs; regs should start with x29, x30 (fp, lr) + // mov x29, sp + virtual void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) = 0; + + // X64-specific; prologue must look like this: + // optional, indicated by setupFrame: + // push rbp + // mov rbp, rsp + // push reg in the order specified in regs + // sub rsp, stackSize + virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) = 0; + virtual size_t getSize() const = 0; virtual size_t getFunctionCount() const = 0; diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index 9f862d23..66749bfc 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -24,17 +24,14 @@ public: void setBeginOffset(size_t beginOffset) override; size_t getBeginOffset() const override; - void startInfo() override; - + void startInfo(Arch arch) override; void startFunction() override; - void spill(int espOffset, X64::RegisterX64 reg) override; - void save(X64::RegisterX64 reg) override; - void allocStack(int size) override; - void setupFrameReg(X64::RegisterX64 reg, int espOffset) override; void finishFunction(uint32_t beginOffset, uint32_t endOffset) override; - void finishInfo() override; + void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) override; + void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) override; + size_t getSize() const override; size_t getFunctionCount() const override; @@ -49,8 +46,6 @@ private: uint8_t rawData[kRawDataLimit]; uint8_t* pos = rawData; - uint32_t stackOffset = 0; - // We will remember the FDE location to write some of the fields like entry length, function start and size later uint8_t* fdeEntryStart = nullptr; }; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index ccd7125d..5afed693 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -44,17 +44,14 @@ public: void setBeginOffset(size_t beginOffset) override; size_t getBeginOffset() const override; - void startInfo() override; - + void startInfo(Arch arch) override; void startFunction() override; - void spill(int espOffset, X64::RegisterX64 reg) override; - void save(X64::RegisterX64 reg) override; - void allocStack(int size) override; - void setupFrameReg(X64::RegisterX64 reg, int espOffset) override; void finishFunction(uint32_t beginOffset, uint32_t endOffset) override; - void finishInfo() override; + void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) override; + void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) override; + size_t getSize() const override; size_t getFunctionCount() const override; @@ -75,7 +72,6 @@ private: uint8_t prologSize = 0; X64::RegisterX64 frameReg = X64::noreg; uint8_t frameRegOffset = 0; - uint32_t stackOffset = 0; }; } // namespace CodeGen diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index ccd15fac..9e338071 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -22,12 +22,25 @@ extern "C" void __register_frame(const void*); extern "C" void __deregister_frame(const void*); +extern "C" void __unw_add_dynamic_fde() __attribute__((weak)); + #endif -#if defined(__APPLE__) -// On Mac, each FDE inside eh_frame section has to be handled separately +namespace Luau +{ +namespace CodeGen +{ + +#if !defined(_WIN32) static void visitFdeEntries(char* pos, void (*cb)(const void*)) { + // When using glibc++ unwinder, we need to call __register_frame/__deregister_frame on the entire .eh_frame data + // When using libc++ unwinder (libunwind), each FDE has to be handled separately + // libc++ unwinder is the macOS unwinder, but on Linux the unwinder depends on the library the executable is linked with + // __unw_add_dynamic_fde is specific to libc++ unwinder, as such we determine the library based on its existence + if (__unw_add_dynamic_fde == nullptr) + return cb(pos); + for (;;) { unsigned partLength; @@ -47,11 +60,6 @@ static void visitFdeEntries(char* pos, void (*cb)(const void*)) } #endif -namespace Luau -{ -namespace CodeGen -{ - void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) { UnwindBuilder* unwind = (UnwindBuilder*)context; @@ -70,10 +78,8 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz LUAU_ASSERT(!"failed to allocate function table"); return nullptr; } -#elif defined(__APPLE__) - visitFdeEntries(unwindData, __register_frame); #elif !defined(_WIN32) - __register_frame(unwindData); + visitFdeEntries(unwindData, __register_frame); #endif beginOffset = unwindSize + unwind->getBeginOffset(); @@ -85,10 +91,8 @@ void destroyBlockUnwindInfo(void* context, void* unwindData) #if defined(_WIN32) && defined(_M_X64) if (!RtlDeleteFunctionTable((RUNTIME_FUNCTION*)unwindData)) LUAU_ASSERT(!"failed to deallocate function table"); -#elif defined(__APPLE__) - visitFdeEntries((char*)unwindData, __deregister_frame); #elif !defined(_WIN32) - __deregister_frame(unwindData); + visitFdeEntries((char*)unwindData, __deregister_frame); #endif } diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index f0be5b3d..ab092faa 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -134,7 +134,6 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& for (size_t i = 0; i < sortedBlocks.size(); ++i) { uint32_t blockIndex = sortedBlocks[i]; - IrBlock& block = function.blocks[blockIndex]; if (block.kind == IrBlockKind::Dead) @@ -191,10 +190,13 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& continue; } + // Either instruction result value is not referenced or the use count is not zero + LUAU_ASSERT(inst.lastUse == 0 || inst.useCount != 0); + if (options.includeIr) { build.logAppend("# "); - toStringDetailed(ctx, inst, index, /* includeUseInfo */ true); + toStringDetailed(ctx, block, blockIndex, inst, index, /* includeUseInfo */ true); } IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; @@ -409,9 +411,11 @@ bool isSupported() if (sizeof(LuaNode) != 32) return false; - // TODO: A64 codegen does not generate correct unwind info at the moment so it requires longjmp instead of C++ exceptions +#ifdef _WIN32 + // Unwind info is not supported for Windows-on-ARM yet if (!LUA_USE_LONGJMP) return false; +#endif return true; #else diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 415cfc92..fbe44e23 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -123,9 +123,6 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde // Arguments: x0 = lua_State*, x1 = Proto*, x2 = native code pointer to jump to, x3 = NativeContext* locations.start = build.setLabel(); - unwind.startFunction(); - - unwind.allocStack(8); // TODO: this is just a hack to make UnwindBuilder assertions cooperate // prologue build.sub(sp, sp, kStackSize); @@ -140,6 +137,8 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde locations.prologueEnd = build.setLabel(); + uint32_t prologueSize = build.getLabelOffset(locations.prologueEnd) - build.getLabelOffset(locations.start); + // Setup native execution environment build.mov(rState, x0); build.mov(rNativeContext, x3); @@ -168,6 +167,8 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde build.ret(); // Our entry function is special, it spans the whole remaining code area + unwind.startFunction(); + unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24}); unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); return locations; @@ -178,7 +179,7 @@ bool initHeaderFunctions(NativeState& data) AssemblyBuilderA64 build(/* logText= */ false); UnwindBuilder& unwind = *data.unwindBuilder.get(); - unwind.startInfo(); + unwind.startInfo(UnwindBuilder::A64); EntryLocations entryLocations = buildEntryFunction(build, unwind); diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 2acb69f9..5f2cd614 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -58,43 +58,44 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde unwind.startFunction(); // Save common non-volatile registers - build.push(rbp); - unwind.save(rbp); - if (build.abi == ABIX64::SystemV) { + // We need to use a standard rbp-based frame setup for debuggers to work with JIT code + build.push(rbp); build.mov(rbp, rsp); - unwind.setupFrameReg(rbp, 0); } build.push(rbx); - unwind.save(rbx); build.push(r12); - unwind.save(r12); build.push(r13); - unwind.save(r13); build.push(r14); - unwind.save(r14); build.push(r15); - unwind.save(r15); if (build.abi == ABIX64::Windows) { // Save non-volatile registers that are specific to Windows x64 ABI build.push(rdi); - unwind.save(rdi); build.push(rsi); - unwind.save(rsi); + + // On Windows, rbp is available as a general-purpose non-volatile register; we currently don't use it, but we need to push an even number + // of registers for stack alignment... + build.push(rbp); // TODO: once we start using non-volatile SIMD registers on Windows, we will save those here } // Allocate stack space (reg home area + local data) build.sub(rsp, kStackSize + kLocalsSize); - unwind.allocStack(kStackSize + kLocalsSize); locations.prologueEnd = build.setLabel(); + uint32_t prologueSize = build.getLabelOffset(locations.prologueEnd) - build.getLabelOffset(locations.start); + + if (build.abi == ABIX64::SystemV) + unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ true, {rbx, r12, r13, r14, r15}); + else if (build.abi == ABIX64::Windows) + unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ false, {rbx, r12, r13, r14, r15, rdi, rsi, rbp}); + // Setup native execution environment build.mov(rState, rArg1); build.mov(rNativeContext, rArg4); @@ -118,6 +119,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde if (build.abi == ABIX64::Windows) { + build.pop(rbp); build.pop(rsi); build.pop(rdi); } @@ -127,7 +129,10 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde build.pop(r13); build.pop(r12); build.pop(rbx); - build.pop(rbp); + + if (build.abi == ABIX64::SystemV) + build.pop(rbp); + build.ret(); // Our entry function is special, it spans the whole remaining code area @@ -141,7 +146,7 @@ bool initHeaderFunctions(NativeState& data) AssemblyBuilderX64 build(/* logText= */ false); UnwindBuilder& unwind = *data.unwindBuilder.get(); - unwind.startInfo(); + unwind.startInfo(UnwindBuilder::X64); EntryLocations entryLocations = buildEntryFunction(build, unwind); diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index af4c529a..474dabf6 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -18,19 +18,6 @@ namespace CodeGen namespace X64 { -static void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, OperandX64 arg2) -{ - ScopedRegX64 tmp{regs, SizeX64::qword}; - build.vcvttsd2si(tmp.reg, arg2); - - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); - callWrap.addArgument(SizeX64::qword, tmp); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_ldexp)]); - - build.vmovsd(luauRegValue(ra), xmm0); -} - static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int nresults) { IrCallWrapperX64 callWrap(regs, build); @@ -115,9 +102,6 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r { switch (bfid) { - case LBF_MATH_LDEXP: - LUAU_ASSERT(nparams == 2 && nresults == 1); - return emitBuiltinMathLdexp(regs, build, ra, arg, arg2); case LBF_MATH_FREXP: LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); return emitBuiltinMathFrexp(regs, build, ra, arg, nresults); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index efe9fcc0..efcacb04 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -162,7 +162,7 @@ uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block) return getLiveInOutValueCount(function, block).second; } -static void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart) +void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart) { if (!defRs.varargSeq) { diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 50c1848e..062321ba 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -62,6 +62,7 @@ static const char* getTagName(uint8_t tag) case LUA_TTHREAD: return "tthread"; default: + LUAU_ASSERT(!"Unknown type tag"); LUAU_UNREACHABLE(); } } @@ -410,27 +411,6 @@ void toString(std::string& result, IrConst constant) } } -void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo) -{ - size_t start = ctx.result.size(); - - toString(ctx, inst, index); - - if (includeUseInfo) - { - padToDetailColumn(ctx.result, start); - - if (inst.useCount == 0 && hasSideEffects(inst.cmd)) - append(ctx.result, "; %%%u, has side-effects\n", index); - else - append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); - } - else - { - ctx.result.append("\n"); - } -} - static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks) { bool comma = false; @@ -470,6 +450,86 @@ static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs, con } } +static RegisterSet getJumpTargetExtraLiveIn(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst) +{ + RegisterSet extraRs; + + if (blockIdx >= ctx.cfg.in.size()) + return extraRs; + + const RegisterSet& defRs = ctx.cfg.in[blockIdx]; + + // Find first block argument, for guard instructions (isNonTerminatingJump), that's the first and only one + LUAU_ASSERT(isNonTerminatingJump(inst.cmd)); + IrOp op = inst.a; + + if (inst.b.kind == IrOpKind::Block) + op = inst.b; + else if (inst.c.kind == IrOpKind::Block) + op = inst.c; + else if (inst.d.kind == IrOpKind::Block) + op = inst.d; + else if (inst.e.kind == IrOpKind::Block) + op = inst.e; + else if (inst.f.kind == IrOpKind::Block) + op = inst.f; + + if (op.kind == IrOpKind::Block && op.index < ctx.cfg.in.size()) + { + const RegisterSet& inRs = ctx.cfg.in[op.index]; + + extraRs.regs = inRs.regs & ~defRs.regs; + + if (inRs.varargSeq) + requireVariadicSequence(extraRs, defRs, inRs.varargStart); + } + + return extraRs; +} + +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, bool includeUseInfo) +{ + size_t start = ctx.result.size(); + + toString(ctx, inst, instIdx); + + if (includeUseInfo) + { + padToDetailColumn(ctx.result, start); + + if (inst.useCount == 0 && hasSideEffects(inst.cmd)) + { + if (isNonTerminatingJump(inst.cmd)) + { + RegisterSet extraRs = getJumpTargetExtraLiveIn(ctx, block, blockIdx, inst); + + if (extraRs.regs.any() || extraRs.varargSeq) + { + append(ctx.result, "; %%%u, extra in: ", instIdx); + appendRegisterSet(ctx, extraRs, ", "); + ctx.result.append("\n"); + } + else + { + append(ctx.result, "; %%%u\n", instIdx); + } + } + else + { + append(ctx.result, "; %%%u\n", instIdx); + } + } + else + { + append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); + } + } + else + { + ctx.result.append("\n"); + } +} + void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo) { // Report captured registers for entry block @@ -581,7 +641,7 @@ std::string toString(const IrFunction& function, bool includeUseInfo) continue; append(ctx.result, " "); - toStringDetailed(ctx, inst, index, includeUseInfo); + toStringDetailed(ctx, block, uint32_t(i), inst, index, includeUseInfo); } append(ctx.result, "\n"); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 7fd684b4..6dec8024 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -122,42 +122,6 @@ static bool emitBuiltin( { switch (bfid) { - case LBF_MATH_LDEXP: - LUAU_ASSERT(nparams == 2 && nresults == 1); - - if (args.kind == IrOpKind::VmReg) - { - build.ldr(d1, mem(rBase, args.index * sizeof(TValue) + offsetof(TValue, value.n))); - build.fcvtzs(w0, d1); - } - else if (args.kind == IrOpKind::VmConst) - { - size_t constantOffset = args.index * sizeof(TValue) + offsetof(TValue, value.n); - - // Note: cumulative offset is guaranteed to be divisible by 8 (since we're loading a double); we can use that to expand the useful range - // that doesn't require temporaries - if (constantOffset / 8 <= AddressA64::kMaxOffset) - { - build.ldr(d1, mem(rConstants, int(constantOffset))); - } - else - { - emitAddOffset(build, x0, rConstants, constantOffset); - build.ldr(d1, x0); - } - - build.fcvtzs(w0, d1); - } - else if (args.kind == IrOpKind::Constant) - build.mov(w0, int(function.doubleOp(args))); - else if (args.kind != IrOpKind::Undef) - LUAU_ASSERT(!"Unsupported instruction form"); - - build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); - build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, libm_ldexp))); - build.blr(x1); - build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - return true; case LBF_MATH_FREXP: LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); @@ -1610,12 +1574,20 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { if (inst.c.kind != IrOpKind::None) { + bool isInt = (inst.c.kind == IrOpKind::Constant) ? constOp(inst.c).kind == IrConstKind::Int + : getCmdValueKind(function.instOp(inst.c).cmd) == IrValueKind::Int; + RegisterA64 temp1 = tempDouble(inst.b); - RegisterA64 temp2 = tempDouble(inst.c); - RegisterA64 temp3 = regs.allocTemp(KindA64::d); // note: spill() frees all registers so we need to avoid alloc after spill + RegisterA64 temp2 = isInt ? tempInt(inst.c) : tempDouble(inst.c); + RegisterA64 temp3 = isInt ? noreg : regs.allocTemp(KindA64::d); // note: spill() frees all registers so we need to avoid alloc after spill regs.spill(build, index, {temp1, temp2}); - if (d0 != temp2) + if (isInt) + { + build.fmov(d0, temp1); + build.mov(w0, temp2); + } + else if (d0 != temp2) { build.fmov(d0, temp1); build.fmov(d1, temp2); @@ -1634,8 +1606,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.fmov(d0, temp1); } - build.ldr(x0, mem(rNativeContext, getNativeContextOffset(uintOp(inst.a)))); - build.blr(x0); + build.ldr(x1, mem(rNativeContext, getNativeContextOffset(uintOp(inst.a)))); + build.blr(x1); inst.regA64 = regs.takeReg(d0, index); break; } diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index bc617571..8c1f2b04 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -1304,7 +1304,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); if (inst.c.kind != IrOpKind::None) - callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.c), inst.c); + { + bool isInt = (inst.c.kind == IrOpKind::Constant) ? constOp(inst.c).kind == IrConstKind::Int + : getCmdValueKind(function.instOp(inst.c).cmd) == IrValueKind::Int; + + if (isInt) + callWrap.addArgument(SizeX64::dword, memRegUintOp(inst.c), inst.c); + else + callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.c), inst.c); + } callWrap.call(qword[rNativeContext + getNativeContextOffset(uintOp(inst.a))]); inst.regX64 = regs.takeReg(xmm0, index); diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index e58d0a12..cfa4bc6c 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -71,23 +71,6 @@ static BuiltinImplResult translateBuiltinNumberToNumberLibm( return {BuiltinImplType::UsesFallback, 1}; } -// (number, number, ...) -> number -static BuiltinImplResult translateBuiltin2NumberToNumber( - IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) -{ - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - builtinCheckDouble(build, build.vmReg(arg), fallback); - builtinCheckDouble(build, args, fallback); - build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(2), build.constInt(1)); - - if (ra != arg) - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - - return {BuiltinImplType::UsesFallback, 1}; -} - static BuiltinImplResult translateBuiltin2NumberToNumberLibm( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { @@ -110,6 +93,30 @@ static BuiltinImplResult translateBuiltin2NumberToNumberLibm( return {BuiltinImplType::UsesFallback, 1}; } +static BuiltinImplResult translateBuiltinMathLdexp( + IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 2 || nresults > 1) + return {BuiltinImplType::None, -1}; + + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); + + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); + + IrOp vbi = build.inst(IrCmd::NUM_TO_INT, vb); + + IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(bfid), va, vbi); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::UsesFallback, 1}; +} + // (number, ...) -> (number, number) static BuiltinImplResult translateBuiltinNumberTo2Number( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) @@ -778,7 +785,7 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_ATAN2: return translateBuiltin2NumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); case LBF_MATH_LDEXP: - return translateBuiltin2NumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + return translateBuiltinMathLdexp(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); case LBF_MATH_FREXP: case LBF_MATH_MODF: return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index a3af4344..03a6c9c4 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -299,6 +299,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl removeUse(function, inst.e); removeUse(function, inst.f); + // Inherit existing use count (last use is skipped as it will be defined later) + replacement.useCount = inst.useCount; + inst = replacement; // Removing the earlier extra reference, this might leave the block without users without marking it as dead @@ -775,6 +778,8 @@ uint32_t getNativeContextOffset(int bfid) return offsetof(NativeContext, libm_pow); case LBF_IR_MATH_LOG2: return offsetof(NativeContext, libm_log2); + case LBF_MATH_LDEXP: + return offsetof(NativeContext, libm_ldexp); default: LUAU_ASSERT(!"Unsupported bfid"); } diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index e7663666..926ead3d 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -290,6 +290,20 @@ struct ConstPropState valueMap[versionedVmRegLoad(loadCmd, storeInst.a)] = storeInst.b.index; } + void clear() + { + for (int i = 0; i <= maxReg; ++i) + regs[i] = RegisterInfo(); + + maxReg = 0; + + inSafeEnv = false; + checkedGc = false; + + instLink.clear(); + valueMap.clear(); + } + IrFunction& function; bool useValueNumbering = false; @@ -854,12 +868,11 @@ static void constPropInBlock(IrBuilder& build, IrBlock& block, ConstPropState& s state.valueMap.clear(); } -static void constPropInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block, bool useValueNumbering) +static void constPropInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block, ConstPropState& state) { IrFunction& function = build.function; - ConstPropState state{function}; - state.useValueNumbering = useValueNumbering; + state.clear(); while (block) { @@ -936,7 +949,7 @@ static std::vector collectDirectBlockJumpPath(IrFunction& function, st return path; } -static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited, IrBlock& startingBlock, bool useValueNumbering) +static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited, IrBlock& startingBlock, ConstPropState& state) { IrFunction& function = build.function; @@ -965,8 +978,9 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited return; // Initialize state with the knowledge of our current block - ConstPropState state{function}; - state.useValueNumbering = useValueNumbering; + state.clear(); + + // TODO: using values from the first block can cause 'live out' of the linear block predecessor to not have all required registers constPropInBlock(build, startingBlock, state); // Veryfy that target hasn't changed @@ -981,10 +995,43 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited replace(function, termInst.a, newBlock); - // Clone the collected path int our fresh block + // Clone the collected path into our fresh block for (uint32_t pathBlockIdx : path) build.clone(function.blocks[pathBlockIdx], /* removeCurrentTerminator */ true); + // If all live in/out data is defined aside from the new block, generate it + // Note that liveness information is not strictly correct after optimization passes and may need to be recomputed before next passes + // The information generated here is consistent with current state that could be outdated, but still useful in IR inspection + if (function.cfg.in.size() == newBlock.index) + { + LUAU_ASSERT(function.cfg.in.size() == function.cfg.out.size()); + LUAU_ASSERT(function.cfg.in.size() == function.cfg.def.size()); + + // Live in is the same as the input of the original first block + function.cfg.in.push_back(function.cfg.in[path.front()]); + + // Live out is the same as the result of the original last block + function.cfg.out.push_back(function.cfg.out[path.back()]); + + // Defs are tricky, registers are joined together, but variadic sequences can be consumed inside the block + function.cfg.def.push_back({}); + RegisterSet& def = function.cfg.def.back(); + + for (uint32_t pathBlockIdx : path) + { + const RegisterSet& pathDef = function.cfg.def[pathBlockIdx]; + + def.regs |= pathDef.regs; + + // Taking only the last defined variadic sequence if it's not consumed before before the end + if (pathDef.varargSeq && function.cfg.out.back().varargSeq) + { + def.varargSeq = true; + def.varargStart = pathDef.varargStart; + } + } + } + // Optimize our linear block IrBlock& linearBlock = function.blockOp(newBlock); constPropInBlock(build, linearBlock, state); @@ -994,6 +1041,9 @@ void constPropInBlockChains(IrBuilder& build, bool useValueNumbering) { IrFunction& function = build.function; + ConstPropState state{function}; + state.useValueNumbering = useValueNumbering; + std::vector visited(function.blocks.size(), false); for (IrBlock& block : function.blocks) @@ -1004,7 +1054,7 @@ void constPropInBlockChains(IrBuilder& build, bool useValueNumbering) if (visited[function.getBlockIndex(block)]) continue; - constPropInBlockChain(build, visited, &block, useValueNumbering); + constPropInBlockChain(build, visited, &block, state); } } @@ -1015,6 +1065,9 @@ void createLinearBlocks(IrBuilder& build, bool useValueNumbering) // new 'block' will only be reachable from a single one and all gathered information can be preserved. IrFunction& function = build.function; + ConstPropState state{function}; + state.useValueNumbering = useValueNumbering; + std::vector visited(function.blocks.size(), false); // This loop can create new 'linear' blocks, so index-based loop has to be used (and it intentionally won't reach those new blocks) @@ -1029,7 +1082,7 @@ void createLinearBlocks(IrBuilder& build, bool useValueNumbering) if (visited[function.getBlockIndex(block)]) continue; - tryCreateLinearBlock(build, visited, block, useValueNumbering); + tryCreateLinearBlock(build, visited, block, state); } } diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index a4be95ff..e9df184d 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -36,27 +36,25 @@ #define DW_CFA_lo_user 0x1c #define DW_CFA_hi_user 0x3f -// Register numbers for x64 (System V ABI, page 57, ch. 3.7, figure 3.36) -#define DW_REG_RAX 0 -#define DW_REG_RDX 1 -#define DW_REG_RCX 2 -#define DW_REG_RBX 3 -#define DW_REG_RSI 4 -#define DW_REG_RDI 5 -#define DW_REG_RBP 6 -#define DW_REG_RSP 7 -#define DW_REG_R8 8 -#define DW_REG_R9 9 -#define DW_REG_R10 10 -#define DW_REG_R11 11 -#define DW_REG_R12 12 -#define DW_REG_R13 13 -#define DW_REG_R14 14 -#define DW_REG_R15 15 -#define DW_REG_RA 16 +// Register numbers for X64 (System V ABI, page 57, ch. 3.7, figure 3.36) +#define DW_REG_X64_RAX 0 +#define DW_REG_X64_RDX 1 +#define DW_REG_X64_RCX 2 +#define DW_REG_X64_RBX 3 +#define DW_REG_X64_RSI 4 +#define DW_REG_X64_RDI 5 +#define DW_REG_X64_RBP 6 +#define DW_REG_X64_RSP 7 +#define DW_REG_X64_RA 16 -const int regIndexToDwRegX64[16] = {DW_REG_RAX, DW_REG_RCX, DW_REG_RDX, DW_REG_RBX, DW_REG_RSP, DW_REG_RBP, DW_REG_RSI, DW_REG_RDI, DW_REG_R8, - DW_REG_R9, DW_REG_R10, DW_REG_R11, DW_REG_R12, DW_REG_R13, DW_REG_R14, DW_REG_R15}; +// Register numbers for A64 (DWARF for the Arm 64-bit Architecture, ch. 4.1) +#define DW_REG_A64_FP 29 +#define DW_REG_A64_LR 30 +#define DW_REG_A64_SP 31 + +// X64 register mapping from real register index to DWARF2 (r8..r15 are mapped 1-1, but named registers aren't) +const int regIndexToDwRegX64[16] = {DW_REG_X64_RAX, DW_REG_X64_RCX, DW_REG_X64_RDX, DW_REG_X64_RBX, DW_REG_X64_RSP, DW_REG_X64_RBP, DW_REG_X64_RSI, + DW_REG_X64_RDI, 8, 9, 10, 11, 12, 13, 14, 15}; const int kCodeAlignFactor = 1; const int kDataAlignFactor = 8; @@ -85,7 +83,7 @@ static uint8_t* defineSavedRegisterLocation(uint8_t* pos, int dwReg, uint32_t st { LUAU_ASSERT(stackOffset % kDataAlignFactor == 0 && "stack offsets have to be measured in kDataAlignFactor units"); - if (dwReg <= 15) + if (dwReg <= 0x3f) { pos = writeu8(pos, DW_CFA_offset + dwReg); } @@ -99,8 +97,9 @@ static uint8_t* defineSavedRegisterLocation(uint8_t* pos, int dwReg, uint32_t st return pos; } -static uint8_t* advanceLocation(uint8_t* pos, uint8_t offset) +static uint8_t* advanceLocation(uint8_t* pos, unsigned int offset) { + LUAU_ASSERT(offset < 256); pos = writeu8(pos, DW_CFA_advance_loc1); pos = writeu8(pos, offset); return pos; @@ -132,8 +131,10 @@ size_t UnwindBuilderDwarf2::getBeginOffset() const return beginOffset; } -void UnwindBuilderDwarf2::startInfo() +void UnwindBuilderDwarf2::startInfo(Arch arch) { + LUAU_ASSERT(arch == A64 || arch == X64); + uint8_t* cieLength = pos; pos = writeu32(pos, 0); // Length (to be filled later) @@ -142,15 +143,24 @@ void UnwindBuilderDwarf2::startInfo() pos = writeu8(pos, 0); // CIE augmentation String "" + int ra = arch == A64 ? DW_REG_A64_LR : DW_REG_X64_RA; + pos = writeuleb128(pos, kCodeAlignFactor); // Code align factor pos = writeuleb128(pos, -kDataAlignFactor & 0x7f); // Data align factor of (as signed LEB128) - pos = writeu8(pos, DW_REG_RA); // Return address register + pos = writeu8(pos, ra); // Return address register // Optional CIE augmentation section (not present) - // Call frame instructions (common for all FDEs, of which we have 1) - pos = defineCfaExpression(pos, DW_REG_RSP, 8); // Define CFA to be the rsp + 8 - pos = defineSavedRegisterLocation(pos, DW_REG_RA, 8); // Define return address register (RA) to be located at CFA - 8 + // Call frame instructions (common for all FDEs) + if (arch == A64) + { + pos = defineCfaExpression(pos, DW_REG_A64_SP, 0); // Define CFA to be the sp + } + else + { + pos = defineCfaExpression(pos, DW_REG_X64_RSP, 8); // Define CFA to be the rsp + 8 + pos = defineSavedRegisterLocation(pos, DW_REG_X64_RA, 8); // Define return address register (RA) to be located at CFA - 8 + } pos = alignPosition(cieLength, pos); writeu32(cieLength, unsigned(pos - cieLength - 4)); // Length field itself is excluded from length @@ -165,8 +175,6 @@ void UnwindBuilderDwarf2::startFunction() func.fdeEntryStartPos = uint32_t(pos - rawData); unwindFunctions.push_back(func); - stackOffset = 8; // Return address was pushed by calling the function - fdeEntryStart = pos; // Will be written at the end pos = writeu32(pos, 0); // Length (to be filled later) pos = writeu32(pos, unsigned(pos - rawData)); // CIE pointer @@ -178,42 +186,11 @@ void UnwindBuilderDwarf2::startFunction() // Function call frame instructions to follow } -void UnwindBuilderDwarf2::spill(int espOffset, X64::RegisterX64 reg) -{ - pos = advanceLocation(pos, 5); // REX.W mov [rsp + imm8], reg -} - -void UnwindBuilderDwarf2::save(X64::RegisterX64 reg) -{ - stackOffset += 8; - pos = advanceLocation(pos, 2); // REX.W push reg - pos = defineCfaExpressionOffset(pos, stackOffset); - pos = defineSavedRegisterLocation(pos, regIndexToDwRegX64[reg.index], stackOffset); -} - -void UnwindBuilderDwarf2::allocStack(int size) -{ - stackOffset += size; - pos = advanceLocation(pos, 4); // REX.W sub rsp, imm8 - pos = defineCfaExpressionOffset(pos, stackOffset); -} - -void UnwindBuilderDwarf2::setupFrameReg(X64::RegisterX64 reg, int espOffset) -{ - if (espOffset != 0) - pos = advanceLocation(pos, 5); // REX.W lea rbp, [rsp + imm8] - else - pos = advanceLocation(pos, 3); // REX.W mov rbp, rsp - - // Cfa is based on rsp, so no additonal commands are required -} - void UnwindBuilderDwarf2::finishFunction(uint32_t beginOffset, uint32_t endOffset) { unwindFunctions.back().beginOffset = beginOffset; unwindFunctions.back().endOffset = endOffset; - LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); LUAU_ASSERT(fdeEntryStart != nullptr); pos = alignPosition(fdeEntryStart, pos); @@ -228,6 +205,69 @@ void UnwindBuilderDwarf2::finishInfo() LUAU_ASSERT(getSize() <= kRawDataLimit); } +void UnwindBuilderDwarf2::prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) +{ + LUAU_ASSERT(stackSize % 16 == 0); + LUAU_ASSERT(regs.size() >= 2 && regs.begin()[0] == A64::x29 && regs.begin()[1] == A64::x30); + LUAU_ASSERT(regs.size() * 8 <= stackSize); + + // sub sp, sp, stackSize + pos = advanceLocation(pos, 4); + pos = defineCfaExpressionOffset(pos, stackSize); + + // stp/str to store each register to stack in order + pos = advanceLocation(pos, prologueSize - 4); + + for (size_t i = 0; i < regs.size(); ++i) + { + LUAU_ASSERT(regs.begin()[i].kind == A64::KindA64::x); + pos = defineSavedRegisterLocation(pos, regs.begin()[i].index, stackSize - unsigned(i * 8)); + } +} + +void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) +{ + LUAU_ASSERT(stackSize > 0 && stackSize <= 128 && stackSize % 8 == 0); + + unsigned int stackOffset = 8; // Return address was pushed by calling the function + unsigned int prologueOffset = 0; + + if (setupFrame) + { + // push rbp + stackOffset += 8; + prologueOffset += 2; + pos = advanceLocation(pos, 2); + pos = defineCfaExpressionOffset(pos, stackOffset); + pos = defineSavedRegisterLocation(pos, DW_REG_X64_RBP, stackOffset); + + // mov rbp, rsp + prologueOffset += 3; + pos = advanceLocation(pos, 3); + } + + // push reg + for (X64::RegisterX64 reg : regs) + { + LUAU_ASSERT(reg.size == X64::SizeX64::qword); + + stackOffset += 8; + prologueOffset += 2; + pos = advanceLocation(pos, 2); + pos = defineCfaExpressionOffset(pos, stackOffset); + pos = defineSavedRegisterLocation(pos, regIndexToDwRegX64[reg.index], stackOffset); + } + + // sub rsp, stackSize + stackOffset += stackSize; + prologueOffset += 4; + pos = advanceLocation(pos, 4); + pos = defineCfaExpressionOffset(pos, stackOffset); + + LUAU_ASSERT(stackOffset % 16 == 0); + LUAU_ASSERT(prologueOffset == prologueSize); +} + size_t UnwindBuilderDwarf2::getSize() const { return size_t(pos - rawData); @@ -244,14 +284,14 @@ void UnwindBuilderDwarf2::finalize(char* target, size_t offset, void* funcAddres for (const UnwindFunctionDwarf2& func : unwindFunctions) { - uint8_t* fdeEntryStart = (uint8_t*)target + func.fdeEntryStartPos; + uint8_t* fdeEntry = (uint8_t*)target + func.fdeEntryStartPos; - writeu64(fdeEntryStart + kFdeInitialLocationOffset, uintptr_t(funcAddress) + offset + func.beginOffset); + writeu64(fdeEntry + kFdeInitialLocationOffset, uintptr_t(funcAddress) + offset + func.beginOffset); if (func.endOffset == kFullBlockFuncton) - writeu64(fdeEntryStart + kFdeAddressRangeOffset, funcSize - offset); + writeu64(fdeEntry + kFdeAddressRangeOffset, funcSize - offset); else - writeu64(fdeEntryStart + kFdeAddressRangeOffset, func.endOffset - func.beginOffset); + writeu64(fdeEntry + kFdeAddressRangeOffset, func.endOffset - func.beginOffset); } } diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 5f4f16a9..f9b927c5 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -31,7 +31,10 @@ size_t UnwindBuilderWin::getBeginOffset() const return beginOffset; } -void UnwindBuilderWin::startInfo() {} +void UnwindBuilderWin::startInfo(Arch arch) +{ + LUAU_ASSERT(arch == X64); +} void UnwindBuilderWin::startFunction() { @@ -50,45 +53,6 @@ void UnwindBuilderWin::startFunction() // rax has register index 0, which in Windows unwind info means that frame register is not used frameReg = X64::rax; frameRegOffset = 0; - - // Return address was pushed by calling the function - stackOffset = 8; -} - -void UnwindBuilderWin::spill(int espOffset, X64::RegisterX64 reg) -{ - prologSize += 5; // REX.W mov [rsp + imm8], reg -} - -void UnwindBuilderWin::save(X64::RegisterX64 reg) -{ - prologSize += 2; // REX.W push reg - stackOffset += 8; - unwindCodes.push_back({prologSize, UWOP_PUSH_NONVOL, reg.index}); -} - -void UnwindBuilderWin::allocStack(int size) -{ - LUAU_ASSERT(size >= 8 && size <= 128 && size % 8 == 0); - - prologSize += 4; // REX.W sub rsp, imm8 - stackOffset += size; - unwindCodes.push_back({prologSize, UWOP_ALLOC_SMALL, uint8_t((size - 8) / 8)}); -} - -void UnwindBuilderWin::setupFrameReg(X64::RegisterX64 reg, int espOffset) -{ - LUAU_ASSERT(espOffset < 256 && espOffset % 16 == 0); - - frameReg = reg; - frameRegOffset = uint8_t(espOffset / 16); - - if (espOffset != 0) - prologSize += 5; // REX.W lea rbp, [rsp + imm8] - else - prologSize += 3; // REX.W mov rbp, rsp - - unwindCodes.push_back({prologSize, UWOP_SET_FPREG, frameRegOffset}); } void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset) @@ -99,8 +63,6 @@ void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset) // Windows unwind code count is stored in uint8_t, so we can't have more LUAU_ASSERT(unwindCodes.size() < 256); - LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); - UnwindInfoWin info; info.version = 1; info.flags = 0; // No EH @@ -142,6 +104,54 @@ void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset) void UnwindBuilderWin::finishInfo() {} +void UnwindBuilderWin::prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) +{ + LUAU_ASSERT(!"Not implemented"); +} + +void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) +{ + LUAU_ASSERT(stackSize > 0 && stackSize <= 128 && stackSize % 8 == 0); + LUAU_ASSERT(prologueSize < 256); + + unsigned int stackOffset = 8; // Return address was pushed by calling the function + unsigned int prologueOffset = 0; + + if (setupFrame) + { + // push rbp + stackOffset += 8; + prologueOffset += 2; + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_PUSH_NONVOL, X64::rbp.index}); + + // mov rbp, rsp + prologueOffset += 3; + frameReg = X64::rbp; + frameRegOffset = 0; + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_SET_FPREG, frameRegOffset}); + } + + // push reg + for (X64::RegisterX64 reg : regs) + { + LUAU_ASSERT(reg.size == X64::SizeX64::qword); + + stackOffset += 8; + prologueOffset += 2; + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_PUSH_NONVOL, reg.index}); + } + + // sub rsp, stackSize + stackOffset += stackSize; + prologueOffset += 4; + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_ALLOC_SMALL, uint8_t((stackSize - 8) / 8)}); + + LUAU_ASSERT(stackOffset % 16 == 0); + LUAU_ASSERT(prologueOffset == prologueSize); + + this->prologSize = prologueSize; +} + size_t UnwindBuilderWin::getSize() const { return sizeof(UnwindFunctionWin) * unwindFunctions.size() + size_t(rawDataPos - rawData); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index b5690acb..e2b769ec 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1701,8 +1701,6 @@ void BytecodeBuilder::dumpConstant(std::string& result, int k) const formatAppend(result, "'%s'", func.dumpname.c_str()); break; } - default: - LUAU_UNREACHABLE(); } } diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index c8a184a1..b3edf2ba 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -913,7 +913,9 @@ reentry: // slow-path: not a function call if (LUAU_UNLIKELY(!ttisfunction(ra))) { - VM_PROTECT(luaV_tryfuncTM(L, ra)); + VM_PROTECT_PC(); // luaV_tryfuncTM may fail + + luaV_tryfuncTM(L, ra); argtop++; // __call adds an extra self } diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 01deddd3..df2fa36b 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -135,20 +135,9 @@ TEST_CASE("WindowsUnwindCodesX64") UnwindBuilderWin unwind; - unwind.startInfo(); + unwind.startInfo(UnwindBuilder::X64); unwind.startFunction(); - unwind.spill(16, rdx); - unwind.spill(8, rcx); - unwind.save(rdi); - unwind.save(rsi); - unwind.save(rbx); - unwind.save(rbp); - unwind.save(r12); - unwind.save(r13); - unwind.save(r14); - unwind.save(r15); - unwind.allocStack(72); - unwind.setupFrameReg(rbp, 48); + unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}); unwind.finishFunction(0x11223344, 0x55443322); unwind.finishInfo(); @@ -156,8 +145,8 @@ TEST_CASE("WindowsUnwindCodesX64") data.resize(unwind.getSize()); unwind.finalize(data.data(), 0, nullptr, 0); - std::vector expected{0x44, 0x33, 0x22, 0x11, 0x22, 0x33, 0x44, 0x55, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x23, 0x0a, 0x35, 0x23, 0x33, 0x1e, - 0x82, 0x1a, 0xf0, 0x18, 0xe0, 0x16, 0xd0, 0x14, 0xc0, 0x12, 0x50, 0x10, 0x30, 0x0e, 0x60, 0x0c, 0x70}; + std::vector expected{0x44, 0x33, 0x22, 0x11, 0x22, 0x33, 0x44, 0x55, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x17, 0x0a, 0x05, 0x17, 0x82, 0x13, + 0xf0, 0x11, 0xe0, 0x0f, 0xd0, 0x0d, 0xc0, 0x0b, 0x30, 0x09, 0x60, 0x07, 0x70, 0x05, 0x03, 0x02, 0x50}; REQUIRE(data.size() == expected.size()); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); @@ -170,18 +159,9 @@ TEST_CASE("Dwarf2UnwindCodesX64") UnwindBuilderDwarf2 unwind; - unwind.startInfo(); + unwind.startInfo(UnwindBuilder::X64); unwind.startFunction(); - unwind.save(rdi); - unwind.save(rsi); - unwind.save(rbx); - unwind.save(rbp); - unwind.save(r12); - unwind.save(r13); - unwind.save(r14); - unwind.save(r15); - unwind.allocStack(72); - unwind.setupFrameReg(rbp, 48); + unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}); unwind.finishFunction(0, 0); unwind.finishInfo(); @@ -189,11 +169,36 @@ TEST_CASE("Dwarf2UnwindCodesX64") data.resize(unwind.getSize()); unwind.finalize(data.data(), 0, nullptr, 0); - std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x05, 0x10, 0x01, + std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x90, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x85, 0x02, 0x02, 0x02, 0x0e, 0x18, 0x84, 0x03, 0x02, 0x02, 0x0e, 0x20, 0x83, - 0x04, 0x02, 0x02, 0x0e, 0x28, 0x86, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, 0x0e, 0x40, - 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x02, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00}; + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x86, 0x02, 0x02, 0x03, 0x02, 0x02, 0x0e, 0x18, 0x85, 0x03, 0x02, 0x02, 0x0e, + 0x20, 0x84, 0x04, 0x02, 0x02, 0x0e, 0x28, 0x83, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, + 0x0e, 0x40, 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}; + + REQUIRE(data.size() == expected.size()); + CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); +} + +TEST_CASE("Dwarf2UnwindCodesA64") +{ + using namespace A64; + + UnwindBuilderDwarf2 unwind; + + unwind.startInfo(UnwindBuilder::A64); + unwind.startFunction(); + unwind.prologueA64(/* prologueSize= */ 28, /* stackSize= */ 64, {x29, x30, x19, x20, x21, x22, x23, x24}); + unwind.finishFunction(0, 32); + unwind.finishInfo(); + + std::vector data; + data.resize(unwind.getSize()); + unwind.finalize(data.data(), 0, nullptr, 0); + + std::vector expected{0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x1e, 0x0c, 0x1f, 0x00, 0x2c, 0x00, 0x00, + 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x04, + 0x0e, 0x40, 0x02, 0x18, 0x9d, 0x08, 0x9e, 0x07, 0x93, 0x06, 0x94, 0x05, 0x95, 0x04, 0x96, 0x03, 0x97, 0x02, 0x98, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00}; REQUIRE(data.size() == expected.size()); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); @@ -247,7 +252,7 @@ TEST_CASE("GeneratedCodeExecutionX64") CHECK(result == 210); } -void throwing(int64_t arg) +static void throwing(int64_t arg) { CHECK(arg == 25); @@ -266,27 +271,25 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") std::unique_ptr unwind = std::make_unique(); #endif - unwind->startInfo(); + unwind->startInfo(UnwindBuilder::X64); Label functionBegin = build.setLabel(); unwind->startFunction(); // Prologue - build.push(rNonVol1); - unwind->save(rNonVol1); - build.push(rNonVol2); - unwind->save(rNonVol2); build.push(rbp); - unwind->save(rbp); + build.mov(rbp, rsp); + build.push(rNonVol1); + build.push(rNonVol2); int stackSize = 32; int localsSize = 16; build.sub(rsp, stackSize + localsSize); - unwind->allocStack(stackSize + localsSize); - build.lea(rbp, addr[rsp + stackSize]); - unwind->setupFrameReg(rbp, stackSize); + uint32_t prologueSize = build.setLabel().location; + + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}); // Body build.mov(rNonVol1, rArg1); @@ -297,10 +300,10 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") build.call(rNonVol2); // Epilogue - build.lea(rsp, addr[rbp + localsSize]); - build.pop(rbp); + build.add(rsp, stackSize + localsSize); build.pop(rNonVol2); build.pop(rNonVol1); + build.pop(rbp); build.ret(); unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); @@ -349,7 +352,7 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") std::unique_ptr unwind = std::make_unique(); #endif - unwind->startInfo(); + unwind->startInfo(UnwindBuilder::X64); Label start1; Label start2; @@ -360,21 +363,19 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") unwind->startFunction(); // Prologue - build.push(rNonVol1); - unwind->save(rNonVol1); - build.push(rNonVol2); - unwind->save(rNonVol2); build.push(rbp); - unwind->save(rbp); + build.mov(rbp, rsp); + build.push(rNonVol1); + build.push(rNonVol2); int stackSize = 32; int localsSize = 16; build.sub(rsp, stackSize + localsSize); - unwind->allocStack(stackSize + localsSize); - build.lea(rbp, addr[rsp + stackSize]); - unwind->setupFrameReg(rbp, stackSize); + uint32_t prologueSize = build.setLabel().location - start1.location; + + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}); // Body build.mov(rNonVol1, rArg1); @@ -385,41 +386,35 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") build.call(rNonVol2); // Epilogue - build.lea(rsp, addr[rbp + localsSize]); - build.pop(rbp); + build.add(rsp, stackSize + localsSize); build.pop(rNonVol2); build.pop(rNonVol1); + build.pop(rbp); build.ret(); Label end1 = build.setLabel(); unwind->finishFunction(build.getLabelOffset(start1), build.getLabelOffset(end1)); } - // Second function with different layout + // Second function with different layout and no frame { build.setLabel(start2); unwind->startFunction(); // Prologue build.push(rNonVol1); - unwind->save(rNonVol1); build.push(rNonVol2); - unwind->save(rNonVol2); build.push(rNonVol3); - unwind->save(rNonVol3); build.push(rNonVol4); - unwind->save(rNonVol4); - build.push(rbp); - unwind->save(rbp); int stackSize = 32; - int localsSize = 32; + int localsSize = 24; build.sub(rsp, stackSize + localsSize); - unwind->allocStack(stackSize + localsSize); - build.lea(rbp, addr[rsp + stackSize]); - unwind->setupFrameReg(rbp, stackSize); + uint32_t prologueSize = build.setLabel().location - start2.location; + + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rNonVol3, rNonVol4}); // Body build.mov(rNonVol3, rArg1); @@ -430,8 +425,7 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") build.call(rNonVol4); // Epilogue - build.lea(rsp, addr[rbp + localsSize]); - build.pop(rbp); + build.add(rsp, stackSize + localsSize); build.pop(rNonVol4); build.pop(rNonVol3); build.pop(rNonVol2); @@ -495,37 +489,29 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") std::unique_ptr unwind = std::make_unique(); #endif - unwind->startInfo(); + unwind->startInfo(UnwindBuilder::X64); Label functionBegin = build.setLabel(); unwind->startFunction(); // Prologue (some of these registers don't have to be saved, but we want to have a big prologue) - build.push(r10); - unwind->save(r10); - build.push(r11); - unwind->save(r11); - build.push(r12); - unwind->save(r12); - build.push(r13); - unwind->save(r13); - build.push(r14); - unwind->save(r14); - build.push(r15); - unwind->save(r15); build.push(rbp); - unwind->save(rbp); + build.mov(rbp, rsp); + build.push(r10); + build.push(r11); + build.push(r12); + build.push(r13); + build.push(r14); + build.push(r15); int stackSize = 64; int localsSize = 16; build.sub(rsp, stackSize + localsSize); - unwind->allocStack(stackSize + localsSize); - build.lea(rbp, addr[rsp + stackSize]); - unwind->setupFrameReg(rbp, stackSize); + uint32_t prologueSize = build.setLabel().location; - size_t prologueSize = build.setLabel().location; + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {r10, r11, r12, r13, r14, r15}); // Body build.mov(rax, rArg1); @@ -535,14 +521,14 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") Label returnOffset = build.setLabel(); // Epilogue - build.lea(rsp, addr[rbp + localsSize]); - build.pop(rbp); + build.add(rsp, stackSize + localsSize); build.pop(r15); build.pop(r14); build.pop(r13); build.pop(r12); build.pop(r11); build.pop(r10); + build.pop(rbp); build.ret(); unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); @@ -650,6 +636,80 @@ TEST_CASE("GeneratedCodeExecutionA64") CHECK(result == 42); } +#if 0 +static void throwing(int64_t arg) +{ + CHECK(arg == 25); + + throw std::runtime_error("testing"); +} + +TEST_CASE("GeneratedCodeExecutionWithThrowA64") +{ + using namespace A64; + + AssemblyBuilderA64 build(/* logText= */ false); + + std::unique_ptr unwind = std::make_unique(); + + unwind->startInfo(UnwindBuilder::A64); + + build.sub(sp, sp, 32); + build.stp(x29, x30, mem(sp)); + build.str(x28, mem(sp, 16)); + build.mov(x29, sp); + + Label prologueEnd = build.setLabel(); + + build.add(x0, x0, 15); + build.blr(x1); + + build.ldr(x28, mem(sp, 16)); + build.ldp(x29, x30, mem(sp)); + build.add(sp, sp, 32); + + build.ret(); + + Label functionEnd = build.setLabel(); + + unwind->startFunction(); + unwind->prologueA64(build.getLabelOffset(prologueEnd), 32, {x29, x30, x28}); + unwind->finishFunction(0, build.getLabelOffset(functionEnd)); + + build.finalize(); + + unwind->finishInfo(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), reinterpret_cast(build.code.data()), build.code.size() * 4, nativeData, + sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f = (FunctionType*)nativeEntry; + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} +#endif + #endif TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index f09f174a..e1213b93 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -532,6 +532,30 @@ bb_0: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "ReplacementPreservesUses") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp unk = build.inst(IrCmd::LOAD_INT, build.vmReg(0)); + build.inst(IrCmd::STORE_INT, build.vmReg(8), build.inst(IrCmd::BITXOR_UINT, unk, build.constInt(~0u))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ true) == R"( +bb_0: ; useCount: 0 + %0 = LOAD_INT R0 ; useCount: 1, lastUse: %0 + %1 = BITNOT_UINT %0 ; useCount: 1, lastUse: %0 + STORE_INT R8, %1 ; %2 + RETURN 0u ; %3 + +)"); +} + TEST_CASE_FIXTURE(IrBuilderFixture, "NumericNan") { IrOp block = build.block(IrBlockKind::Internal); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 6552a24d..26b3b00d 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -470,8 +470,6 @@ TEST_SUITE_END(); struct NormalizeFixture : Fixture { - ScopedFastFlag sff2{"LuauNegatedClassTypes", true}; - TypeArena arena; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; @@ -632,11 +630,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_function_and_top_function") TEST_CASE_FIXTURE(NormalizeFixture, "negated_function_is_anything_except_a_function") { - ScopedFastFlag sffs[] = { - {"LuauNegatedTableTypes", true}, - {"LuauNegatedClassTypes", true}, - }; - CHECK("(boolean | class | number | string | table | thread)?" == toString(normal(R"( Not )"))); @@ -649,11 +642,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "specific_functions_cannot_be_negated") TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean") { - ScopedFastFlag sffs[] = { - {"LuauNegatedTableTypes", true}, - {"LuauNegatedClassTypes", true}, - }; - // TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function CHECK("(class | function | number | string | table | thread)?" == toString(normal(R"( Not @@ -723,8 +711,6 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Parent | Unrelated" == toString(normal("Parent | Unrelated"))); CHECK("Parent" == toString(normal("Parent | Child"))); @@ -733,8 +719,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "unions_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Child" == toString(normal("Parent & Child"))); CHECK("never" == toString(normal("Child & Unrelated"))); @@ -742,8 +726,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersections_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Child" == toString(normal("(Child | Unrelated) & Child"))); } @@ -764,11 +746,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") { - ScopedFastFlag sffs[] = { - {"LuauNegatedTableTypes", true}, - {"LuauNegatedClassTypes", true}, - }; - createSomeClasses(&frontend); CHECK("(Parent & ~Child) | Unrelated" == toString(normal("(Parent & Not) | Unrelated"))); CHECK("((class & ~Child) | boolean | function | number | string | table | thread)?" == toString(normal("Not"))); @@ -781,24 +758,18 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_unknown") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("Parent" == toString(normal("Parent & unknown"))); } TEST_CASE_FIXTURE(NormalizeFixture, "classes_and_never") { - ScopedFastFlag sff{"LuauNegatedClassTypes", true}; - createSomeClasses(&frontend); CHECK("never" == toString(normal("Parent & never"))); } TEST_CASE_FIXTURE(NormalizeFixture, "top_table_type") { - ScopedFastFlag sff{"LuauNegatedTableTypes", true}; - CHECK("table" == toString(normal("{} | tbl"))); CHECK("{| |}" == toString(normal("{} & tbl"))); CHECK("never" == toString(normal("number & tbl"))); @@ -806,8 +777,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "top_table_type") TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_tables") { - ScopedFastFlag sff{"LuauNegatedTableTypes", true}; - CHECK(nullptr == toNormalizedType("Not<{}>")); CHECK("(boolean | class | function | number | string | thread)?" == toString(normal("Not"))); CHECK("table" == toString(normal("Not>"))); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index ef5aabbe..1335b6f4 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -112,14 +112,6 @@ TEST_CASE_FIXTURE(Fixture, "can_haz_annotations") REQUIRE(block != nullptr); } -TEST_CASE_FIXTURE(Fixture, "local_cannot_have_annotation_with_extensions_disabled") -{ - Luau::ParseOptions options; - options.allowTypeAnnotations = false; - - CHECK_THROWS_AS(parse("local foo: string = \"Hello Types!\"", options), std::exception); -} - TEST_CASE_FIXTURE(Fixture, "local_with_annotation") { AstStatBlock* block = parse(R"( @@ -150,14 +142,6 @@ TEST_CASE_FIXTURE(Fixture, "type_names_can_contain_dots") REQUIRE(block != nullptr); } -TEST_CASE_FIXTURE(Fixture, "functions_cannot_have_return_annotations_if_extensions_are_disabled") -{ - Luau::ParseOptions options; - options.allowTypeAnnotations = false; - - CHECK_THROWS_AS(parse("function foo(): number return 55 end", options), std::exception); -} - TEST_CASE_FIXTURE(Fixture, "functions_can_have_return_annotations") { AstStatBlock* block = parse(R"( @@ -395,14 +379,6 @@ TEST_CASE_FIXTURE(Fixture, "return_type_is_an_intersection_type_if_led_with_one_ CHECK(returnAnnotation->types.data[1]->as()); } -TEST_CASE_FIXTURE(Fixture, "illegal_type_alias_if_extensions_are_disabled") -{ - Luau::ParseOptions options; - options.allowTypeAnnotations = false; - - CHECK_THROWS_AS(parse("type A = number", options), std::exception); -} - TEST_CASE_FIXTURE(Fixture, "type_alias_to_a_typeof") { AstStatBlock* block = parse(R"( @@ -2837,8 +2813,6 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_after_last_t TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_type_parameter") { - ScopedFastFlag sff{"LuauParserErrorsOnMissingDefaultTypePackArgument", true}; - ParseResult result = tryParse(R"( type Foo = nil )"); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 687bc766..0f255f08 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -108,7 +108,10 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_ERROR_COUNT(2, result); + else + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("*error-type*", toString(requireType("a"))); } diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 9086a604..94cf4b32 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -169,14 +169,27 @@ TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_ LUAU_REQUIRE_ERROR_COUNT(2, result); - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(builtinTypes->numberType, tm->wantedType); - CHECK_EQ(builtinTypes->stringType, tm->givenType); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + GenericError* g = get(result.errors[0]); + REQUIRE(g); + CHECK(g->message == "None of the overloads for function that accept 1 arguments are compatible."); + } + else + { + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); + } ExtraInformation* ei = get(result.errors[1]); REQUIRE(ei); - CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK("Available overloads: (number) -> number; and (number) -> string" == ei->message); + else + CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); } TEST_CASE_FIXTURE(Fixture, "list_all_overloads_if_no_overload_takes_given_argument_count") diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 7a134358..c3dbbc7d 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -2,6 +2,7 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Frontend.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Type.h" @@ -31,6 +32,53 @@ TEST_CASE_FIXTURE(Fixture, "for_loop") CHECK_EQ(*builtinTypes->numberType, *requireType("q")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_no_table_passed") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + CheckResult result = check(R"( + +type Iterable = typeof(setmetatable( + {}, + {}::{ + __iter: (self: Iterable) -> (any, number) -> (number, string) + } +)) + +local t: Iterable + +for a, b in t do end +)"); + + + LUAU_REQUIRE_ERROR_COUNT(1, result); + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("__iter metamethod must return (next[, table[, state]])", ge->message); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "iteration_regression_issue_69967") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + CheckResult result = check(R"( + +type Iterable = typeof(setmetatable( + {}, + {}::{ + __iter: (self: Iterable) -> () -> (number, string) + } +)) + +local t: Iterable + +for a, b in t do end +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("__iter metamethod must return (next[, table[, state]])", ge->message); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index f2b3d055..ee747252 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -26,9 +26,17 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defi someTable.Function1() -- Argument count mismatch )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "No overload for function accepts 0 arguments."); + CHECK(toString(result.errors[1]) == "Available overloads: (a) -> ()"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") @@ -42,9 +50,17 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_ someTable.Function2() -- Argument count mismatch )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "No overload for function accepts 0 arguments."); + CHECK(toString(result.errors[1]) == "Available overloads: (a, b) -> ()"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 7f21641d..58acef22 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -52,6 +52,43 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") CHECK_EQ(expected, decorateWithTypes(code)); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.filter") +{ + // This test exercises the fact that we should reduce sealed/unsealed/free tables + // res is a unsealed table with type {((T & ~nil)?) & any} + // Because we do not reduce it fully, we cannot unify it with `Array = { [number] : T} + // TLDR; reduction needs to reduce the indexer on res so it unifies with Array + CheckResult result = check(R"( +--!strict +-- Implements Javascript's `Array.prototype.filter` as defined below +-- https://developer.cmozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/filter +type Array = { [number]: T } +type callbackFn = (element: T, index: number, array: Array) -> boolean +type callbackFnWithThisArg = (thisArg: U, element: T, index: number, array: Array) -> boolean +type Object = { [string]: any } +return function(t: Array, callback: callbackFn | callbackFnWithThisArg, thisArg: U?): Array + + local len = #t + local res = {} + if thisArg == nil then + for i = 1, len do + local kValue = t[i] + if kValue ~= nil then + if (callback :: callbackFn)(kValue, i, t) then + res[i] = kValue + end + end + end + else + end + + return res +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") { const std::string code = R"( diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 06cbe0cf..c55497ae 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,7 +8,6 @@ #include "doctest.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauNegatedClassTypes) using namespace Luau; @@ -64,7 +63,7 @@ struct RefinementClassFixture : BuiltinsFixture TypeArena& arena = frontend.globals.globalTypes; NotNull scope{frontend.globals.globalScope.get()}; - std::optional rootSuper = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; + std::optional rootSuper = std::make_optional(builtinTypes->classType); unfreeze(arena); TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 23e49f58..f028e8e0 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -131,8 +131,16 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); - CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("None of the overloads for function that accept 2 arguments are compatible.", toString(result.errors[0])); + CHECK_EQ("Available overloads: (true, string) -> (); and (false, number) -> ()", toString(result.errors[1])); + } + else + { + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); + } } TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index fcf2c8a4..4b24fb22 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -3625,4 +3625,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "top_table_type_is_isomorphic_to_empty_sealed )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.includes") +{ + + CheckResult result = check(R"( +type Array = { [number]: T } + +function indexOf(array: Array, searchElement: any, fromIndex: number?): number + return -1 +end + +return function(array: Array, searchElement: any, fromIndex: number?): boolean + return -1 ~= indexOf(array, searchElement, fromIndex) +end + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index 38fa7f5f..655d094f 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,9 +1,5 @@ AnnotationTests.too_many_type_params AstQuery.last_argument_function_call_type -AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method -AstQuery::getDocumentationSymbolAtPosition.overloaded_fn -AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop -AutocompleteTest.autocomplete_response_perf1 BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 @@ -54,6 +50,7 @@ ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean ProvisionalTests.free_options_cannot_be_unified_together ProvisionalTests.generic_type_leak_to_module_interface_variadic ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns +ProvisionalTests.luau-polyfill.Array.filter ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument @@ -146,7 +143,6 @@ TypeInferClasses.index_instance_property TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferClasses.warn_when_prop_almost_matches TypeInferFunctions.cannot_hoist_interior_defns_into_signature -TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite @@ -158,7 +154,6 @@ TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.no_lossy_function_type TypeInferFunctions.occurs_check_failure_in_function_return_type -TypeInferFunctions.record_matching_overload TypeInferFunctions.report_exiting_without_return_strict TypeInferFunctions.return_type_by_overload TypeInferFunctions.too_few_arguments_variadic @@ -205,6 +200,7 @@ TypePackTests.variadic_packs TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch TypeSingletons.no_widening_from_callsites +TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.table_properties_type_error_escapes TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton