diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index bc709c7f..96bac9e4 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -39,6 +39,7 @@ enum class AutocompleteEntryKind Type, Module, GeneratedFunction, + RequirePath, }; enum class ParenthesesRecommendation diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 71e50580..45015459 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -9,6 +9,8 @@ namespace Luau { +static constexpr char kRequireTagName[] = "require"; + struct Frontend; struct GlobalTypes; struct TypeChecker; diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index a84561dd..718a0350 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -68,7 +68,6 @@ private: DenseHashMap compoundAssignDefs{nullptr}; DenseHashMap astRefinementKeys{nullptr}; - friend struct DataFlowGraphBuilder; }; @@ -83,6 +82,7 @@ struct DfgScope DfgScope* parent; ScopeType scopeType; + Location location; using Bindings = DenseHashMap; using Props = DenseHashMap>; @@ -105,10 +105,44 @@ struct DataFlowResult const RefinementKey* parent = nullptr; }; +using ScopeStack = std::vector; + struct DataFlowGraphBuilder { static DataFlowGraph build(AstStatBlock* root, NotNull handle); + /** + * This method is identical to the build method above, but returns a pair of dfg, scopes as the data flow graph + * here is intended to live on the module between runs of typechecking. Before, the DFG only needed to live as + * long as the typecheck, but in a world with incremental typechecking, we need the information on the dfg to incrementally + * typecheck small fragments of code. + * @param block - pointer to the ast to build the dfg for + * @param handle - for raising internal errors while building the dfg + */ + static std::pair, std::vector>> buildShared( + AstStatBlock* block, + NotNull handle + ); + + /** + * Takes a stale graph along with a list of scopes, a small fragment of the ast, and a cursor position + * and constructs the DataFlowGraph for just that fragment. This method will fabricate defs in the final + * DFG for things that have been referenced and exist in the stale dfg. + * For example, the fragment local z = x + y will populate defs for x and y from the stale graph. + * @param staleGraph - the old DFG + * @param scopes - the old DfgScopes in the graph + * @param fragment - the Ast Fragment to re-build the root for + * @param cursorPos - the current location of the cursor - used to determine which scope we are currently in + * @param handle - for internal compiler errors + */ + static DataFlowGraph updateGraph( + const DataFlowGraph& staleGraph, + const std::vector>& scopes, + AstStatBlock* fragment, + const Position& cursorPos, + NotNull handle + ); + private: DataFlowGraphBuilder() = default; @@ -120,10 +154,15 @@ private: NotNull keyArena{&graph.keyArena}; struct InternalErrorReporter* handle = nullptr; - DfgScope* moduleScope = nullptr; + /// The arena owning all of the scope allocations for the dataflow graph being built. std::vector> scopes; + /// A stack of scopes used by the visitor to see where we are. + ScopeStack scopeStack; + + DfgScope* currentScope(); + struct FunctionCapture { std::vector captureDefs; @@ -134,81 +173,81 @@ private: DenseHashMap captures{Symbol{}}; void resolveCaptures(); - DfgScope* childScope(DfgScope* scope, DfgScope::ScopeType scopeType = DfgScope::Linear); + DfgScope* makeChildScope(Location loc, DfgScope::ScopeType scopeType = DfgScope::Linear); void join(DfgScope* p, DfgScope* a, DfgScope* b); void joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b); void joinProps(DfgScope* p, const DfgScope& a, const DfgScope& b); - DefId lookup(DfgScope* scope, Symbol symbol); - DefId lookup(DfgScope* scope, DefId def, const std::string& key); + DefId lookup(Symbol symbol); + DefId lookup(DefId def, const std::string& key); - ControlFlow visit(DfgScope* scope, AstStatBlock* b); - ControlFlow visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); + ControlFlow visit(AstStatBlock* b); + ControlFlow visitBlockWithoutChildScope(AstStatBlock* b); - ControlFlow visit(DfgScope* scope, AstStat* s); - ControlFlow visit(DfgScope* scope, AstStatIf* i); - ControlFlow visit(DfgScope* scope, AstStatWhile* w); - ControlFlow visit(DfgScope* scope, AstStatRepeat* r); - ControlFlow visit(DfgScope* scope, AstStatBreak* b); - ControlFlow visit(DfgScope* scope, AstStatContinue* c); - ControlFlow visit(DfgScope* scope, AstStatReturn* r); - ControlFlow visit(DfgScope* scope, AstStatExpr* e); - ControlFlow visit(DfgScope* scope, AstStatLocal* l); - ControlFlow visit(DfgScope* scope, AstStatFor* f); - ControlFlow visit(DfgScope* scope, AstStatForIn* f); - ControlFlow visit(DfgScope* scope, AstStatAssign* a); - ControlFlow visit(DfgScope* scope, AstStatCompoundAssign* c); - ControlFlow visit(DfgScope* scope, AstStatFunction* f); - ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l); - ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t); - ControlFlow visit(DfgScope* scope, AstStatTypeFunction* f); - ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d); - ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d); - ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d); - ControlFlow visit(DfgScope* scope, AstStatError* error); + ControlFlow visit(AstStat* s); + ControlFlow visit(AstStatIf* i); + ControlFlow visit(AstStatWhile* w); + ControlFlow visit(AstStatRepeat* r); + ControlFlow visit(AstStatBreak* b); + ControlFlow visit(AstStatContinue* c); + ControlFlow visit(AstStatReturn* r); + ControlFlow visit(AstStatExpr* e); + ControlFlow visit(AstStatLocal* l); + ControlFlow visit(AstStatFor* f); + ControlFlow visit(AstStatForIn* f); + ControlFlow visit(AstStatAssign* a); + ControlFlow visit(AstStatCompoundAssign* c); + ControlFlow visit(AstStatFunction* f); + ControlFlow visit(AstStatLocalFunction* l); + ControlFlow visit(AstStatTypeAlias* t); + ControlFlow visit(AstStatTypeFunction* f); + ControlFlow visit(AstStatDeclareGlobal* d); + ControlFlow visit(AstStatDeclareFunction* d); + ControlFlow visit(AstStatDeclareClass* d); + ControlFlow visit(AstStatError* error); - DataFlowResult visitExpr(DfgScope* scope, AstExpr* e); - DataFlowResult visitExpr(DfgScope* scope, AstExprGroup* group); - DataFlowResult visitExpr(DfgScope* scope, AstExprLocal* l); - DataFlowResult visitExpr(DfgScope* scope, AstExprGlobal* g); - DataFlowResult visitExpr(DfgScope* scope, AstExprCall* c); - DataFlowResult visitExpr(DfgScope* scope, AstExprIndexName* i); - DataFlowResult visitExpr(DfgScope* scope, AstExprIndexExpr* i); - DataFlowResult visitExpr(DfgScope* scope, AstExprFunction* f); - DataFlowResult visitExpr(DfgScope* scope, AstExprTable* t); - DataFlowResult visitExpr(DfgScope* scope, AstExprUnary* u); - DataFlowResult visitExpr(DfgScope* scope, AstExprBinary* b); - DataFlowResult visitExpr(DfgScope* scope, AstExprTypeAssertion* t); - DataFlowResult visitExpr(DfgScope* scope, AstExprIfElse* i); - DataFlowResult visitExpr(DfgScope* scope, AstExprInterpString* i); - DataFlowResult visitExpr(DfgScope* scope, AstExprError* error); + DataFlowResult visitExpr(AstExpr* e); + DataFlowResult visitExpr(AstExprGroup* group); + DataFlowResult visitExpr(AstExprLocal* l); + DataFlowResult visitExpr(AstExprGlobal* g); + DataFlowResult visitExpr(AstExprCall* c); + DataFlowResult visitExpr(AstExprIndexName* i); + DataFlowResult visitExpr(AstExprIndexExpr* i); + DataFlowResult visitExpr(AstExprFunction* f); + DataFlowResult visitExpr(AstExprTable* t); + DataFlowResult visitExpr(AstExprUnary* u); + DataFlowResult visitExpr(AstExprBinary* b); + DataFlowResult visitExpr(AstExprTypeAssertion* t); + DataFlowResult visitExpr(AstExprIfElse* i); + DataFlowResult visitExpr(AstExprInterpString* i); + DataFlowResult visitExpr(AstExprError* error); - void visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef); - DefId visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef); - DefId visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef); - DefId visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef); - DefId visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef); - DefId visitLValue(DfgScope* scope, AstExprError* e, DefId incomingDef); + void visitLValue(AstExpr* e, DefId incomingDef); + DefId visitLValue(AstExprLocal* l, DefId incomingDef); + DefId visitLValue(AstExprGlobal* g, DefId incomingDef); + DefId visitLValue(AstExprIndexName* i, DefId incomingDef); + DefId visitLValue(AstExprIndexExpr* i, DefId incomingDef); + DefId visitLValue(AstExprError* e, DefId incomingDef); - void visitType(DfgScope* scope, AstType* t); - void visitType(DfgScope* scope, AstTypeReference* r); - void visitType(DfgScope* scope, AstTypeTable* t); - void visitType(DfgScope* scope, AstTypeFunction* f); - void visitType(DfgScope* scope, AstTypeTypeof* t); - void visitType(DfgScope* scope, AstTypeUnion* u); - void visitType(DfgScope* scope, AstTypeIntersection* i); - void visitType(DfgScope* scope, AstTypeError* error); + void visitType(AstType* t); + void visitType(AstTypeReference* r); + void visitType(AstTypeTable* t); + void visitType(AstTypeFunction* f); + void visitType(AstTypeTypeof* t); + void visitType(AstTypeUnion* u); + void visitType(AstTypeIntersection* i); + void visitType(AstTypeError* error); - void visitTypePack(DfgScope* scope, AstTypePack* p); - void visitTypePack(DfgScope* scope, AstTypePackExplicit* e); - void visitTypePack(DfgScope* scope, AstTypePackVariadic* v); - void visitTypePack(DfgScope* scope, AstTypePackGeneric* g); + void visitTypePack(AstTypePack* p); + void visitTypePack(AstTypePackExplicit* e); + void visitTypePack(AstTypePackVariadic* v); + void visitTypePack(AstTypePackGeneric* g); - void visitTypeList(DfgScope* scope, AstTypeList l); + void visitTypeList(AstTypeList l); - void visitGenerics(DfgScope* scope, AstArray g); - void visitGenericPacks(DfgScope* scope, AstArray g); + void visitGenerics(AstArray g); + void visitGenericPacks(AstArray g); }; } // namespace Luau diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 0fdcce16..2f17e566 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -3,6 +3,7 @@ #include #include +#include namespace Luau { @@ -31,6 +32,9 @@ struct ModuleInfo bool optional = false; }; +using RequireSuggestion = std::string; +using RequireSuggestions = std::vector; + struct FileResolver { virtual ~FileResolver() {} @@ -51,6 +55,11 @@ struct FileResolver { return std::nullopt; } + + virtual std::optional getRequireSuggestions(const ModuleName& requirer, const std::optional& pathString) const + { + return std::nullopt; + } }; struct NullFileResolver : FileResolver diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h index 53e301c1..bfc5f6e6 100644 --- a/Analysis/include/Luau/FragmentAutocomplete.h +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -1,12 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/DenseHash.h" #include "Luau/Ast.h" +#include "Luau/Parser.h" +#include "Luau/Autocomplete.h" +#include "Luau/DenseHash.h" +#include "Luau/Module.h" +#include #include - namespace Luau { @@ -15,9 +18,28 @@ struct FragmentAutocompleteAncestryResult DenseHashMap localMap{AstName()}; std::vector localStack; std::vector ancestry; - AstStat* nearestStatement; + AstStat* nearestStatement = nullptr; +}; + +struct FragmentParseResult +{ + std::string fragmentToParse; + AstStatBlock* root = nullptr; + std::vector ancestry; + std::unique_ptr alloc = std::make_unique(); }; FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); +FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos); + +AutocompleteResult fragmentAutocomplete( + Frontend& frontend, + std::string_view src, + const ModuleName& moduleName, + Position& cursorPosition, + StringCompletionCallback callback +); + + } // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index f909deb8..82c189aa 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -9,6 +9,7 @@ #include "Luau/Scope.h" #include "Luau/TypeArena.h" #include "Luau/AnyTypeSummary.h" +#include "Luau/DataFlowGraph.h" #include #include @@ -131,6 +132,9 @@ struct Module TypePackId returnType = nullptr; std::unordered_map exportedTypeBindings; + // We also need to keep DFG data alive between runs + std::shared_ptr dataFlowGraph = nullptr; + std::vector> dfgScopes; bool hasModuleScope() const; ScopePtr getModuleScope() const; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 3a7aefd8..55089aa3 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -667,6 +667,11 @@ struct AnyType { }; +// A special, trivial type for the refinement system that is always eliminated from intersections. +struct NoRefineType +{ +}; + // `T | U` struct UnionType { @@ -755,6 +760,7 @@ using TypeVariant = Unifiable::Variant< UnknownType, NeverType, NegationType, + NoRefineType, TypeFunctionInstanceType>; struct Type final @@ -949,6 +955,7 @@ public: const TypeId unknownType; const TypeId neverType; const TypeId errorType; + const TypeId noRefineType; const TypeId falsyType; const TypeId truthyType; diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 92be19d1..de9660ef 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -248,4 +248,36 @@ std::optional follow(std::optional ty) return std::nullopt; } +/** + * Returns whether or not expr is a literal expression, for example: + * - Scalar literals (numbers, booleans, strings, nil) + * - Table literals + * - Lambdas (a "function literal") + */ +bool isLiteral(const AstExpr* expr); + +/** + * Given a table literal and a mapping from expression to type, determine + * whether any literal expression in this table depends on any blocked types. + * This is used as a precondition for bidirectional inference: be warned that + * the behavior of this algorithm is tightly coupled to that of bidirectional + * inference. + * @param expr Expression to search + * @param astTypes Mapping from AST node to TypeID + * @returns A vector of blocked types + */ +std::vector findBlockedTypesIn(AstExprTable* expr, NotNull> astTypes); + +/** + * Given a function call and a mapping from expression to type, determine + * whether the type of any argument in said call in depends on a blocked types. + * This is used as a precondition for bidirectional inference: be warned that + * the behavior of this algorithm is tightly coupled to that of bidirectional + * inference. + * @param expr Expression to search + * @param astTypes Mapping from AST node to TypeID + * @returns A vector of blocked types + */ +std::vector findBlockedArgTypesIn(AstExprCall* expr, NotNull> astTypes); + } // namespace Luau diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index e943cced..7202c100 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -133,6 +133,10 @@ struct GenericTypeVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const NoRefineType& nrt) + { + return visit(ty); + } virtual bool visit(TypeId ty, const UnknownType& utv) { return visit(ty); @@ -345,6 +349,8 @@ struct GenericTypeVisitor } else if (auto atv = get(ty)) visit(ty, *atv); + else if (auto nrt = get(ty)) + visit(ty, *nrt); else if (auto utv = get(ty)) { if (visit(ty, *utv)) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 521c7948..f2235bb9 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -3,6 +3,8 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/ToString.h" #include "Luau/Subtyping.h" @@ -15,6 +17,7 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauAutocompleteNewSolverLimit) +LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions, false) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTINT(LuauTypeInferIterationLimit) @@ -215,8 +218,7 @@ static TypeCorrectKind checkTypeCorrectKind( { for (TypeId id : itv->parts) { - if (DFInt::LuauTypeSolverRelease >= 644) - id = follow(id); + id = follow(id); if (const FunctionType* ftv = get(id); ftv && checkFunctionType(ftv)) { @@ -1444,11 +1446,25 @@ static std::optional getStringContents(const AstNode* node) } } +static std::optional convertRequireSuggestionsToAutocompleteEntryMap(std::optional suggestions) +{ + if (!suggestions) + return std::nullopt; + + AutocompleteEntryMap result; + for (const RequireSuggestion& suggestion : *suggestions) + { + result[suggestion] = {AutocompleteEntryKind::RequirePath}; + } + return result; +} + static std::optional autocompleteStringParams( const SourceModule& sourceModule, const ModulePtr& module, const std::vector& nodes, Position position, + FileResolver* fileResolver, StringCompletionCallback callback ) { @@ -1495,6 +1511,13 @@ static std::optional autocompleteStringParams( { for (const std::string& tag : funcType->tags) { + if (FFlag::AutocompleteRequirePathSuggestions) + { + if (tag == kRequireTagName && fileResolver) + { + return convertRequireSuggestionsToAutocompleteEntryMap(fileResolver->getRequireSuggestions(module->name, candidateString)); + } + } if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) { return ret; @@ -1679,6 +1702,7 @@ static AutocompleteResult autocomplete( TypeArena* typeArena, Scope* globalScope, Position position, + FileResolver* fileResolver, StringCompletionCallback callback ) { @@ -1922,7 +1946,7 @@ static AutocompleteResult autocomplete( else if (isIdentifier(node) && (parent->is() || parent->is())) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, callback)) + if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, fileResolver, callback)) { return {*ret, ancestry, AutocompleteContext::String}; } @@ -1999,7 +2023,7 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName globalScope = frontend.globalsForAutocomplete.globalScope.get(); TypeArena typeArena; - return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); + return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, frontend.fileResolver, callback); } } // namespace Luau diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 32692a6e..041d1bed 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -27,6 +27,8 @@ LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(AutocompleteRequirePathSuggestions); + namespace Luau { @@ -413,8 +415,18 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); } - attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); + if (FFlag::AutocompleteRequirePathSuggestions) + { + TypeId requireTy = getGlobalBinding(globals, "require"); + attachTag(requireTy, kRequireTagName); + attachMagicFunction(requireTy, magicFunctionRequire); + attachDcrMagicFunction(requireTy, dcrMagicFunctionRequire); + } + else + { + attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); + attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); + } } static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 4af3e7f8..d5793c93 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -359,6 +359,11 @@ private: // noop. } + void cloneChildren(NoRefineType* t) + { + // noop. + } + void cloneChildren(UnionType* t) { for (TypeId& ty : t->options) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index efa023bf..9c30668c 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -26,10 +26,10 @@ #include #include -LUAU_FASTINT(LuauCheckRecursionLimit); -LUAU_FASTFLAG(DebugLuauLogSolverToJson); -LUAU_FASTFLAG(DebugLuauMagicTypes); -LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease); +LUAU_FASTINT(LuauCheckRecursionLimit) +LUAU_FASTFLAG(DebugLuauLogSolverToJson) +LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) namespace Luau { @@ -2883,9 +2883,45 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, { Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice}; std::vector toBlock; - matchLiteralType( - NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr, toBlock - ); + if (DFInt::LuauTypeSolverRelease >= 648) + { + // This logic is incomplete as we want to re-run this + // _after_ blocked types have resolved, but this + // allows us to do some bidirectional inference. + toBlock = findBlockedTypesIn(expr, NotNull{&module->astTypes}); + if (toBlock.empty()) + { + matchLiteralType( + NotNull{&module->astTypes}, + NotNull{&module->astExpectedTypes}, + builtinTypes, + arena, + NotNull{&unifier}, + *expectedType, + ty, + expr, + toBlock + ); + // The visitor we ran prior should ensure that there are no + // blocked types that we would encounter while matching on + // this expression. + LUAU_ASSERT(toBlock.empty()); + } + } + else + { + matchLiteralType( + NotNull{&module->astTypes}, + NotNull{&module->astExpectedTypes}, + builtinTypes, + arena, + NotNull{&unifier}, + *expectedType, + ty, + expr, + toBlock + ); + } } return Inference{ty}; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 7cb545f2..c0d30137 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -32,6 +32,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings, false) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack, false) namespace Luau { @@ -1238,14 +1239,22 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(asMutable(follow(*ty)), builtinTypes->anyType); + if (FFlag::LuauRemoveNotAnyHack) + { + // We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored. + emplaceType(asMutable(follow(*ty)), builtinTypes->noRefineType); + } + else + { + // We use `any` here because the discriminant type may be pointed at by both branches, + // where the discriminant type is not negated, and the other where it is negated, i.e. + // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` + // v.s. + // `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` + // + // In practice, users cannot negate `any`, so this is an implementation detail we can always change. + emplaceType(asMutable(follow(*ty)), builtinTypes->anyType); + } } OverloadResolver resolver{ @@ -1322,6 +1331,22 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull= 648) + { + // This is expensive as we need to traverse a (potentially large) + // literal up front in order to determine if there are any blocked + // types, otherwise we may run `matchTypeLiteral` multiple times, + // which right now may fail due to being non-idempotent (it + // destructively updates the underlying literal type). + auto blockedTypes = findBlockedArgTypesIn(c.callSite, c.astTypes); + for (const auto ty : blockedTypes) + { + block(ty, constraint); + } + if (!blockedTypes.empty()) + return false; + } + // We know the type of the function and the arguments it expects to receive. // We also know the TypeIds of the actual arguments that will be passed. // @@ -1384,7 +1409,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullargs.data[i]); + AstExpr* expr = unwrapGroup(c.callSite->args.data[i]); (*c.astExpectedTypes)[expr] = expectedArgTy; @@ -1416,10 +1441,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullscope, NotNull{&iceReporter}}; std::vector toBlock; (void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr, toBlock); - for (auto t : toBlock) - block(t, constraint); - if (!toBlock.empty()) - return false; + if (DFInt::LuauTypeSolverRelease >= 648) + { + LUAU_ASSERT(toBlock.empty()); + } + else + { + for (auto t : toBlock) + block(t, constraint); + if (!toBlock.empty()) + return false; + } } } @@ -1748,8 +1780,9 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull(lhsType)) { - if (get(lhsFree->upperBound) || get(lhsFree->upperBound)) - lhsType = lhsFree->upperBound; + auto lhsFreeUpperBound = DFInt::LuauTypeSolverRelease >= 648 ? follow(lhsFree->upperBound) : lhsFree->upperBound; + if (get(lhsFreeUpperBound) || get(lhsFreeUpperBound)) + lhsType = lhsFreeUpperBound; else { TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope}); @@ -1759,7 +1792,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullprops[c.propName] = rhsType; // Food for thought: Could we block if simplification encounters a blocked type? - lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFree->upperBound, newUpperBound).result; + lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFreeUpperBound, newUpperBound).result; bind(constraint, c.propType, rhsType); return true; diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 9c42e4d8..1307da8d 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -7,6 +7,7 @@ #include "Luau/Error.h" #include "Luau/TimeTrace.h" +#include #include LUAU_FASTFLAG(DebugLuauFreezeArena) @@ -17,6 +18,38 @@ namespace Luau bool doesCallError(const AstExprCall* call); // TypeInfer.cpp +struct ReferencedDefFinder : public AstVisitor +{ + bool visit(AstExprLocal* local) override + { + referencedLocalDefs.push_back(local->local); + return true; + } + // ast defs is just a mapping from expr -> def in general + // will get built up by the dfg builder + + // localDefs, we need to copy over + std::vector referencedLocalDefs; +}; + +struct PushScope +{ + ScopeStack& stack; + + PushScope(ScopeStack& stack, DfgScope* scope) + : stack(stack) + { + // `scope` should never be `nullptr` here. + LUAU_ASSERT(scope); + stack.push_back(scope); + } + + ~PushScope() + { + stack.pop_back(); + } +}; + const RefinementKey* RefinementKeyArena::leaf(DefId def) { return allocator.allocate(RefinementKey{nullptr, def, std::nullopt}); @@ -143,8 +176,9 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNulllocation); + PushScope ps{builder.scopeStack, moduleScope}; + builder.visitBlockWithoutChildScope(block); builder.resolveCaptures(); if (FFlag::DebugLuauFreezeArena) @@ -156,6 +190,82 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull, std::vector>> DataFlowGraphBuilder::buildShared( + AstStatBlock* block, + NotNull handle +) +{ + + LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); + + LUAU_ASSERT(FFlag::LuauSolverV2); + + DataFlowGraphBuilder builder; + builder.handle = handle; + DfgScope* moduleScope = builder.makeChildScope(block->location); + PushScope ps{builder.scopeStack, moduleScope}; + builder.visitBlockWithoutChildScope(block); + builder.resolveCaptures(); + + if (FFlag::DebugLuauFreezeArena) + { + builder.defArena->allocator.freeze(); + builder.keyArena->allocator.freeze(); + } + + return {std::make_shared(std::move(builder.graph)), std::move(builder.scopes)}; +} + +DataFlowGraph DataFlowGraphBuilder::updateGraph( + const DataFlowGraph& staleGraph, + const std::vector>& scopes, + AstStatBlock* fragment, + const Position& cursorPos, + NotNull handle +) +{ + LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); + LUAU_ASSERT(FFlag::LuauSolverV2); + + DataFlowGraphBuilder builder; + builder.handle = handle; + // Generate a list of prepopulated locals + ReferencedDefFinder finder; + fragment->visit(&finder); + for (AstLocal* loc : finder.referencedLocalDefs) + { + if (staleGraph.localDefs.contains(loc)) + { + builder.graph.localDefs[loc] = *staleGraph.localDefs.find(loc); + } + } + + // Figure out which scope we should start re-accumulating DFG information from again + DfgScope* nearest = nullptr; + for (auto& sc : scopes) + { + if (nearest == nullptr || (sc->location.begin <= cursorPos && nearest->location.begin < sc->location.begin)) + nearest = sc.get(); + } + + // The scope stack should start with the nearest enclosing scope so we can resume DFG'ing correctly + PushScope ps{builder.scopeStack, nearest}; + // Conspire for the current scope in the scope stack to be a fresh dfg scope, parented to the above nearest enclosing scope, so any insertions are + // isolated there + DfgScope* scope = builder.makeChildScope(fragment->location); + PushScope psAgain{builder.scopeStack, scope}; + + builder.visitBlockWithoutChildScope(fragment); + + if (FFlag::DebugLuauFreezeArena) + { + builder.defArena->allocator.freeze(); + builder.keyArena->allocator.freeze(); + } + + return std::move(builder.graph); +} + void DataFlowGraphBuilder::resolveCaptures() { for (const auto& [_, capture] : captures) @@ -174,9 +284,16 @@ void DataFlowGraphBuilder::resolveCaptures() } } -DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope, DfgScope::ScopeType scopeType) +DfgScope* DataFlowGraphBuilder::currentScope() { - return scopes.emplace_back(new DfgScope{scope, scopeType}).get(); + if (scopeStack.empty()) + return nullptr; // nullptr is the root DFG scope. + return scopeStack.back(); +} + +DfgScope* DataFlowGraphBuilder::makeChildScope(Location loc, DfgScope::ScopeType scopeType) +{ + return scopes.emplace_back(new DfgScope{currentScope(), scopeType, loc}).get(); } void DataFlowGraphBuilder::join(DfgScope* p, DfgScope* a, DfgScope* b) @@ -251,8 +368,10 @@ void DataFlowGraphBuilder::joinProps(DfgScope* result, const DfgScope& a, const } } -DefId DataFlowGraphBuilder::lookup(DfgScope* scope, Symbol symbol) +DefId DataFlowGraphBuilder::lookup(Symbol symbol) { + DfgScope* scope = currentScope(); + // true if any of the considered scopes are a loop. bool outsideLoopScope = false; for (DfgScope* current = scope; current; current = current->parent) @@ -282,8 +401,9 @@ DefId DataFlowGraphBuilder::lookup(DfgScope* scope, Symbol symbol) return result; } -DefId DataFlowGraphBuilder::lookup(DfgScope* scope, DefId def, const std::string& key) +DefId DataFlowGraphBuilder::lookup(DefId def, const std::string& key) { + DfgScope* scope = currentScope(); for (DfgScope* current = scope; current; current = current->parent) { if (auto props = current->props.find(def)) @@ -303,7 +423,7 @@ DefId DataFlowGraphBuilder::lookup(DfgScope* scope, DefId def, const std::string { std::vector defs; for (DefId operand : phi->operands) - defs.push_back(lookup(scope, operand, key)); + defs.push_back(lookup(operand, key)); DefId result = defArena->phi(defs); scope->props[def][key] = result; @@ -319,20 +439,26 @@ DefId DataFlowGraphBuilder::lookup(DfgScope* scope, DefId def, const std::string handle->ice("Inexhaustive lookup cases in DataFlowGraphBuilder::lookup"); } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) +ControlFlow DataFlowGraphBuilder::visit(AstStatBlock* b) { - DfgScope* child = childScope(scope); - ControlFlow cf = visitBlockWithoutChildScope(child, b); - scope->inherit(child); + DfgScope* child = makeChildScope(b->location); + + ControlFlow cf; + { + PushScope ps{scopeStack, child}; + cf = visitBlockWithoutChildScope(b); + } + + currentScope()->inherit(child); return cf; } -ControlFlow DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b) +ControlFlow DataFlowGraphBuilder::visitBlockWithoutChildScope(AstStatBlock* b) { std::optional firstControlFlow; for (AstStat* stat : b->body) { - ControlFlow cf = visit(scope, stat); + ControlFlow cf = visit(stat); if (cf != ControlFlow::None && !firstControlFlow) firstControlFlow = cf; } @@ -340,66 +466,75 @@ ControlFlow DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, A return firstControlFlow.value_or(ControlFlow::None); } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) +ControlFlow DataFlowGraphBuilder::visit(AstStat* s) { if (auto b = s->as()) - return visit(scope, b); + return visit(b); else if (auto i = s->as()) - return visit(scope, i); + return visit(i); else if (auto w = s->as()) - return visit(scope, w); + return visit(w); else if (auto r = s->as()) - return visit(scope, r); + return visit(r); else if (auto b = s->as()) - return visit(scope, b); + return visit(b); else if (auto c = s->as()) - return visit(scope, c); + return visit(c); else if (auto r = s->as()) - return visit(scope, r); + return visit(r); else if (auto e = s->as()) - return visit(scope, e); + return visit(e); else if (auto l = s->as()) - return visit(scope, l); + return visit(l); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto a = s->as()) - return visit(scope, a); + return visit(a); else if (auto c = s->as()) - return visit(scope, c); + return visit(c); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto l = s->as()) - return visit(scope, l); + return visit(l); else if (auto t = s->as()) - return visit(scope, t); + return visit(t); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto d = s->as()) - return visit(scope, d); + return visit(d); else if (auto d = s->as()) - return visit(scope, d); + return visit(d); else if (auto d = s->as()) - return visit(scope, d); + return visit(d); else if (auto error = s->as()) - return visit(scope, error); + return visit(error); else handle->ice("Unknown AstStat in DataFlowGraphBuilder::visit"); } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) +ControlFlow DataFlowGraphBuilder::visit(AstStatIf* i) { - visitExpr(scope, i->condition); + visitExpr(i->condition); - DfgScope* thenScope = childScope(scope); - DfgScope* elseScope = childScope(scope); + DfgScope* thenScope = makeChildScope(i->thenbody->location); + DfgScope* elseScope = makeChildScope(i->elsebody ? i->elsebody->location : i->location); + + ControlFlow thencf; + { + PushScope ps{scopeStack, thenScope}; + thencf = visit(i->thenbody); + } - ControlFlow thencf = visit(thenScope, i->thenbody); ControlFlow elsecf = ControlFlow::None; if (i->elsebody) - elsecf = visit(elseScope, i->elsebody); + { + PushScope ps{scopeStack, elseScope}; + elsecf = visit(i->elsebody); + } + DfgScope* scope = currentScope(); if (thencf != ControlFlow::None && elsecf == ControlFlow::None) join(scope, scope, elseScope); else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) @@ -415,70 +550,78 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) +ControlFlow DataFlowGraphBuilder::visit(AstStatWhile* w) { // TODO(controlflow): entry point has a back edge from exit point - DfgScope* whileScope = childScope(scope, DfgScope::Loop); - visitExpr(whileScope, w->condition); - visit(whileScope, w->body); + DfgScope* whileScope = makeChildScope(w->location, DfgScope::Loop); - scope->inherit(whileScope); + { + PushScope ps{scopeStack, whileScope}; + visitExpr(w->condition); + visit(w->body); + } + + currentScope()->inherit(whileScope); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r) +ControlFlow DataFlowGraphBuilder::visit(AstStatRepeat* r) { // TODO(controlflow): entry point has a back edge from exit point - DfgScope* repeatScope = childScope(scope, DfgScope::Loop); - visitBlockWithoutChildScope(repeatScope, r->body); - visitExpr(repeatScope, r->condition); + DfgScope* repeatScope = makeChildScope(r->location, DfgScope::Loop); - scope->inherit(repeatScope); + { + PushScope ps{scopeStack, repeatScope}; + visitBlockWithoutChildScope(r->body); + visitExpr(r->condition); + } + + currentScope()->inherit(repeatScope); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b) +ControlFlow DataFlowGraphBuilder::visit(AstStatBreak* b) { return ControlFlow::Breaks; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c) +ControlFlow DataFlowGraphBuilder::visit(AstStatContinue* c) { return ControlFlow::Continues; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r) +ControlFlow DataFlowGraphBuilder::visit(AstStatReturn* r) { for (AstExpr* e : r->list) - visitExpr(scope, e); + visitExpr(e); return ControlFlow::Returns; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) +ControlFlow DataFlowGraphBuilder::visit(AstStatExpr* e) { - visitExpr(scope, e->expr); + visitExpr(e->expr); if (auto call = e->expr->as(); call && doesCallError(call)) return ControlFlow::Throws; else return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) +ControlFlow DataFlowGraphBuilder::visit(AstStatLocal* l) { // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) std::vector defs; defs.reserve(l->values.size); for (AstExpr* e : l->values) - defs.push_back(visitExpr(scope, e).def); + defs.push_back(visitExpr(e).def); for (size_t i = 0; i < l->vars.size; ++i) { AstLocal* local = l->vars.data[i]; if (local->annotation) - visitType(scope, local->annotation); + visitType(local->annotation); // We need to create a new def to intentionally avoid alias tracking, but we'd like to // make sure that the non-aliased defs are also marked as a subscript for refinements. @@ -493,90 +636,98 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) } } graph.localDefs[local] = def; - scope->bindings[local] = def; + currentScope()->bindings[local] = def; captures[local].allVersions.push_back(def); } return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) +ControlFlow DataFlowGraphBuilder::visit(AstStatFor* f) { - DfgScope* forScope = childScope(scope, DfgScope::Loop); + DfgScope* forScope = makeChildScope(f->location, DfgScope::Loop); - visitExpr(scope, f->from); - visitExpr(scope, f->to); + visitExpr(f->from); + visitExpr(f->to); if (f->step) - visitExpr(scope, f->step); + visitExpr(f->step); - if (f->var->annotation) - visitType(forScope, f->var->annotation); - - DefId def = defArena->freshCell(); - graph.localDefs[f->var] = def; - scope->bindings[f->var] = def; - captures[f->var].allVersions.push_back(def); - - // TODO(controlflow): entry point has a back edge from exit point - visit(forScope, f->body); - - scope->inherit(forScope); - - return ControlFlow::None; -} - -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) -{ - DfgScope* forScope = childScope(scope, DfgScope::Loop); - - for (AstLocal* local : f->vars) { - if (local->annotation) - visitType(forScope, local->annotation); + PushScope ps{scopeStack, forScope}; + + if (f->var->annotation) + visitType(f->var->annotation); DefId def = defArena->freshCell(); - graph.localDefs[local] = def; - forScope->bindings[local] = def; - captures[local].allVersions.push_back(def); + graph.localDefs[f->var] = def; + currentScope()->bindings[f->var] = def; + captures[f->var].allVersions.push_back(def); + + // TODO(controlflow): entry point has a back edge from exit point + visit(f->body); } - // TODO(controlflow): entry point has a back edge from exit point - // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) - for (AstExpr* e : f->values) - visitExpr(forScope, e); - - visit(forScope, f->body); - - scope->inherit(forScope); + currentScope()->inherit(forScope); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) +ControlFlow DataFlowGraphBuilder::visit(AstStatForIn* f) +{ + DfgScope* forScope = makeChildScope(f->location, DfgScope::Loop); + + { + PushScope ps{scopeStack, forScope}; + + for (AstLocal* local : f->vars) + { + if (local->annotation) + visitType(local->annotation); + + DefId def = defArena->freshCell(); + graph.localDefs[local] = def; + currentScope()->bindings[local] = def; + captures[local].allVersions.push_back(def); + } + + // TODO(controlflow): entry point has a back edge from exit point + // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) + for (AstExpr* e : f->values) + visitExpr(e); + + visit(f->body); + } + + currentScope()->inherit(forScope); + + return ControlFlow::None; +} + +ControlFlow DataFlowGraphBuilder::visit(AstStatAssign* a) { std::vector defs; defs.reserve(a->values.size); for (AstExpr* e : a->values) - defs.push_back(visitExpr(scope, e).def); + defs.push_back(visitExpr(e).def); for (size_t i = 0; i < a->vars.size; ++i) { AstExpr* v = a->vars.data[i]; - visitLValue(scope, v, i < defs.size() ? defs[i] : defArena->freshCell()); + visitLValue(v, i < defs.size() ? defs[i] : defArena->freshCell()); } return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) +ControlFlow DataFlowGraphBuilder::visit(AstStatCompoundAssign* c) { - (void) visitExpr(scope, c->value); - (void) visitExpr(scope, c->var); + (void)visitExpr(c->value); + (void)visitExpr(c->var); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) +ControlFlow DataFlowGraphBuilder::visit(AstStatFunction* f) { // In the old solver, we assumed that the name of the function is always a function in the body // but this isn't true, e.g. the following example will print `5`, not a function address. @@ -588,8 +739,8 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) // // which is evidence that references to variables must be a phi node of all possible definitions, // but for bug compatibility, we'll assume the same thing here. - visitLValue(scope, f->name, defArena->freshCell()); - visitExpr(scope, f->func); + visitLValue(f->name, defArena->freshCell()); + visitExpr(f->func); if (auto local = f->name->as()) { @@ -606,87 +757,97 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) +ControlFlow DataFlowGraphBuilder::visit(AstStatLocalFunction* l) { DefId def = defArena->freshCell(); graph.localDefs[l->name] = def; - scope->bindings[l->name] = def; + currentScope()->bindings[l->name] = def; captures[l->name].allVersions.push_back(def); - visitExpr(scope, l->func); + visitExpr(l->func); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t) +ControlFlow DataFlowGraphBuilder::visit(AstStatTypeAlias* t) { - DfgScope* unreachable = childScope(scope); - visitGenerics(unreachable, t->generics); - visitGenericPacks(unreachable, t->genericPacks); - visitType(unreachable, t->type); + DfgScope* unreachable = makeChildScope(t->location); + PushScope ps{scopeStack, unreachable}; + + visitGenerics(t->generics); + visitGenericPacks(t->genericPacks); + visitType(t->type); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeFunction* f) +ControlFlow DataFlowGraphBuilder::visit(AstStatTypeFunction* f) { - DfgScope* unreachable = childScope(scope); - visitExpr(unreachable, f->body); + DfgScope* unreachable = makeChildScope(f->location); + PushScope ps{scopeStack, unreachable}; + + visitExpr(f->body); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d) +ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareGlobal* d) { DefId def = defArena->freshCell(); graph.declaredDefs[d] = def; - scope->bindings[d->name] = def; + currentScope()->bindings[d->name] = def; captures[d->name].allVersions.push_back(def); - visitType(scope, d->type); + visitType(d->type); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d) +ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareFunction* d) { DefId def = defArena->freshCell(); graph.declaredDefs[d] = def; - scope->bindings[d->name] = def; + currentScope()->bindings[d->name] = def; captures[d->name].allVersions.push_back(def); - DfgScope* unreachable = childScope(scope); - visitGenerics(unreachable, d->generics); - visitGenericPacks(unreachable, d->genericPacks); - visitTypeList(unreachable, d->params); - visitTypeList(unreachable, d->retTypes); + DfgScope* unreachable = makeChildScope(d->location); + PushScope ps{scopeStack, unreachable}; + + visitGenerics(d->generics); + visitGenericPacks(d->genericPacks); + visitTypeList(d->params); + visitTypeList(d->retTypes); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d) +ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareClass* d) { // This declaration does not "introduce" any bindings in value namespace, // so there's no symbolic value to begin with. We'll traverse the properties // because their type annotations may depend on something in the value namespace. - DfgScope* unreachable = childScope(scope); + DfgScope* unreachable = makeChildScope(d->location); + PushScope ps{scopeStack, unreachable}; + for (AstDeclaredClassProp prop : d->props) - visitType(unreachable, prop.ty); + visitType(prop.ty); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatError* error) +ControlFlow DataFlowGraphBuilder::visit(AstStatError* error) { - DfgScope* unreachable = childScope(scope); + DfgScope* unreachable = makeChildScope(error->location); + PushScope ps{scopeStack, unreachable}; + for (AstStat* s : error->statements) - visit(unreachable, s); + visit(s); for (AstExpr* e : error->expressions) - visitExpr(unreachable, e); + visitExpr(e); return ControlFlow::None; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExpr* e) { // Some subexpressions could be visited two times. If we've already seen it, just extract it. if (auto def = graph.astDefs.find(e)) @@ -698,7 +859,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) auto go = [&]() -> DataFlowResult { if (auto g = e->as()) - return visitExpr(scope, g); + return visitExpr(g); else if (auto c = e->as()) return {defArena->freshCell(), nullptr}; // ok else if (auto c = e->as()) @@ -708,33 +869,33 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) else if (auto c = e->as()) return {defArena->freshCell(), nullptr}; // ok else if (auto l = e->as()) - return visitExpr(scope, l); + return visitExpr(l); else if (auto g = e->as()) - return visitExpr(scope, g); + return visitExpr(g); else if (auto v = e->as()) return {defArena->freshCell(), nullptr}; // ok else if (auto c = e->as()) - return visitExpr(scope, c); + return visitExpr(c); else if (auto i = e->as()) - return visitExpr(scope, i); + return visitExpr(i); else if (auto i = e->as()) - return visitExpr(scope, i); + return visitExpr(i); else if (auto f = e->as()) - return visitExpr(scope, f); + return visitExpr(f); else if (auto t = e->as()) - return visitExpr(scope, t); + return visitExpr(t); else if (auto u = e->as()) - return visitExpr(scope, u); + return visitExpr(u); else if (auto b = e->as()) - return visitExpr(scope, b); + return visitExpr(b); else if (auto t = e->as()) - return visitExpr(scope, t); + return visitExpr(t); else if (auto i = e->as()) - return visitExpr(scope, i); + return visitExpr(i); else if (auto i = e->as()) - return visitExpr(scope, i); + return visitExpr(i); else if (auto error = e->as()) - return visitExpr(scope, error); + return visitExpr(error); else handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitExpr"); }; @@ -746,64 +907,65 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) return {def, key}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGroup* group) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprGroup* group) { - return visitExpr(scope, group->expr); + return visitExpr(group->expr); } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprLocal* l) { - DefId def = lookup(scope, l->local); + DefId def = lookup(l->local); const RefinementKey* key = keyArena->leaf(def); return {def, key}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprGlobal* g) { - DefId def = lookup(scope, g->name); + DefId def = lookup(g->name); return {def, keyArena->leaf(def)}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c) { - visitExpr(scope, c->func); + visitExpr(c->func); for (AstExpr* arg : c->args) - visitExpr(scope, arg); + visitExpr(arg); // calls should be treated as subscripted. return {defArena->freshCell(/* subscripted */ true), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexName* i) { - auto [parentDef, parentKey] = visitExpr(scope, i->expr); + auto [parentDef, parentKey] = visitExpr(i->expr); std::string index = i->index.value; - DefId def = lookup(scope, parentDef, index); + DefId def = lookup(parentDef, index); return {def, keyArena->node(parentKey, def, index)}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexExpr* i) { - auto [parentDef, parentKey] = visitExpr(scope, i->expr); - visitExpr(scope, i->index); + auto [parentDef, parentKey] = visitExpr(i->expr); + visitExpr(i->index); if (auto string = i->index->as()) { std::string index{string->value.data, string->value.size}; - DefId def = lookup(scope, parentDef, index); + DefId def = lookup(parentDef, index); return {def, keyArena->node(parentKey, def, index)}; } return {defArena->freshCell(/* subscripted= */ true), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprFunction* f) { - DfgScope* signatureScope = childScope(scope, DfgScope::Function); + DfgScope* signatureScope = makeChildScope(f->location, DfgScope::Function); + PushScope ps{scopeStack, signatureScope}; if (AstLocal* self = f->self) { @@ -819,7 +981,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* for (AstLocal* param : f->args) { if (param->annotation) - visitType(signatureScope, param->annotation); + visitType(param->annotation); DefId def = defArena->freshCell(); graph.localDefs[param] = def; @@ -828,10 +990,10 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* } if (f->varargAnnotation) - visitTypePack(scope, f->varargAnnotation); + visitTypePack(f->varargAnnotation); if (f->returnAnnotation) - visitTypeList(signatureScope, *f->returnAnnotation); + visitTypeList(*f->returnAnnotation); // TODO: function body can be re-entrant, as in mutations that occurs at the end of the function can also be // visible to the beginning of the function, so statically speaking, the body of the function has an exit point @@ -841,92 +1003,94 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* // local g = f // g() --> function: address // g() --> 5 - visit(signatureScope, f->body); + visit(f->body); return {defArena->freshCell(), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprTable* t) { DefId tableCell = defArena->freshCell(); - scope->props[tableCell] = {}; + currentScope()->props[tableCell] = {}; for (AstExprTable::Item item : t->items) { - DataFlowResult result = visitExpr(scope, item.value); + DataFlowResult result = visitExpr(item.value); if (item.key) { - visitExpr(scope, item.key); + visitExpr(item.key); if (auto string = item.key->as()) - scope->props[tableCell][string->value.data] = result.def; + currentScope()->props[tableCell][string->value.data] = result.def; } } return {tableCell, nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprUnary* u) { - visitExpr(scope, u->expr); + visitExpr(u->expr); return {defArena->freshCell(), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprBinary* b) { - visitExpr(scope, b->left); - visitExpr(scope, b->right); + visitExpr(b->left); + visitExpr(b->right); return {defArena->freshCell(), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprTypeAssertion* t) { - auto [def, key] = visitExpr(scope, t->expr); - visitType(scope, t->annotation); + auto [def, key] = visitExpr(t->expr); + visitType(t->annotation); return {def, key}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIfElse* i) { - visitExpr(scope, i->condition); - visitExpr(scope, i->trueExpr); - visitExpr(scope, i->falseExpr); + visitExpr(i->condition); + visitExpr(i->trueExpr); + visitExpr(i->falseExpr); return {defArena->freshCell(), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprInterpString* i) { for (AstExpr* e : i->expressions) - visitExpr(scope, e); + visitExpr(e); return {defArena->freshCell(), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* error) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprError* error) { - DfgScope* unreachable = childScope(scope); + DfgScope* unreachable = makeChildScope(error->location); + PushScope ps{scopeStack, unreachable}; + for (AstExpr* e : error->expressions) - visitExpr(unreachable, e); + visitExpr(e); return {defArena->freshCell(), nullptr}; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef) +void DataFlowGraphBuilder::visitLValue(AstExpr* e, DefId incomingDef) { auto go = [&]() { if (auto l = e->as()) - return visitLValue(scope, l, incomingDef); + return visitLValue(l, incomingDef); else if (auto g = e->as()) - return visitLValue(scope, g, incomingDef); + return visitLValue(g, incomingDef); else if (auto i = e->as()) - return visitLValue(scope, i, incomingDef); + return visitLValue(i, incomingDef); else if (auto i = e->as()) - return visitLValue(scope, i, incomingDef); + return visitLValue(i, incomingDef); else if (auto error = e->as()) - return visitLValue(scope, error, incomingDef); + return visitLValue(error, incomingDef); else handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); }; @@ -934,8 +1098,10 @@ void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomi graph.astDefs[e] = go(); } -DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(AstExprLocal* l, DefId incomingDef) { + DfgScope* scope = currentScope(); + // In order to avoid alias tracking, we need to clip the reference to the parent def. if (scope->canUpdateDefinition(l->local)) { @@ -945,11 +1111,13 @@ DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l, DefId return updated; } else - return visitExpr(scope, static_cast(l)).def; + return visitExpr(static_cast(l)).def; } -DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(AstExprGlobal* g, DefId incomingDef) { + DfgScope* scope = currentScope(); + // In order to avoid alias tracking, we need to clip the reference to the parent def. if (scope->canUpdateDefinition(g->name)) { @@ -959,13 +1127,14 @@ DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, DefId return updated; } else - return visitExpr(scope, static_cast(g)).def; + return visitExpr(static_cast(g)).def; } -DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(AstExprIndexName* i, DefId incomingDef) { - DefId parentDef = visitExpr(scope, i->expr).def; + DefId parentDef = visitExpr(i->expr).def; + DfgScope* scope = currentScope(); if (scope->canUpdateDefinition(parentDef, i->index.value)) { DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); @@ -973,14 +1142,15 @@ DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i, De return updated; } else - return visitExpr(scope, static_cast(i)).def; + return visitExpr(static_cast(i)).def; } -DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(AstExprIndexExpr* i, DefId incomingDef) { - DefId parentDef = visitExpr(scope, i->expr).def; - visitExpr(scope, i->index); + DefId parentDef = visitExpr(i->expr).def; + visitExpr(i->index); + DfgScope* scope = currentScope(); if (auto string = i->index->as()) { if (scope->canUpdateDefinition(parentDef, string->value.data)) @@ -990,33 +1160,33 @@ DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, De return updated; } else - return visitExpr(scope, static_cast(i)).def; + return visitExpr(static_cast(i)).def; } else return defArena->freshCell(/*subscripted=*/true); } -DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprError* error, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(AstExprError* error, DefId incomingDef) { - return visitExpr(scope, error).def; + return visitExpr(error).def; } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t) +void DataFlowGraphBuilder::visitType(AstType* t) { if (auto r = t->as()) - return visitType(scope, r); + return visitType(r); else if (auto table = t->as()) - return visitType(scope, table); + return visitType(table); else if (auto f = t->as()) - return visitType(scope, f); + return visitType(f); else if (auto tyof = t->as()) - return visitType(scope, tyof); + return visitType(tyof); else if (auto u = t->as()) - return visitType(scope, u); + return visitType(u); else if (auto i = t->as()) - return visitType(scope, i); + return visitType(i); else if (auto e = t->as()) - return visitType(scope, e); + return visitType(e); else if (auto s = t->as()) return; // ok else if (auto s = t->as()) @@ -1025,106 +1195,106 @@ void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t) handle->ice("Unknown AstType in DataFlowGraphBuilder::visitType"); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeReference* r) +void DataFlowGraphBuilder::visitType(AstTypeReference* r) { for (AstTypeOrPack param : r->parameters) { if (param.type) - visitType(scope, param.type); + visitType(param.type); else - visitTypePack(scope, param.typePack); + visitTypePack(param.typePack); } } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeTable* t) +void DataFlowGraphBuilder::visitType(AstTypeTable* t) { for (AstTableProp p : t->props) - visitType(scope, p.type); + visitType(p.type); if (t->indexer) { - visitType(scope, t->indexer->indexType); - visitType(scope, t->indexer->resultType); + visitType(t->indexer->indexType); + visitType(t->indexer->resultType); } } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeFunction* f) +void DataFlowGraphBuilder::visitType(AstTypeFunction* f) { - visitGenerics(scope, f->generics); - visitGenericPacks(scope, f->genericPacks); - visitTypeList(scope, f->argTypes); - visitTypeList(scope, f->returnTypes); + visitGenerics(f->generics); + visitGenericPacks(f->genericPacks); + visitTypeList(f->argTypes); + visitTypeList(f->returnTypes); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeTypeof* t) +void DataFlowGraphBuilder::visitType(AstTypeTypeof* t) { - visitExpr(scope, t->expr); + visitExpr(t->expr); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeUnion* u) +void DataFlowGraphBuilder::visitType(AstTypeUnion* u) { for (AstType* t : u->types) - visitType(scope, t); + visitType(t); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeIntersection* i) +void DataFlowGraphBuilder::visitType(AstTypeIntersection* i) { for (AstType* t : i->types) - visitType(scope, t); + visitType(t); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeError* error) +void DataFlowGraphBuilder::visitType(AstTypeError* error) { for (AstType* t : error->types) - visitType(scope, t); + visitType(t); } -void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePack* p) +void DataFlowGraphBuilder::visitTypePack(AstTypePack* p) { if (auto e = p->as()) - return visitTypePack(scope, e); + return visitTypePack(e); else if (auto v = p->as()) - return visitTypePack(scope, v); + return visitTypePack(v); else if (auto g = p->as()) return; // ok else handle->ice("Unknown AstTypePack in DataFlowGraphBuilder::visitTypePack"); } -void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePackExplicit* e) +void DataFlowGraphBuilder::visitTypePack(AstTypePackExplicit* e) { - visitTypeList(scope, e->typeList); + visitTypeList(e->typeList); } -void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePackVariadic* v) +void DataFlowGraphBuilder::visitTypePack(AstTypePackVariadic* v) { - visitType(scope, v->variadicType); + visitType(v->variadicType); } -void DataFlowGraphBuilder::visitTypeList(DfgScope* scope, AstTypeList l) +void DataFlowGraphBuilder::visitTypeList(AstTypeList l) { for (AstType* t : l.types) - visitType(scope, t); + visitType(t); if (l.tailType) - visitTypePack(scope, l.tailType); + visitTypePack(l.tailType); } -void DataFlowGraphBuilder::visitGenerics(DfgScope* scope, AstArray g) +void DataFlowGraphBuilder::visitGenerics(AstArray g) { for (AstGenericType generic : g) { if (generic.defaultValue) - visitType(scope, generic.defaultValue); + visitType(generic.defaultValue); } } -void DataFlowGraphBuilder::visitGenericPacks(DfgScope* scope, AstArray g) +void DataFlowGraphBuilder::visitGenericPacks(AstArray g) { for (AstGenericTypePack generic : g) { if (generic.defaultValue) - visitTypePack(scope, generic.defaultValue); + visitTypePack(generic.defaultValue); } } diff --git a/Analysis/src/Differ.cpp b/Analysis/src/Differ.cpp index 25687e11..b2cebc0b 100644 --- a/Analysis/src/Differ.cpp +++ b/Analysis/src/Differ.cpp @@ -13,6 +13,7 @@ namespace Luau { + std::string DiffPathNode::toString() const { switch (kind) @@ -944,12 +945,14 @@ std::vector>::const_reverse_iterator DifferEnvironment return visitingStack.crend(); } + DifferResult diff(TypeId ty1, TypeId ty2) { DifferEnvironment differEnv{ty1, ty2, std::nullopt, std::nullopt}; return diffUsingEnv(differEnv, ty1, ty2); } + DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional symbol1, std::optional symbol2) { DifferEnvironment differEnv{ty1, ty2, symbol1, symbol2}; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index e539661a..50e090ca 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,10 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAG(LuauMathMap) + namespace Luau { -static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC( +// TODO: there has to be a better way, like splitting up per library +static const std::string kBuiltinDefinitionLuaSrcChecked_DEPRECATED = R"BUILTIN_SRC( declare bit32: { band: @checked (...number) -> number, @@ -195,6 +198,228 @@ declare utf8: { declare function unpack(tab: {V}, i: number?, j: number?): ...V +--- Buffer API +declare buffer: { + create: @checked (size: number) -> buffer, + fromstring: @checked (str: string) -> buffer, + tostring: @checked (b: buffer) -> string, + len: @checked (b: buffer) -> number, + copy: @checked (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (), + fill: @checked (b: buffer, offset: number, value: number, count: number?) -> (), + readi8: @checked (b: buffer, offset: number) -> number, + readu8: @checked (b: buffer, offset: number) -> number, + readi16: @checked (b: buffer, offset: number) -> number, + readu16: @checked (b: buffer, offset: number) -> number, + readi32: @checked (b: buffer, offset: number) -> number, + readu32: @checked (b: buffer, offset: number) -> number, + readf32: @checked (b: buffer, offset: number) -> number, + readf64: @checked (b: buffer, offset: number) -> number, + writei8: @checked (b: buffer, offset: number, value: number) -> (), + writeu8: @checked (b: buffer, offset: number, value: number) -> (), + writei16: @checked (b: buffer, offset: number, value: number) -> (), + writeu16: @checked (b: buffer, offset: number, value: number) -> (), + writei32: @checked (b: buffer, offset: number, value: number) -> (), + writeu32: @checked (b: buffer, offset: number, value: number) -> (), + writef32: @checked (b: buffer, offset: number, value: number) -> (), + writef64: @checked (b: buffer, offset: number, value: number) -> (), + readstring: @checked (b: buffer, offset: number, count: number) -> string, + writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (), +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC( + +declare bit32: { + band: @checked (...number) -> number, + bor: @checked (...number) -> number, + bxor: @checked (...number) -> number, + btest: @checked (number, ...number) -> boolean, + rrotate: @checked (x: number, disp: number) -> number, + lrotate: @checked (x: number, disp: number) -> number, + lshift: @checked (x: number, disp: number) -> number, + arshift: @checked (x: number, disp: number) -> number, + rshift: @checked (x: number, disp: number) -> number, + bnot: @checked (x: number) -> number, + extract: @checked (n: number, field: number, width: number?) -> number, + replace: @checked (n: number, v: number, field: number, width: number?) -> number, + countlz: @checked (n: number) -> number, + countrz: @checked (n: number) -> number, + byteswap: @checked (n: number) -> number, +} + +declare math: { + frexp: @checked (n: number) -> (number, number), + ldexp: @checked (s: number, e: number) -> number, + fmod: @checked (x: number, y: number) -> number, + modf: @checked (n: number) -> (number, number), + pow: @checked (x: number, y: number) -> number, + exp: @checked (n: number) -> number, + + ceil: @checked (n: number) -> number, + floor: @checked (n: number) -> number, + abs: @checked (n: number) -> number, + sqrt: @checked (n: number) -> number, + + log: @checked (n: number, base: number?) -> number, + log10: @checked (n: number) -> number, + + rad: @checked (n: number) -> number, + deg: @checked (n: number) -> number, + + sin: @checked (n: number) -> number, + cos: @checked (n: number) -> number, + tan: @checked (n: number) -> number, + sinh: @checked (n: number) -> number, + cosh: @checked (n: number) -> number, + tanh: @checked (n: number) -> number, + atan: @checked (n: number) -> number, + acos: @checked (n: number) -> number, + asin: @checked (n: number) -> number, + atan2: @checked (y: number, x: number) -> number, + + min: @checked (number, ...number) -> number, + max: @checked (number, ...number) -> number, + + pi: number, + huge: number, + + randomseed: @checked (seed: number) -> (), + random: @checked (number?, number?) -> number, + + sign: @checked (n: number) -> number, + clamp: @checked (n: number, min: number, max: number) -> number, + noise: @checked (x: number, y: number?, z: number?) -> number, + round: @checked (n: number) -> number, + map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number, +} + +type DateTypeArg = { + year: number, + month: number, + day: number, + hour: number?, + min: number?, + sec: number?, + isdst: boolean?, +} + +type DateTypeResult = { + year: number, + month: number, + wday: number, + yday: number, + day: number, + hour: number, + min: number, + sec: number, + isdst: boolean, +} + +declare os: { + time: (time: DateTypeArg?) -> number, + date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string), + difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, + clock: () -> number, +} + +@checked declare function require(target: any): any + +@checked declare function getfenv(target: any): { [string]: any } + +declare _G: any +declare _VERSION: string + +declare function gcinfo(): number + +declare function print(...: T...) + +declare function type(value: T): string +declare function typeof(value: T): string + +-- `assert` has a magic function attached that will give more detailed type information +declare function assert(value: T, errorMessage: string?): T +declare function error(message: T, level: number?): never + +declare function tostring(value: T): string +declare function tonumber(value: T, radix: number?): number? + +declare function rawequal(a: T1, b: T2): boolean +declare function rawget(tab: {[K]: V}, k: K): V +declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} +declare function rawlen(obj: {[K]: V} | string): number + +declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? + +declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number) + +declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) + +-- FIXME: The actual type of `xpcall` is: +-- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) +-- Since we can't represent the return value, we use (boolean, R1...). +declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) + +-- `select` has a magic function attached to provide more detailed type information +declare function select(i: string | number, ...: A...): ...any + +-- FIXME: This type is not entirely correct - `loadstring` returns a function or +-- (nil, string). +declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) + +@checked declare function newproxy(mt: boolean?): any + +declare coroutine: { + create: (f: (A...) -> R...) -> thread, + resume: (co: thread, A...) -> (boolean, R...), + running: () -> thread, + status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended", + wrap: (f: (A...) -> R...) -> ((A...) -> R...), + yield: (A...) -> R..., + isyieldable: () -> boolean, + close: @checked (co: thread) -> (boolean, any) +} + +declare table: { + concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, + insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), + maxn: (t: {V}) -> number, + remove: (t: {V}, number?) -> V?, + sort: (t: {V}, comp: ((V, V) -> boolean)?) -> (), + create: (count: number, value: V?) -> {V}, + find: (haystack: {V}, needle: V, init: number?) -> number?, + + unpack: (list: {V}, i: number?, j: number?) -> ...V, + pack: (...V) -> { n: number, [number]: V }, + + getn: (t: {V}) -> number, + foreach: (t: {[K]: V}, f: (K, V) -> ()) -> (), + foreachi: ({V}, (number, V) -> ()) -> (), + + move: (src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V}, + clear: (table: {[K]: V}) -> (), + + isfrozen: (t: {[K]: V}) -> boolean, +} + +declare debug: { + info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), + traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), +} + +declare utf8: { + char: @checked (...number) -> string, + charpattern: string, + codes: @checked (str: string) -> ((string, number) -> (number, number), string, number), + codepoint: @checked (str: string, i: number?, j: number?) -> ...number, + len: @checked (s: string, i: number?, j: number?) -> (number?, number?), + offset: @checked (s: string, n: number?, i: number?) -> number, +} + +-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. +declare function unpack(tab: {V}, i: number?, j: number?): ...V + + --- Buffer API declare buffer: { create: @checked (size: number) -> buffer, @@ -227,7 +452,7 @@ declare buffer: { std::string getBuiltinDefinitionSource() { - std::string result = kBuiltinDefinitionLuaSrcChecked; + std::string result = FFlag::LuauMathMap ? kBuiltinDefinitionLuaSrcChecked : kBuiltinDefinitionLuaSrcChecked_DEPRECATED; return result; } diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index 4088c500..853a3d89 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -3,6 +3,11 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" +#include "Luau/Common.h" +#include "Luau/Frontend.h" +#include "Luau/Parser.h" +#include "Luau/ParseOptions.h" +#include "Luau/Module.h" namespace Luau { @@ -10,6 +15,8 @@ namespace Luau FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos) { std::vector ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos); + // Should always contain the root AstStat + LUAU_ASSERT(ancestry.size() >= 1); DenseHashMap localMap{AstName()}; std::vector localStack; AstStat* nearestStatement = nullptr; @@ -21,7 +28,7 @@ FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* ro { if (stat->location.begin <= cursorPos) nearestStatement = stat; - if (stat->location.begin <= cursorPos) + if (stat->location.begin < cursorPos && stat->location.begin.line < cursorPos.line) { // This statement precedes the current one if (auto loc = stat->as()) @@ -42,7 +49,116 @@ FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* ro } } + if (!nearestStatement) + nearestStatement = ancestry[0]->asStat(); + LUAU_ASSERT(nearestStatement); return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)}; } +std::pair getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos) +{ + unsigned int lineCount = 0; + unsigned int colCount = 0; + + unsigned int docOffset = 0; + unsigned int startOffset = 0; + unsigned int endOffset = 0; + bool foundStart = false; + bool foundEnd = false; + for (char c : src) + { + if (foundStart && foundEnd) + break; + + if (startPos.line == lineCount && startPos.column == colCount) + { + foundStart = true; + startOffset = docOffset; + } + + if (endPos.line == lineCount && endPos.column == colCount) + { + endOffset = docOffset; + foundEnd = true; + } + + if (c == '\n') + { + lineCount++; + colCount = 0; + } + else + colCount++; + docOffset++; + } + + + unsigned int min = std::min(startOffset, endOffset); + unsigned int len = std::max(startOffset, endOffset) - min; + return {min, len}; +} + +ScopePtr findClosestScope(const ModulePtr& module, const Position& cursorPos) +{ + LUAU_ASSERT(module->hasModuleScope()); + + ScopePtr closest = module->getModuleScope(); + for (auto [loc, sc] : module->scopes) + { + if (loc.begin <= cursorPos && closest->location.begin <= loc.begin) + closest = sc; + } + + return closest; +} + +FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos) +{ + FragmentAutocompleteAncestryResult result = findAncestryForFragmentParse(srcModule.root, cursorPos); + ParseOptions opts; + opts.allowDeclarationSyntax = false; + opts.captureComments = false; + opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)}; + AstStat* enclosingStatement = result.nearestStatement; + + const Position& endPos = cursorPos; + // If the statement starts on a previous line, grab the statement beginning + // otherwise, grab the statement end to whatever is being typed right now + const Position& startPos = + enclosingStatement->location.begin.line == cursorPos.line ? enclosingStatement->location.begin : enclosingStatement->location.end; + + auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos); + + const char* srcStart = src.data() + offsetStart; + std::string_view dbg = src.substr(offsetStart, parseLength); + const std::shared_ptr& nameTbl = srcModule.names; + FragmentParseResult fragmentResult; + fragmentResult.fragmentToParse = std::string(dbg.data(), parseLength); + // For the duration of the incremental parse, we want to allow the name table to re-use duplicate names + ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts); + + std::vector fabricatedAncestry = std::move(result.ancestry); + std::vector fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end); + fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end()); + if (enclosingStatement == nullptr) + enclosingStatement = p.root; + fragmentResult.root = std::move(p.root); + fragmentResult.ancestry = std::move(fabricatedAncestry); + return fragmentResult; +} + + +AutocompleteResult fragmentAutocomplete( + Frontend& frontend, + std::string_view src, + const ModuleName& moduleName, + Position& cursorPosition, + StringCompletionCallback callback +) +{ + LUAU_ASSERT(FFlag::LuauSolverV2); + // TODO + return {}; +} + } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b2325d37..95ad58f3 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -49,6 +49,7 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection, false) LUAU_FASTFLAG(StudioReportLuauAny2) +LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule, false) namespace Luau { @@ -1315,6 +1316,18 @@ ModulePtr check( } DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); + DataFlowGraph* dfgForConstraintGeneration = nullptr; + if (FFlag::LuauStoreDFGOnModule) + { + auto [dfg, scopes] = DataFlowGraphBuilder::buildShared(sourceModule.root, iceHandler); + result->dataFlowGraph = std::move(dfg); + result->dfgScopes = std::move(scopes); + dfgForConstraintGeneration = result->dataFlowGraph.get(); + } + else + { + dfgForConstraintGeneration = &dfg; + } UnifierSharedState unifierState{iceHandler}; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; @@ -1336,7 +1349,7 @@ ModulePtr check( parentScope, std::move(prepareModuleScope), logger.get(), - NotNull{&dfg}, + NotNull{dfgForConstraintGeneration}, requireCycles }; diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp index 8c6cf378..8dce95f9 100644 --- a/Analysis/src/Generalization.cpp +++ b/Analysis/src/Generalization.cpp @@ -801,6 +801,12 @@ struct TypeCacher : TypeOnceVisitor return false; } + bool visit(TypeId ty, const NoRefineType&) override + { + cache(ty); + return false; + } + bool visit(TypeId ty, const UnionType& ut) override { if (isUncacheable(ty) || isCached(ty)) diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index 7d5859ce..0ebc573d 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -594,7 +594,7 @@ struct NonStrictTypeChecker std::shared_ptr norm = normalizer.normalize(expectedArgType); DefId def = dfg->getDef(arg); TypeId runTimeErrorTy; - // If we're dealing with any, negating any will cause all subtype tests to fail, since ~any is any + // If we're dealing with any, negating any will cause all subtype tests to fail // However, when someone calls this function, they're going to want to be able to pass it anything, // for that reason, we manually inject never into the context so that the runtime test will always pass. if (!norm) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 475f45c3..c5e3496f 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1872,7 +1872,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t if (res != NormalizationResult::True) return res; } - else if (get(there) || get(there)) + else if (get(there) || get(there) || get(there)) { // nothing } @@ -3217,6 +3217,11 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type // assumption that it is the same as any. return NormalizationResult::True; } + else if (get(t)) + { + // `*no-refine*` means we will never do anything to affect the intersection. + return NormalizationResult::True; + } else if (get(t)) { // if we're intersecting with `~never`, this is equivalent to intersecting with `unknown` @@ -3243,6 +3248,11 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { here.classes.resetToNever(); } + else if (get(there)) + { + // `*no-refine*` means we will never do anything to affect the intersection. + return NormalizationResult::True; + } else LUAU_ASSERT(!"Unreachable"); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 634f5241..dd5a2f85 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -50,6 +50,11 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a LUAU_ASSERT(ty->persistent); return ty; } + else if constexpr (std::is_same_v) + { + LUAU_ASSERT(ty->persistent); + return ty; + } else if constexpr (std::is_same_v) { LUAU_ASSERT(ty->persistent); diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index f3571b7a..cc1ed7cf 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -261,92 +261,50 @@ SubtypingResult SubtypingResult::any(const std::vector& results struct ApplyMappedGenerics : Substitution { - using MappedGenerics = DenseHashMap; - using MappedGenericPacks = DenseHashMap; - NotNull builtinTypes; NotNull arena; SubtypingEnvironment& env; - MappedGenerics& mappedGenerics_DEPRECATED; - MappedGenericPacks& mappedGenericPacks_DEPRECATED; - - ApplyMappedGenerics( - NotNull builtinTypes, - NotNull arena, - SubtypingEnvironment& env, - MappedGenerics& mappedGenerics, - MappedGenericPacks& mappedGenericPacks - ) + ApplyMappedGenerics(NotNull builtinTypes, NotNull arena, SubtypingEnvironment& env) : Substitution(TxnLog::empty(), arena) , builtinTypes(builtinTypes) , arena(arena) , env(env) - , mappedGenerics_DEPRECATED(mappedGenerics) - , mappedGenericPacks_DEPRECATED(mappedGenericPacks) { } bool isDirty(TypeId ty) override { - if (DFInt::LuauTypeSolverRelease >= 644) - return env.containsMappedType(ty); - else - return mappedGenerics_DEPRECATED.contains(ty); + return env.containsMappedType(ty); } bool isDirty(TypePackId tp) override { - if (DFInt::LuauTypeSolverRelease >= 644) - return env.containsMappedPack(tp); - else - return mappedGenericPacks_DEPRECATED.contains(tp); + return env.containsMappedPack(tp); } TypeId clean(TypeId ty) override { - if (DFInt::LuauTypeSolverRelease >= 644) - { - const auto& bounds = env.getMappedTypeBounds(ty); + const auto& bounds = env.getMappedTypeBounds(ty); - if (bounds.upperBound.empty()) - return builtinTypes->unknownType; + if (bounds.upperBound.empty()) + return builtinTypes->unknownType; - if (bounds.upperBound.size() == 1) - return *begin(bounds.upperBound); + if (bounds.upperBound.size() == 1) + return *begin(bounds.upperBound); - return arena->addType(IntersectionType{std::vector(begin(bounds.upperBound), end(bounds.upperBound))}); - } - else - { - const auto& bounds = mappedGenerics_DEPRECATED[ty]; - - if (bounds.upperBound.empty()) - return builtinTypes->unknownType; - - if (bounds.upperBound.size() == 1) - return *begin(bounds.upperBound); - - return arena->addType(IntersectionType{std::vector(begin(bounds.upperBound), end(bounds.upperBound))}); - } + return arena->addType(IntersectionType{std::vector(begin(bounds.upperBound), end(bounds.upperBound))}); } TypePackId clean(TypePackId tp) override { - if (DFInt::LuauTypeSolverRelease >= 644) - { - if (auto it = env.getMappedPackBounds(tp)) - return *it; + if (auto it = env.getMappedPackBounds(tp)) + return *it; - // Clean is only called when isDirty found a pack bound - LUAU_ASSERT(!"Unreachable"); - return nullptr; - } - else - { - return mappedGenericPacks_DEPRECATED[tp]; - } + // Clean is only called when isDirty found a pack bound + LUAU_ASSERT(!"Unreachable"); + return nullptr; } bool ignoreChildren(TypeId ty) override @@ -364,7 +322,7 @@ struct ApplyMappedGenerics : Substitution std::optional SubtypingEnvironment::applyMappedGenerics(NotNull builtinTypes, NotNull arena, TypeId ty) { - ApplyMappedGenerics amg{builtinTypes, arena, *this, mappedGenerics, mappedGenericPacks}; + ApplyMappedGenerics amg{builtinTypes, arena, *this}; return amg.substitute(ty); } @@ -489,22 +447,12 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy, NotNull= 644) - { - SubtypingEnvironment boundsEnv; - boundsEnv.parent = &env; - SubtypingResult boundsResult = isCovariantWith(boundsEnv, lowerBound, upperBound, scope); - boundsResult.reasoning.clear(); + SubtypingEnvironment boundsEnv; + boundsEnv.parent = &env; + SubtypingResult boundsResult = isCovariantWith(boundsEnv, lowerBound, upperBound, scope); + boundsResult.reasoning.clear(); - result.andAlso(boundsResult); - } - else - { - SubtypingResult boundsResult = isCovariantWith(env, lowerBound, upperBound, scope); - boundsResult.reasoning.clear(); - - result.andAlso(boundsResult); - } + result.andAlso(boundsResult); } /* TODO: We presently don't store subtype test results in the persistent @@ -582,18 +530,17 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub subTy = follow(subTy); superTy = follow(superTy); - if (const TypeId* subIt = (DFInt::LuauTypeSolverRelease >= 644 ? env.tryFindSubstitution(subTy) : env.substitutions.find(subTy)); subIt && *subIt) + if (const TypeId* subIt = env.tryFindSubstitution(subTy); subIt && *subIt) subTy = *subIt; - if (const TypeId* superIt = (DFInt::LuauTypeSolverRelease >= 644 ? env.tryFindSubstitution(superTy) : env.substitutions.find(superTy)); - superIt && *superIt) + if (const TypeId* superIt = env.tryFindSubstitution(superTy); superIt && *superIt) superTy = *superIt; const SubtypingResult* cachedResult = resultCache.find({subTy, superTy}); if (cachedResult) return *cachedResult; - cachedResult = DFInt::LuauTypeSolverRelease >= 644 ? env.tryFindSubtypingResult({subTy, superTy}) : env.ephemeralCache.find({subTy, superTy}); + cachedResult = env.tryFindSubtypingResult({subTy, superTy}); if (cachedResult) return *cachedResult; @@ -838,8 +785,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId std::vector headSlice(begin(superHead), begin(superHead) + headSize); TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail); - if (TypePackId* other = - (DFInt::LuauTypeSolverRelease >= 644 ? env.getMappedPackBounds(*subTail) : env.mappedGenericPacks.find(*subTail))) + if (TypePackId* other = env.getMappedPackBounds(*subTail)) // TODO: TypePath can't express "slice of a pack + its tail". results.push_back(isCovariantWith(env, *other, superTailPack, scope).withSubComponent(TypePath::PackField::Tail)); else @@ -894,8 +840,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId std::vector headSlice(begin(subHead), begin(subHead) + headSize); TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail); - if (TypePackId* other = - (DFInt::LuauTypeSolverRelease >= 644 ? env.getMappedPackBounds(*superTail) : env.mappedGenericPacks.find(*superTail))) + if (TypePackId* other = env.getMappedPackBounds(*superTail)) // TODO: TypePath can't express "slice of a pack + its tail". results.push_back(isContravariantWith(env, subTailPack, *other, scope).withSuperComponent(TypePath::PackField::Tail)); else @@ -1837,11 +1782,8 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId supe if (!get(subTy)) return false; - if (DFInt::LuauTypeSolverRelease >= 644) - { - if (!env.mappedGenerics.find(subTy) && env.containsMappedType(subTy)) - iceReporter->ice("attempting to modify bounds of a potentially visited generic"); - } + if (!env.mappedGenerics.find(subTy) && env.containsMappedType(subTy)) + iceReporter->ice("attempting to modify bounds of a potentially visited generic"); env.mappedGenerics[subTy].upperBound.insert(superTy); } @@ -1850,11 +1792,8 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId supe if (!get(superTy)) return false; - if (DFInt::LuauTypeSolverRelease >= 644) - { - if (!env.mappedGenerics.find(superTy) && env.containsMappedType(superTy)) - iceReporter->ice("attempting to modify bounds of a potentially visited generic"); - } + if (!env.mappedGenerics.find(superTy) && env.containsMappedType(superTy)) + iceReporter->ice("attempting to modify bounds of a potentially visited generic"); env.mappedGenerics[superTy].lowerBound.insert(subTy); } @@ -1901,7 +1840,7 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypePackId subTp, TypePac if (!get(subTp)) return false; - if (TypePackId* m = (DFInt::LuauTypeSolverRelease >= 644 ? env.getMappedPackBounds(subTp) : env.mappedGenericPacks.find(subTp))) + if (TypePackId* m = env.getMappedPackBounds(subTp)) return *m == superTp; env.mappedGenericPacks[subTp] = superTp; diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp index 177396a7..bcd5c1d8 100644 --- a/Analysis/src/TableLiteralInference.cpp +++ b/Analysis/src/TableLiteralInference.cpp @@ -6,19 +6,14 @@ #include "Luau/Type.h" #include "Luau/ToString.h" #include "Luau/TypeArena.h" +#include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) + namespace Luau { -static bool isLiteral(const AstExpr* expr) -{ - return ( - expr->is() || expr->is() || expr->is() || expr->is() || - expr->is() || expr->is() - ); -} - // A fast approximation of subTy <: superTy static bool fastIsSubtype(TypeId subTy, TypeId superTy) { @@ -381,15 +376,21 @@ TypeId matchLiteralType( const TypeId* keyTy = astTypes->find(item.key); LUAU_ASSERT(keyTy); TypeId tKey = follow(*keyTy); - if (get(tKey)) + if (DFInt::LuauTypeSolverRelease >= 648) + { + LUAU_ASSERT(!is(tKey)); + } + else if (get(tKey)) toBlock.push_back(tKey); - const TypeId* propTy = astTypes->find(item.value); LUAU_ASSERT(propTy); TypeId tProp = follow(*propTy); - if (get(tProp)) + if (DFInt::LuauTypeSolverRelease >= 648) + { + LUAU_ASSERT(!is(tKey)); + } + else if (get(tProp)) toBlock.push_back(tProp); - // Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings) if (!item.key->as() && expectedTableTy->indexer) (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 4408063f..e3f4fd3b 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -269,6 +269,12 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } + else if constexpr (std::is_same_v) + { + formatAppend(result, "NoRefineType %d", index); + finishNodeLabel(ty); + finishNode(); + } else if constexpr (std::is_same_v) { formatAppend(result, "UnknownType %d", index); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 66d037ed..5b191d30 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -856,6 +856,11 @@ struct TypeStringifier state.emit("any"); } + void operator()(TypeId, const NoRefineType&) + { + state.emit("*no-refine*"); + } + void operator()(TypeId, const UnionType& uv) { if (state.hasSeen(&uv)) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index b024fdd2..a77836c5 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -1030,6 +1030,7 @@ BuiltinTypes::BuiltinTypes() , unknownType(arena->addType(Type{UnknownType{}, /*persistent*/ true})) , neverType(arena->addType(Type{NeverType{}, /*persistent*/ true})) , errorType(arena->addType(Type{ErrorType{}, /*persistent*/ true})) + , noRefineType(arena->addType(Type{NoRefineType{}, /*persistent*/ true})) , falsyType(arena->addType(Type{UnionType{{falseType, nilType}}, /*persistent*/ true})) , truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true})) , optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true})) diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index a288cfbe..a28ff987 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -145,6 +145,12 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName("any"), std::nullopt, Location()); } + + AstType* operator()(const NoRefineType&) + { + return allocator->alloc(Location(), std::nullopt, AstName("*no-refine*"), std::nullopt, Location()); + } + AstType* operator()(const TableType& ttv) { RecursionCounter counter(&count); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 0a47f0df..2634b89e 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -3022,20 +3022,9 @@ PropertyType TypeChecker2::hasIndexTypeFromType( if (tt->indexer) { TypeId indexType = follow(tt->indexer->indexType); - if (DFInt::LuauTypeSolverRelease >= 644) - { - TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}}); - if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, *ice)) - return {NormalizationResult::True, {tt->indexer->indexResultType}}; - } - else - { - if (isPrim(indexType, PrimitiveType::String)) - return {NormalizationResult::True, {tt->indexer->indexResultType}}; - // If the indexer looks like { [any] : _} - the prop lookup should be allowed! - else if (get(indexType) || get(indexType)) - return {NormalizationResult::True, {tt->indexer->indexResultType}}; - } + TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}}); + if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, *ice)) + return {NormalizationResult::True, {tt->indexer->indexResultType}}; } diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index 82bfeca9..d5eac1f2 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -49,6 +49,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions2, false) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation) LUAU_FASTFLAG(LuauUserTypeFunFixRegister) +LUAU_FASTFLAG(LuauRemoveNotAnyHack) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) @@ -777,16 +778,8 @@ TypeFunctionReductionResult lenTypeFunction( if (normTy->hasTopTable() || get(normalizedOperand)) return {ctx->builtins->numberType, false, {}, {}}; - if (DFInt::LuauTypeSolverRelease >= 644) - { - if (auto result = tryDistributeTypeFunctionApp(lenTypeFunction, instance, typeParams, packParams, ctx)) - return *result; - } - else - { - if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) - return *result; - } + if (auto result = tryDistributeTypeFunctionApp(lenTypeFunction, instance, typeParams, packParams, ctx)) + return *result; // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. @@ -874,16 +867,8 @@ TypeFunctionReductionResult unmTypeFunction( if (normTy->isExactlyNumber()) return {ctx->builtins->numberType, false, {}, {}}; - if (DFInt::LuauTypeSolverRelease >= 644) - { - if (auto result = tryDistributeTypeFunctionApp(unmTypeFunction, instance, typeParams, packParams, ctx)) - return *result; - } - else - { - if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) - return *result; - } + if (auto result = tryDistributeTypeFunctionApp(unmTypeFunction, instance, typeParams, packParams, ctx)) + return *result; // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. @@ -1810,7 +1795,6 @@ struct FindRefinementBlockers : TypeOnceVisitor } }; - TypeFunctionReductionResult refineTypeFunction( TypeId instance, const std::vector& typeParams, @@ -1878,8 +1862,18 @@ TypeFunctionReductionResult refineTypeFunction( * We need to treat T & ~any as T in this case. */ if (auto nt = get(discriminant)) - if (get(follow(nt->ty))) - return {target, {}}; + { + if (FFlag::LuauRemoveNotAnyHack) + { + if (get(follow(nt->ty))) + return {target, {}}; + } + else + { + if (get(follow(nt->ty))) + return {target, {}}; + } + } // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. @@ -2059,6 +2053,15 @@ TypeFunctionReductionResult intersectTypeFunction( for (auto ty : typeParams) types.emplace_back(follow(ty)); + if (FFlag::LuauRemoveNotAnyHack) + { + // if we only have two parameters and one is `*no-refine*`, we're all done. + if (types.size() == 2 && get(types[1])) + return {types[0], false, {}, {}}; + else if (types.size() == 2 && get(types[0])) + return {types[1], false, {}, {}}; + } + // check to see if the operand types are resolved enough, and wait to reduce if not // if any of them are `never`, the intersection will always be `never`, so we can reduce directly. for (auto ty : types) @@ -2073,6 +2076,10 @@ TypeFunctionReductionResult intersectTypeFunction( TypeId resultTy = ctx->builtins->unknownType; for (auto ty : types) { + // skip any `*no-refine*` types. + if (FFlag::LuauRemoveNotAnyHack && get(ty)) + continue; + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, resultTy, ty); if (!result.blockedTypes.empty()) return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; diff --git a/Analysis/src/TypeFunctionRuntime.cpp b/Analysis/src/TypeFunctionRuntime.cpp index 768288ad..84fa0fea 100644 --- a/Analysis/src/TypeFunctionRuntime.cpp +++ b/Analysis/src/TypeFunctionRuntime.cpp @@ -16,6 +16,7 @@ LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixRegister, false) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite, false) namespace Luau { @@ -634,6 +635,8 @@ static int readTableProp(lua_State* L) auto prop = tftt->props.at(tfsst->value); if (prop.readTy) allocTypeUserData(L, (*prop.readTy)->type); + else if (FFlag::LuauUserTypeFunFixNoReadWrite) + lua_pushnil(L); else luaL_error(L, "type.readproperty: property %s is write-only, and therefore does not have a read type.", tfsst->value.c_str()); @@ -672,6 +675,8 @@ static int writeTableProp(lua_State* L) auto prop = tftt->props.at(tfsst->value); if (prop.writeTy) allocTypeUserData(L, (*prop.writeTy)->type); + else if (FFlag::LuauUserTypeFunFixNoReadWrite) + lua_pushnil(L); else luaL_error(L, "type.writeproperty: property %s is read-only, and therefore does not have a write type.", tfsst->value.c_str()); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index f1c60f06..1ed1b9e0 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -479,4 +479,68 @@ ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId return result; } +bool isLiteral(const AstExpr* expr) +{ + return ( + expr->is() || expr->is() || expr->is() || expr->is() || + expr->is() || expr->is() + ); +} +/** + * Visitor which, given an expression and a mapping from expression to TypeId, + * determines if there are any literal expressions that contain blocked types. + * This is used for bi-directional inference: we want to "apply" a type from + * a function argument or a type annotation to a literal. + */ +class BlockedTypeInLiteralVisitor : public AstVisitor +{ +public: + explicit BlockedTypeInLiteralVisitor(NotNull> astTypes, NotNull> toBlock) + : astTypes_{astTypes} + , toBlock_{toBlock} + { + } + bool visit(AstNode*) override + { + return false; + } + + bool visit(AstExpr* e) override + { + auto ty = astTypes_->find(e); + if (ty && (get(follow(*ty)) != nullptr)) + { + toBlock_->push_back(*ty); + } + return isLiteral(e) || e->is(); + } + +private: + NotNull> astTypes_; + NotNull> toBlock_; +}; + +std::vector findBlockedTypesIn(AstExprTable* expr, NotNull> astTypes) +{ + std::vector toBlock; + BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}}; + expr->visit(&v); + return toBlock; +} + +std::vector findBlockedArgTypesIn(AstExprCall* expr, NotNull> astTypes) +{ + std::vector toBlock; + BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}}; + for (auto arg: expr->args) + { + if (isLiteral(arg) || arg->is()) + { + arg->visit(&v); + } + } + return toBlock; +} + + } // namespace Luau diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 5f6fcf5e..76ed2a5a 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -188,9 +188,18 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc functionStack.reserve(8); functionStack.push_back(top); - nameSelf = names.addStatic("self"); - nameNumber = names.addStatic("number"); - nameError = names.addStatic(kParseNameError); + if (FFlag::LuauAllowFragmentParsing) + { + nameSelf = names.getOrAdd("self"); + nameNumber = names.getOrAdd("number"); + nameError = names.getOrAdd(kParseNameError); + } + else + { + nameSelf = names.addStatic("self"); + nameNumber = names.addStatic("number"); + nameError = names.addStatic(kParseNameError); + } nameNil = names.getOrAdd("nil"); // nil is a reserved keyword matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 0cffec40..400654b7 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,8 +17,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauErrorResumeCleanupArgs, false) - /* ** {====================================================== ** Error-recovery functions @@ -430,11 +428,7 @@ static void resume_handle(lua_State* L, void* ud) static int resume_error(lua_State* L, const char* msg, int narg) { - if (FFlag::LuauErrorResumeCleanupArgs) - L->top -= narg; - else - L->top = L->ci->base; - + L->top -= narg; setsvalue(L, L->top, luaS_new(L, msg)); incr_top(L); return LUA_ERRRUN; diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 7adaf0b4..879a9538 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -7,6 +7,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauMathMap, false) + #undef PI #define PI (3.14159265358979323846) #define RADIANS_PER_DEGREE (PI / 180.0) @@ -403,6 +405,19 @@ static int math_round(lua_State* L) return 1; } +static int math_map(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + double inmin = luaL_checknumber(L, 2); + double inmax = luaL_checknumber(L, 3); + double outmin = luaL_checknumber(L, 4); + double outmax = luaL_checknumber(L, 5); + + double result = outmin + (x - inmin) * (outmax - outmin) / (inmax - inmin); + lua_pushnumber(L, result); + return 1; +} + static const luaL_Reg mathlib[] = { {"abs", math_abs}, {"acos", math_acos}, @@ -455,5 +470,12 @@ int luaopen_math(lua_State* L) lua_setfield(L, -2, "pi"); lua_pushnumber(L, HUGE_VAL); lua_setfield(L, -2, "huge"); + + if (FFlag::LuauMathMap) + { + lua_pushcfunction(L, math_map, "map"); + lua_setfield(L, -2, "map"); + } + return 1; } diff --git a/bench/bench.py b/bench/bench.py index bc06b663..002dfadb 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -508,9 +508,6 @@ def runTest(subdir, filename, filepath): filepath = os.path.abspath(filepath) mainVm = os.path.abspath(arguments.vm) - if not os.path.isfile(mainVm): - print(f"{colored(Color.RED, 'ERROR')}: VM executable '{mainVm}' does not exist.") - sys.exit(1) # Process output will contain the test name and execution times mainOutput = getVmOutput(substituteArguments(mainVm, getExtraArguments(filepath)) + " " + filepath) @@ -890,11 +887,9 @@ def run(args, argsubcb): analyzeResult('', mainResult, compareResults) else: all_files = [subdir + os.sep + filename for subdir, dirs, files in os.walk(arguments.folder) for filename in files] - if len(all_files) == 0: - print(f"{colored(Color.YELLOW, 'WARNING')}: No test files found in '{arguments.folder}'.") for filepath in sorted(all_files): subdir, filename = os.path.split(filepath) - if filename.endswith(".lua") or filename.endswith(".luau"): + if filename.endswith(".lua"): if arguments.run_test == None or re.match(arguments.run_test, filename[:-4]): runTest(subdir, filename, filepath) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index df6e5332..e135cc52 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -31,6 +31,7 @@ extern int optimizationLevel; void luaC_fullgc(lua_State* L); void luaC_validate(lua_State* L); +LUAU_FASTFLAG(LuauMathMap) LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTFLAG(LuauNativeAttribute) @@ -652,6 +653,8 @@ TEST_CASE("Buffers") TEST_CASE("Math") { + ScopedFastFlag LuauMathMap{FFlag::LuauMathMap, true}; + runConformance("math.lua"); } diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index b8b7829d..692d6e0f 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -4,10 +4,13 @@ #include "Fixture.h" #include "Luau/Ast.h" #include "Luau/AstQuery.h" +#include "Luau/Common.h" using namespace Luau; +LUAU_FASTFLAG(LuauAllowFragmentParsing); + struct FragmentAutocompleteFixture : Fixture { @@ -17,9 +20,25 @@ struct FragmentAutocompleteFixture : Fixture REQUIRE(p.root); return findAncestryForFragmentParse(p.root, cursorPos); } + + CheckResult checkBase(const std::string& document) + { + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + FrontendOptions opts; + opts.retainFullTypeGraphs = true; + return this->frontend.check("MainModule", opts); + } + + FragmentParseResult parseFragment(const std::string& document, const Position& cursorPos) + { + ScopedFastFlag sffs[]{{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}}; + SourceModule* srcModule = this->getMainSourceModule(); + std::string_view srcString = document; + return Luau::parseFragment(*srcModule, srcString, cursorPos); + } }; -TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTest"); +TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals") { @@ -32,7 +51,7 @@ local y = 5 ); CHECK_EQ(3, result.ancestry.size()); - CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(1, result.localStack.size()); CHECK_EQ(result.localMap.size(), result.localStack.size()); REQUIRE(result.nearestStatement); @@ -56,10 +75,10 @@ end ); CHECK_EQ(5, result.ancestry.size()); - CHECK_EQ(3, result.localStack.size()); + CHECK_EQ(2, result.localStack.size()); CHECK_EQ(result.localMap.size(), result.localStack.size()); REQUIRE(result.nearestStatement); - CHECK_EQ("e", std::string(result.localStack.back()->name.value)); + CHECK_EQ("y", std::string(result.localStack.back()->name.value)); AstStatLocal* local = result.nearestStatement->as(); REQUIRE(local); @@ -85,10 +104,10 @@ end ); CHECK_EQ(6, result.ancestry.size()); - CHECK_EQ(4, result.localStack.size()); + CHECK_EQ(3, result.localStack.size()); CHECK_EQ(result.localMap.size(), result.localStack.size()); REQUIRE(result.nearestStatement); - CHECK_EQ("q", std::string(result.localStack.back()->name.value)); + CHECK_EQ("z", std::string(result.localStack.back()->name.value)); AstStatLocal* local = result.nearestStatement->as(); REQUIRE(local); @@ -129,11 +148,122 @@ local function bar() return x + foo() end ); CHECK_EQ(8, result.ancestry.size()); - CHECK_EQ(3, result.localStack.size()); + CHECK_EQ(2, result.localStack.size()); CHECK_EQ(result.localMap.size(), result.localStack.size()); - CHECK_EQ("bar", std::string(result.localStack.back()->name.value)); + CHECK_EQ("x", std::string(result.localStack.back()->name.value)); auto returnSt = result.nearestStatement->as(); CHECK(returnSt != nullptr); } TEST_SUITE_END(); + + +TEST_SUITE_BEGIN("FragmentAutocompleteParserTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") +{ + auto res = check(R"( + +)"); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = parseFragment( + R"( + +)", + Position(1, 0) + ); + CHECK_EQ("\n", fragment.fragmentToParse); + CHECK_EQ(2, fragment.ancestry.size()); + REQUIRE(fragment.root); + CHECK_EQ(0, fragment.root->body.size); + auto statBody = fragment.root->as(); + CHECK(statBody != nullptr); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_complete_fragments") +{ + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = parseFragment( + R"( +local x = 4 +local y = 5 +local z = x + y +)", + Position{3, 15} + ); + + CHECK_EQ("\nlocal z = x + y", fragment.fragmentToParse); + CHECK_EQ(5, fragment.ancestry.size()); + REQUIRE(fragment.root); + CHECK_EQ(1, fragment.root->body.size); + auto stat = fragment.root->body.data[0]->as(); + REQUIRE(stat); + CHECK_EQ(1, stat->vars.size); + CHECK_EQ(1, stat->values.size); + CHECK_EQ("z", std::string(stat->vars.data[0]->name.value)); + + auto bin = stat->values.data[0]->as(); + REQUIRE(bin); + CHECK_EQ(AstExprBinary::Op::Add, bin->op); + + auto lhs = bin->left->as(); + auto rhs = bin->right->as(); + REQUIRE(lhs); + REQUIRE(rhs); + CHECK_EQ("x", std::string(lhs->local->name.value)); + CHECK_EQ("y", std::string(rhs->local->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_fragments_in_line") +{ + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = parseFragment( + R"( +local x = 4 +local z = x + y +local y = 5 +)", + Position{2, 15} + ); + + CHECK_EQ("local z = x + y", fragment.fragmentToParse); + CHECK_EQ(5, fragment.ancestry.size()); + REQUIRE(fragment.root); + CHECK_EQ(1, fragment.root->body.size); + auto stat = fragment.root->body.data[0]->as(); + REQUIRE(stat); + CHECK_EQ(1, stat->vars.size); + CHECK_EQ(1, stat->values.size); + CHECK_EQ("z", std::string(stat->vars.data[0]->name.value)); + + auto bin = stat->values.data[0]->as(); + REQUIRE(bin); + CHECK_EQ(AstExprBinary::Op::Add, bin->op); + + auto lhs = bin->left->as(); + auto rhs = bin->right->as(); + REQUIRE(lhs); + REQUIRE(rhs); + CHECK_EQ("x", std::string(lhs->local->name.value)); + CHECK_EQ("y", std::string(rhs->name.value)); +} + +TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index a0de6f10..71a46878 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -3,6 +3,7 @@ #include "lualib.h" #include "Repl.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -12,6 +13,8 @@ #include #include +LUAU_FASTFLAG(LuauMathMap) + struct Completion { std::string completion; @@ -172,15 +175,17 @@ TEST_CASE_FIXTURE(ReplFixture, "CompleteGlobalVariables") CHECK(checkCompletion(completions, prefix, "myvariable1")); CHECK(checkCompletion(completions, prefix, "myvariable2")); } + if (FFlag::LuauMathMap) { // Try completing some builtin functions CompletionSet completions = getCompletionSet("math.m"); std::string prefix = "math."; - CHECK(completions.size() == 3); + CHECK(completions.size() == 4); CHECK(checkCompletion(completions, prefix, "max(")); CHECK(checkCompletion(completions, prefix, "min(")); CHECK(checkCompletion(completions, prefix, "modf(")); + CHECK(checkCompletion(completions, prefix, "map(")); } } diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp index 5a3dca10..b6160b3c 100644 --- a/tests/TypeFunction.user.test.cpp +++ b/tests/TypeFunction.user.test.cpp @@ -12,6 +12,7 @@ LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) LUAU_FASTFLAG(LuauUserDefinedTypeFunctions2) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation) LUAU_FASTFLAG(LuauUserTypeFunFixRegister) +LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite) TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); @@ -674,6 +675,36 @@ TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works") CHECK(toString(tpm->givenTp) == "{ BaseField: number, read BaseMethod: (BaseClass, number) -> (), read Touched: Connection }"); } +TEST_CASE_FIXTURE(ClassFixture, "write_of_readonly_is_nil") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag udtfRwFix{FFlag::LuauUserTypeFunFixNoReadWrite, true}; + + + CheckResult result = check(R"( + type function getclass(arg) + local props = arg:properties() + local table = types.newtable(props) + local singleton = types.singleton("BaseMethod") + + if table:writeproperty(singleton) then + return types.singleton(true) + else + return types.singleton(false) + end + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getclass): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "false"); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index c324901d..b2510fe0 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4891,4 +4891,41 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_union_type") ); } +TEST_CASE_FIXTURE(Fixture, "function_check_constraint_too_eager") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + // CLI-121540: All of these examples should have no errors. + + LUAU_CHECK_ERROR_COUNT(3, check(R"( + local function doTheThing(_: { [string]: unknown }) end + doTheThing({ + ['foo'] = 5, + ['bar'] = 'heyo', + }) + )")); + + LUAU_CHECK_ERROR_COUNT(1, check(R"( + type Input = { [string]: unknown } + + local i : Input = { + [('%s'):format('3.14')]=5, + ['stringField']='Heyo' + } + )")); + + // This example previously asserted due to eagerly mutating the underlying + // table type. + LUAU_CHECK_ERROR_COUNT(3, check(R"( + type Input = { [string]: unknown } + + local function doTheThing(_: Input) end + + doTheThing({ + [('%s'):format('3.14')]=5, + ['stringField']='Heyo' + }) + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 09c6c05b..fbc03213 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1683,4 +1683,27 @@ TEST_CASE_FIXTURE(Fixture, "leading_ampersand_no_type") CHECK("*error-type*" == toString(requireTypeAlias("Amp"))); } +TEST_CASE_FIXTURE(Fixture, "react_lua_follow_free_type_ub") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + return function(Roact) + local Tree = Roact.Component:extend("Tree") + + function Tree:render() + local breadth, components, depth, id, wrap = + self.props.breadth, self.props.components, self.props.depth, self.props.id, self.props.wrap + local Box = components.Box + if depth == 0 then + Roact.createElement(Box, {}) + else + Roact.createElement(Tree, {}) + end + + end + end + )")); +} + TEST_SUITE_END(); diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 98d5b317..97c44462 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -388,6 +388,20 @@ assert(math.pow(noinline(2), 2) == 4) assert(math.pow(noinline(4), 0.5) == 2) assert(math.pow(noinline(-2), 2) == 4) +-- map +assert(math.map(0, -1, 1, 0, 2) == 1) +assert(math.map(1, 1, 4, 0, 2) == 0) +assert(math.map(2.5, 1, 4, 0, 2) == 1) +assert(math.map(4, 1, 4, 0, 2) == 2) +assert(math.map(1, 1, 4, 2, 0) == 2) +assert(math.map(2.5, 1, 4, 2, 0) == 1) +assert(math.map(4, 1, 4, 2, 0) == 0) +assert(math.map(1, 4, 1, 2, 0) == 0) +assert(math.map(2.5, 4, 1, 2, 0) == 1) +assert(math.map(4, 4, 1, 2, 0) == 2) +assert(math.map(-8, 0, 4, 0, 2) == -4) +assert(math.map(16, 0, 4, 0, 2) == 8) + assert(tostring(math.pow(-2, 0.5)) == "nan") -- test that fastcalls return correct number of results