diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index 96bac9e4..b54f7a44 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -1,10 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/AutocompleteTypes.h" #include "Luau/Location.h" #include "Luau/Type.h" -#include #include #include #include @@ -16,90 +16,8 @@ struct Frontend; struct SourceModule; struct Module; struct TypeChecker; - -using ModulePtr = std::shared_ptr; - -enum class AutocompleteContext -{ - Unknown, - Expression, - Statement, - Property, - Type, - Keyword, - String, -}; - -enum class AutocompleteEntryKind -{ - Property, - Binding, - Keyword, - String, - Type, - Module, - GeneratedFunction, - RequirePath, -}; - -enum class ParenthesesRecommendation -{ - None, - CursorAfter, - CursorInside, -}; - -enum class TypeCorrectKind -{ - None, - Correct, - CorrectFunctionResult, -}; - -struct AutocompleteEntry -{ - AutocompleteEntryKind kind = AutocompleteEntryKind::Property; - // Nullopt if kind is Keyword - std::optional type = std::nullopt; - bool deprecated = false; - // Only meaningful if kind is Property. - bool wrongIndexType = false; - // Set if this suggestion matches the type expected in the context - TypeCorrectKind typeCorrect = TypeCorrectKind::None; - - std::optional containingClass = std::nullopt; - std::optional prop = std::nullopt; - std::optional documentationSymbol = std::nullopt; - Tags tags; - ParenthesesRecommendation parens = ParenthesesRecommendation::None; - std::optional insertText; - - // Only meaningful if kind is Property. - bool indexedWithSelf = false; -}; - -using AutocompleteEntryMap = std::unordered_map; -struct AutocompleteResult -{ - AutocompleteEntryMap entryMap; - std::vector ancestry; - AutocompleteContext context = AutocompleteContext::Unknown; - - AutocompleteResult() = default; - AutocompleteResult(AutocompleteEntryMap entryMap, std::vector ancestry, AutocompleteContext context) - : entryMap(std::move(entryMap)) - , ancestry(std::move(ancestry)) - , context(context) - { - } -}; - -using ModuleName = std::string; -using StringCompletionCallback = - std::function(std::string tag, std::optional ctx, std::optional contents)>; +struct FileResolver; AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); -constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)"; - } // namespace Luau diff --git a/Analysis/include/Luau/AutocompleteTypes.h b/Analysis/include/Luau/AutocompleteTypes.h new file mode 100644 index 00000000..37d45244 --- /dev/null +++ b/Analysis/include/Luau/AutocompleteTypes.h @@ -0,0 +1,92 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Type.h" + +#include + +namespace Luau +{ + +enum class AutocompleteContext +{ + Unknown, + Expression, + Statement, + Property, + Type, + Keyword, + String, +}; + +enum class AutocompleteEntryKind +{ + Property, + Binding, + Keyword, + String, + Type, + Module, + GeneratedFunction, + RequirePath, +}; + +enum class ParenthesesRecommendation +{ + None, + CursorAfter, + CursorInside, +}; + +enum class TypeCorrectKind +{ + None, + Correct, + CorrectFunctionResult, +}; + +struct AutocompleteEntry +{ + AutocompleteEntryKind kind = AutocompleteEntryKind::Property; + // Nullopt if kind is Keyword + std::optional type = std::nullopt; + bool deprecated = false; + // Only meaningful if kind is Property. + bool wrongIndexType = false; + // Set if this suggestion matches the type expected in the context + TypeCorrectKind typeCorrect = TypeCorrectKind::None; + + std::optional containingClass = std::nullopt; + std::optional prop = std::nullopt; + std::optional documentationSymbol = std::nullopt; + Tags tags; + ParenthesesRecommendation parens = ParenthesesRecommendation::None; + std::optional insertText; + + // Only meaningful if kind is Property. + bool indexedWithSelf = false; +}; + +using AutocompleteEntryMap = std::unordered_map; +struct AutocompleteResult +{ + AutocompleteEntryMap entryMap; + std::vector ancestry; + AutocompleteContext context = AutocompleteContext::Unknown; + + AutocompleteResult() = default; + AutocompleteResult(AutocompleteEntryMap entryMap, std::vector ancestry, AutocompleteContext context) + : entryMap(std::move(entryMap)) + , ancestry(std::move(ancestry)) + , context(context) + { + } +}; + +using StringCompletionCallback = + std::function(std::string tag, std::optional ctx, std::optional contents)>; + +constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)"; + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 435c62fb..b3b35fc2 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -5,6 +5,7 @@ #include "Luau/Constraint.h" #include "Luau/ControlFlow.h" #include "Luau/DataFlowGraph.h" +#include "Luau/EqSatSimplification.h" #include "Luau/InsertionOrderedMap.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" @@ -15,7 +16,6 @@ #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Variant.h" -#include "Luau/Normalize.h" #include #include @@ -109,6 +109,9 @@ struct ConstraintGenerator // Needed to be able to enable error-suppression preservation for immediate refinements. NotNull normalizer; + + NotNull simplifier; + // Needed to register all available type functions for execution at later stages. NotNull typeFunctionRuntime; // Needed to resolve modules to make 'require' import types properly. @@ -128,6 +131,7 @@ struct ConstraintGenerator ConstraintGenerator( ModulePtr module, NotNull normalizer, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull moduleResolver, NotNull builtinTypes, @@ -405,6 +409,7 @@ private: TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); // make an intersect type function of these two types TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + void prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program); /** Scan the program for global definitions. * @@ -435,6 +440,8 @@ private: const ScopePtr& scope, Location location ); + + TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right); }; /** Borrow a vector of pointers from a vector of owning pointers to constraints. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index c9336c1d..37042c75 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -5,6 +5,7 @@ #include "Luau/Constraint.h" #include "Luau/DataFlowGraph.h" #include "Luau/DenseHash.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Error.h" #include "Luau/Location.h" #include "Luau/Module.h" @@ -64,6 +65,7 @@ struct ConstraintSolver NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; + NotNull simplifier; NotNull typeFunctionRuntime; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; @@ -117,6 +119,7 @@ struct ConstraintSolver explicit ConstraintSolver( NotNull normalizer, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull rootScope, std::vector> constraints, @@ -384,6 +387,10 @@ public: **/ void reproduceConstraints(NotNull scope, const Location& location, const Substitution& subst); + TypeId simplifyIntersection(NotNull scope, Location location, TypeId left, TypeId right); + TypeId simplifyIntersection(NotNull scope, Location location, std::set parts); + TypeId simplifyUnion(NotNull scope, Location location, TypeId left, TypeId right); + TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; diff --git a/Analysis/include/Luau/EqSatSimplification.h b/Analysis/include/Luau/EqSatSimplification.h new file mode 100644 index 00000000..16d00849 --- /dev/null +++ b/Analysis/include/Luau/EqSatSimplification.h @@ -0,0 +1,50 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/TypeFwd.h" +#include "Luau/NotNull.h" +#include "Luau/DenseHash.h" + +#include +#include +#include + +namespace Luau +{ +struct TypeArena; +} + +// The EqSat stuff is pretty template heavy, so we go to some lengths to prevent +// the complexity from leaking outside its implementation sources. +namespace Luau::EqSatSimplification +{ + +struct Simplifier; + +using SimplifierPtr = std::unique_ptr; + +SimplifierPtr newSimplifier(NotNull arena, NotNull builtinTypes); + +} // namespace Luau::EqSatSimplification + +namespace Luau +{ + +struct EqSatSimplificationResult +{ + TypeId result; + + // New type function applications that were created by the reduction phase. + // We return these so that the ConstraintSolver can know to try to reduce + // them. + std::vector newTypeFunctions; +}; + +using EqSatSimplification::newSimplifier; // NOLINT: clang-tidy thinks these are unused. It is incorrect. +using Luau::EqSatSimplification::Simplifier; // NOLINT +using Luau::EqSatSimplification::SimplifierPtr; + +std::optional eqSatSimplify(NotNull simplifier, TypeId ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/EqSatSimplificationImpl.h b/Analysis/include/Luau/EqSatSimplificationImpl.h new file mode 100644 index 00000000..24e8777a --- /dev/null +++ b/Analysis/include/Luau/EqSatSimplificationImpl.h @@ -0,0 +1,363 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/EGraph.h" +#include "Luau/Id.h" +#include "Luau/Language.h" +#include "Luau/Lexer.h" // For Allocator +#include "Luau/NotNull.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ +struct TypeFunction; +} + +namespace Luau::EqSatSimplification +{ + +using StringId = uint32_t; +using Id = Luau::EqSat::Id; + +LUAU_EQSAT_UNIT(TNil); +LUAU_EQSAT_UNIT(TBoolean); +LUAU_EQSAT_UNIT(TNumber); +LUAU_EQSAT_UNIT(TString); +LUAU_EQSAT_UNIT(TThread); +LUAU_EQSAT_UNIT(TTopFunction); +LUAU_EQSAT_UNIT(TTopTable); +LUAU_EQSAT_UNIT(TTopClass); +LUAU_EQSAT_UNIT(TBuffer); + +// Used for any type that eqsat can't do anything interesting with. +LUAU_EQSAT_ATOM(TOpaque, TypeId); + +LUAU_EQSAT_ATOM(SBoolean, bool); +LUAU_EQSAT_ATOM(SString, StringId); + +LUAU_EQSAT_ATOM(TFunction, TypeId); + +LUAU_EQSAT_ATOM(TImportedTable, TypeId); + +LUAU_EQSAT_ATOM(TClass, TypeId); + +LUAU_EQSAT_UNIT(TAny); +LUAU_EQSAT_UNIT(TError); +LUAU_EQSAT_UNIT(TUnknown); +LUAU_EQSAT_UNIT(TNever); + +LUAU_EQSAT_NODE_SET(Union); +LUAU_EQSAT_NODE_SET(Intersection); + +LUAU_EQSAT_NODE_ARRAY(Negation, 1); + +LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*); + +LUAU_EQSAT_UNIT(TNoRefine); +LUAU_EQSAT_UNIT(Invalid); + +// enodes are immutable, but types are cyclic. We need a way to tie the knot. +// We handle this by generating TBound nodes at points where we encounter cycles. +// Each TBound has an ordinal that we later map onto the type. +// We use a substitution rule to replace all TBound nodes with their referrent. +LUAU_EQSAT_ATOM(TBound, size_t); + +// Tables are sufficiently unlike other enodes that the Language.h macros won't cut it. +struct TTable +{ + explicit TTable(Id basis); + TTable(Id basis, std::vector propNames_, std::vector propTypes_); + + // All TTables extend some other table. This may be TTopTable. + // + // It will frequently be a TImportedTable, in which case we can reuse things + // like source location and documentation info. + Id getBasis() const; + EqSat::Slice propTypes() const; + // TODO: Also support read-only table props + // TODO: Indexer type, index result type. + + std::vector propNames; + + // The enode interface + EqSat::Slice mutableOperands(); + EqSat::Slice operands() const; + bool operator==(const TTable& rhs) const; + bool operator!=(const TTable& rhs) const + { + return !(*this == rhs); + } + + struct Hash + { + size_t operator()(const TTable& value) const; + }; + +private: + // The first element of this vector is the basis. Subsequent elements are + // property types. As we add other things like read-only properties and + // indexers, the structure of this array is likely to change. + // + // We encode our data in this way so that the operands() method can properly + // return a Slice. + std::vector storage; +}; + +using EType = EqSat::Language< + TNil, + TBoolean, + TNumber, + TString, + TThread, + TTopFunction, + TTopTable, + TTopClass, + TBuffer, + + TOpaque, + + SBoolean, + SString, + + TFunction, + TTable, + TImportedTable, + TClass, + + TAny, + TError, + TUnknown, + TNever, + + Union, + Intersection, + + Negation, + + TTypeFun, + + Invalid, + TNoRefine, + TBound>; + + +struct StringCache +{ + Allocator allocator; + DenseHashMap strings{{}}; + std::vector views; + + StringId add(std::string_view s); + std::string_view asStringView(StringId id) const; + std::string asString(StringId id) const; +}; + +using EGraph = Luau::EqSat::EGraph; + +struct Simplify +{ + using Data = bool; + + template + Data make(const EGraph&, const T&) const; + + void join(Data& left, const Data& right) const; +}; + +struct Subst +{ + Id eclass; + Id newClass; + + std::string desc; + + Subst(Id eclass, Id newClass, std::string desc = ""); +}; + +struct Simplifier +{ + NotNull arena; + NotNull builtinTypes; + EGraph egraph; + StringCache stringCache; + + // enodes are immutable but types can be cyclic, so we need some way to + // encode the cycle. This map is used to connect TBound nodes to the right + // eclass. + // + // The cyclicIntersection rewrite rule uses this to sense when a cycle can + // be deleted from an intersection or union. + std::unordered_map mappingIdToClass; + + std::vector substs; + + using RewriteRuleFn = void (Simplifier::*)(Id id); + + Simplifier(NotNull arena, NotNull builtinTypes); + + // Utilities + const EqSat::EClass& get(Id id) const; + Id find(Id id) const; + Id add(EType enode); + + template + const Tag* isTag(Id id) const; + + template + const Tag* isTag(const EType& enode) const; + + void subst(Id from, Id to); + void subst(Id from, Id to, const std::string& ruleName); + void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map& forceNodes); + + void unionClasses(std::vector& hereParts, Id there); + + // Rewrite rules + void simplifyUnion(Id id); + void uninhabitedIntersection(Id id); + void intersectWithNegatedClass(Id id); + void intersectWithNoRefine(Id id); + void cyclicIntersectionOfUnion(Id id); + void cyclicUnionOfIntersection(Id id); + void expandNegation(Id id); + void intersectionOfUnion(Id id); + void intersectTableProperty(Id id); + void uninhabitedTable(Id id); + void unneededTableModification(Id id); + void builtinTypeFunctions(Id id); + void iffyTypeFunctions(Id id); +}; + +template +struct QueryIterator +{ + QueryIterator(); + QueryIterator(EGraph* egraph, Id eclass); + + bool operator==(const QueryIterator& other) const; + bool operator!=(const QueryIterator& other) const; + + std::pair operator*() const; + + QueryIterator& operator++(); + QueryIterator& operator++(int); + +private: + EGraph* egraph = nullptr; + Id eclass; + size_t index = 0; +}; + +template +struct Query +{ + EGraph* egraph; + Id eclass; + + Query(EGraph* egraph, Id eclass) + : egraph(egraph) + , eclass(eclass) + { + } + + QueryIterator begin() + { + return QueryIterator{egraph, eclass}; + } + + QueryIterator end() + { + return QueryIterator{}; + } +}; + +template +QueryIterator::QueryIterator() + : egraph(nullptr) + , eclass(Id{0}) + , index(0) +{ +} + +template +QueryIterator::QueryIterator(EGraph* egraph_, Id eclass) + : egraph(egraph_) + , eclass(eclass) + , index(0) +{ + const auto& ecl = (*egraph)[eclass]; + + static constexpr const int idx = EType::VariantTy::getTypeId(); + + for (const auto& enode : ecl.nodes) + { + if (enode.index() < idx) + ++index; + else + break; + } + + if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx) + { + egraph = nullptr; + index = 0; + } +} + +template +bool QueryIterator::operator==(const QueryIterator& rhs) const +{ + if (egraph == nullptr && rhs.egraph == nullptr) + return true; + + return egraph == rhs.egraph && eclass == rhs.eclass && index == rhs.index; +} + +template +bool QueryIterator::operator!=(const QueryIterator& rhs) const +{ + return !(*this == rhs); +} + +template +std::pair QueryIterator::operator*() const +{ + LUAU_ASSERT(egraph != nullptr); + + EGraph::EClassT& ecl = (*egraph)[eclass]; + + LUAU_ASSERT(index < ecl.nodes.size()); + auto& enode = ecl.nodes[index]; + Tag* result = enode.template get(); + LUAU_ASSERT(result); + return {result, index}; +} + +// pre-increment +template +QueryIterator& QueryIterator::operator++() +{ + const auto& ecl = (*egraph)[eclass]; + + ++index; + if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId()) + { + egraph = nullptr; + index = 0; + } + + return *this; +} + +// post-increment +template +QueryIterator& QueryIterator::operator++(int) +{ + QueryIterator res = *this; + ++res; + return res; +} + +} // namespace Luau::EqSatSimplification diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 2f17e566..d3fc6ad3 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -32,7 +32,11 @@ struct ModuleInfo bool optional = false; }; -using RequireSuggestion = std::string; +struct RequireSuggestion +{ + std::string label; + std::string fullPath; +}; using RequireSuggestions = std::vector; struct FileResolver diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h index 671cbb69..50c456f1 100644 --- a/Analysis/include/Luau/FragmentAutocomplete.h +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -3,9 +3,10 @@ #include "Luau/Ast.h" #include "Luau/Parser.h" -#include "Luau/Autocomplete.h" +#include "Luau/AutocompleteTypes.h" #include "Luau/DenseHash.h" #include "Luau/Module.h" +#include "Luau/Frontend.h" #include #include @@ -27,13 +28,23 @@ struct FragmentParseResult std::string fragmentToParse; AstStatBlock* root = nullptr; std::vector ancestry; + AstStat* nearestStatement = nullptr; std::unique_ptr alloc = std::make_unique(); }; struct FragmentTypeCheckResult { ModulePtr incrementalModule = nullptr; - Scope* freshScope = nullptr; + ScopePtr freshScope; + std::vector ancestry; +}; + +struct FragmentAutocompleteResult +{ + ModulePtr incrementalModule; + Scope* freshScope; + TypeArena arenaForAutocomplete; + AutocompleteResult acResults; }; FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); @@ -48,11 +59,11 @@ FragmentTypeCheckResult typecheckFragment( std::string_view src ); -AutocompleteResult fragmentAutocomplete( +FragmentAutocompleteResult fragmentAutocomplete( Frontend& frontend, std::string_view src, const ModuleName& moduleName, - Position& cursorPosition, + Position cursorPosition, std::optional opts, StringCompletionCallback callback ); diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index f8001e08..4862e3b4 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -44,6 +44,7 @@ struct ToStringOptions bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self + bool useQuestionMarks = true; // If true, use a postfix ? for options, else write them out as unions that include nil. size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index d100fa4d..0005605e 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -31,6 +31,7 @@ namespace Luau struct TypeArena; struct Scope; using ScopePtr = std::shared_ptr; +struct Module; struct TypeFunction; struct Constraint; @@ -598,6 +599,18 @@ struct ClassType } }; +// Data required to initialize a user-defined function and its environment +struct UserDefinedFunctionData +{ + // Store a weak module reference to ensure the lifetime requirements are preserved + std::weak_ptr owner; + + // References to AST elements are owned by the Module allocator which also stores this type + AstStatTypeFunction* definition = nullptr; + + DenseHashMap environment{""}; +}; + /** * An instance of a type function that has not yet been reduced to a more concrete * type. The constraint solver receives a constraint to reduce each @@ -613,17 +626,20 @@ struct TypeFunctionInstanceType std::vector packArguments; std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs + UserDefinedFunctionData userFuncData; TypeFunctionInstanceType( NotNull function, std::vector typeArguments, std::vector packArguments, - std::optional userFuncName = std::nullopt + std::optional userFuncName, + UserDefinedFunctionData userFuncData ) : function(function) , typeArguments(typeArguments) , packArguments(packArguments) , userFuncName(userFuncName) + , userFuncData(userFuncData) { } @@ -640,6 +656,13 @@ struct TypeFunctionInstanceType , packArguments(packArguments) { } + + TypeFunctionInstanceType(NotNull function, std::vector typeArguments, std::vector packArguments) + : function{function} + , typeArguments(typeArguments) + , packArguments(packArguments) + { + } }; /** Represents a pending type alias instantiation. diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 829f6bb7..eb7e2298 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -2,2000 +2,17 @@ #include "Luau/Autocomplete.h" #include "Luau/AstQuery.h" -#include "Luau/BuiltinDefinitions.h" -#include "Luau/Common.h" -#include "Luau/FileResolver.h" +#include "Luau/TypeArena.h" +#include "Luau/Module.h" #include "Luau/Frontend.h" -#include "Luau/ToString.h" -#include "Luau/Subtyping.h" -#include "Luau/TypeInfer.h" -#include "Luau/TypePack.h" -#include -#include -#include +#include "AutocompleteCore.h" LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions) - -LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) -LUAU_FASTINT(LuauTypeInferIterationLimit) -LUAU_FASTINT(LuauTypeInferRecursionLimit) - -static const std::unordered_set kStatementStartingKeywords = - {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; namespace Luau { - -static bool alreadyHasParens(const std::vector& nodes) -{ - auto iter = nodes.rbegin(); - while (iter != nodes.rend() && - ((*iter)->is() || (*iter)->is() || (*iter)->is() || (*iter)->is())) - { - iter++; - } - - if (iter == nodes.rend() || iter == nodes.rbegin()) - { - return false; - } - - if (AstExprCall* call = (*iter)->as()) - { - return call->func == *(iter - 1); - } - - return false; -} - -static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionType* func, const std::vector& nodes) -{ - if (alreadyHasParens(nodes)) - { - return ParenthesesRecommendation::None; - } - - auto idxExpr = nodes.back()->as(); - bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; - auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); - - if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) - return ParenthesesRecommendation::CursorInside; - - bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); - return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; -} - -static ParenthesesRecommendation getParenRecommendationForIntersect(const IntersectionType* intersect, const std::vector& nodes) -{ - ParenthesesRecommendation rec = ParenthesesRecommendation::None; - for (Luau::TypeId partId : intersect->parts) - { - if (auto partFunc = Luau::get(partId)) - { - rec = std::max(rec, getParenRecommendationForFunc(partFunc, nodes)); - } - else - { - return ParenthesesRecommendation::None; - } - } - return rec; -} - -static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::vector& nodes, TypeCorrectKind typeCorrect) -{ - // If element is already type-correct, even a function should be inserted without parenthesis - if (typeCorrect == TypeCorrectKind::Correct) - return ParenthesesRecommendation::None; - - id = Luau::follow(id); - if (auto func = get(id)) - { - return getParenRecommendationForFunc(func, nodes); - } - else if (auto intersect = get(id)) - { - return getParenRecommendationForIntersect(intersect, nodes); - } - return ParenthesesRecommendation::None; -} - -static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) -{ - auto expr = node->asExpr(); - if (!expr) - return std::nullopt; - - // Extra care for first function call argument location - // When we don't have anything inside () yet, we also don't have an AST node to base our lookup - if (AstExprCall* exprCall = expr->as()) - { - if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) - { - auto it = module.astTypes.find(exprCall->func); - - if (!it) - return std::nullopt; - - const FunctionType* ftv = get(follow(*it)); - - if (!ftv) - return std::nullopt; - - auto [head, tail] = flatten(ftv->argTypes); - unsigned index = exprCall->self ? 1 : 0; - - if (index < head.size()) - return head[index]; - - return std::nullopt; - } - } - - auto it = module.astExpectedTypes.find(expr); - if (!it) - return std::nullopt; - - return *it; -} - -static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena, NotNull builtinTypes) -{ - InternalErrorReporter iceReporter; - UnifierSharedState unifierState(&iceReporter); - Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; - - if (FFlag::LuauSolverV2) - { - TypeCheckLimits limits; - TypeFunctionRuntime typeFunctionRuntime{ - NotNull{&iceReporter}, NotNull{&limits} - }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime - - unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; - unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; - - Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; - - return subtyping.isSubtype(subTy, superTy, scope).isSubtype; - } - else - { - Unifier unifier(NotNull{&normalizer}, scope, Location(), Variance::Covariant); - - // Cost of normalization can be too high for autocomplete response time requirements - unifier.normalize = false; - unifier.checkInhabited = false; - - return unifier.canUnify(subTy, superTy).empty(); - } -} - -static TypeCorrectKind checkTypeCorrectKind( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - AstNode* node, - Position position, - TypeId ty -) -{ - ty = follow(ty); - - LUAU_ASSERT(module.hasModuleScope()); - - NotNull moduleScope{module.getModuleScope().get()}; - - auto typeAtPosition = findExpectedTypeAt(module, node, position); - - if (!typeAtPosition) - return TypeCorrectKind::None; - - TypeId expectedType = follow(*typeAtPosition); - - auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) - { - if (std::optional firstRetTy = first(ftv->retTypes)) - return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); - - return false; - }; - - // We also want to suggest functions that return compatible result - if (const FunctionType* ftv = get(ty); ftv && checkFunctionType(ftv)) - { - return TypeCorrectKind::CorrectFunctionResult; - } - else if (const IntersectionType* itv = get(ty)) - { - for (TypeId id : itv->parts) - { - id = follow(id); - - if (const FunctionType* ftv = get(id); ftv && checkFunctionType(ftv)) - { - return TypeCorrectKind::CorrectFunctionResult; - } - } - } - - return checkTypeMatch(ty, expectedType, moduleScope, typeArena, builtinTypes) ? TypeCorrectKind::Correct : TypeCorrectKind::None; -} - -enum class PropIndexType -{ - Point, - Colon, - Key, -}; - -static void autocompleteProps( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - TypeId rootTy, - TypeId ty, - PropIndexType indexType, - const std::vector& nodes, - AutocompleteEntryMap& result, - std::unordered_set& seen, - std::optional containingClass = std::nullopt -) -{ - rootTy = follow(rootTy); - ty = follow(ty); - - if (seen.count(ty)) - return; - seen.insert(ty); - - auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) - { - if (indexType == PropIndexType::Key) - return false; - - bool calledWithSelf = indexType == PropIndexType::Colon; - - auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) - { - // Strong match with definition is a success - if (calledWithSelf == ftv->hasSelf) - return true; - - // Calls on classes require strict match between how function is declared and how it's called - if (get(rootTy)) - return false; - - // When called with ':', but declared without 'self', it is invalid if a function has incompatible first argument or no arguments at all - // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible - if (std::optional firstArgTy = first(ftv->argTypes)) - { - if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes)) - return calledWithSelf; - } - - return !calledWithSelf; - }; - - if (const FunctionType* ftv = get(type)) - return !isCompatibleCall(ftv); - - // For intersections, any part that is successful makes the whole call successful - if (const IntersectionType* itv = get(type)) - { - for (auto subType : itv->parts) - { - if (const FunctionType* ftv = get(Luau::follow(subType))) - { - if (isCompatibleCall(ftv)) - return false; - } - } - } - - return calledWithSelf; - }; - - auto fillProps = [&](const ClassType::Props& props) - { - for (const auto& [name, prop] : props) - { - // We are walking up the class hierarchy, so if we encounter a property that we have - // already populated, it takes precedence over the property we found just now. - if (result.count(name) == 0 && name != kParseNameError) - { - Luau::TypeId type; - - if (FFlag::LuauSolverV2) - { - if (auto ty = prop.readTy) - type = follow(*ty); - else - continue; - } - else - type = follow(prop.type()); - - TypeCorrectKind typeCorrect = indexType == PropIndexType::Key - ? TypeCorrectKind::Correct - : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); - - ParenthesesRecommendation parens = - indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); - - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Property, - type, - prop.deprecated, - isWrongIndexer(type), - typeCorrect, - containingClass, - &prop, - prop.documentationSymbol, - {}, - parens, - {}, - indexType == PropIndexType::Colon - }; - } - } - }; - - auto fillMetatableProps = [&](const TableType* mtable) - { - auto indexIt = mtable->props.find("__index"); - if (indexIt != mtable->props.end()) - { - TypeId followed = follow(indexIt->second.type()); - if (get(followed) || get(followed)) - { - autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen); - } - else if (auto indexFunction = get(followed)) - { - std::optional indexFunctionResult = first(indexFunction->retTypes); - if (indexFunctionResult) - autocompleteProps(module, typeArena, builtinTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); - } - } - }; - - if (auto cls = get(ty)) - { - containingClass = containingClass.value_or(cls); - fillProps(cls->props); - if (cls->parent) - autocompleteProps(module, typeArena, builtinTypes, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); - } - else if (auto tbl = get(ty)) - fillProps(tbl->props); - else if (auto mt = get(ty)) - { - autocompleteProps(module, typeArena, builtinTypes, rootTy, mt->table, indexType, nodes, result, seen); - - if (auto mtable = get(follow(mt->metatable))) - fillMetatableProps(mtable); - } - else if (auto i = get(ty)) - { - // Complete all properties in every variant - for (TypeId ty : i->parts) - { - AutocompleteEntryMap inner; - std::unordered_set innerSeen = seen; - - autocompleteProps(module, typeArena, builtinTypes, rootTy, ty, indexType, nodes, inner, innerSeen); - - for (auto& pair : inner) - result.insert(pair); - } - } - else if (auto u = get(ty)) - { - // Complete all properties common to all variants - auto iter = begin(u); - auto endIter = end(u); - - while (iter != endIter) - { - if (isNil(*iter)) - ++iter; - else - break; - } - - if (iter == endIter) - return; - - autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, result, seen); - - ++iter; - - while (iter != endIter) - { - AutocompleteEntryMap inner; - std::unordered_set innerSeen; - - if (isNil(*iter)) - { - ++iter; - continue; - } - - autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, inner, innerSeen); - - std::unordered_set toRemove; - - for (const auto& [k, v] : result) - { - (void)v; - if (!inner.count(k)) - toRemove.insert(k); - } - - for (const std::string& k : toRemove) - result.erase(k); - - ++iter; - } - } - else if (auto pt = get(ty)) - { - if (pt->metatable) - { - if (auto mtable = get(*pt->metatable)) - fillMetatableProps(mtable); - } - } - else if (get(get(ty))) - { - autocompleteProps(module, typeArena, builtinTypes, rootTy, builtinTypes->stringType, indexType, nodes, result, seen); - } -} - -static void autocompleteKeywords( - const SourceModule& sourceModule, - const std::vector& ancestry, - Position position, - AutocompleteEntryMap& result -) -{ - LUAU_ASSERT(!ancestry.empty()); - - AstNode* node = ancestry.back(); - - if (!node->is() && node->asExpr()) - { - // This is not strictly correct. We should recommend `and` and `or` only after - // another expression, not at the start of a new one. We should only recommend - // `not` at the start of an expression. Detecting either case reliably is quite - // complex, however; this is good enough for now. - - // These are not context-sensitive keywords, so we can unconditionally assign. - result["and"] = {AutocompleteEntryKind::Keyword}; - result["or"] = {AutocompleteEntryKind::Keyword}; - result["not"] = {AutocompleteEntryKind::Keyword}; - } -} - -static void autocompleteProps( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - TypeId ty, - PropIndexType indexType, - const std::vector& nodes, - AutocompleteEntryMap& result -) -{ - std::unordered_set seen; - autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen); -} - -AutocompleteEntryMap autocompleteProps( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - TypeId ty, - PropIndexType indexType, - const std::vector& nodes -) -{ - AutocompleteEntryMap result; - autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result); - return result; -} - -AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position position, std::string_view moduleName) -{ - AutocompleteEntryMap result; - - for (ScopePtr scope = findScopeAtPosition(module, position); scope; scope = scope->parent) - { - if (auto it = scope->importedTypeBindings.find(std::string(moduleName)); it != scope->importedTypeBindings.end()) - { - for (const auto& [name, ty] : it->second) - result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type}; - - break; - } - } - - return result; -} - -static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result) -{ - if (position == node->location.begin || position == node->location.end) - { - if (auto str = node->as(); str && str->quoteStyle == AstExprConstantString::Quoted) - return; - else if (node->is()) - return; - } - - auto formatKey = [addQuotes](const std::string& key) - { - if (addQuotes) - return "\"" + escape(key) + "\""; - - return escape(key); - }; - - ty = follow(ty); - - if (auto ss = get(get(ty))) - { - result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; - } - else if (auto uty = get(ty)) - { - for (auto el : uty) - { - if (auto ss = get(get(el))) - result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; - } - } -}; - -static bool canSuggestInferredType(ScopePtr scope, TypeId ty) -{ - ty = follow(ty); - - // No point in suggesting 'any', invalid to suggest others - if (get(ty) || get(ty) || get(ty) || get(ty)) - return false; - - // No syntax for unnamed tables with a metatable - if (get(ty)) - return false; - - if (const TableType* ttv = get(ty)) - { - if (ttv->name) - return true; - - if (ttv->syntheticName) - return false; - } - - // We might still have a type with cycles or one that is too long, we'll check that later - return true; -} - -// Walk complex type trees to find the element that is being edited -static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position); - -static std::optional findTypeElementAt(const AstTypeList& astTypeList, TypePackId tp, Position position) -{ - for (size_t i = 0; i < astTypeList.types.size; i++) - { - AstType* type = astTypeList.types.data[i]; - - if (type->location.containsClosed(position)) - { - auto [head, _] = flatten(tp); - - if (i < head.size()) - return findTypeElementAt(type, head[i], position); - } - } - - if (AstTypePack* argTp = astTypeList.tailType) - { - if (auto variadic = argTp->as()) - { - if (variadic->location.containsClosed(position)) - { - auto [_, tail] = flatten(tp); - - if (tail) - { - if (const VariadicTypePack* vtp = get(follow(*tail))) - return findTypeElementAt(variadic->variadicType, vtp->ty, position); - } - } - } - } - - return {}; -} - -static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position) -{ - ty = follow(ty); - - if (astType->is()) - return ty; - - if (astType->is()) - return ty; - - if (AstTypeFunction* type = astType->as()) - { - const FunctionType* ftv = get(ty); - - if (!ftv) - return {}; - - if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) - return element; - - if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) - return element; - } - - // It's possible to walk through other types like intrsection and unions if we find value in doing that - return {}; -} - -std::optional getLocalTypeInScopeAt(const Module& module, Position position, AstLocal* local) -{ - if (ScopePtr scope = findScopeAtPosition(module, position)) - { - for (const auto& [name, binding] : scope->bindings) - { - if (name == local) - return binding.typeId; - } - } - - return {}; -} - -template -static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) -{ - ToStringOptions opts; - opts.useLineBreaks = false; - opts.hideTableKind = true; - opts.functionTypeArguments = functionTypeArguments; - opts.scope = scope; - ToStringResult name = toStringDetailed(ty, opts); - - if (name.error || name.invalid || name.cycle || name.truncated) - return std::nullopt; - - return name.name; -} - -static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool functionTypeArguments = false) -{ - if (!canSuggestInferredType(scope, ty)) - return std::nullopt; - - return tryToStringDetailed(scope, ty, functionTypeArguments); -} - -static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) -{ - std::optional ty; - - if (topType) - ty = findTypeElementAt(topType, inferredType, position); - else - ty = inferredType; - - if (!ty) - return false; - - if (auto name = tryGetTypeNameInScope(scope, *ty)) - { - if (auto it = result.find(*name); it != result.end()) - it->second.typeCorrect = TypeCorrectKind::Correct; - else - result[*name] = AutocompleteEntry{AutocompleteEntryKind::Type, *ty, false, false, TypeCorrectKind::Correct}; - - return true; - } - - return false; -} - -static std::optional tryGetTypePackTypeAt(TypePackId tp, size_t index) -{ - auto [tpHead, tpTail] = flatten(tp); - - if (index < tpHead.size()) - return tpHead[index]; - - // Infinite tail - if (tpTail) - { - if (const VariadicTypePack* vtp = get(follow(*tpTail))) - return vtp->ty; - } - - return {}; -} - -template -std::optional returnFirstNonnullOptionOfType(const UnionType* utv) -{ - std::optional ret; - for (TypeId subTy : utv) - { - if (isNil(subTy)) - continue; - - if (const T* ftv = get(follow(subTy))) - { - if (ret.has_value()) - { - return std::nullopt; - } - ret = ftv; - } - else - { - return std::nullopt; - } - } - return ret; -} - -static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) -{ - auto typeAtPosition = findExpectedTypeAt(module, node, position); - - if (!typeAtPosition) - return std::nullopt; - - TypeId expectedType = follow(*typeAtPosition); - - if (get(expectedType)) - return true; - - if (const IntersectionType* itv = get(expectedType)) - { - return std::all_of( - begin(itv->parts), - end(itv->parts), - [](auto&& ty) - { - return get(Luau::follow(ty)) != nullptr; - } - ); - } - - if (const UnionType* utv = get(expectedType)) - return returnFirstNonnullOptionOfType(utv).has_value(); - - return false; -} - -AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position position, const std::vector& ancestry) -{ - AutocompleteEntryMap result; - - ScopePtr startScope = findScopeAtPosition(module, position); - - for (ScopePtr scope = startScope; scope; scope = scope->parent) - { - for (const auto& [name, ty] : scope->exportedTypeBindings) - { - if (!result.count(name)) - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Type, - ty.type, - false, - false, - TypeCorrectKind::None, - std::nullopt, - std::nullopt, - ty.type->documentationSymbol - }; - } - - for (const auto& [name, ty] : scope->privateTypeBindings) - { - if (!result.count(name)) - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Type, - ty.type, - false, - false, - TypeCorrectKind::None, - std::nullopt, - std::nullopt, - ty.type->documentationSymbol - }; - } - - for (const auto& [name, _] : scope->importedTypeBindings) - { - if (auto binding = scope->linearSearchForBinding(name, true)) - { - if (!result.count(name)) - result[name] = AutocompleteEntry{AutocompleteEntryKind::Module, binding->typeId}; - } - } - } - - AstNode* parent = nullptr; - AstType* topType = nullptr; // TODO: rename? - - for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) - { - if (AstType* asType = (*it)->asType()) - { - topType = asType; - } - else - { - parent = *it; - break; - } - } - - if (!parent) - return result; - - if (AstStatLocal* node = parent->as()) // Try to provide inferred type of the local - { - // Look at which of the variable types we are defining - for (size_t i = 0; i < node->vars.size; i++) - { - AstLocal* var = node->vars.data[i]; - - if (var->annotation && var->annotation->location.containsClosed(position)) - { - if (node->values.size == 0) - break; - - unsigned tailPos = 0; - - // For multiple return values we will try to unpack last function call return type pack - if (i >= node->values.size) - { - tailPos = int(i) - int(node->values.size) + 1; - i = int(node->values.size) - 1; - } - - AstExpr* expr = node->values.data[i]->asExpr(); - - if (!expr) - break; - - TypeId inferredType = nullptr; - - if (AstExprCall* exprCall = expr->as()) - { - if (auto it = module.astTypes.find(exprCall->func)) - { - if (const FunctionType* ftv = get(follow(*it))) - { - if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) - inferredType = *ty; - } - } - } - else - { - if (tailPos != 0) - break; - - if (auto it = module.astTypes.find(expr)) - inferredType = *it; - } - - if (inferredType) - tryAddTypeCorrectSuggestion(result, startScope, topType, inferredType, position); - - break; - } - } - } - else if (AstExprFunction* node = parent->as()) - { - // For lookup inside expected function type if that's available - auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* - { - auto it = module.astExpectedTypes.find(expr); - - if (!it) - return nullptr; - - TypeId ty = follow(*it); - - if (const FunctionType* ftv = get(ty)) - return ftv; - - // Handle optional function type - if (const UnionType* utv = get(ty)) - { - return returnFirstNonnullOptionOfType(utv).value_or(nullptr); - } - - return nullptr; - }; - - // Find which argument type we are defining - for (size_t i = 0; i < node->args.size; i++) - { - AstLocal* arg = node->args.data[i]; - - if (arg->annotation && arg->annotation->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, i)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - // Otherwise, try to use the type inferred by typechecker - else if (auto inferredType = getLocalTypeInScopeAt(module, position, arg)) - { - tryAddTypeCorrectSuggestion(result, startScope, topType, *inferredType, position); - } - - break; - } - } - - if (AstTypePack* argTp = node->varargAnnotation) - { - if (auto variadic = argTp->as()) - { - if (variadic->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, ~0u)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - } - } - } - - if (!node->returnAnnotation) - return result; - - for (size_t i = 0; i < node->returnAnnotation->types.size; i++) - { - AstType* ret = node->returnAnnotation->types.data[i]; - - if (ret->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - - // TODO: with additional type information, we could suggest inferred return type here - break; - } - } - - if (AstTypePack* retTp = node->returnAnnotation->tailType) - { - if (auto variadic = retTp->as()) - { - if (variadic->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - } - } - } - } - - return result; -} - -static bool isInLocalNames(const std::vector& ancestry, Position position) -{ - for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) - { - if (auto statLocal = (*iter)->as()) - { - for (auto var : statLocal->vars) - { - if (var->location.containsClosed(position)) - { - return true; - } - } - } - else if (auto funcExpr = (*iter)->as()) - { - if (funcExpr->argLocation && funcExpr->argLocation->contains(position)) - { - return true; - } - } - else if (auto localFunc = (*iter)->as()) - { - return localFunc->name->location.containsClosed(position); - } - else if (auto block = (*iter)->as()) - { - if (block->body.size > 0) - { - return false; - } - } - else if ((*iter)->asStat()) - { - return false; - } - } - return false; -} - -static bool isIdentifier(AstNode* node) -{ - return node->is() || node->is(); -} - -static bool isBeingDefined(const std::vector& ancestry, const Symbol& symbol) -{ - // Current set of rules only check for local binding match - if (!symbol.local) - return false; - - for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) - { - if (auto statLocal = (*iter)->as()) - { - for (auto var : statLocal->vars) - { - if (symbol.local == var) - return true; - } - } - } - - return false; -} - -template -T* extractStat(const std::vector& ancestry) -{ - AstNode* node = ancestry.size() >= 1 ? ancestry.rbegin()[0] : nullptr; - if (!node) - return nullptr; - - if (T* t = node->as()) - return t; - - AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; - if (!parent) - return nullptr; - - AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; - AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; - - if (!grandParent) - return nullptr; - - if (T* t = parent->as(); t && grandParent->is()) - return t; - - if (!greatGrandParent) - return nullptr; - - if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) - return t; - - return nullptr; -} - -static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& binding, Position pos) -{ - if (symbol.local) - return binding.location.end < pos; - - // Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it - return binding.location == Location() || !binding.location.containsClosed(pos); -} - -static AutocompleteEntryMap autocompleteStatement( - const SourceModule& sourceModule, - const Module& module, - const std::vector& ancestry, - Position position -) -{ - // This is inefficient. :( - ScopePtr scope = findScopeAtPosition(module, position); - - AutocompleteEntryMap result; - - if (isInLocalNames(ancestry, position)) - { - autocompleteKeywords(sourceModule, ancestry, position, result); - return result; - } - - while (scope) - { - for (const auto& [name, binding] : scope->bindings) - { - if (!isBindingLegalAtCurrentPosition(name, binding, position)) - continue; - - std::string n = toString(name); - if (!result.count(n)) - result[n] = { - AutocompleteEntryKind::Binding, - binding.typeId, - binding.deprecated, - false, - TypeCorrectKind::None, - std::nullopt, - std::nullopt, - binding.documentationSymbol, - {}, - getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None) - }; - } - - scope = scope->parent; - } - - for (const auto& kw : kStatementStartingKeywords) - result.emplace(kw, AutocompleteEntry{AutocompleteEntryKind::Keyword}); - - for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) - { - if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatIf* statIf = (*it)->as()) - { - bool hasEnd = statIf->thenbody->hasEnd; - if (statIf->elsebody) - { - if (AstStatBlock* elseBlock = statIf->elsebody->as()) - hasEnd = elseBlock->hasEnd; - } - - if (!hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - - if (ancestry.size() >= 2) - { - AstNode* parent = ancestry.rbegin()[1]; - if (AstStatIf* statIf = parent->as()) - { - if (!statIf->elsebody || (statIf->elseLocation && statIf->elseLocation->containsClosed(position))) - { - result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - } - - if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - - if (ancestry.size() >= 4) - { - auto iter = ancestry.rbegin(); - if (AstStatIf* statIf = iter[3]->as(); - statIf != nullptr && !statIf->elsebody && iter[2]->is() && iter[1]->is() && isIdentifier(iter[0])) - { - result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - } - - if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - - return result; -} - -// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) -static bool autocompleteIfElseExpression( - const AstNode* node, - const std::vector& ancestry, - const Position& position, - AutocompleteEntryMap& outResult -) -{ - AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; - if (!parent) - return false; - - if (node->is()) - { - // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else - // expression. - return true; - } - - AstExprIfElse* ifElseExpr = parent->as(); - if (!ifElseExpr || ifElseExpr->condition->location.containsClosed(position)) - { - return false; - } - else if (!ifElseExpr->hasThen) - { - outResult["then"] = {AutocompleteEntryKind::Keyword}; - return true; - } - else if (ifElseExpr->trueExpr->location.containsClosed(position)) - { - return false; - } - else if (!ifElseExpr->hasElse) - { - outResult["else"] = {AutocompleteEntryKind::Keyword}; - outResult["elseif"] = {AutocompleteEntryKind::Keyword}; - return true; - } - else - { - return false; - } -} - -static AutocompleteContext autocompleteExpression( - const SourceModule& sourceModule, - const Module& module, - NotNull builtinTypes, - TypeArena* typeArena, - const std::vector& ancestry, - Position position, - AutocompleteEntryMap& result -) -{ - LUAU_ASSERT(!ancestry.empty()); - - AstNode* node = ancestry.rbegin()[0]; - - if (node->is()) - { - if (auto it = module.astTypes.find(node->asExpr())) - autocompleteProps(module, typeArena, builtinTypes, *it, PropIndexType::Point, ancestry, result); - } - else if (autocompleteIfElseExpression(node, ancestry, position, result)) - return AutocompleteContext::Keyword; - else if (node->is()) - return AutocompleteContext::Unknown; - else - { - // This is inefficient. :( - ScopePtr scope = findScopeAtPosition(module, position); - - while (scope) - { - for (const auto& [name, binding] : scope->bindings) - { - if (!isBindingLegalAtCurrentPosition(name, binding, position)) - continue; - - if (isBeingDefined(ancestry, name)) - continue; - - std::string n = toString(name); - if (!result.count(n)) - { - TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId); - - result[n] = { - AutocompleteEntryKind::Binding, - binding.typeId, - binding.deprecated, - false, - typeCorrect, - std::nullopt, - std::nullopt, - binding.documentationSymbol, - {}, - getParenRecommendation(binding.typeId, ancestry, typeCorrect) - }; - } - } - - scope = scope->parent; - } - - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->nilType); - TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->trueType); - TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->falseType); - TypeCorrectKind correctForFunction = - functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForTrue}; - result["false"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForFalse}; - result["nil"] = {AutocompleteEntryKind::Keyword, builtinTypes->nilType, false, false, correctForNil}; - result["not"] = {AutocompleteEntryKind::Keyword}; - result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; - - if (auto ty = findExpectedTypeAt(module, node, position)) - autocompleteStringSingleton(*ty, true, node, position, result); - } - - return AutocompleteContext::Expression; -} - -static AutocompleteResult autocompleteExpression( - const SourceModule& sourceModule, - const Module& module, - NotNull builtinTypes, - TypeArena* typeArena, - const std::vector& ancestry, - Position position -) -{ - AutocompleteEntryMap result; - AutocompleteContext context = autocompleteExpression(sourceModule, module, builtinTypes, typeArena, ancestry, position, result); - return {result, ancestry, context}; -} - -static std::optional getMethodContainingClass(const ModulePtr& module, AstExpr* funcExpr) -{ - AstExpr* parentExpr = nullptr; - if (auto indexName = funcExpr->as()) - { - parentExpr = indexName->expr; - } - else if (auto indexExpr = funcExpr->as()) - { - parentExpr = indexExpr->expr; - } - else - { - return std::nullopt; - } - - auto parentIt = module->astTypes.find(parentExpr); - if (!parentIt) - { - return std::nullopt; - } - - Luau::TypeId parentType = Luau::follow(*parentIt); - - if (auto parentClass = Luau::get(parentType)) - { - return parentClass; - } - - if (auto parentUnion = Luau::get(parentType)) - { - return returnFirstNonnullOptionOfType(parentUnion); - } - - return std::nullopt; -} - -static bool stringPartOfInterpString(const AstNode* node, Position position) -{ - const AstExprInterpString* interpString = node->as(); - if (!interpString) - { - return false; - } - - for (const AstExpr* expression : interpString->expressions) - { - if (expression->location.containsClosed(position)) - { - return false; - } - } - - return true; -} - -static bool isSimpleInterpolatedString(const AstNode* node) -{ - const AstExprInterpString* interpString = node->as(); - return interpString != nullptr && interpString->expressions.size == 0; -} - -static std::optional getStringContents(const AstNode* node) -{ - if (const AstExprConstantString* string = node->as()) - { - return std::string(string->value.data, string->value.size); - } - else if (const AstExprInterpString* interpString = node->as(); interpString && interpString->expressions.size == 0) - { - LUAU_ASSERT(interpString->strings.size == 1); - return std::string(interpString->strings.data->data, interpString->strings.data->size); - } - else - { - return std::nullopt; - } -} - -static std::optional convertRequireSuggestionsToAutocompleteEntryMap(std::optional suggestions) -{ - if (!suggestions) - return std::nullopt; - - AutocompleteEntryMap result; - for (const RequireSuggestion& suggestion : *suggestions) - { - result[suggestion] = {AutocompleteEntryKind::RequirePath}; - } - return result; -} - -static std::optional autocompleteStringParams( - const SourceModule& sourceModule, - const ModulePtr& module, - const std::vector& nodes, - Position position, - FileResolver* fileResolver, - StringCompletionCallback callback -) -{ - if (nodes.size() < 2) - { - return std::nullopt; - } - - if (!nodes.back()->is() && !isSimpleInterpolatedString(nodes.back()) && !nodes.back()->is()) - { - return std::nullopt; - } - - if (!nodes.back()->is()) - { - if (nodes.back()->location.end == position || nodes.back()->location.begin == position) - { - return std::nullopt; - } - } - - AstExprCall* candidate = nodes.at(nodes.size() - 2)->as(); - if (!candidate) - { - return std::nullopt; - } - - // HACK: All current instances of 'magic string' params are the first parameter of their functions, - // so we encode that here rather than putting a useless member on the FunctionType struct. - if (candidate->args.size > 1 && !candidate->args.data[0]->location.contains(position)) - { - return std::nullopt; - } - - auto it = module->astTypes.find(candidate->func); - if (!it) - { - return std::nullopt; - } - - std::optional candidateString = getStringContents(nodes.back()); - - auto performCallback = [&](const FunctionType* funcType) -> std::optional - { - for (const std::string& tag : funcType->tags) - { - if (FFlag::AutocompleteRequirePathSuggestions) - { - if (tag == kRequireTagName && fileResolver) - { - return convertRequireSuggestionsToAutocompleteEntryMap(fileResolver->getRequireSuggestions(module->name, candidateString)); - } - } - if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) - { - return ret; - } - } - return std::nullopt; - }; - - auto followedId = Luau::follow(*it); - if (auto functionType = Luau::get(followedId)) - { - return performCallback(functionType); - } - - if (auto intersect = Luau::get(followedId)) - { - for (TypeId part : intersect->parts) - { - if (auto candidateFunctionType = Luau::get(part)) - { - if (std::optional ret = performCallback(candidateFunctionType)) - { - return ret; - } - } - } - } - - return std::nullopt; -} - -static AutocompleteResult autocompleteWhileLoopKeywords(std::vector ancestry) -{ - AutocompleteEntryMap ret; - ret["do"] = {AutocompleteEntryKind::Keyword}; - ret["and"] = {AutocompleteEntryKind::Keyword}; - ret["or"] = {AutocompleteEntryKind::Keyword}; - return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword}; -} - -static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) -{ - std::string result = "function("; - - auto [args, tail] = Luau::flatten(funcTy.argTypes); - - bool first = true; - // Skip the implicit 'self' argument if call is indexed with ':' - for (size_t argIdx = 0; argIdx < args.size(); ++argIdx) - { - if (!first) - result += ", "; - else - first = false; - - std::string name; - if (argIdx < funcTy.argNames.size() && funcTy.argNames[argIdx]) - name = funcTy.argNames[argIdx]->name; - else - name = "a" + std::to_string(argIdx); - - if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) - result += name + ": " + *type; - else - result += name; - } - - if (tail && (Luau::isVariadic(*tail) || Luau::get(Luau::follow(*tail)))) - { - if (!first) - result += ", "; - - std::optional varArgType; - if (const VariadicTypePack* pack = get(follow(*tail))) - { - if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) - varArgType = std::move(res); - } - - if (varArgType) - result += "...: " + *varArgType; - else - result += "..."; - } - - result += ")"; - - auto [rets, retTail] = Luau::flatten(funcTy.retTypes); - if (const size_t totalRetSize = rets.size() + (retTail ? 1 : 0); totalRetSize > 0) - { - if (std::optional returnTypes = tryToStringDetailed(scope, funcTy.retTypes, true)) - { - result += ": "; - bool wrap = totalRetSize != 1; - if (wrap) - result += "("; - result += *returnTypes; - if (wrap) - result += ")"; - } - } - result += " end"; - return result; -} - -static std::optional makeAnonymousAutofilled( - const ModulePtr& module, - Position position, - const AstNode* node, - const std::vector& ancestry -) -{ - const AstExprCall* call = node->as(); - if (!call && ancestry.size() > 1) - call = ancestry[ancestry.size() - 2]->as(); - - if (!call) - return std::nullopt; - - if (!call->location.containsClosed(position) || call->func->location.containsClosed(position)) - return std::nullopt; - - TypeId* typeIter = module->astTypes.find(call->func); - if (!typeIter) - return std::nullopt; - - const FunctionType* outerFunction = get(follow(*typeIter)); - if (!outerFunction) - return std::nullopt; - - size_t argument = 0; - for (size_t i = 0; i < call->args.size; ++i) - { - if (call->args.data[i]->location.containsClosed(position)) - { - argument = i; - break; - } - } - - if (call->self) - argument++; - - std::optional argType; - auto [args, tail] = flatten(outerFunction->argTypes); - if (argument < args.size()) - argType = args[argument]; - - if (!argType) - return std::nullopt; - - TypeId followed = follow(*argType); - const FunctionType* type = get(followed); - if (!type) - { - if (const UnionType* unionType = get(followed)) - { - if (std::optional nonnullFunction = returnFirstNonnullOptionOfType(unionType)) - type = *nonnullFunction; - } - } - - if (!type) - return std::nullopt; - - const ScopePtr scope = findScopeAtPosition(*module, position); - if (!scope) - return std::nullopt; - - AutocompleteEntry entry; - entry.kind = AutocompleteEntryKind::GeneratedFunction; - entry.typeCorrect = TypeCorrectKind::Correct; - entry.type = argType; - entry.insertText = makeAnonymous(scope, *type); - return std::make_optional(std::move(entry)); -} - -static AutocompleteResult autocomplete( - const SourceModule& sourceModule, - const ModulePtr& module, - NotNull builtinTypes, - TypeArena* typeArena, - Scope* globalScope, - Position position, - FileResolver* fileResolver, - StringCompletionCallback callback -) -{ - if (isWithinComment(sourceModule, position)) - return {}; - - std::vector ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position); - LUAU_ASSERT(!ancestry.empty()); - AstNode* node = ancestry.back(); - - AstExprConstantNil dummy{Location{}}; - AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; - - // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node - if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) - { - ancestry.pop_back(); - - node = ancestry.back(); - parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; - } - - if (auto indexName = node->as()) - { - auto it = module->astTypes.find(indexName->expr); - if (!it) - return {}; - - TypeId ty = follow(*it); - PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - - return {autocompleteProps(*module, typeArena, builtinTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; - } - else if (auto typeReference = node->as()) - { - if (typeReference->prefix) - return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), ancestry, AutocompleteContext::Type}; - else - return {autocompleteTypeNames(*module, position, ancestry), ancestry, AutocompleteContext::Type}; - } - else if (node->is()) - { - return {autocompleteTypeNames(*module, position, ancestry), ancestry, AutocompleteContext::Type}; - } - else if (AstStatLocal* statLocal = node->as()) - { - if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) - return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Unknown}; - else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - else - return {}; - } - - else if (AstStatFor* statFor = extractStat(ancestry)) - { - if (!statFor->hasDo || position < statFor->doLocation.begin) - { - if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || - (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - return {}; - } - - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - } - - else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) - { - if (!statForIn->hasIn || position <= statForIn->inLocation.begin) - { - AstLocal* lastName = statForIn->vars.data[statForIn->vars.size - 1]; - if (lastName->name == kParseNameError || lastName->location.containsClosed(position)) - { - // Here we are either working with a missing binding (as would be the case in a bare "for" keyword) or - // the cursor is still touching a binding name. The user is still typing a new name, so we should not offer - // any suggestions. - return {}; - } - - return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } - - if (!statForIn->hasDo || position <= statForIn->doLocation.begin) - { - LUAU_ASSERT(statForIn->values.size > 0); - AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; - - if (lastExpr->location.containsClosed(position)) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (position > lastExpr->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - - return {}; // Not sure what this means - } - } - else if (AstStatForIn* statForIn = extractStat(ancestry)) - { - // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. - // ex "for f in f do" - if (!statForIn->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - } - - else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) - { - if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) - { - return autocompleteWhileLoopKeywords(ancestry); - } - - if (!statWhile->hasDo || position < statWhile->doLocation.begin) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (statWhile->hasDo && position > statWhile->doLocation.end) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - } - - else if (AstStatWhile* statWhile = extractStat(ancestry); - (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition && - !statWhile->condition->location.containsClosed(position))) - { - return autocompleteWhileLoopKeywords(ancestry); - } - else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) - { - return { - {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, - ancestry, - AutocompleteContext::Keyword - }; - } - else if (AstStatIf* statIf = parent->as(); statIf && node->is()) - { - if (statIf->condition->is()) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } - else if (AstStatIf* statIf = extractStat(ancestry); statIf && - (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && - (statIf->condition && !statIf->condition->location.containsClosed(position))) - { - AutocompleteEntryMap ret; - ret["then"] = {AutocompleteEntryKind::Keyword}; - ret["and"] = {AutocompleteEntryKind::Keyword}; - ret["or"] = {AutocompleteEntryKind::Keyword}; - return {std::move(ret), ancestry, AutocompleteContext::Keyword}; - } - else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - else if (AstExprTable* exprTable = parent->as(); - exprTable && (node->is() || node->is() || node->is())) - { - for (const auto& [kind, key, value] : exprTable->items) - { - // If item doesn't have a key, maybe the value is actually the key - if (key ? key == node : node->is() && value == node) - { - if (auto it = module->astExpectedTypes.find(exprTable)) - { - auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); - - if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*nodeIt, !node->is(), node, position, result); - - if (!key) - { - // If there is "no key," it may be that the user - // intends for the current token to be the key, but - // has yet to type the `=` sign. - // - // If the key type is a union of singleton strings, - // suggest those too. - if (auto ttv = get(follow(*it)); ttv && ttv->indexer) - { - autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); - } - } - - // Remove keys that are already completed - for (const auto& item : exprTable->items) - { - if (!item.key) - continue; - - if (auto stringKey = item.key->as()) - result.erase(std::string(stringKey->value.data, stringKey->value.size)); - } - - // If we know for sure that a key is being written, do not offer general expression suggestions - if (!key) - autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position, result); - - return {result, ancestry, AutocompleteContext::Property}; - } - - break; - } - } - } - else if (AstExprTable* exprTable = node->as()) - { - AutocompleteEntryMap result; - - if (auto it = module->astExpectedTypes.find(exprTable)) - { - result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); - - // If the key type is a union of singleton strings, - // suggest those too. - if (auto ttv = get(follow(*it)); ttv && ttv->indexer) - { - autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); - } - - // Remove keys that are already completed - for (const auto& item : exprTable->items) - { - if (!item.key) - continue; - - if (auto stringKey = item.key->as()) - result.erase(std::string(stringKey->value.data, stringKey->value.size)); - } - } - - // Also offer general expression suggestions - autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position, result); - - return {result, ancestry, AutocompleteContext::Property}; - } - else if (isIdentifier(node) && (parent->is() || parent->is())) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - - if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, fileResolver, callback)) - { - return {*ret, ancestry, AutocompleteContext::String}; - } - else if (node->is() || isSimpleInterpolatedString(node)) - { - AutocompleteEntryMap result; - - if (auto it = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*it, false, node, position, result); - - if (ancestry.size() >= 2) - { - if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) - { - if (auto it = module->astTypes.find(idxExpr->expr)) - autocompleteProps(*module, typeArena, builtinTypes, follow(*it), PropIndexType::Point, ancestry, result); - } - else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) - { - if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) - { - if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) - autocompleteStringSingleton(*it, false, node, position, result); - } - } - } - - return {result, ancestry, AutocompleteContext::String}; - } - else if (stringPartOfInterpString(node, position)) - { - // We're not a simple interpolated string, we're something like `a{"b"}@1`, and we - // can't know what to format to - AutocompleteEntryMap map; - return {map, ancestry, AutocompleteContext::String}; - } - - if (node->is()) - return {}; - - if (node->asExpr()) - { - AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) - ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); - return ret; - } - else if (node->asStat()) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - - return {}; -} - AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { const SourceModule* sourceModule = frontend.getSourceModule(moduleName); @@ -2019,7 +36,13 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName globalScope = frontend.globalsForAutocomplete.globalScope.get(); TypeArena typeArena; - return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, frontend.fileResolver, callback); + if (isWithinComment(*sourceModule, position)) + return {}; + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*sourceModule, position); + LUAU_ASSERT(!ancestry.empty()); + ScopePtr startScope = findScopeAtPosition(*module, position); + return autocomplete_(module, builtinTypes, &typeArena, ancestry, globalScope, startScope, position, frontend.fileResolver, callback); } } // namespace Luau diff --git a/Analysis/src/AutocompleteCore.cpp b/Analysis/src/AutocompleteCore.cpp new file mode 100644 index 00000000..ee045771 --- /dev/null +++ b/Analysis/src/AutocompleteCore.cpp @@ -0,0 +1,2002 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "AutocompleteCore.h" + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/AutocompleteTypes.h" + +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/FileResolver.h" +#include "Luau/Frontend.h" +#include "Luau/ToString.h" +#include "Luau/Subtyping.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypePack.h" + +#include +#include +#include + +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions2) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTypeInferRecursionLimit) + +LUAU_FASTFLAGVARIABLE(LuauAutocompleteRefactorsForIncrementalAutocomplete) + +static const std::unordered_set kStatementStartingKeywords = + {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; + +namespace Luau +{ + +static bool alreadyHasParens(const std::vector& nodes) +{ + auto iter = nodes.rbegin(); + while (iter != nodes.rend() && + ((*iter)->is() || (*iter)->is() || (*iter)->is() || (*iter)->is())) + { + iter++; + } + + if (iter == nodes.rend() || iter == nodes.rbegin()) + { + return false; + } + + if (AstExprCall* call = (*iter)->as()) + { + return call->func == *(iter - 1); + } + + return false; +} + +static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionType* func, const std::vector& nodes) +{ + if (alreadyHasParens(nodes)) + { + return ParenthesesRecommendation::None; + } + + auto idxExpr = nodes.back()->as(); + bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; + auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); + + if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) + return ParenthesesRecommendation::CursorInside; + + bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); + return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; +} + +static ParenthesesRecommendation getParenRecommendationForIntersect(const IntersectionType* intersect, const std::vector& nodes) +{ + ParenthesesRecommendation rec = ParenthesesRecommendation::None; + for (Luau::TypeId partId : intersect->parts) + { + if (auto partFunc = Luau::get(partId)) + { + rec = std::max(rec, getParenRecommendationForFunc(partFunc, nodes)); + } + else + { + return ParenthesesRecommendation::None; + } + } + return rec; +} + +static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::vector& nodes, TypeCorrectKind typeCorrect) +{ + // If element is already type-correct, even a function should be inserted without parenthesis + if (typeCorrect == TypeCorrectKind::Correct) + return ParenthesesRecommendation::None; + + id = Luau::follow(id); + if (auto func = get(id)) + { + return getParenRecommendationForFunc(func, nodes); + } + else if (auto intersect = get(id)) + { + return getParenRecommendationForIntersect(intersect, nodes); + } + return ParenthesesRecommendation::None; +} + +static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) +{ + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + // Extra care for first function call argument location + // When we don't have anything inside () yet, we also don't have an AST node to base our lookup + if (AstExprCall* exprCall = expr->as()) + { + if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) + { + auto it = module.astTypes.find(exprCall->func); + + if (!it) + return std::nullopt; + + const FunctionType* ftv = get(follow(*it)); + + if (!ftv) + return std::nullopt; + + auto [head, tail] = flatten(ftv->argTypes); + unsigned index = exprCall->self ? 1 : 0; + + if (index < head.size()) + return head[index]; + + return std::nullopt; + } + } + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + return *it; +} + +static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena, NotNull builtinTypes) +{ + InternalErrorReporter iceReporter; + UnifierSharedState unifierState(&iceReporter); + Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; + + if (FFlag::LuauSolverV2) + { + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{ + NotNull{&iceReporter}, NotNull{&limits} + }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime + + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + + Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; + + return subtyping.isSubtype(subTy, superTy, scope).isSubtype; + } + else + { + Unifier unifier(NotNull{&normalizer}, scope, Location(), Variance::Covariant); + + // Cost of normalization can be too high for autocomplete response time requirements + unifier.normalize = false; + unifier.checkInhabited = false; + + return unifier.canUnify(subTy, superTy).empty(); + } +} + +static TypeCorrectKind checkTypeCorrectKind( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + AstNode* node, + Position position, + TypeId ty +) +{ + ty = follow(ty); + + LUAU_ASSERT(module.hasModuleScope()); + + NotNull moduleScope{module.getModuleScope().get()}; + + auto typeAtPosition = findExpectedTypeAt(module, node, position); + + if (!typeAtPosition) + return TypeCorrectKind::None; + + TypeId expectedType = follow(*typeAtPosition); + + auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) + { + if (std::optional firstRetTy = first(ftv->retTypes)) + return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); + + return false; + }; + + // We also want to suggest functions that return compatible result + if (const FunctionType* ftv = get(ty); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + else if (const IntersectionType* itv = get(ty)) + { + for (TypeId id : itv->parts) + { + id = follow(id); + + if (const FunctionType* ftv = get(id); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + } + } + + return checkTypeMatch(ty, expectedType, moduleScope, typeArena, builtinTypes) ? TypeCorrectKind::Correct : TypeCorrectKind::None; +} + +enum class PropIndexType +{ + Point, + Colon, + Key, +}; + +static void autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId rootTy, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes, + AutocompleteEntryMap& result, + std::unordered_set& seen, + std::optional containingClass = std::nullopt +) +{ + rootTy = follow(rootTy); + ty = follow(ty); + + if (seen.count(ty)) + return; + seen.insert(ty); + + auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) + { + if (indexType == PropIndexType::Key) + return false; + + bool calledWithSelf = indexType == PropIndexType::Colon; + + auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) + { + // Strong match with definition is a success + if (calledWithSelf == ftv->hasSelf) + return true; + + // Calls on classes require strict match between how function is declared and how it's called + if (get(rootTy)) + return false; + + // When called with ':', but declared without 'self', it is invalid if a function has incompatible first argument or no arguments at all + // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible + if (std::optional firstArgTy = first(ftv->argTypes)) + { + if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes)) + return calledWithSelf; + } + + return !calledWithSelf; + }; + + if (const FunctionType* ftv = get(type)) + return !isCompatibleCall(ftv); + + // For intersections, any part that is successful makes the whole call successful + if (const IntersectionType* itv = get(type)) + { + for (auto subType : itv->parts) + { + if (const FunctionType* ftv = get(Luau::follow(subType))) + { + if (isCompatibleCall(ftv)) + return false; + } + } + } + + return calledWithSelf; + }; + + auto fillProps = [&](const ClassType::Props& props) + { + for (const auto& [name, prop] : props) + { + // We are walking up the class hierarchy, so if we encounter a property that we have + // already populated, it takes precedence over the property we found just now. + if (result.count(name) == 0 && name != kParseNameError) + { + Luau::TypeId type; + + if (FFlag::LuauSolverV2) + { + if (auto ty = prop.readTy) + type = follow(*ty); + else + continue; + } + else + type = follow(prop.type()); + + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key + ? TypeCorrectKind::Correct + : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); + + ParenthesesRecommendation parens = + indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); + + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Property, + type, + prop.deprecated, + isWrongIndexer(type), + typeCorrect, + containingClass, + &prop, + prop.documentationSymbol, + {}, + parens, + {}, + indexType == PropIndexType::Colon + }; + } + } + }; + + auto fillMetatableProps = [&](const TableType* mtable) + { + auto indexIt = mtable->props.find("__index"); + if (indexIt != mtable->props.end()) + { + TypeId followed = follow(indexIt->second.type()); + if (get(followed) || get(followed)) + { + autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen); + } + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retTypes); + if (indexFunctionResult) + autocompleteProps(module, typeArena, builtinTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + } + } + }; + + if (auto cls = get(ty)) + { + containingClass = containingClass.value_or(cls); + fillProps(cls->props); + if (cls->parent) + autocompleteProps(module, typeArena, builtinTypes, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); + } + else if (auto tbl = get(ty)) + fillProps(tbl->props); + else if (auto mt = get(ty)) + { + autocompleteProps(module, typeArena, builtinTypes, rootTy, mt->table, indexType, nodes, result, seen); + + if (auto mtable = get(follow(mt->metatable))) + fillMetatableProps(mtable); + } + else if (auto i = get(ty)) + { + // Complete all properties in every variant + for (TypeId ty : i->parts) + { + AutocompleteEntryMap inner; + std::unordered_set innerSeen = seen; + + autocompleteProps(module, typeArena, builtinTypes, rootTy, ty, indexType, nodes, inner, innerSeen); + + for (auto& pair : inner) + result.insert(pair); + } + } + else if (auto u = get(ty)) + { + // Complete all properties common to all variants + auto iter = begin(u); + auto endIter = end(u); + + while (iter != endIter) + { + if (isNil(*iter)) + ++iter; + else + break; + } + + if (iter == endIter) + return; + + autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, result, seen); + + ++iter; + + while (iter != endIter) + { + AutocompleteEntryMap inner; + std::unordered_set innerSeen; + + if (isNil(*iter)) + { + ++iter; + continue; + } + + autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, inner, innerSeen); + + std::unordered_set toRemove; + + for (const auto& [k, v] : result) + { + (void)v; + if (!inner.count(k)) + toRemove.insert(k); + } + + for (const std::string& k : toRemove) + result.erase(k); + + ++iter; + } + } + else if (auto pt = get(ty)) + { + if (pt->metatable) + { + if (auto mtable = get(*pt->metatable)) + fillMetatableProps(mtable); + } + } + else if (get(get(ty))) + { + autocompleteProps(module, typeArena, builtinTypes, rootTy, builtinTypes->stringType, indexType, nodes, result, seen); + } +} + +static void autocompleteKeywords(const std::vector& ancestry, Position position, AutocompleteEntryMap& result) +{ + LUAU_ASSERT(!ancestry.empty()); + + AstNode* node = ancestry.back(); + + if (!node->is() && node->asExpr()) + { + // This is not strictly correct. We should recommend `and` and `or` only after + // another expression, not at the start of a new one. We should only recommend + // `not` at the start of an expression. Detecting either case reliably is quite + // complex, however; this is good enough for now. + + // These are not context-sensitive keywords, so we can unconditionally assign. + result["and"] = {AutocompleteEntryKind::Keyword}; + result["or"] = {AutocompleteEntryKind::Keyword}; + result["not"] = {AutocompleteEntryKind::Keyword}; + } +} + +static void autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes, + AutocompleteEntryMap& result +) +{ + std::unordered_set seen; + autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen); +} + +AutocompleteEntryMap autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes +) +{ + AutocompleteEntryMap result; + autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result); + return result; +} + +AutocompleteEntryMap autocompleteModuleTypes(const Module& module, const ScopePtr& scopeAtPosition, Position position, std::string_view moduleName) +{ + AutocompleteEntryMap result; + ScopePtr startScope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + for (ScopePtr& scope = startScope; scope; scope = scope->parent) + { + if (auto it = scope->importedTypeBindings.find(std::string(moduleName)); it != scope->importedTypeBindings.end()) + { + for (const auto& [name, ty] : it->second) + result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type}; + + break; + } + } + + return result; +} + +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result) +{ + if (position == node->location.begin || position == node->location.end) + { + if (auto str = node->as(); str && str->isQuoted()) + return; + else if (node->is()) + return; + } + + auto formatKey = [addQuotes](const std::string& key) + { + if (addQuotes) + return "\"" + escape(key) + "\""; + + return escape(key); + }; + + ty = follow(ty); + + if (auto ss = get(get(ty))) + { + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + else if (auto uty = get(ty)) + { + for (auto el : uty) + { + if (auto ss = get(get(el))) + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + } +}; + +static bool canSuggestInferredType(ScopePtr scope, TypeId ty) +{ + ty = follow(ty); + + // No point in suggesting 'any', invalid to suggest others + if (get(ty) || get(ty) || get(ty) || get(ty)) + return false; + + // No syntax for unnamed tables with a metatable + if (get(ty)) + return false; + + if (const TableType* ttv = get(ty)) + { + if (ttv->name) + return true; + + if (ttv->syntheticName) + return false; + } + + // We might still have a type with cycles or one that is too long, we'll check that later + return true; +} + +// Walk complex type trees to find the element that is being edited +static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position); + +static std::optional findTypeElementAt(const AstTypeList& astTypeList, TypePackId tp, Position position) +{ + for (size_t i = 0; i < astTypeList.types.size; i++) + { + AstType* type = astTypeList.types.data[i]; + + if (type->location.containsClosed(position)) + { + auto [head, _] = flatten(tp); + + if (i < head.size()) + return findTypeElementAt(type, head[i], position); + } + } + + if (AstTypePack* argTp = astTypeList.tailType) + { + if (auto variadic = argTp->as()) + { + if (variadic->location.containsClosed(position)) + { + auto [_, tail] = flatten(tp); + + if (tail) + { + if (const VariadicTypePack* vtp = get(follow(*tail))) + return findTypeElementAt(variadic->variadicType, vtp->ty, position); + } + } + } + } + + return {}; +} + +static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position) +{ + ty = follow(ty); + + if (astType->is()) + return ty; + + if (astType->is()) + return ty; + + if (AstTypeFunction* type = astType->as()) + { + const FunctionType* ftv = get(ty); + + if (!ftv) + return {}; + + if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) + return element; + + if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) + return element; + } + + // It's possible to walk through other types like intrsection and unions if we find value in doing that + return {}; +} + +std::optional getLocalTypeInScopeAt(const Module& module, const ScopePtr& scopeAtPosition, Position position, AstLocal* local) +{ + if (ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position)) + { + for (const auto& [name, binding] : scope->bindings) + { + if (name == local) + return binding.typeId; + } + } + + return {}; +} + +template +static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) +{ + ToStringOptions opts; + opts.useLineBreaks = false; + opts.hideTableKind = true; + opts.functionTypeArguments = functionTypeArguments; + opts.scope = scope; + ToStringResult name = toStringDetailed(ty, opts); + + if (name.error || name.invalid || name.cycle || name.truncated) + return std::nullopt; + + return name.name; +} + +static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool functionTypeArguments = false) +{ + if (!canSuggestInferredType(scope, ty)) + return std::nullopt; + + return tryToStringDetailed(scope, ty, functionTypeArguments); +} + +static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) +{ + std::optional ty; + + if (topType) + ty = findTypeElementAt(topType, inferredType, position); + else + ty = inferredType; + + if (!ty) + return false; + + if (auto name = tryGetTypeNameInScope(scope, *ty)) + { + if (auto it = result.find(*name); it != result.end()) + it->second.typeCorrect = TypeCorrectKind::Correct; + else + result[*name] = AutocompleteEntry{AutocompleteEntryKind::Type, *ty, false, false, TypeCorrectKind::Correct}; + + return true; + } + + return false; +} + +static std::optional tryGetTypePackTypeAt(TypePackId tp, size_t index) +{ + auto [tpHead, tpTail] = flatten(tp); + + if (index < tpHead.size()) + return tpHead[index]; + + // Infinite tail + if (tpTail) + { + if (const VariadicTypePack* vtp = get(follow(*tpTail))) + return vtp->ty; + } + + return {}; +} + +template +std::optional returnFirstNonnullOptionOfType(const UnionType* utv) +{ + std::optional ret; + for (TypeId subTy : utv) + { + if (isNil(subTy)) + continue; + + if (const T* ftv = get(follow(subTy))) + { + if (ret.has_value()) + { + return std::nullopt; + } + ret = ftv; + } + else + { + return std::nullopt; + } + } + return ret; +} + +static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) +{ + auto typeAtPosition = findExpectedTypeAt(module, node, position); + + if (!typeAtPosition) + return std::nullopt; + + TypeId expectedType = follow(*typeAtPosition); + + if (get(expectedType)) + return true; + + if (const IntersectionType* itv = get(expectedType)) + { + return std::all_of( + begin(itv->parts), + end(itv->parts), + [](auto&& ty) + { + return get(Luau::follow(ty)) != nullptr; + } + ); + } + + if (const UnionType* utv = get(expectedType)) + return returnFirstNonnullOptionOfType(utv).has_value(); + + return false; +} + +AutocompleteEntryMap autocompleteTypeNames( + const Module& module, + const ScopePtr& scopeAtPosition, + Position& position, + const std::vector& ancestry +) +{ + AutocompleteEntryMap result; + + ScopePtr startScope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + + for (ScopePtr scope = startScope; scope; scope = scope->parent) + { + for (const auto& [name, ty] : scope->exportedTypeBindings) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Type, + ty.type, + false, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + ty.type->documentationSymbol + }; + } + + for (const auto& [name, ty] : scope->privateTypeBindings) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Type, + ty.type, + false, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + ty.type->documentationSymbol + }; + } + + for (const auto& [name, _] : scope->importedTypeBindings) + { + if (auto binding = scope->linearSearchForBinding(name, true)) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{AutocompleteEntryKind::Module, binding->typeId}; + } + } + } + + AstNode* parent = nullptr; + AstType* topType = nullptr; // TODO: rename? + + for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) + { + if (AstType* asType = (*it)->asType()) + { + topType = asType; + } + else + { + parent = *it; + break; + } + } + + if (!parent) + return result; + + if (AstStatLocal* node = parent->as()) // Try to provide inferred type of the local + { + // Look at which of the variable types we are defining + for (size_t i = 0; i < node->vars.size; i++) + { + AstLocal* var = node->vars.data[i]; + + if (var->annotation && var->annotation->location.containsClosed(position)) + { + if (node->values.size == 0) + break; + + unsigned tailPos = 0; + + // For multiple return values we will try to unpack last function call return type pack + if (i >= node->values.size) + { + tailPos = int(i) - int(node->values.size) + 1; + i = int(node->values.size) - 1; + } + + AstExpr* expr = node->values.data[i]->asExpr(); + + if (!expr) + break; + + TypeId inferredType = nullptr; + + if (AstExprCall* exprCall = expr->as()) + { + if (auto it = module.astTypes.find(exprCall->func)) + { + if (const FunctionType* ftv = get(follow(*it))) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) + inferredType = *ty; + } + } + } + else + { + if (tailPos != 0) + break; + + if (auto it = module.astTypes.find(expr)) + inferredType = *it; + } + + if (inferredType) + tryAddTypeCorrectSuggestion(result, startScope, topType, inferredType, position); + + break; + } + } + } + else if (AstExprFunction* node = parent->as()) + { + // For lookup inside expected function type if that's available + auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* + { + auto it = module.astExpectedTypes.find(expr); + + if (!it) + return nullptr; + + TypeId ty = follow(*it); + + if (const FunctionType* ftv = get(ty)) + return ftv; + + // Handle optional function type + if (const UnionType* utv = get(ty)) + { + return returnFirstNonnullOptionOfType(utv).value_or(nullptr); + } + + return nullptr; + }; + + // Find which argument type we are defining + for (size_t i = 0; i < node->args.size; i++) + { + AstLocal* arg = node->args.data[i]; + + if (arg->annotation && arg->annotation->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, i)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + // Otherwise, try to use the type inferred by typechecker + else if (auto inferredType = getLocalTypeInScopeAt(module, scopeAtPosition, position, arg)) + { + tryAddTypeCorrectSuggestion(result, startScope, topType, *inferredType, position); + } + + break; + } + } + + if (AstTypePack* argTp = node->varargAnnotation) + { + if (auto variadic = argTp->as()) + { + if (variadic->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, ~0u)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + } + } + } + + if (!node->returnAnnotation) + return result; + + for (size_t i = 0; i < node->returnAnnotation->types.size; i++) + { + AstType* ret = node->returnAnnotation->types.data[i]; + + if (ret->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + + // TODO: with additional type information, we could suggest inferred return type here + break; + } + } + + if (AstTypePack* retTp = node->returnAnnotation->tailType) + { + if (auto variadic = retTp->as()) + { + if (variadic->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + } + } + } + } + + return result; +} + +static bool isInLocalNames(const std::vector& ancestry, Position position) +{ + for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) + { + if (auto statLocal = (*iter)->as()) + { + for (auto var : statLocal->vars) + { + if (var->location.containsClosed(position)) + { + return true; + } + } + } + else if (auto funcExpr = (*iter)->as()) + { + if (funcExpr->argLocation && funcExpr->argLocation->contains(position)) + { + return true; + } + } + else if (auto localFunc = (*iter)->as()) + { + return localFunc->name->location.containsClosed(position); + } + else if (auto block = (*iter)->as()) + { + if (block->body.size > 0) + { + return false; + } + } + else if ((*iter)->asStat()) + { + return false; + } + } + return false; +} + +static bool isIdentifier(AstNode* node) +{ + return node->is() || node->is(); +} + +static bool isBeingDefined(const std::vector& ancestry, const Symbol& symbol) +{ + // Current set of rules only check for local binding match + if (!symbol.local) + return false; + + for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) + { + if (auto statLocal = (*iter)->as()) + { + for (auto var : statLocal->vars) + { + if (symbol.local == var) + return true; + } + } + } + + return false; +} + +template +T* extractStat(const std::vector& ancestry) +{ + AstNode* node = ancestry.size() >= 1 ? ancestry.rbegin()[0] : nullptr; + if (!node) + return nullptr; + + if (T* t = node->as()) + return t; + + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; + if (!parent) + return nullptr; + + AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; + AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; + + if (!grandParent) + return nullptr; + + if (T* t = parent->as(); t && grandParent->is()) + return t; + + if (!greatGrandParent) + return nullptr; + + if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) + return t; + + return nullptr; +} + +static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& binding, Position pos) +{ + if (symbol.local) + return binding.location.end < pos; + + // Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it + return binding.location == Location() || !binding.location.containsClosed(pos); +} + +static AutocompleteEntryMap autocompleteStatement( + const Module& module, + const std::vector& ancestry, + const ScopePtr& scopeAtPosition, + Position& position +) +{ + // This is inefficient. :( + ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + + AutocompleteEntryMap result; + + if (isInLocalNames(ancestry, position)) + { + autocompleteKeywords(ancestry, position, result); + return result; + } + + while (scope) + { + for (const auto& [name, binding] : scope->bindings) + { + if (!isBindingLegalAtCurrentPosition(name, binding, position)) + continue; + + std::string n = toString(name); + if (!result.count(n)) + result[n] = { + AutocompleteEntryKind::Binding, + binding.typeId, + binding.deprecated, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + binding.documentationSymbol, + {}, + getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None) + }; + } + + scope = scope->parent; + } + + for (const auto& kw : kStatementStartingKeywords) + result.emplace(kw, AutocompleteEntry{AutocompleteEntryKind::Keyword}); + + for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) + { + if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstStatIf* statIf = (*it)->as()) + { + bool hasEnd = statIf->thenbody->hasEnd; + if (statIf->elsebody) + { + if (AstStatBlock* elseBlock = statIf->elsebody->as()) + hasEnd = elseBlock->hasEnd; + } + + if (!hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + + if (ancestry.size() >= 2) + { + AstNode* parent = ancestry.rbegin()[1]; + if (AstStatIf* statIf = parent->as()) + { + if (!statIf->elsebody || (statIf->elseLocation && statIf->elseLocation->containsClosed(position))) + { + result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + } + + if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + + if (ancestry.size() >= 4) + { + auto iter = ancestry.rbegin(); + if (AstStatIf* statIf = iter[3]->as(); + statIf != nullptr && !statIf->elsebody && iter[2]->is() && iter[1]->is() && isIdentifier(iter[0])) + { + result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + } + + if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + + return result; +} + +// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) +static bool autocompleteIfElseExpression( + const AstNode* node, + const std::vector& ancestry, + const Position& position, + AutocompleteEntryMap& outResult +) +{ + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; + if (!parent) + return false; + + if (node->is()) + { + // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else + // expression. + return true; + } + + AstExprIfElse* ifElseExpr = parent->as(); + if (!ifElseExpr || ifElseExpr->condition->location.containsClosed(position)) + { + return false; + } + else if (!ifElseExpr->hasThen) + { + outResult["then"] = {AutocompleteEntryKind::Keyword}; + return true; + } + else if (ifElseExpr->trueExpr->location.containsClosed(position)) + { + return false; + } + else if (!ifElseExpr->hasElse) + { + outResult["else"] = {AutocompleteEntryKind::Keyword}; + outResult["elseif"] = {AutocompleteEntryKind::Keyword}; + return true; + } + else + { + return false; + } +} + +static AutocompleteContext autocompleteExpression( + const Module& module, + NotNull builtinTypes, + TypeArena* typeArena, + const std::vector& ancestry, + const ScopePtr& scopeAtPosition, + Position position, + AutocompleteEntryMap& result +) +{ + LUAU_ASSERT(!ancestry.empty()); + + AstNode* node = ancestry.rbegin()[0]; + + if (node->is()) + { + if (auto it = module.astTypes.find(node->asExpr())) + autocompleteProps(module, typeArena, builtinTypes, *it, PropIndexType::Point, ancestry, result); + } + else if (autocompleteIfElseExpression(node, ancestry, position, result)) + return AutocompleteContext::Keyword; + else if (node->is()) + return AutocompleteContext::Unknown; + else + { + // This is inefficient. :( + ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + + while (scope) + { + for (const auto& [name, binding] : scope->bindings) + { + if (!isBindingLegalAtCurrentPosition(name, binding, position)) + continue; + + if (isBeingDefined(ancestry, name)) + continue; + + std::string n = toString(name); + if (!result.count(n)) + { + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId); + + result[n] = { + AutocompleteEntryKind::Binding, + binding.typeId, + binding.deprecated, + false, + typeCorrect, + std::nullopt, + std::nullopt, + binding.documentationSymbol, + {}, + getParenRecommendation(binding.typeId, ancestry, typeCorrect) + }; + } + } + + scope = scope->parent; + } + + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->nilType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->falseType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForFalse}; + result["nil"] = {AutocompleteEntryKind::Keyword, builtinTypes->nilType, false, false, correctForNil}; + result["not"] = {AutocompleteEntryKind::Keyword}; + result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + + if (auto ty = findExpectedTypeAt(module, node, position)) + autocompleteStringSingleton(*ty, true, node, position, result); + } + + return AutocompleteContext::Expression; +} + +static AutocompleteResult autocompleteExpression( + const Module& module, + NotNull builtinTypes, + TypeArena* typeArena, + const std::vector& ancestry, + const ScopePtr& scopeAtPosition, + Position position +) +{ + AutocompleteEntryMap result; + AutocompleteContext context = autocompleteExpression(module, builtinTypes, typeArena, ancestry, scopeAtPosition, position, result); + return {result, ancestry, context}; +} + +static std::optional getMethodContainingClass(const ModulePtr& module, AstExpr* funcExpr) +{ + AstExpr* parentExpr = nullptr; + if (auto indexName = funcExpr->as()) + { + parentExpr = indexName->expr; + } + else if (auto indexExpr = funcExpr->as()) + { + parentExpr = indexExpr->expr; + } + else + { + return std::nullopt; + } + + auto parentIt = module->astTypes.find(parentExpr); + if (!parentIt) + { + return std::nullopt; + } + + Luau::TypeId parentType = Luau::follow(*parentIt); + + if (auto parentClass = Luau::get(parentType)) + { + return parentClass; + } + + if (auto parentUnion = Luau::get(parentType)) + { + return returnFirstNonnullOptionOfType(parentUnion); + } + + return std::nullopt; +} + +static bool stringPartOfInterpString(const AstNode* node, Position position) +{ + const AstExprInterpString* interpString = node->as(); + if (!interpString) + { + return false; + } + + for (const AstExpr* expression : interpString->expressions) + { + if (expression->location.containsClosed(position)) + { + return false; + } + } + + return true; +} + +static bool isSimpleInterpolatedString(const AstNode* node) +{ + const AstExprInterpString* interpString = node->as(); + return interpString != nullptr && interpString->expressions.size == 0; +} + +static std::optional getStringContents(const AstNode* node) +{ + if (const AstExprConstantString* string = node->as()) + { + return std::string(string->value.data, string->value.size); + } + else if (const AstExprInterpString* interpString = node->as(); interpString && interpString->expressions.size == 0) + { + LUAU_ASSERT(interpString->strings.size == 1); + return std::string(interpString->strings.data->data, interpString->strings.data->size); + } + else + { + return std::nullopt; + } +} + +static std::optional convertRequireSuggestionsToAutocompleteEntryMap(std::optional suggestions) +{ + if (!suggestions) + return std::nullopt; + + AutocompleteEntryMap result; + for (const RequireSuggestion& suggestion : *suggestions) + { + AutocompleteEntry entry = {AutocompleteEntryKind::RequirePath}; + entry.insertText = std::move(suggestion.fullPath); + result[std::move(suggestion.label)] = std::move(entry); + } + return result; +} + +static std::optional autocompleteStringParams( + const ModulePtr& module, + const std::vector& nodes, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +) +{ + if (nodes.size() < 2) + { + return std::nullopt; + } + + if (!nodes.back()->is() && !isSimpleInterpolatedString(nodes.back()) && !nodes.back()->is()) + { + return std::nullopt; + } + + if (!nodes.back()->is()) + { + if (nodes.back()->location.end == position || nodes.back()->location.begin == position) + { + return std::nullopt; + } + } + + AstExprCall* candidate = nodes.at(nodes.size() - 2)->as(); + if (!candidate) + { + return std::nullopt; + } + + // HACK: All current instances of 'magic string' params are the first parameter of their functions, + // so we encode that here rather than putting a useless member on the FunctionType struct. + if (candidate->args.size > 1 && !candidate->args.data[0]->location.contains(position)) + { + return std::nullopt; + } + + auto it = module->astTypes.find(candidate->func); + if (!it) + { + return std::nullopt; + } + + std::optional candidateString = getStringContents(nodes.back()); + + auto performCallback = [&](const FunctionType* funcType) -> std::optional + { + for (const std::string& tag : funcType->tags) + { + if (FFlag::AutocompleteRequirePathSuggestions2) + { + if (tag == kRequireTagName && fileResolver) + { + return convertRequireSuggestionsToAutocompleteEntryMap(fileResolver->getRequireSuggestions(module->name, candidateString)); + } + } + if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) + { + return ret; + } + } + return std::nullopt; + }; + + auto followedId = Luau::follow(*it); + if (auto functionType = Luau::get(followedId)) + { + return performCallback(functionType); + } + + if (auto intersect = Luau::get(followedId)) + { + for (TypeId part : intersect->parts) + { + if (auto candidateFunctionType = Luau::get(part)) + { + if (std::optional ret = performCallback(candidateFunctionType)) + { + return ret; + } + } + } + } + + return std::nullopt; +} + +static AutocompleteResult autocompleteWhileLoopKeywords(std::vector ancestry) +{ + AutocompleteEntryMap ret; + ret["do"] = {AutocompleteEntryKind::Keyword}; + ret["and"] = {AutocompleteEntryKind::Keyword}; + ret["or"] = {AutocompleteEntryKind::Keyword}; + return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword}; +} + +static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) +{ + std::string result = "function("; + + auto [args, tail] = Luau::flatten(funcTy.argTypes); + + bool first = true; + // Skip the implicit 'self' argument if call is indexed with ':' + for (size_t argIdx = 0; argIdx < args.size(); ++argIdx) + { + if (!first) + result += ", "; + else + first = false; + + std::string name; + if (argIdx < funcTy.argNames.size() && funcTy.argNames[argIdx]) + name = funcTy.argNames[argIdx]->name; + else + name = "a" + std::to_string(argIdx); + + if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) + result += name + ": " + *type; + else + result += name; + } + + if (tail && (Luau::isVariadic(*tail) || Luau::get(Luau::follow(*tail)))) + { + if (!first) + result += ", "; + + std::optional varArgType; + if (const VariadicTypePack* pack = get(follow(*tail))) + { + if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) + varArgType = std::move(res); + } + + if (varArgType) + result += "...: " + *varArgType; + else + result += "..."; + } + + result += ")"; + + auto [rets, retTail] = Luau::flatten(funcTy.retTypes); + if (const size_t totalRetSize = rets.size() + (retTail ? 1 : 0); totalRetSize > 0) + { + if (std::optional returnTypes = tryToStringDetailed(scope, funcTy.retTypes, true)) + { + result += ": "; + bool wrap = totalRetSize != 1; + if (wrap) + result += "("; + result += *returnTypes; + if (wrap) + result += ")"; + } + } + result += " end"; + return result; +} + +static std::optional makeAnonymousAutofilled( + const ModulePtr& module, + const ScopePtr& scopeAtPosition, + Position position, + const AstNode* node, + const std::vector& ancestry +) +{ + const AstExprCall* call = node->as(); + if (!call && ancestry.size() > 1) + call = ancestry[ancestry.size() - 2]->as(); + + if (!call) + return std::nullopt; + + if (!call->location.containsClosed(position) || call->func->location.containsClosed(position)) + return std::nullopt; + + TypeId* typeIter = module->astTypes.find(call->func); + if (!typeIter) + return std::nullopt; + + const FunctionType* outerFunction = get(follow(*typeIter)); + if (!outerFunction) + return std::nullopt; + + size_t argument = 0; + for (size_t i = 0; i < call->args.size; ++i) + { + if (call->args.data[i]->location.containsClosed(position)) + { + argument = i; + break; + } + } + + if (call->self) + argument++; + + std::optional argType; + auto [args, tail] = flatten(outerFunction->argTypes); + if (argument < args.size()) + argType = args[argument]; + + if (!argType) + return std::nullopt; + + TypeId followed = follow(*argType); + const FunctionType* type = get(followed); + if (!type) + { + if (const UnionType* unionType = get(followed)) + { + if (std::optional nonnullFunction = returnFirstNonnullOptionOfType(unionType)) + type = *nonnullFunction; + } + } + + if (!type) + return std::nullopt; + + const ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(*module, position); + if (!scope) + return std::nullopt; + + AutocompleteEntry entry; + entry.kind = AutocompleteEntryKind::GeneratedFunction; + entry.typeCorrect = TypeCorrectKind::Correct; + entry.type = argType; + entry.insertText = makeAnonymous(scope, *type); + return std::make_optional(std::move(entry)); +} + +AutocompleteResult autocomplete_( + const ModulePtr& module, + NotNull builtinTypes, + TypeArena* typeArena, + std::vector& ancestry, + Scope* globalScope, + const ScopePtr& scopeAtPosition, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +) +{ + AstNode* node = ancestry.back(); + + AstExprConstantNil dummy{Location{}}; + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; + + // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node + if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) + { + ancestry.pop_back(); + + node = ancestry.back(); + parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; + } + + if (auto indexName = node->as()) + { + auto it = module->astTypes.find(indexName->expr); + if (!it) + return {}; + + TypeId ty = follow(*it); + PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; + + return {autocompleteProps(*module, typeArena, builtinTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; + } + else if (auto typeReference = node->as()) + { + if (typeReference->prefix) + return {autocompleteModuleTypes(*module, scopeAtPosition, position, typeReference->prefix->value), ancestry, AutocompleteContext::Type}; + else + return {autocompleteTypeNames(*module, scopeAtPosition, position, ancestry), ancestry, AutocompleteContext::Type}; + } + else if (node->is()) + { + return {autocompleteTypeNames(*module, scopeAtPosition, position, ancestry), ancestry, AutocompleteContext::Type}; + } + else if (AstStatLocal* statLocal = node->as()) + { + if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) + return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Unknown}; + else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + else + return {}; + } + + else if (AstStatFor* statFor = extractStat(ancestry)) + { + if (!statFor->hasDo || position < statFor->doLocation.begin) + { + if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || + (statFor->step && statFor->step->location.containsClosed(position))) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + + if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + return {}; + } + + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + } + + else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) + { + if (!statForIn->hasIn || position <= statForIn->inLocation.begin) + { + AstLocal* lastName = statForIn->vars.data[statForIn->vars.size - 1]; + if (lastName->name == kParseNameError || lastName->location.containsClosed(position)) + { + // Here we are either working with a missing binding (as would be the case in a bare "for" keyword) or + // the cursor is still touching a binding name. The user is still typing a new name, so we should not offer + // any suggestions. + return {}; + } + + return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + } + + if (!statForIn->hasDo || position <= statForIn->doLocation.begin) + { + LUAU_ASSERT(statForIn->values.size > 0); + AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; + + if (lastExpr->location.containsClosed(position)) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + + if (position > lastExpr->location.end) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + + return {}; // Not sure what this means + } + } + else if (AstStatForIn* statForIn = extractStat(ancestry)) + { + // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. + // ex "for f in f do" + if (!statForIn->hasDo) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + } + + else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) + { + if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) + { + return autocompleteWhileLoopKeywords(ancestry); + } + + if (!statWhile->hasDo || position < statWhile->doLocation.begin) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + + if (statWhile->hasDo && position > statWhile->doLocation.end) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + } + + else if (AstStatWhile* statWhile = extractStat(ancestry); + (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition && + !statWhile->condition->location.containsClosed(position))) + { + return autocompleteWhileLoopKeywords(ancestry); + } + else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) + { + return { + {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, + ancestry, + AutocompleteContext::Keyword + }; + } + else if (AstStatIf* statIf = parent->as(); statIf && node->is()) + { + if (statIf->condition->is()) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + } + else if (AstStatIf* statIf = extractStat(ancestry); statIf && + (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && + (statIf->condition && !statIf->condition->location.containsClosed(position))) + { + AutocompleteEntryMap ret; + ret["then"] = {AutocompleteEntryKind::Keyword}; + ret["and"] = {AutocompleteEntryKind::Keyword}; + ret["or"] = {AutocompleteEntryKind::Keyword}; + return {std::move(ret), ancestry, AutocompleteContext::Keyword}; + } + else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + else if (AstExprTable* exprTable = parent->as(); + exprTable && (node->is() || node->is() || node->is())) + { + for (const auto& [kind, key, value] : exprTable->items) + { + // If item doesn't have a key, maybe the value is actually the key + if (key ? key == node : node->is() && value == node) + { + if (auto it = module->astExpectedTypes.find(exprTable)) + { + auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); + + if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*nodeIt, !node->is(), node, position, result); + + if (!key) + { + // If there is "no key," it may be that the user + // intends for the current token to be the key, but + // has yet to type the `=` sign. + // + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + { + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); + } + } + + // Remove keys that are already completed + for (const auto& item : exprTable->items) + { + if (!item.key) + continue; + + if (auto stringKey = item.key->as()) + result.erase(std::string(stringKey->value.data, stringKey->value.size)); + } + + // If we know for sure that a key is being written, do not offer general expression suggestions + if (!key) + autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position, result); + + return {result, ancestry, AutocompleteContext::Property}; + } + + break; + } + } + } + else if (AstExprTable* exprTable = node->as()) + { + AutocompleteEntryMap result; + + if (auto it = module->astExpectedTypes.find(exprTable)) + { + result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); + + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + { + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); + } + + // Remove keys that are already completed + for (const auto& item : exprTable->items) + { + if (!item.key) + continue; + + if (auto stringKey = item.key->as()) + result.erase(std::string(stringKey->value.data, stringKey->value.size)); + } + } + + // Also offer general expression suggestions + autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position, result); + + return {result, ancestry, AutocompleteContext::Property}; + } + else if (isIdentifier(node) && (parent->is() || parent->is())) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + + if (std::optional ret = autocompleteStringParams(module, ancestry, position, fileResolver, callback)) + { + return {*ret, ancestry, AutocompleteContext::String}; + } + else if (node->is() || isSimpleInterpolatedString(node)) + { + AutocompleteEntryMap result; + + if (auto it = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*it, false, node, position, result); + + if (ancestry.size() >= 2) + { + if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) + { + if (auto it = module->astTypes.find(idxExpr->expr)) + autocompleteProps(*module, typeArena, builtinTypes, follow(*it), PropIndexType::Point, ancestry, result); + } + else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) + { + if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) + { + if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) + autocompleteStringSingleton(*it, false, node, position, result); + } + } + } + + return {result, ancestry, AutocompleteContext::String}; + } + else if (stringPartOfInterpString(node, position)) + { + // We're not a simple interpolated string, we're something like `a{"b"}@1`, and we + // can't know what to format to + AutocompleteEntryMap map; + return {map, ancestry, AutocompleteContext::String}; + } + + if (node->is()) + return {}; + + if (node->asExpr()) + { + AutocompleteResult ret = autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + if (std::optional generated = makeAnonymousAutofilled(module, scopeAtPosition, position, node, ancestry)) + ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); + return ret; + } + else if (node->asStat()) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + + return {}; +} + + +} // namespace Luau diff --git a/Analysis/src/AutocompleteCore.h b/Analysis/src/AutocompleteCore.h new file mode 100644 index 00000000..d4264da2 --- /dev/null +++ b/Analysis/src/AutocompleteCore.h @@ -0,0 +1,27 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AutocompleteTypes.h" + +namespace Luau +{ +struct Module; +struct FileResolver; + +using ModulePtr = std::shared_ptr; +using ModuleName = std::string; + + +AutocompleteResult autocomplete_( + const ModulePtr& module, + NotNull builtinTypes, + TypeArena* typeArena, + std::vector& ancestry, + Scope* globalScope, + const ScopePtr& scopeAtPosition, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +); + +} // namespace Luau diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 30fc2696..3dacae04 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -33,7 +33,7 @@ LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2) LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix) -LUAU_FASTFLAG(AutocompleteRequirePathSuggestions) +LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2); namespace Luau { @@ -426,7 +426,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze); } - if (FFlag::AutocompleteRequirePathSuggestions) + if (FFlag::AutocompleteRequirePathSuggestions2) { TypeId requireTy = getGlobalBinding(globals, "require"); attachTag(requireTy, kRequireTagName); diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index a62879fa..a0b5fcf4 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -3,6 +3,8 @@ #include "Luau/Constraint.h" #include "Luau/VisitType.h" +LUAU_FASTFLAGVARIABLE(LuauDontRefCountTypesInTypeFunctions) + namespace Luau { @@ -46,6 +48,21 @@ struct ReferenceCountInitializer : TypeOnceVisitor // ClassTypes never contain free types. return false; } + + bool visit(TypeId, const TypeFunctionInstanceType&) override + { + // We do not consider reference counted types that are inside a type + // function to be part of the reachable reference counted types. + // Otherwise, code can be constructed in just the right way such + // that two type functions both claim to mutate a free type, which + // prevents either type function from trying to generalize it, so + // we potentially get stuck. + // + // The default behavior here is `true` for "visit the child types" + // of this type, hence: + return !FFlag::LuauDontRefCountTypesInTypeFunctions; + } + }; bool isReferenceCountedType(const TypeId typ) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index d05623a8..ee602999 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -10,6 +10,7 @@ #include "Luau/Def.h" #include "Luau/DenseHash.h" #include "Luau/ModuleResolver.h" +#include "Luau/NotNull.h" #include "Luau/RecursionCounter.h" #include "Luau/Refinement.h" #include "Luau/Scope.h" @@ -30,11 +31,13 @@ LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunExportedAndLocal) LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) namespace Luau @@ -172,6 +175,7 @@ bool hasFreeType(TypeId ty) ConstraintGenerator::ConstraintGenerator( ModulePtr module, NotNull normalizer, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull moduleResolver, NotNull builtinTypes, @@ -188,6 +192,7 @@ ConstraintGenerator::ConstraintGenerator( , rootScope(nullptr) , dfg(dfg) , normalizer(normalizer) + , simplifier(simplifier) , typeFunctionRuntime(typeFunctionRuntime) , moduleResolver(moduleResolver) , ice(ice) @@ -257,7 +262,7 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) d = follow(d); if (d == ty) continue; - domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + domainTy = simplifyUnion(scope, Location{}, domainTy, d); } LUAU_ASSERT(get(ty)); @@ -267,7 +272,15 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block) { + // We prepopulate global data in the resumeScope to avoid writing data into the old modules scopes + prepopulateGlobalScopeForFragmentTypecheck(globalScope, resumeScope, block); + // Pre + // We need to pop the interior types, + interiorTypes.emplace_back(); visitBlockWithoutChildScope(resumeScope, block); + // Post + interiorTypes.pop_back(); + fillInInferredBindings(resumeScope, block); if (logger) @@ -282,7 +295,7 @@ void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStat d = follow(d); if (d == ty) continue; - domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + domainTy = simplifyUnion(resumeScope, resumeScope->location, domainTy, d); } LUAU_ASSERT(get(ty)); @@ -711,7 +724,7 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc continue; } - if (scope->parent != globalScope) + if (!FFlag::LuauUserTypeFunExportedAndLocal && scope->parent != globalScope) { reportError(function->location, GenericError{"Local user-defined functions are not supported yet"}); continue; @@ -740,17 +753,26 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc if (std::optional error = typeFunctionRuntime->registerFunction(function)) reportError(function->location, GenericError{*error}); - TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ - NotNull{&builtinTypeFunctions().userFunc}, - std::move(typeParams), - {}, - function->name, - }); + UserDefinedFunctionData udtfData; + + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + udtfData.owner = module; + udtfData.definition = function; + } + + TypeId typeFunctionTy = arena->addType( + TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, udtfData} + ); TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; // Set type bindings and definition locations for this user-defined type function - scope->privateTypeBindings[function->name.value] = std::move(typeFunction); + if (FFlag::LuauUserTypeFunExportedAndLocal && function->exported) + scope->exportedTypeBindings[function->name.value] = std::move(typeFunction); + else + scope->privateTypeBindings[function->name.value] = std::move(typeFunction); + aliasDefinitionLocations[function->name.value] = function->location; } else if (auto classDeclaration = stat->as()) @@ -780,6 +802,55 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc classDefinitionLocations[classDeclaration->name.value] = classDeclaration->location; } } + + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Additional pass for user-defined type functions to fill in their environments completely + for (AstStat* stat : block->body) + { + if (auto function = stat->as()) + { + // Find the type function we have already created + TypeFunctionInstanceType* mainTypeFun = nullptr; + + if (auto it = scope->privateTypeBindings.find(function->name.value); it != scope->privateTypeBindings.end()) + mainTypeFun = getMutable(it->second.type); + + if (!mainTypeFun) + { + if (auto it = scope->exportedTypeBindings.find(function->name.value); it != scope->exportedTypeBindings.end()) + mainTypeFun = getMutable(it->second.type); + } + + // Fill it with all visible type functions + if (mainTypeFun) + { + UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData; + + for (Scope* curr = scope.get(); curr; curr = curr->parent.get()) + { + for (auto& [name, tf] : curr->privateTypeBindings) + { + if (userFuncData.environment.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment[name] = ty->userFuncData.definition; + } + + for (auto& [name, tf] : curr->exportedTypeBindings) + { + if (userFuncData.environment.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment[name] = ty->userFuncData.definition; + } + } + } + } + } + } } ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) @@ -900,12 +971,8 @@ ControlFlow ConstraintGenerator::visitBlockWithoutChildScope_DEPRECATED(const Sc if (std::optional error = typeFunctionRuntime->registerFunction(function)) reportError(function->location, GenericError{*error}); - TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ - NotNull{&builtinTypeFunctions().userFunc}, - std::move(typeParams), - {}, - function->name, - }); + TypeId typeFunctionTy = + arena->addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, {}}); TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; @@ -2807,7 +2874,7 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local case ErrorSuppression::DoNotSuppress: break; case ErrorSuppression::Suppress: - ty = simplifyUnion(builtinTypes, arena, *ty, builtinTypes->errorType).result; + ty = simplifyUnion(scope, local->location, *ty, builtinTypes->errorType); break; case ErrorSuppression::NormalizationFailed: reportError(local->local->annotation->location, NormalizationTooComplex{}); @@ -3673,6 +3740,32 @@ TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location locati return resultType; } +struct FragmentTypeCheckGlobalPrepopulator : AstVisitor +{ + const NotNull globalScope; + const NotNull currentScope; + const NotNull dfg; + + FragmentTypeCheckGlobalPrepopulator(NotNull globalScope, NotNull currentScope, NotNull dfg) + : globalScope(globalScope) + , currentScope(currentScope) + , dfg(dfg) + { + } + + bool visit(AstExprGlobal* global) override + { + if (auto ty = globalScope->lookup(global->name)) + { + DefId def = dfg->getDef(global); + // We only want to write into the current scope the type of the global + currentScope->lvalueTypes[def] = *ty; + } + + return true; + } +}; + struct GlobalPrepopulator : AstVisitor { const NotNull globalScope; @@ -3719,6 +3812,14 @@ struct GlobalPrepopulator : AstVisitor } }; +void ConstraintGenerator::prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program) +{ + FragmentTypeCheckGlobalPrepopulator gp{NotNull{globalScope.get()}, NotNull{resumeScope.get()}, dfg}; + if (prepareModuleScope) + prepareModuleScope(module->name, resumeScope); + program->visit(&gp); +} + void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) { GlobalPrepopulator gp{NotNull{globalScope.get()}, arena, dfg}; @@ -3870,6 +3971,24 @@ TypeId ConstraintGenerator::createTypeFunctionInstance( return result; } +TypeId ConstraintGenerator::simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(UnionType{{left, right}}); + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId tyFun : res->newTypeFunctions) + addConstraint(scope, location, ReduceConstraint{tyFun}); + + return res->result; + } + else + return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result; +} + std::vector> borrowConstraints(const std::vector& constraints) { std::vector> result; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 398f0aa5..2b7a7232 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -33,6 +33,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack) +LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification) LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) namespace Luau @@ -320,6 +321,7 @@ struct InstantiationQueuer : TypeOnceVisitor ConstraintSolver::ConstraintSolver( NotNull normalizer, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull rootScope, std::vector> constraints, @@ -333,6 +335,7 @@ ConstraintSolver::ConstraintSolver( : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) + , simplifier(simplifier) , typeFunctionRuntime(typeFunctionRuntime) , constraints(std::move(constraints)) , rootScope(rootScope) @@ -1802,7 +1805,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullprops[c.propName] = rhsType; // Food for thought: Could we block if simplification encounters a blocked type? - lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFreeUpperBound, newUpperBound).result; + lhsFree->upperBound = simplifyIntersection(constraint->scope, constraint->location, lhsFreeUpperBound, newUpperBound); bind(constraint, c.propType, rhsType); return true; @@ -2016,7 +2019,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullscope, constraint->location, std::move(parts)); unify(constraint, rhsType, res); } @@ -2596,9 +2599,9 @@ std::pair, std::optional> ConstraintSolver::lookupTa // if we're in an lvalue context, we need the _common_ type here. if (context == ValueContext::LValue) - return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)}; - return {{}, simplifyUnion(builtinTypes, arena, one, two).result}; + return {{}, simplifyUnion(constraint->scope, constraint->location, one, two)}; } // if we're in an lvalue context, we need the _common_ type here. else if (context == ValueContext::LValue) @@ -2630,7 +2633,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa { TypeId one = *begin(options); TypeId two = *(++begin(options)); - return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)}; } else return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; @@ -3019,6 +3022,63 @@ bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) return false; } +TypeId ConstraintSolver::simplifyIntersection(NotNull scope, Location location, TypeId left, TypeId right) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(IntersectionType{{left, right}}); + + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId ty : res->newTypeFunctions) + pushConstraint(scope, location, ReduceConstraint{ty}); + + return res->result; + } + else + return ::Luau::simplifyIntersection(builtinTypes, arena, left, right).result; +} + +TypeId ConstraintSolver::simplifyIntersection(NotNull scope, Location location, std::set parts) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(IntersectionType{std::vector(parts.begin(), parts.end())}); + + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId ty : res->newTypeFunctions) + pushConstraint(scope, location, ReduceConstraint{ty}); + + return res->result; + } + else + return ::Luau::simplifyIntersection(builtinTypes, arena, std::move(parts)).result; +} + +TypeId ConstraintSolver::simplifyUnion(NotNull scope, Location location, TypeId left, TypeId right) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(UnionType{{left, right}}); + + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId ty : res->newTypeFunctions) + pushConstraint(scope, location, ReduceConstraint{ty}); + + return res->result; + } + else + return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result; +} + TypeId ConstraintSolver::errorRecoveryType() const { return builtinTypes->errorRecoveryType(); diff --git a/Analysis/src/EqSatSimplification.cpp b/Analysis/src/EqSatSimplification.cpp new file mode 100644 index 00000000..41e87de2 --- /dev/null +++ b/Analysis/src/EqSatSimplification.cpp @@ -0,0 +1,2449 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/EqSatSimplification.h" +#include "Luau/EqSatSimplificationImpl.h" + +#include "Luau/EGraph.h" +#include "Luau/Id.h" +#include "Luau/Language.h" + +#include "Luau/StringUtils.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFunction.h" +#include "Luau/VisitType.h" + +#include +#include +#include +#include +#include +#include +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauLogSimplification) +LUAU_FASTFLAGVARIABLE(DebugLuauExtraEqSatSanityChecks) + +namespace Luau::EqSatSimplification +{ +using Id = Luau::EqSat::Id; + +using EGraph = Luau::EqSat::EGraph; +using Luau::EqSat::Slice; + +TTable::TTable(Id basis) +{ + storage.push_back(basis); +} + +// I suspect that this is going to become a performance hotspot. It would be +// nice to avoid allocating propTypes_ +TTable::TTable(Id basis, std::vector propNames_, std::vector propTypes_) + : propNames(std::move(propNames_)) +{ + storage.reserve(propTypes_.size() + 1); + storage.push_back(basis); + storage.insert(storage.end(), propTypes_.begin(), propTypes_.end()); + + LUAU_ASSERT(storage.size() == 1 + propTypes_.size()); +} + +Id TTable::getBasis() const +{ + LUAU_ASSERT(!storage.empty()); + return storage[0]; +} + +Slice TTable::propTypes() const +{ + LUAU_ASSERT(propNames.size() + 1 == storage.size()); + + return Slice{storage.data() + 1, propNames.size()}; +} + +Slice TTable::mutableOperands() +{ + return Slice{storage.data(), storage.size()}; +} + +Slice TTable::operands() const +{ + return Slice{storage.data(), storage.size()}; +} + +bool TTable::operator==(const TTable& rhs) const +{ + return storage == rhs.storage && propNames == rhs.propNames; +} + +size_t TTable::Hash::operator()(const TTable& value) const +{ + size_t hash = 0; + + // We're using pointers here, which does mean platform divergence. I think + // it's okay? (famous last words, I know) + for (StringId s : value.propNames) + EqSat::hashCombine(hash, EqSat::languageHash(s)); + + EqSat::hashCombine(hash, EqSat::languageHash(value.storage)); + + return hash; +} + +uint32_t StringCache::add(std::string_view s) +{ + size_t hash = std::hash()(s); + if (uint32_t* it = strings.find(hash)) + return *it; + + char* storage = static_cast(allocator.allocate(s.size())); + memcpy(storage, s.data(), s.size()); + + uint32_t result = uint32_t(views.size()); + views.emplace_back(storage, s.size()); + strings[hash] = result; + return result; +} + +std::string_view StringCache::asStringView(StringId id) const +{ + LUAU_ASSERT(id < views.size()); + return views[id]; +} + +std::string StringCache::asString(StringId id) const +{ + return std::string{asStringView(id)}; +} + +template +Simplify::Data Simplify::make(const EGraph&, const T&) const +{ + return true; +} + +void Simplify::join(Data& left, const Data& right) const +{ + left = left || right; +} + +using EClass = Luau::EqSat::EClass; + +// A terminal type is a type that does not contain any other types. +// Examples: any, unknown, number, string, boolean, nil, table, class, thread, function +// +// All class types are also terminal. +static bool isTerminal(const EType& node) +{ + return node.get() || node.get() || node.get() || node.get() || node.get() || + node.get() || node.get() || node.get() || node.get() || node.get() || + node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || + node.get() || node.get(); +} + +static bool isTerminal(const EGraph& egraph, Id eclass) +{ + const auto& nodes = egraph[eclass].nodes; + return std::any_of( + nodes.begin(), + nodes.end(), + [](auto& a) + { + return isTerminal(a); + } + ); +} + +Id mkUnion(EGraph& egraph, std::vector parts) +{ + if (parts.size() == 0) + return egraph.add(TNever{}); + else if (parts.size() == 1) + return parts[0]; + else + return egraph.add(Union{std::move(parts)}); +} + +Id mkIntersection(EGraph& egraph, std::vector parts) +{ + if (parts.size() == 0) + return egraph.add(TUnknown{}); + else if (parts.size() == 1) + return parts[0]; + else + return egraph.add(Intersection{std::move(parts)}); +} + +struct ListRemover +{ + std::unordered_map>& mappings2; + TypeId ty; + + ~ListRemover() + { + mappings2.erase(ty); + } +}; + +/* + * Crucial subtlety: It is very extremely important that enodes and eclasses are + * immutable. Mutating an enode would mean that it is no longer equivalent to + * other nodes in the same eclass. + * + * At the same time, many TypeIds are NOT immutable! + * + * The thing that makes this navigable is that it is okay if the same TypeId is + * imported as a different Id at different times as type inference runs. For + * example, if we at one point import a BlockedType as a TOpaque, and later + * import that same TypeId as some other enode type, this is all completely + * okay. + * + * The main thing we have to be very cautious about, I think, is unsealed + * tables. Unsealed table types have properties imperatively inserted into them + * as type inference runs. If we were to encode that TypeId as part of an + * enode, we could run into a situation where the egraph makes incorrect + * assumptions about the table. + * + * The solution is pretty simple: Never use the contents of a mutable TypeId in + * any reduction rule. TOpaque is always okay because we never actually poke + * around inside the TypeId to do anything. + */ +Id toId( + EGraph& egraph, + NotNull builtinTypes, + std::unordered_map& mappingIdToClass, + std::unordered_map>& typeToMappingId, // (TypeId: (MappingId, count)) + std::unordered_set& boundNodes, + StringCache& strings, + TypeId ty +) +{ + ty = follow(ty); + + // First, handle types which do not contain other types. They obviously + // cannot participate in cycles, so we don't have to check for that. + + if (auto freeTy = get(ty)) + return egraph.add(TOpaque{ty}); + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (auto prim = get(ty)) + { + switch (prim->type) + { + case Luau::PrimitiveType::NilType: + return egraph.add(TNil{}); + case Luau::PrimitiveType::Boolean: + return egraph.add(TBoolean{}); + case Luau::PrimitiveType::Number: + return egraph.add(TNumber{}); + case Luau::PrimitiveType::String: + return egraph.add(TString{}); + case Luau::PrimitiveType::Thread: + return egraph.add(TThread{}); + case Luau::PrimitiveType::Function: + return egraph.add(TTopFunction{}); + case Luau::PrimitiveType::Table: + return egraph.add(TTopTable{}); + case Luau::PrimitiveType::Buffer: + return egraph.add(TBuffer{}); + default: + LUAU_ASSERT(!"Unimplemented"); + return egraph.add(Invalid{}); + } + } + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + return egraph.add(SBoolean{bs->value}); + else if (auto ss = get(s)) + return egraph.add(SString{strings.add(ss->value)}); + else + LUAU_ASSERT(!"Unexpected"); + } + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (get(ty)) + return egraph.add(TFunction{ty}); + else if (ty == builtinTypes->classType) + return egraph.add(TTopClass{}); + else if (get(ty)) + return egraph.add(TClass{ty}); + else if (get(ty)) + return egraph.add(TAny{}); + else if (get(ty)) + return egraph.add(TError{}); + else if (get(ty)) + return egraph.add(TUnknown{}); + else if (get(ty)) + return egraph.add(TNever{}); + + // Now handle composite types. + + if (auto it = typeToMappingId.find(ty); it != typeToMappingId.end()) + { + auto& [mappingId, count] = it->second; + ++count; + Id res = egraph.add(TBound{mappingId}); + boundNodes.insert(res); + return res; + } + + typeToMappingId.emplace(ty, std::pair{mappingIdToClass.size(), 0}); + ListRemover lr{typeToMappingId, ty}; + + auto cache = [&](Id res) + { + const auto& [mappingId, count] = typeToMappingId.at(ty); + if (count > 0) + mappingIdToClass.emplace(mappingId, res); + return res; + }; + + if (auto tt = get(ty)) + return egraph.add(TImportedTable{ty}); + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (auto ut = get(ty)) + { + std::vector parts; + for (TypeId part : ut) + parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); + + return cache(mkUnion(egraph, std::move(parts))); + } + else if (auto it = get(ty)) + { + std::vector parts; + for (TypeId part : it) + parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); + + LUAU_ASSERT(parts.size() > 1); + + return cache(mkIntersection(egraph, std::move(parts))); + } + else if (auto negation = get(ty)) + { + Id part = toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, negation->ty); + return cache(egraph.add(Negation{std::array{part}})); + } + else if (auto tfun = get(ty)) + { + LUAU_ASSERT(tfun->packArguments.empty()); + + std::vector parts; + for (TypeId part : tfun->typeArguments) + parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); + + return cache(egraph.add(TTypeFun{tfun->function.get(), std::move(parts)})); + } + else if (get(ty)) + return egraph.add(TNoRefine{}); + else + { + LUAU_ASSERT(!"Unhandled Type"); + return cache(egraph.add(Invalid{})); + } +} + +Id toId(EGraph& egraph, NotNull builtinTypes, std::unordered_map& mappingIdToClass, StringCache& strings, TypeId ty) +{ + std::unordered_map> typeToMappingId; + std::unordered_set boundNodes; + Id id = toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, ty); + + for (Id id : boundNodes) + { + for (const auto [tb, _index] : Query(&egraph, id)) + { + Id bindee = mappingIdToClass.at(tb->value()); + egraph.merge(id, bindee); + } + } + + egraph.rebuild(); + + return egraph.find(id); +} + +// We apply a penalty to cyclic types to guide the system away from them where +// possible. +static const int CYCLE_PENALTY = 5000; + +// Composite types have cost equal to the sum of the costs of their parts plus a +// constant factor. +static const int SET_TYPE_PENALTY = 1; +static const int TABLE_TYPE_PENALTY = 2; +static const int NEGATION_PENALTY = 2; +static const int TFUN_PENALTY = 2; + +// FIXME. We don't have an accurate way to score a TImportedTable table against +// a TTable. +static const int IMPORTED_TABLE_PENALTY = 50; + +// TBound shouldn't ever be selected as the best node of a class unless we are +// debugging eqsat itself and need to stringify eclasses. We thus penalize it +// so heavily that we'll use any other alternative. +static const int BOUND_PENALTY = 999999999; + +// TODO iteration count limit +// TODO also: accept an argument which is the maximum cost to consider before +// abandoning the count. +// TODO: the egraph should be the first parameter. +static size_t computeCost(std::unordered_map& bestNodes, const EGraph& egraph, std::unordered_map& costs, Id id) +{ + if (auto it = costs.find(id); it != costs.end()) + return it->second; + + const std::vector& nodes = egraph[id].nodes; + + size_t minCost = std::numeric_limits::max(); + size_t bestNode = std::numeric_limits::max(); + + const auto updateCost = [&](size_t cost, size_t node) + { + if (cost < minCost) + { + minCost = cost; + bestNode = node; + } + }; + + // First, quickly scan for a terminal type. If we can find one, it is obviously the best. + for (size_t index = 0; index < nodes.size(); ++index) + { + if (isTerminal(nodes[index])) + { + minCost = 1; + bestNode = index; + + costs[id] = 1; + const auto [iter, isFresh] = bestNodes.insert({id, index}); + + // If we are forcing the cost function to select a specific node, + // then we still need to traverse into that node, even if this + // particular node is the obvious choice under normal circumstances. + if (isFresh || iter->second == index) + return 1; + } + } + + // If we recur into this type before this call frame completes, it is + // because this type participates in a cycle. + costs[id] = CYCLE_PENALTY; + + auto computeChildren = [&](Slice parts, size_t maxCost) -> std::optional + { + size_t cost = 0; + for (Id part : parts) + { + cost += computeCost(bestNodes, egraph, costs, part); + + // Abandon this node if it is too costly + if (cost > maxCost) + return std::nullopt; + } + return cost; + }; + + size_t startIndex = 0; + size_t endIndex = nodes.size(); + + // FFlag::DebugLuauLogSimplification will sometimes stringify an Id and pass + // in a prepopulated bestNodes map. If that mapping already has an index + // for this Id, don't look at the other nodes of this class. + if (auto it = bestNodes.find(id); it != bestNodes.end()) + { + LUAU_ASSERT(it->second < nodes.size()); + + startIndex = it->second; + endIndex = startIndex + 1; + } + + for (size_t index = startIndex; index < endIndex; ++index) + { + const auto& node = nodes[index]; + + if (node.get()) + updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound. + else if (node.get()) + { + minCost = 1; + bestNode = index; + } + else if (auto tbl = node.get()) + { + // TODO: We could make the penalty a parameter to computeChildren. + std::optional maybeCost = computeChildren(tbl->operands(), minCost); + if (maybeCost) + updateCost(TABLE_TYPE_PENALTY + *maybeCost, index); + } + else if (node.get()) + { + minCost = IMPORTED_TABLE_PENALTY; + bestNode = index; + } + else if (auto u = node.get()) + { + std::optional maybeCost = computeChildren(u->operands(), minCost); + if (maybeCost) + updateCost(SET_TYPE_PENALTY + *maybeCost, index); + } + else if (auto i = node.get()) + { + std::optional maybeCost = computeChildren(i->operands(), minCost); + if (maybeCost) + updateCost(SET_TYPE_PENALTY + *maybeCost, index); + } + else if (auto negation = node.get()) + { + std::optional maybeCost = computeChildren(negation->operands(), minCost); + if (maybeCost) + updateCost(NEGATION_PENALTY + *maybeCost, index); + } + else if (auto tfun = node.get()) + { + std::optional maybeCost = computeChildren(tfun->operands(), minCost); + if (maybeCost) + updateCost(TFUN_PENALTY + *maybeCost, index); + } + } + + LUAU_ASSERT(bestNode < nodes.size()); + + costs[id] = minCost; + bestNodes.insert({id, bestNode}); + return minCost; +} + +static std::unordered_map computeBestResult(const EGraph& egraph, Id id, const std::unordered_map& forceNodes) +{ + std::unordered_map costs; + std::unordered_map bestNodes = forceNodes; + computeCost(bestNodes, egraph, costs, id); + return bestNodes; +} + +static std::unordered_map computeBestResult(const EGraph& egraph, Id id) +{ + std::unordered_map costs; + std::unordered_map bestNodes; + computeCost(bestNodes, egraph, costs, id); + return bestNodes; +} + +TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& bestNodes, + std::unordered_map& seen, + std::vector& newTypeFunctions, + Id rootId +); + +TypeId flattenTableNode( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& bestNodes, + std::unordered_map& seen, + std::vector& newTypeFunctions, + Id rootId +) +{ + std::vector stack; + std::unordered_set seenIds; + + Id id = rootId; + const TImportedTable* importedTable = nullptr; + while (true) + { + size_t index = bestNodes.at(id); + const auto& eclass = egraph[id]; + + const auto [_iter, isFresh] = seenIds.insert(id); + if (!isFresh) + { + // If a TTable is its own basis, it must be the case that some other + // node on this eclass is a TImportedTable. Let's use that. + + for (size_t i = 0; i < eclass.nodes.size(); ++i) + { + if (eclass.nodes[i].get()) + { + index = i; + break; + } + } + + // If we couldn't find one, we don't know what to do. Use ErrorType. + LUAU_ASSERT(0); + return builtinTypes->errorType; + } + + const auto& node = eclass.nodes[index]; + if (const TTable* ttable = node.get()) + { + stack.push_back(ttable); + id = ttable->getBasis(); + continue; + } + else if (const TImportedTable* ti = node.get()) + { + importedTable = ti; + break; + } + else + LUAU_ASSERT(0); + } + + TableType resultTable; + if (importedTable) + { + const TableType* t = Luau::get(importedTable->value()); + LUAU_ASSERT(t); + resultTable = *t; // Intentional shallow clone here + } + + while (!stack.empty()) + { + const TTable* t = stack.back(); + stack.pop_back(); + + for (size_t i = 0; i < t->propNames.size(); ++i) + { + StringId propName = t->propNames[i]; + const Id propType = t->propTypes()[i]; + + resultTable.props[strings.asString(propName)] = Property{fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, propType)}; + } + } + + return arena->addType(std::move(resultTable)); +} + +TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& bestNodes, + std::unordered_map& seen, + std::vector& newTypeFunctions, + Id rootId +) +{ + if (auto it = seen.find(rootId); it != seen.end()) + return it->second; + + size_t index = bestNodes.at(rootId); + LUAU_ASSERT(index <= egraph[rootId].nodes.size()); + + const EType& node = egraph[rootId].nodes[index]; + + if (node.get()) + return builtinTypes->nilType; + else if (node.get()) + return builtinTypes->booleanType; + else if (node.get()) + return builtinTypes->numberType; + else if (node.get()) + return builtinTypes->stringType; + else if (node.get()) + return builtinTypes->threadType; + else if (node.get()) + return builtinTypes->functionType; + else if (node.get()) + return builtinTypes->tableType; + else if (node.get()) + return builtinTypes->classType; + else if (node.get()) + return builtinTypes->bufferType; + else if (auto opaque = node.get()) + return opaque->value(); + else if (auto b = node.get()) + return b->value() ? builtinTypes->trueType : builtinTypes->falseType; + else if (auto s = node.get()) + return arena->addType(SingletonType{StringSingleton{strings.asString(s->value())}}); + else if (auto fun = node.get()) + return fun->value(); + else if (auto tbl = node.get()) + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + TypeId flattened = flattenTableNode(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); + + asMutable(res)->ty.emplace(flattened); + return flattened; + } + else if (auto tbl = node.get()) + return tbl->value(); + else if (auto cls = node.get()) + return cls->value(); + else if (node.get()) + return builtinTypes->anyType; + else if (node.get()) + return builtinTypes->errorType; + else if (node.get()) + return builtinTypes->unknownType; + else if (node.get()) + return builtinTypes->neverType; + else if (auto u = node.get()) + { + Slice parts = u->operands(); + + if (parts.empty()) + return builtinTypes->neverType; + else if (parts.size() == 1) + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); + else + { + TypeId res = arena->addType(BlockedType{}); + + seen[rootId] = res; + + std::vector partTypes; + partTypes.reserve(parts.size()); + + for (Id part : parts) + partTypes.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); + + asMutable(res)->ty.emplace(std::move(partTypes)); + + return res; + } + } + else if (auto i = node.get()) + { + Slice parts = i->operands(); + + if (parts.empty()) + return builtinTypes->neverType; + else if (parts.size() == 1) + { + LUAU_ASSERT(parts[0] != rootId); + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); + } + else + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + std::vector partTypes; + partTypes.reserve(parts.size()); + + for (Id part : parts) + partTypes.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); + + asMutable(res)->ty.emplace(std::move(partTypes)); + + return res; + } + } + else if (auto negation = node.get()) + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + TypeId ty = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, negation->operands()[0]); + + asMutable(res)->ty.emplace(ty); + + return res; + } + else if (auto tfun = node.get()) + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + std::vector args; + for (Id part : tfun->operands()) + args.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); + + asMutable(res)->ty.emplace(*tfun->value(), std::move(args)); + + newTypeFunctions.push_back(res); + + return res; + } + else if (node.get()) + return builtinTypes->errorType; + else if (node.get()) + return builtinTypes->noRefineType; + else + { + LUAU_ASSERT(!"Unimplemented"); + return nullptr; + } +} + +static TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& forceNodes, + std::vector& newTypeFunctions, + Id rootId +) +{ + const std::unordered_map bestNodes = computeBestResult(egraph, rootId, forceNodes); + std::unordered_map seen; + + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); +} + +static TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + std::vector& newTypeFunctions, + Id rootId +) +{ + const std::unordered_map bestNodes = computeBestResult(egraph, rootId); + std::unordered_map seen; + + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); +} + +Subst::Subst(Id eclass, Id newClass, std::string desc) + : eclass(std::move(eclass)) + , newClass(std::move(newClass)) + , desc(std::move(desc)) +{ +} + +std::string mkDesc( + EGraph& egraph, + const StringCache& strings, + NotNull arena, + NotNull builtinTypes, + Id from, + Id to, + const std::unordered_map& forceNodes, + const std::string& rule +) +{ + if (!FFlag::DebugLuauLogSimplification) + return ""; + + std::vector newTypeFunctions; + + TypeId fromTy = fromId(egraph, strings, builtinTypes, arena, forceNodes, newTypeFunctions, from); + TypeId toTy = fromId(egraph, strings, builtinTypes, arena, forceNodes, newTypeFunctions, to); + + ToStringOptions opts; + opts.useQuestionMarks = false; + + const int RULE_PADDING = 35; + const std::string rulePadding(std::max(0, RULE_PADDING - rule.size()), ' '); + const std::string fromIdStr = ""; // "(" + std::to_string(uint32_t(from)) + ") "; + const std::string toIdStr = ""; // "(" + std::to_string(uint32_t(to)) + ") "; + + return rule + ":" + rulePadding + fromIdStr + toString(fromTy, opts) + " <=> " + toIdStr + toString(toTy, opts); +} + +std::string mkDesc(EGraph& egraph, const StringCache& strings, NotNull arena, NotNull builtinTypes, Id from, Id to, const std::string& rule) +{ + if (!FFlag::DebugLuauLogSimplification) + return ""; + + return mkDesc(egraph, strings, arena, builtinTypes, from, to, {}, rule); +} + +static std::string getNodeName(const StringCache& strings, const EType& node) +{ + if (node.get()) + return "nil"; + else if (node.get()) + return "boolean"; + else if (node.get()) + return "number"; + else if (node.get()) + return "string"; + else if (node.get()) + return "thread"; + else if (node.get()) + return "function"; + else if (node.get()) + return "table"; + else if (node.get()) + return "class"; + else if (node.get()) + return "buffer"; + else if (node.get()) + return "opaque"; + else if (auto b = node.get()) + return b->value() ? "true" : "false"; + else if (auto s = node.get()) + return "\"" + strings.asString(s->value()) + "\""; + else if (node.get()) + return "\xe2\x88\xaa"; + else if (node.get()) + return "\xe2\x88\xa9"; + else if (auto cls = node.get()) + { + const ClassType* ct = get(cls->value()); + LUAU_ASSERT(ct); + return ct->name; + } + else if (node.get()) + return "any"; + else if (node.get()) + return "error"; + else if (node.get()) + return "unknown"; + else if (node.get()) + return "never"; + else if (auto tfun = node.get()) + return "tfun " + tfun->value()->name; + else if (node.get()) + return "~"; + else if (node.get()) + return "invalid?"; + else if (node.get()) + return "bound"; + + return "???"; +} + +std::string toDot(const StringCache& strings, const EGraph& egraph) +{ + std::stringstream ss; + ss << "digraph G {" << '\n'; + ss << " graph [fontsize=10 fontname=\"Verdana\" compound=true];" << '\n'; + ss << " node [shape=record fontsize=10 fontname=\"Verdana\"];" << '\n'; + + std::set populated; + + for (const auto& [id, eclass] : egraph.getAllClasses()) + { + for (const auto& node : eclass.nodes) + { + if (!node.operands().empty()) + populated.insert(id); + for (Id op : node.operands()) + populated.insert(op); + } + } + + for (const auto& [id, eclass] : egraph.getAllClasses()) + { + if (!populated.count(id)) + continue; + + const std::string className = "cluster_" + std::to_string(uint32_t(id)); + ss << " subgraph " << className << " {" << '\n'; + ss << " node [style=\"rounded,filled\"];" << '\n'; + ss << " label = \"" << uint32_t(id) << "\";" << '\n'; + ss << " color = blue;" << '\n'; + + for (size_t index = 0; index < eclass.nodes.size(); ++index) + { + const auto& node = eclass.nodes[index]; + + const std::string label = getNodeName(strings, node); + const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index); + + ss << " " << nodeName << " [label=\"" << label << "\"];" << '\n'; + } + + ss << " }" << '\n'; + } + + for (const auto& [id, eclass] : egraph.getAllClasses()) + { + for (size_t index = 0; index < eclass.nodes.size(); ++index) + { + const auto& node = eclass.nodes[index]; + + const std::string label = getNodeName(strings, node); + const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index); + + for (Id op : node.operands()) + { + op = egraph.find(op); + const std::string destNodeName = "n" + std::to_string(uint32_t(op)) + "_0"; + ss << " " << nodeName << " -> " << destNodeName << " [lhead=cluster_" << uint32_t(op) << "];" << '\n'; + } + } + } + + ss << "}" << '\n'; + + return ss.str(); +} + +template +static Tag const* isTag(const EType& node) +{ + return node.get(); +} + +/// Important: Only use this to test for leaf node types like TUnknown and +/// TNumber. Things that we know cannot be simplified any further and are safe +/// to short-circuit on. +/// +/// It does a linear scan and exits early, so if a particular eclass has +/// multiple "interesting" representations, this function can surprise you. +template +static Tag const* isTag(const EGraph& egraph, Id id) +{ + for (const auto& node : egraph[id].nodes) + { + if (auto n = isTag(node)) + return n; + } + return nullptr; +} + +struct RewriteRule +{ + explicit RewriteRule(EGraph* egraph) + : egraph(egraph) + { + } + + virtual void read(std::vector& substs, Id eclass, const EType* enode) = 0; + +protected: + const EqSat::EClass& get(Id id) + { + return (*egraph)[id]; + } + + Id find(Id id) + { + return egraph->find(id); + } + + Id add(EType enode) + { + return egraph->add(std::move(enode)); + } + + template + const Tag* isTag(Id id) + { + for (const auto& node : (*egraph)[id].nodes) + { + if (auto n = node.get()) + return n; + } + return nullptr; + } + + template + bool isTag(const EType& enode) + { + return enode.get(); + } + +public: + EGraph* egraph; +}; + +enum SubclassRelationship +{ + LeftSuper, + RightSuper, + Unrelated +}; + +static SubclassRelationship relateClasses(const TClass* leftClass, const TClass* rightClass) +{ + const ClassType* leftClassType = Luau::get(leftClass->value()); + const ClassType* rightClassType = Luau::get(rightClass->value()); + + if (isSubclass(leftClassType, rightClassType)) + return RightSuper; + else if (isSubclass(rightClassType, leftClassType)) + return LeftSuper; + else + return Unrelated; +} + +// Entirely analogous to NormalizedType except that it operates on eclasses instead of TypeIds. +struct CanonicalizedType +{ + std::optional nilPart; + std::optional truePart; + std::optional falsePart; + std::optional numberPart; + std::optional stringPart; + std::vector stringSingletons; + std::optional threadPart; + std::optional functionPart; + std::optional tablePart; + std::vector classParts; + std::optional bufferPart; + std::optional errorPart; + + // Functions that have been union'd into the type + std::unordered_set functionParts; + + // Anything that isn't canonical: Intersections, unions, free types, and so on. + std::unordered_set otherParts; + + bool isUnknown() const + { + return nilPart && truePart && falsePart && numberPart && stringPart && threadPart && functionPart && tablePart && bufferPart; + } +}; + +void unionUnknown(EGraph& egraph, CanonicalizedType& ct) +{ + ct.nilPart = egraph.add(TNil{}); + ct.truePart = egraph.add(SBoolean{true}); + ct.falsePart = egraph.add(SBoolean{false}); + ct.numberPart = egraph.add(TNumber{}); + ct.stringPart = egraph.add(TString{}); + ct.threadPart = egraph.add(TThread{}); + ct.functionPart = egraph.add(TTopFunction{}); + ct.tablePart = egraph.add(TTopTable{}); + ct.bufferPart = egraph.add(TBuffer{}); + + ct.functionParts.clear(); + ct.otherParts.clear(); +} + +void unionAny(EGraph& egraph, CanonicalizedType& ct) +{ + unionUnknown(egraph, ct); + ct.errorPart = egraph.add(TError{}); +} + +void unionClasses(EGraph& egraph, std::vector& hereParts, Id there) +{ + if (1 == hereParts.size() && isTag(egraph, hereParts[0])) + return; + + const auto thereClass = isTag(egraph, there); + if (!thereClass) + return; + + for (size_t index = 0; index < hereParts.size(); ++index) + { + const Id herePart = hereParts[index]; + + if (auto partClass = isTag(egraph, herePart)) + { + switch (relateClasses(partClass, thereClass)) + { + case LeftSuper: + return; + case RightSuper: + hereParts[index] = there; + std::sort(hereParts.begin(), hereParts.end()); + return; + case Unrelated: + continue; + } + } + } + + hereParts.push_back(there); + std::sort(hereParts.begin(), hereParts.end()); +} + +void unionWithType(EGraph& egraph, CanonicalizedType& ct, Id part) +{ + if (isTag(egraph, part)) + ct.nilPart = part; + else if (isTag(egraph, part)) + ct.truePart = ct.falsePart = part; + else if (auto b = isTag(egraph, part)) + { + if (b->value()) + ct.truePart = part; + else + ct.falsePart = part; + } + else if (isTag(egraph, part)) + ct.numberPart = part; + else if (isTag(egraph, part)) + ct.stringPart = part; + else if (isTag(egraph, part)) + ct.stringSingletons.push_back(part); + else if (isTag(egraph, part)) + ct.threadPart = part; + else if (isTag(egraph, part)) + { + ct.functionPart = part; + ct.functionParts.clear(); + } + else if (isTag(egraph, part)) + ct.tablePart = part; + else if (isTag(egraph, part)) + ct.classParts = {part}; + else if (isTag(egraph, part)) + ct.bufferPart = part; + else if (isTag(egraph, part)) + { + if (!ct.functionPart) + ct.functionParts.insert(part); + } + else if (auto tclass = isTag(egraph, part)) + unionClasses(egraph, ct.classParts, part); + else if (isTag(egraph, part)) + { + unionAny(egraph, ct); + return; + } + else if (isTag(egraph, part)) + ct.errorPart = part; + else if (isTag(egraph, part)) + unionUnknown(egraph, ct); + else if (isTag(egraph, part)) + { + // Nothing + } + else + ct.otherParts.insert(part); +} + +// Find an enode under the given eclass which is simple enough that it could be +// subtracted from a CanonicalizedType easily. +// +// A union is "simple enough" if it is acyclic and is only comprised of terminal +// types and unions that are themselves subtractable +const EType* findSubtractableClass(const EGraph& egraph, std::unordered_set& seen, Id id) +{ + if (seen.count(id)) + return nullptr; + + const EType* bestUnion = nullptr; + std::optional unionSize; + + for (const auto& node : egraph[id].nodes) + { + if (isTerminal(node)) + return &node; + + if (const auto u = node.get()) + { + seen.insert(id); + + for (Id part : u->operands()) + { + if (!findSubtractableClass(egraph, seen, part)) + return nullptr; + } + + // If multiple unions in this class are all simple enough, prefer + // the shortest one. + if (!unionSize || u->operands().size() < unionSize) + { + unionSize = u->operands().size(); + bestUnion = &node; + } + } + } + + return bestUnion; +} + +const EType* findSubtractableClass(const EGraph& egraph, Id id) +{ + std::unordered_set seen; + + return findSubtractableClass(egraph, seen, id); +} + +// Subtract the type 'part' from 'ct' +// Returns true if the subtraction succeeded. This function will fail if 'part` is too complicated. +bool subtract(EGraph& egraph, CanonicalizedType& ct, Id part) +{ + const EType* etype = findSubtractableClass(egraph, part); + if (!etype) + return false; + + if (etype->get()) + ct.nilPart.reset(); + else if (etype->get()) + { + ct.truePart.reset(); + ct.falsePart.reset(); + } + else if (auto b = etype->get()) + { + if (b->value()) + ct.truePart.reset(); + else + ct.falsePart.reset(); + } + else if (etype->get()) + ct.numberPart.reset(); + else if (etype->get()) + ct.stringPart.reset(); + else if (etype->get()) + return false; + else if (etype->get()) + ct.threadPart.reset(); + else if (etype->get()) + ct.functionPart.reset(); + else if (etype->get()) + ct.tablePart.reset(); + else if (etype->get()) + ct.classParts.clear(); + else if (auto tclass = etype->get()) + { + auto it = std::find(ct.classParts.begin(), ct.classParts.end(), part); + if (it != ct.classParts.end()) + ct.classParts.erase(it); + else + return false; + } + else if (etype->get()) + ct.bufferPart.reset(); + else if (etype->get()) + ct = {}; + else if (etype->get()) + ct.errorPart.reset(); + else if (etype->get()) + { + std::optional errorPart = ct.errorPart; + ct = {}; + ct.errorPart = errorPart; + } + else if (etype->get()) + { + // Nothing + } + else if (auto u = etype->get()) + { + // TODO cycles + // TODO this is super promlematic because 'part' represents a whole group of equivalent enodes. + for (Id unionPart : u->operands()) + { + // TODO: This recursive call will require that we re-traverse this + // eclass to find the subtractible enode. It would be nice to do the + // work just once and reuse it. + bool ok = subtract(egraph, ct, unionPart); + if (!ok) + return false; + } + } + else if (etype->get()) + return false; + else + return false; + + return true; +} + +Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) +{ + if (ct.isUnknown()) + { + if (ct.errorPart) + return egraph.add(TAny{}); + else + return egraph.add(TUnknown{}); + } + + std::vector parts; + + if (ct.nilPart) + parts.push_back(*ct.nilPart); + + if (ct.truePart && ct.falsePart) + parts.push_back(egraph.add(TBoolean{})); + else if (ct.truePart) + parts.push_back(*ct.truePart); + else if (ct.falsePart) + parts.push_back(*ct.falsePart); + + if (ct.numberPart) + parts.push_back(*ct.numberPart); + + if (ct.stringPart) + parts.push_back(*ct.stringPart); + else if (!ct.stringSingletons.empty()) + parts.insert(parts.end(), ct.stringSingletons.begin(), ct.stringSingletons.end()); + + if (ct.threadPart) + parts.push_back(*ct.threadPart); + if (ct.functionPart) + parts.push_back(*ct.functionPart); + if (ct.tablePart) + parts.push_back(*ct.tablePart); + parts.insert(parts.end(), ct.classParts.begin(), ct.classParts.end()); + if (ct.bufferPart) + parts.push_back(*ct.bufferPart); + if (ct.errorPart) + parts.push_back(*ct.errorPart); + + parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end()); + parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end()); + + return mkUnion(egraph, std::move(parts)); +} + +void addChildren(const EGraph& egraph, const EType* enode, VecDeque& worklist) +{ + for (Id id : enode->operands()) + worklist.push_back(id); +} + +static bool occurs(EGraph& egraph, Id outerId, Slice operands) +{ + for (const Id i : operands) + { + if (egraph.find(i) == outerId) + return true; + } + return false; +} + +Simplifier::Simplifier(NotNull arena, NotNull builtinTypes) + : arena(arena) + , builtinTypes(builtinTypes) + , egraph(Simplify{}) +{ +} + +const EqSat::EClass& Simplifier::get(Id id) const +{ + return egraph[id]; +} + +Id Simplifier::find(Id id) const +{ + return egraph.find(id); +} + +Id Simplifier::add(EType enode) +{ + return egraph.add(std::move(enode)); +} + +template +const Tag* Simplifier::isTag(Id id) const +{ + for (const auto& node : get(id).nodes) + { + if (const Tag* ty = node.get()) + return ty; + } + + return nullptr; +} + +template +const Tag* Simplifier::isTag(const EType& enode) const +{ + return enode.get(); +} + +void Simplifier::subst(Id from, Id to) +{ + substs.emplace_back(from, to, " - "); +} + +void Simplifier::subst(Id from, Id to, const std::string& ruleName) +{ + std::string desc; + if (FFlag::DebugLuauLogSimplification) + desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, std::move(ruleName)); + substs.emplace_back(from, to, desc); +} + +void Simplifier::subst(Id from, Id to, const std::string& ruleName, const std::unordered_map& forceNodes) +{ + std::string desc; + if (FFlag::DebugLuauLogSimplification) + desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName); + substs.emplace_back(from, to, desc); +} + +void Simplifier::unionClasses(std::vector& hereParts, Id there) +{ + if (1 == hereParts.size() && isTag(hereParts[0])) + return; + + const auto thereClass = isTag(there); + if (!thereClass) + return; + + for (size_t index = 0; index < hereParts.size(); ++index) + { + const Id herePart = hereParts[index]; + + if (auto partClass = isTag(herePart)) + { + switch (relateClasses(partClass, thereClass)) + { + case LeftSuper: + return; + case RightSuper: + hereParts[index] = there; + std::sort(hereParts.begin(), hereParts.end()); + return; + case Unrelated: + continue; + } + } + } + + hereParts.push_back(there); + std::sort(hereParts.begin(), hereParts.end()); +} + +void Simplifier::simplifyUnion(Id id) +{ + id = find(id); + + for (const auto [u, unionIndex] : Query(&egraph, id)) + { + std::vector newParts; + std::unordered_set seen; + + CanonicalizedType canonicalized; + + if (occurs(egraph, id, u->operands())) + continue; + + for (Id part : u->operands()) + unionWithType(egraph, canonicalized, find(part)); + + Id resultId = fromCanonicalized(egraph, canonicalized); + + subst(id, resultId, "simplifyUnion", {{id, unionIndex}}); + } +} + +// If one of the nodes matches the given Tag, succeed and return the id and node for the other half. +// If neither matches, return nullopt. +template +static std::optional> matchOne(Id hereId, const EType* hereNode, Id thereId, const EType* thereNode) +{ + if (hereNode->get()) + return std::pair{thereId, thereNode}; + else if (thereNode->get()) + return std::pair{hereId, hereNode}; + else + return std::nullopt; +} + +// If the two nodes can be intersected into a "simple" type, return that, else return nullopt. +std::optional intersectOne(EGraph& egraph, Id hereId, const EType* hereNode, Id thereId, const EType* thereNode) +{ + hereId = egraph.find(hereId); + thereId = egraph.find(thereId); + + if (hereId == thereId) + return *hereNode; + + if (hereNode->get() || thereNode->get()) + return TNever{}; + + if (hereNode->get() || hereNode->get() || hereNode->get() || thereNode->get() || + thereNode->get() || thereNode->get() || hereNode->get() || thereNode->get()) + return std::nullopt; + + if (hereNode->get()) + return *thereNode; + if (thereNode->get()) + return *hereNode; + + if (hereNode->get()) + return *thereNode; + if (thereNode->get()) + return *hereNode; + + if (hereNode->get() || thereNode->get()) + return std::nullopt; + + if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) + { + const auto [otherId, otherNode] = *res; + + if (otherNode->get() || otherNode->get()) + return *otherNode; + else + return TNever{}; + } + if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) + { + const auto [otherId, otherNode] = *res; + + if (otherNode->get() || otherNode->get()) + return *otherNode; + } + if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) + { + const auto [otherId, otherNode] = *res; + + if (otherNode->get()) + return std::nullopt; // TODO + else + return TNever{}; + } + if (auto hereClass = hereNode->get()) + { + if (auto thereClass = thereNode->get()) + { + switch (relateClasses(hereClass, thereClass)) + { + case LeftSuper: + return *thereNode; + case RightSuper: + return *hereNode; + case Unrelated: + return TNever{}; + } + } + else + return TNever{}; + } + if (auto hereBool = hereNode->get()) + { + if (auto thereBool = thereNode->get()) + { + if (hereBool->value() == thereBool->value()) + return *hereNode; + else + return TNever{}; + } + else if (thereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (auto thereBool = thereNode->get()) + { + if (auto hereBool = hereNode->get()) + { + if (thereBool->value() == hereBool->value()) + return *thereNode; + else + return TNever{}; + } + else if (hereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (hereNode->get()) + { + if (thereNode->get()) + return TBoolean{}; + else if (thereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (thereNode->get()) + { + if (hereNode->get()) + return TBoolean{}; + else if (hereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (hereNode->get()) + { + if (thereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (thereNode->get()) + { + if (hereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (hereNode->get()) + { + if (thereNode->get() || thereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (thereNode->get()) + { + if (hereNode->get() || hereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (hereNode->get() && thereNode->get()) + return std::nullopt; + if (hereNode->get() && isTerminal(*thereNode)) + return TNever{}; + if (thereNode->get() && isTerminal(*hereNode)) + return TNever{}; + if (isTerminal(*hereNode) && isTerminal(*thereNode)) + { + // We already know that 'here' and 'there' are different classes. + return TNever{}; + } + + return std::nullopt; +} + +void Simplifier::uninhabitedIntersection(Id id) +{ + for (const auto [intersection, index] : Query(&egraph, id)) + { + Slice parts = intersection->operands(); + + if (parts.empty()) + { + Id never = egraph.add(TNever{}); + subst(id, never, "uninhabitedIntersection"); + return; + } + else if (1 == parts.size()) + { + subst(id, parts[0], "uninhabitedIntersection"); + return; + } + + Id accumulator = egraph.add(TUnknown{}); + EType accumulatorNode = TUnknown{}; + + std::vector unsimplified; + + if (occurs(egraph, id, parts)) + continue; + + for (Id partId : parts) + { + if (isTag(partId)) + return; + + bool found = false; + + const auto& partNodes = egraph[partId].nodes; + for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex) + { + const EType& N = partNodes[partIndex]; + if (std::optional intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N)) + { + if (isTag(*intersection)) + { + subst(id, egraph.add(TNever{}), "uninhabitedIntersection", {{id, index}, {partId, partIndex}}); + return; + } + + accumulator = egraph.add(*intersection); + accumulatorNode = *intersection; + found = true; + break; + } + } + + if (!found) + unsimplified.push_back(partId); + } + + if ((unsimplified.empty() || !isTag(accumulator)) && find(accumulator) != id) + unsimplified.push_back(accumulator); + + const Id result = mkIntersection(egraph, std::move(unsimplified)); + + subst(id, result, "uninhabitedIntersection", {{id, index}}); + } +} + +void Simplifier::intersectWithNegatedClass(Id id) +{ + for (const auto pair : Query(&egraph, id)) + { + const Intersection* intersection = pair.first; + const size_t intersectionIndex = pair.second; + + auto trySubst = [&](size_t i, size_t j) + { + Id iId = intersection->operands()[i]; + Id jId = intersection->operands()[j]; + + for (const auto [negation, negationIndex] : Query(&egraph, jId)) + { + const Id negated = negation->operands()[0]; + + if (iId == negated) + { + subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {jId, negationIndex}}); + return; + } + + for (const auto [negatedClass, negatedClassIndex] : Query(&egraph, negated)) + { + const auto& iNodes = egraph[iId].nodes; + for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex) + { + const EType& iNode = iNodes[iIndex]; + if (isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || + isTag(iNode) || + // isTag(iNode) || // I'm not sure about this one. + isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode)) + { + // eg string & ~SomeClass + subst(id, iId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); + return; + } + + if (const TClass* class_ = iNode.get()) + { + switch (relateClasses(class_, negatedClass)) + { + case LeftSuper: + // eg Instance & ~Part + // This cannot be meaningfully reduced. + continue; + case RightSuper: + subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); + return; + case Unrelated: + // Part & ~Folder == Part + { + std::vector newParts; + newParts.reserve(intersection->operands().size() - 1); + for (Id part : intersection->operands()) + { + if (part != jId) + newParts.push_back(part); + } + + Id substId = egraph.add(Intersection{newParts.begin(), newParts.end()}); + subst(id, substId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); + } + } + } + } + } + } + }; + + if (2 != intersection->operands().size()) + continue; + + trySubst(0, 1); + trySubst(1, 0); + } +} + +void Simplifier::intersectWithNoRefine(Id id) +{ + for (const auto pair : Query(&egraph, id)) + { + const Intersection* intersection = pair.first; + const size_t intersectionIndex = pair.second; + + const Slice intersectionOperands = intersection->operands(); + + for (size_t index = 0; index < intersectionOperands.size(); ++index) + { + const auto replace = [&]() + { + std::vector newOperands{intersectionOperands.begin(), intersectionOperands.end()}; + newOperands.erase(newOperands.begin() + index); + + Id substId = egraph.add(Intersection{std::move(newOperands)}); + + subst(id, substId, "intersectWithNoRefine", {{id, intersectionIndex}}); + }; + + if (isTag(intersectionOperands[index])) + replace(); + else + { + for (const auto [negation, negationIndex] : Query(&egraph, intersectionOperands[index])) + { + if (isTag(negation->operands()[0])) + { + replace(); + break; + } + } + } + } + } +} + +/* + * Replace x where x = A & (B | x) with A + * + * Important subtlety: The egraph is routinely going to create cyclic unions and + * intersections. We can't arbitrarily remove things from a union just because + * it can be referred to in a cyclic way. We must only do this for things that + * can only be expressed in a cyclic way. + * + * As an example, we will bind the following type to true: + * + * (true | buffer | class | function | number | string | table | thread) & + * boolean + * + * The egraph represented by this type will indeed be cyclic as the 'true' class + * includes both 'true' itself and the above type, but removing true from the + * union will result is an incorrect judgment! + * + * The solution (for now) is only to consider a type to be cyclic if it was + * cyclic on its original import. + * + * FIXME: I still don't think this is quite right, but I don't know how to + * articulate what the actual rule ought to be. + */ +void Simplifier::cyclicIntersectionOfUnion(Id id) +{ + // FIXME: This has pretty terrible runtime complexity. + + for (const auto [i, intersectionIndex] : Query(&egraph, id)) + { + Slice intersectionParts = i->operands(); + for (size_t intersectionOperandIndex = 0; intersectionOperandIndex < intersectionParts.size(); ++intersectionOperandIndex) + { + const Id intersectionPart = find(intersectionParts[intersectionOperandIndex]); + + for (const auto [bound, _boundIndex] : Query(&egraph, intersectionPart)) + { + const Id pointee = find(mappingIdToClass.at(bound->value())); + + for (const auto [u, unionIndex] : Query(&egraph, pointee)) + { + const Slice& unionOperands = u->operands(); + for (size_t unionOperandIndex = 0; unionOperandIndex < unionOperands.size(); ++unionOperandIndex) + { + Id unionOperand = find(unionOperands[unionOperandIndex]); + if (unionOperand == id) + { + std::vector newIntersectionParts(intersectionParts.begin(), intersectionParts.end()); + newIntersectionParts.erase(newIntersectionParts.begin() + intersectionOperandIndex); + + subst( + id, + mkIntersection(egraph, std::move(newIntersectionParts)), + "cyclicIntersectionOfUnion", + {{id, intersectionIndex}, {pointee, unionIndex}} + ); + } + } + } + } + } + } +} + +void Simplifier::cyclicUnionOfIntersection(Id id) +{ + // FIXME: This has pretty terrible runtime complexity. + + for (const auto [union_, unionIndex] : Query(&egraph, id)) + { + Slice unionOperands = union_->operands(); + for (size_t unionOperandIndex = 0; unionOperandIndex < unionOperands.size(); ++unionOperandIndex) + { + const Id unionPart = find(unionOperands[unionOperandIndex]); + + for (const auto [bound, _boundIndex] : Query(&egraph, unionPart)) + { + const Id pointee = find(mappingIdToClass.at(bound->value())); + + for (const auto [intersection, intersectionIndex] : Query(&egraph, pointee)) + { + Slice intersectionOperands = intersection->operands(); + for (size_t intersectionOperandIndex = 0; intersectionOperandIndex < intersectionOperands.size(); ++intersectionOperandIndex) + { + const Id intersectionPart = find(intersectionOperands[intersectionOperandIndex]); + if (intersectionPart == id) + { + std::vector newIntersectionParts(intersectionOperands.begin(), intersectionOperands.end()); + newIntersectionParts.erase(newIntersectionParts.begin() + intersectionOperandIndex); + + if (!newIntersectionParts.empty()) + { + Id newIntersection = mkIntersection(egraph, std::move(newIntersectionParts)); + + std::vector newIntersectionParts(unionOperands.begin(), unionOperands.end()); + newIntersectionParts.erase(newIntersectionParts.begin() + unionOperandIndex); + newIntersectionParts.push_back(newIntersection); + + subst( + id, + mkUnion(egraph, std::move(newIntersectionParts)), + "cyclicUnionOfIntersection", + {{id, unionIndex}, {pointee, intersectionIndex}} + ); + } + } + } + } + } + } + } +} + +void Simplifier::expandNegation(Id id) +{ + for (const auto [negation, index] : Query{&egraph, id}) + { + if (isTag(negation->operands()[0])) + return; + + CanonicalizedType canonicalized; + unionUnknown(egraph, canonicalized); + + const bool ok = subtract(egraph, canonicalized, negation->operands()[0]); + if (!ok) + continue; + + subst(id, fromCanonicalized(egraph, canonicalized), "expandNegation", {{id, index}}); + } +} + +/** + * Let A be a class-node having the form B & C1 & ... & Cn + * And B be a class-node having the form (D | E) + * + * Create a class containing the node (C1 & ... & Cn & D) | (C1 & ... & Cn & E) + * + * This function does nothing and returns nullopt if A and B are cyclic. + */ +static std::optional distributeIntersectionOfUnion( + EGraph& egraph, + Id outerClass, + const Intersection* outerIntersection, + Id innerClass, + const Union* innerUnion +) +{ + Slice outerOperands = outerIntersection->operands(); + + std::vector newOperands; + newOperands.reserve(innerUnion->operands().size()); + for (Id innerOperand : innerUnion->operands()) + { + if (isTag(egraph, innerOperand)) + continue; + + if (innerOperand == outerClass) + { + // Skip cyclic intersections of unions. There's a separate + // rule to get rid of those. + return std::nullopt; + } + + std::vector intersectionParts; + intersectionParts.reserve(outerOperands.size()); + intersectionParts.push_back(innerOperand); + + for (const Id op : outerOperands) + { + if (isTag(egraph, op)) + { + break; + } + if (op != innerClass) + intersectionParts.push_back(op); + } + + newOperands.push_back(mkIntersection(egraph, intersectionParts)); + } + + return mkUnion(egraph, std::move(newOperands)); +} + +// A & (B | C) -> (A & B) | (A & C) +// +// A & B & (C | D) -> A & (B & (C | D)) +// -> A & ((B & C) | (B & D)) +// -> (A & B & C) | (A & B & D) +void Simplifier::intersectionOfUnion(Id id) +{ + id = find(id); + + for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) + { + // For each operand O + // For each node N + // If N is a union U + // Create a new union comprised of every operand except O intersected with every operand of U + const Slice operands = intersection->operands(); + + if (operands.size() < 2) + return; + + if (occurs(egraph, id, operands)) + continue; + + for (Id operand : operands) + { + operand = find(operand); + if (operand == id) + break; + // Optimization: Decline to distribute any unions on an eclass that + // also contains a terminal node. + if (isTerminal(egraph, operand)) + continue; + + for (const auto [operandUnion, unionIndex] : Query(&egraph, operand)) + { + if (occurs(egraph, id, operandUnion->operands())) + continue; + + std::optional distributed = distributeIntersectionOfUnion(egraph, id, intersection, operand, operandUnion); + + if (distributed) + subst(id, *distributed, "intersectionOfUnion", {{id, intersectionIndex}, {operand, unionIndex}}); + } + } + } +} + +// {"a": b} & {"a": c, ...} => {"a": b & c, ...} +void Simplifier::intersectTableProperty(Id id) +{ + for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) + { + const Slice intersectionParts = intersection->operands(); + for (size_t i = 0; i < intersection->operands().size(); ++i) + { + const Id iId = intersection->operands()[i]; + + for (size_t j = 0; j < intersection->operands().size(); ++j) + { + if (i == j) + continue; + + const Id jId = intersection->operands()[j]; + + if (iId == jId) + continue; + + for (const auto [table1, table1Index] : Query(&egraph, iId)) + { + const TableType* table1Ty = Luau::get(table1->value()); + LUAU_ASSERT(table1Ty); + + if (table1Ty->props.size() != 1) + continue; + + for (const auto [table2, table2Index] : Query(&egraph, jId)) + { + const TableType* table2Ty = Luau::get(table2->value()); + LUAU_ASSERT(table2Ty); + + auto it = table2Ty->props.find(table1Ty->props.begin()->first); + if (it != table2Ty->props.end()) + { + std::vector newIntersectionParts; + newIntersectionParts.reserve(intersectionParts.size() - 1); + + for (size_t index = 0; index < intersectionParts.size(); ++index) + { + if (index != i && index != j) + newIntersectionParts.push_back(intersectionParts[index]); + } + + Id newTableProp = egraph.add(Intersection{ + toId(egraph, builtinTypes, mappingIdToClass, stringCache, it->second.type()), + toId(egraph, builtinTypes, mappingIdToClass, stringCache, table1Ty->props.begin()->second.type()) + }); + + newIntersectionParts.push_back(egraph.add(TTable{jId, {stringCache.add(it->first)}, {newTableProp}})); + + subst( + id, + egraph.add(Intersection{std::move(newIntersectionParts)}), + "intersectTableProperty", + {{id, intersectionIndex}, {iId, table1Index}, {jId, table2Index}} + ); + } + } + } + } + } + } +} + +// { prop: never } == never +void Simplifier::uninhabitedTable(Id id) +{ + for (const auto [table, tableIndex] : Query(&egraph, id)) + { + const TableType* tt = Luau::get(table->value()); + LUAU_ASSERT(tt); + + for (const auto& [propName, prop] : tt->props) + { + if (prop.readTy && Luau::get(follow(*prop.readTy))) + { + subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); + return; + } + + if (prop.writeTy && Luau::get(follow(*prop.writeTy))) + { + subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); + return; + } + } + } + + for (const auto [table, tableIndex] : Query(&egraph, id)) + { + for (Id propType : table->propTypes()) + { + if (isTag(propType)) + { + subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); + return; + } + } + } +} + +void Simplifier::unneededTableModification(Id id) +{ + for (const auto [tbl, tblIndex] : Query(&egraph, id)) + { + const Id basis = tbl->getBasis(); + for (const auto [importedTbl, importedTblIndex] : Query(&egraph, basis)) + { + const TableType* tt = Luau::get(importedTbl->value()); + LUAU_ASSERT(tt); + + bool skip = false; + + for (size_t i = 0; i < tbl->propNames.size(); ++i) + { + StringId propName = tbl->propNames[i]; + const Id propType = tbl->propTypes()[i]; + + Id importedProp = toId(egraph, builtinTypes, mappingIdToClass, stringCache, tt->props.at(stringCache.asString(propName)).type()); + + if (find(importedProp) != find(propType)) + { + skip = true; + break; + } + } + + if (!skip) + subst(id, basis, "unneededTableModification", {{id, tblIndex}, {basis, importedTblIndex}}); + } + } +} + +void Simplifier::builtinTypeFunctions(Id id) +{ + for (const auto [tfun, index] : Query(&egraph, id)) + { + const Slice& args = tfun->operands(); + + if (args.size() != 2) + continue; + + const std::string& name = tfun->value()->name; + if (name == "add" || name == "sub" || name == "mul" || name == "div" || name == "idiv" || name == "pow" || name == "mod") + { + if (isTag(args[0]) && isTag(args[1])) + { + subst(id, add(TNumber{}), "builtinTypeFunctions", {{id, index}}); + } + } + } +} + +// Replace union<>, intersect<>, and refine<> with unions or intersections. +// These type functions exist primarily to cause simplification to defer until +// particular points in execution, so it is safe to get rid of them here. +// +// It's not clear that these type functions should exist at all. +void Simplifier::iffyTypeFunctions(Id id) +{ + for (const auto [tfun, index] : Query(&egraph, id)) + { + const Slice& args = tfun->operands(); + + const std::string& name = tfun->value()->name; + + if (name == "union") + subst(id, add(Union{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); + else if (name == "intersect" || name == "refine") + subst(id, add(Intersection{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); + } +} + +static void deleteSimplifier(Simplifier* s) +{ + delete s; +} + +SimplifierPtr newSimplifier(NotNull arena, NotNull builtinTypes) +{ + return SimplifierPtr{new Simplifier(arena, builtinTypes), &deleteSimplifier}; +} + +} // namespace Luau::EqSatSimplification + +namespace Luau +{ + +std::optional eqSatSimplify(NotNull simplifier, TypeId ty) +{ + using namespace Luau::EqSatSimplification; + + std::unordered_map newMappings; + Id rootId = toId(simplifier->egraph, simplifier->builtinTypes, newMappings, simplifier->stringCache, ty); + simplifier->mappingIdToClass.insert(newMappings.begin(), newMappings.end()); + + Simplifier::RewriteRuleFn rules[] = { + &Simplifier::simplifyUnion, + &Simplifier::uninhabitedIntersection, + &Simplifier::intersectWithNegatedClass, + &Simplifier::intersectWithNoRefine, + &Simplifier::cyclicIntersectionOfUnion, + &Simplifier::cyclicUnionOfIntersection, + &Simplifier::expandNegation, + &Simplifier::intersectionOfUnion, + &Simplifier::intersectTableProperty, + &Simplifier::uninhabitedTable, + &Simplifier::unneededTableModification, + &Simplifier::builtinTypeFunctions, + &Simplifier::iffyTypeFunctions, + }; + + std::unordered_set seen; + VecDeque worklist; + + bool progressed = true; + + int count = 0; + const int MAX_COUNT = 1000; + + if (FFlag::DebugLuauLogSimplification) + std::ofstream("begin.dot") << toDot(simplifier->stringCache, simplifier->egraph); + + auto& egraph = simplifier->egraph; + const auto& builtinTypes = simplifier->builtinTypes; + auto& arena = simplifier->arena; + + if (FFlag::DebugLuauLogSimplification) + printf(">> simplify %s\n", toString(ty).c_str()); + + while (progressed && count < MAX_COUNT) + { + progressed = false; + worklist.clear(); + seen.clear(); + + rootId = egraph.find(rootId); + + worklist.push_back(rootId); + + if (FFlag::DebugLuauLogSimplification) + { + std::vector newTypeFunctions; + const TypeId t = fromId(egraph, simplifier->stringCache, builtinTypes, arena, newTypeFunctions, rootId); + + std::cout << "Begin (" << uint32_t(egraph.find(rootId)) << ")\t" << toString(t) << '\n'; + } + + while (!worklist.empty() && count < MAX_COUNT) + { + Id id = egraph.find(worklist.front()); + worklist.pop_front(); + + const bool isFresh = seen.insert(id).second; + if (!isFresh) + continue; + + simplifier->substs.clear(); + + // Optimization: If this class alraedy has a terminal node, don't + // try to run any rules on it. + bool shouldAbort = false; + + for (const EType& enode : egraph[id].nodes) + { + if (isTerminal(enode)) + { + shouldAbort = true; + break; + } + } + + if (shouldAbort) + continue; + + for (const EType& enode : egraph[id].nodes) + addChildren(egraph, &enode, worklist); + + for (Simplifier::RewriteRuleFn rule : rules) + (simplifier.get()->*rule)(id); + + if (simplifier->substs.empty()) + continue; + + for (const Subst& subst : simplifier->substs) + { + if (subst.newClass == subst.eclass) + continue; + + if (FFlag::DebugLuauExtraEqSatSanityChecks) + { + const Id never = egraph.find(egraph.add(TNever{})); + const Id str = egraph.find(egraph.add(TString{})); + const Id unk = egraph.find(egraph.add(TUnknown{})); + LUAU_ASSERT(never != str); + LUAU_ASSERT(never != unk); + } + + const bool isFresh = egraph.merge(subst.newClass, subst.eclass); + + ++count; + + if (FFlag::DebugLuauLogSimplification) + { + if (isFresh) + std::cout << "count=" << std::setw(3) << count << "\t" << subst.desc << '\n'; + + std::string filename = format("step%03d.dot", count); + std::ofstream(filename) << toDot(simplifier->stringCache, egraph); + } + + if (FFlag::DebugLuauExtraEqSatSanityChecks) + { + const Id never = egraph.find(egraph.add(TNever{})); + const Id str = egraph.find(egraph.add(TString{})); + const Id unk = egraph.find(egraph.add(TUnknown{})); + const Id trueId = egraph.find(egraph.add(SBoolean{true})); + + LUAU_ASSERT(never != str); + LUAU_ASSERT(never != unk); + LUAU_ASSERT(never != trueId); + } + + progressed |= isFresh; + } + + egraph.rebuild(); + } + } + + EqSatSimplificationResult result; + result.result = fromId(egraph, simplifier->stringCache, builtinTypes, arena, result.newTypeFunctions, rootId); + + if (FFlag::DebugLuauLogSimplification) + printf("<< simplify %s\n", toString(result.result).c_str()); + + return result; +} + +} // namespace Luau diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index d4f3ebd9..3395f125 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -4,6 +4,7 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Common.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Parser.h" #include "Luau/ParseOptions.h" #include "Luau/Module.h" @@ -18,11 +19,14 @@ #include "Luau/ParseOptions.h" #include "Luau/Module.h" +#include "AutocompleteCore.h" + LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauStoreDFGOnModule2); +LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) namespace { @@ -41,7 +45,6 @@ void copyModuleMap(Luau::DenseHashMap& result, const Luau::DenseHashMap getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos) +/** + * Get document offsets is a function that takes a source text document as well as a start position and end position(line, column) in that + * document and attempts to get the concrete text between those points. It returns a pair of: + * - start offset that represents an index in the source `char*` corresponding to startPos + * - length, that represents how many more bytes to read to get to endPos. + * Example - your document is "foo bar baz" and getDocumentOffsets is passed (1, 4) - (1, 8). This function returns the pair {3, 7}, + * which corresponds to the string " bar " + */ +std::pair getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos) { - unsigned int lineCount = 0; - unsigned int colCount = 0; + size_t lineCount = 0; + size_t colCount = 0; - unsigned int docOffset = 0; - unsigned int startOffset = 0; - unsigned int endOffset = 0; + size_t docOffset = 0; + size_t startOffset = 0; + size_t endOffset = 0; bool foundStart = false; bool foundEnd = false; for (char c : src) @@ -115,6 +126,13 @@ std::pair getDocumentOffsets(const std::string_view& foundEnd = true; } + // We put a cursor position that extends beyond the extents of the current line + if (foundStart && !foundEnd && (lineCount > endPos.line)) + { + foundEnd = true; + endOffset = docOffset - 1; + } + if (c == '\n') { lineCount++; @@ -125,20 +143,24 @@ std::pair getDocumentOffsets(const std::string_view& docOffset++; } + if (foundStart && !foundEnd) + endOffset = src.length(); - unsigned int min = std::min(startOffset, endOffset); - unsigned int len = std::max(startOffset, endOffset) - min; + size_t min = std::min(startOffset, endOffset); + size_t len = std::max(startOffset, endOffset) - min; return {min, len}; } -ScopePtr findClosestScope(const ModulePtr& module, const Position& cursorPos) +ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStatement) { LUAU_ASSERT(module->hasModuleScope()); ScopePtr closest = module->getModuleScope(); + + // find the scope the nearest statement belonged to. for (auto [loc, sc] : module->scopes) { - if (loc.begin <= cursorPos && closest->location.begin <= loc.begin) + if (loc.encloses(nearestStatement->location) && closest->location.begin <= loc.begin) closest = sc; } @@ -152,13 +174,27 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie opts.allowDeclarationSyntax = false; opts.captureComments = false; opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)}; - AstStat* enclosingStatement = result.nearestStatement; + AstStat* nearestStatement = result.nearestStatement; - const Position& endPos = cursorPos; - // If the statement starts on a previous line, grab the statement beginning - // otherwise, grab the statement end to whatever is being typed right now - const Position& startPos = - enclosingStatement->location.begin.line == cursorPos.line ? enclosingStatement->location.begin : enclosingStatement->location.end; + const Location& rootSpan = srcModule.root->location; + // Did we append vs did we insert inline + bool appended = cursorPos >= rootSpan.end; + // statement spans multiple lines + bool multiline = nearestStatement->location.begin.line != nearestStatement->location.end.line; + + const Position endPos = cursorPos; + + // We start by re-parsing everything (we'll refine this as we go) + Position startPos = srcModule.root->location.begin; + + // If we added to the end of the sourceModule, use the end of the nearest location + if (appended && multiline) + startPos = nearestStatement->location.end; + // Statement spans one line && cursorPos is on a different line + else if (!multiline && cursorPos.line != nearestStatement->location.end.line) + startPos = nearestStatement->location.end; + else + startPos = nearestStatement->location.begin; auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos); @@ -173,10 +209,11 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie std::vector fabricatedAncestry = std::move(result.ancestry); std::vector fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end); fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end()); - if (enclosingStatement == nullptr) - enclosingStatement = p.root; + if (nearestStatement == nullptr) + nearestStatement = p.root; fragmentResult.root = std::move(p.root); fragmentResult.ancestry = std::move(fabricatedAncestry); + fragmentResult.nearestStatement = nearestStatement; return fragmentResult; } @@ -205,7 +242,7 @@ ModulePtr copyModule(const ModulePtr& result, std::unique_ptr alloc) return incrementalModule; } -FragmentTypeCheckResult typeCheckFragmentHelper( +FragmentTypeCheckResult typecheckFragment_( Frontend& frontend, AstStatBlock* root, const ModulePtr& stale, @@ -245,15 +282,18 @@ FragmentTypeCheckResult typeCheckFragmentHelper( /// Create a DataFlowGraph just for the surrounding context auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler); + SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes); + /// Contraint Generator ConstraintGenerator cg{ incrementalModule, NotNull{&normalizer}, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, NotNull{&frontend.moduleResolver}, frontend.builtinTypes, iceHandler, - frontend.globals.globalScope, + stale->getModuleScope(), nullptr, nullptr, NotNull{&updatedDfg}, @@ -262,7 +302,7 @@ FragmentTypeCheckResult typeCheckFragmentHelper( cg.rootScope = stale->getModuleScope().get(); // Any additions to the scope must occur in a fresh scope auto freshChildOfNearestScope = std::make_shared(closestScope); - incrementalModule->scopes.push_back({root->location, freshChildOfNearestScope}); + incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); // closest Scope -> children = { ...., freshChildOfNearestScope} // We need to trim nearestChild from the scope hierarcy @@ -274,9 +314,11 @@ FragmentTypeCheckResult typeCheckFragmentHelper( LUAU_ASSERT(back == freshChildOfNearestScope.get()); closestScope->children.pop_back(); + /// Initialize the constraint solver and run it ConstraintSolver cs{ NotNull{&normalizer}, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), @@ -307,7 +349,7 @@ FragmentTypeCheckResult typeCheckFragmentHelper( freeze(incrementalModule->internalTypes); freeze(incrementalModule->interfaceTypes); - return {std::move(incrementalModule), freshChildOfNearestScope.get()}; + return {std::move(incrementalModule), std::move(freshChildOfNearestScope)}; } @@ -327,27 +369,51 @@ FragmentTypeCheckResult typecheckFragment( } ModulePtr module = frontend.moduleResolver.getModule(moduleName); - const ScopePtr& closestScope = findClosestScope(module, cursorPos); - - - FragmentParseResult r = parseFragment(*sourceModule, src, cursorPos); + FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos); FrontendOptions frontendOptions = opts.value_or(frontend.options); - return typeCheckFragmentHelper(frontend, r.root, module, closestScope, cursorPos, std::move(r.alloc), frontendOptions); + const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement); + FragmentTypeCheckResult result = + typecheckFragment_(frontend, parseResult.root, module, closestScope, cursorPos, std::move(parseResult.alloc), frontendOptions); + result.ancestry = std::move(parseResult.ancestry); + return result; } -AutocompleteResult fragmentAutocomplete( + +FragmentAutocompleteResult fragmentAutocomplete( Frontend& frontend, std::string_view src, const ModuleName& moduleName, - Position& cursorPosition, - const FrontendOptions& opts, + Position cursorPosition, + std::optional opts, StringCompletionCallback callback ) { LUAU_ASSERT(FFlag::LuauSolverV2); LUAU_ASSERT(FFlag::LuauAllowFragmentParsing); LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2); - return {}; + LUAU_ASSERT(FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete); + + const SourceModule* sourceModule = frontend.getSourceModule(moduleName); + if (!sourceModule) + { + LUAU_ASSERT(!"Expected Source Module for fragment typecheck"); + return {}; + } + + auto tcResult = typecheckFragment(frontend, moduleName, cursorPosition, opts, src); + TypeArena arenaForFragmentAutocomplete; + auto result = Luau::autocomplete_( + tcResult.incrementalModule, + frontend.builtinTypes, + &arenaForFragmentAutocomplete, + tcResult.ancestry, + frontend.globals.globalScope.get(), + tcResult.freshScope, + cursorPosition, + frontend.fileResolver, + callback + ); + return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForFragmentAutocomplete), std::move(result)}; } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index e94b4a29..261e3781 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -10,6 +10,7 @@ #include "Luau/ConstraintSolver.h" #include "Luau/DataFlowGraph.h" #include "Luau/DcrLogger.h" +#include "Luau/EqSatSimplification.h" #include "Luau/FileResolver.h" #include "Luau/NonStrictTypeChecker.h" #include "Luau/Parser.h" @@ -46,7 +47,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode) LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionNoEvaluation) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) -LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection) LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2) @@ -287,8 +287,7 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector getRequireCycles( const FileResolver* resolver, const std::unordered_map>& sourceNodes, - const SourceNode* start, - bool stopAtFirst = false + const SourceNode* start ) { std::vector result; @@ -358,9 +357,6 @@ std::vector getRequireCycles( { result.push_back({depLocation, std::move(cycle)}); - if (stopAtFirst) - return result; - // note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start // so it's safe to *only* clear seen vector when we find a cycle // if we don't do it, we will not have correct reporting for some cycles @@ -884,18 +880,11 @@ void Frontend::addBuildQueueItems( data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete); data.recordJsonLog = FFlag::DebugLuauLogSolverToJson; - const Mode mode = sourceModule->mode.value_or(data.config.mode); - // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // all correct programs must be acyclic so this code triggers rarely if (cycleDetected) - { - if (FFlag::LuauMoreThoroughCycleDetection) - data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), false); - else - data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck); - } + data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get()); data.options = frontendOptions; @@ -1334,6 +1323,7 @@ ModulePtr check( unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; + SimplifierPtr simplifier = newSimplifier(NotNull{&result->internalTypes}, builtinTypes); TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}}; if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) @@ -1342,6 +1332,7 @@ ModulePtr check( ConstraintGenerator cg{ result, NotNull{&normalizer}, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, moduleResolver, builtinTypes, @@ -1358,6 +1349,7 @@ ModulePtr check( ConstraintSolver cs{ NotNull{&normalizer}, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index dd5a2f85..1618b78f 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -132,7 +132,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a return dest.addType(NegationType{a.ty}); else if constexpr (std::is_same_v) { - TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName}; + TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncData}; return dest.addType(std::move(clone)); } else diff --git a/Analysis/src/Symbol.cpp b/Analysis/src/Symbol.cpp index 5e5b9d8c..a5117608 100644 --- a/Analysis/src/Symbol.cpp +++ b/Analysis/src/Symbol.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(LuauSymbolEquality) namespace Luau { @@ -14,7 +15,7 @@ bool Symbol::operator==(const Symbol& rhs) const return local == rhs.local; else if (global.value) return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. - else if (FFlag::LuauSolverV2) + else if (FFlag::LuauSolverV2 || FFlag::LuauSymbolEquality) return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. else return false; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 0bb7344a..60ed3027 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -870,6 +870,8 @@ struct TypeStringifier return; } + LUAU_ASSERT(uv.options.size() > 1); + bool optional = false; bool hasNonNilDisjunct = false; @@ -878,7 +880,7 @@ struct TypeStringifier { el = follow(el); - if (isNil(el)) + if (state.opts.useQuestionMarks && isNil(el)) { optional = true; continue; diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index 0193f4f1..d0ad82ec 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -51,6 +51,7 @@ LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation) LUAU_FASTFLAG(LuauUserTypeFunFixRegister) LUAU_FASTFLAG(LuauRemoveNotAnyHack) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionResetState) +LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) @@ -610,10 +611,29 @@ TypeFunctionReductionResult userDefinedTypeFunction( NotNull ctx ) { - if (!ctx->userFuncName) + auto typeFunction = getMutable(instance); + + if (FFlag::LuauUserTypeFunExportedAndLocal) { - ctx->ice->ice("all user-defined type functions must have an associated function definition"); - return {std::nullopt, true, {}, {}}; + if (typeFunction->userFuncData.owner.expired()) + { + ctx->ice->ice("user-defined type function module has expired"); + return {std::nullopt, true, {}, {}}; + } + + if (!typeFunction->userFuncName || !typeFunction->userFuncData.definition) + { + ctx->ice->ice("all user-defined type functions must have an associated function definition"); + return {std::nullopt, true, {}, {}}; + } + } + else + { + if (!ctx->userFuncName) + { + ctx->ice->ice("all user-defined type functions must have an associated function definition"); + return {std::nullopt, true, {}, {}}; + } } if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) @@ -632,7 +652,22 @@ TypeFunctionReductionResult userDefinedTypeFunction( return {std::nullopt, false, {ty}, {}}; } - AstName name = *ctx->userFuncName; + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Ensure that whole type function environment is registered + for (auto& [name, definition] : typeFunction->userFuncData.environment) + { + if (std::optional error = ctx->typeFunctionRuntime->registerFunction(definition)) + { + // Failure to register at this point means that original definition had to error out and should not have been present in the + // environment + ctx->ice->ice("user-defined type function reference cannot be registered"); + return {std::nullopt, true, {}, {}}; + } + } + } + + AstName name = FFlag::LuauUserTypeFunExportedAndLocal ? typeFunction->userFuncData.definition->name : *ctx->userFuncName; lua_State* global = ctx->typeFunctionRuntime->state.get(); @@ -643,8 +678,44 @@ TypeFunctionReductionResult userDefinedTypeFunction( lua_State* L = lua_newthread(global); LuauTempThreadPopper popper(global); - lua_getglobal(global, name.value); - lua_xmove(global, L, 1); + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Fetch the function we want to evaluate + lua_pushlightuserdata(L, typeFunction->userFuncData.definition); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + { + ctx->ice->ice("user-defined type function reference cannot be found in the registry"); + return {std::nullopt, true, {}, {}}; + } + + // Build up the environment + lua_getfenv(L, -1); + lua_setreadonly(L, -1, false); + + for (auto& [name, definition] : typeFunction->userFuncData.environment) + { + lua_pushlightuserdata(L, definition); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + { + ctx->ice->ice("user-defined type function reference cannot be found in the registry"); + return {std::nullopt, true, {}, {}}; + } + + lua_setfield(L, -2, name.c_str()); + } + + lua_setreadonly(L, -1, true); + lua_pop(L, 1); + } + else + { + lua_getglobal(global, name.value); + lua_xmove(global, L, 1); + } if (FFlag::LuauUserDefinedTypeFunctionResetState) resetTypeFunctionState(L); @@ -693,7 +764,7 @@ TypeFunctionReductionResult userDefinedTypeFunction( TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get()); - // At least 1 error occured while deserializing + // At least 1 error occurred while deserializing if (runtimeBuilder->errors.size() > 0) return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; @@ -935,6 +1006,23 @@ std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunc prepareState(); + lua_State* global = state.get(); + + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Fetch to check if function is already registered + lua_pushlightuserdata(global, function); + lua_gettable(global, LUA_REGISTRYINDEX); + + if (!lua_isnil(global, -1)) + { + lua_pop(global, 1); + return std::nullopt; + } + + lua_pop(global, 1); + } + AstName name = function->name; // Construct ParseResult containing the type function @@ -961,7 +1049,6 @@ std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunc std::string bytecode = builder.getBytecode(); - lua_State* global = state.get(); // Separate sandboxed thread for individual execution and private globals lua_State* L = lua_newthread(global); @@ -989,9 +1076,19 @@ std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunc return format("Could not find '%s' type function in the global scope", name.value); } - // Store resulting function in the global environment - lua_xmove(L, global, 1); - lua_setglobal(global, name.value); + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Store resulting function in the registry + lua_pushlightuserdata(global, function); + lua_xmove(L, global, 1); + lua_settable(global, LUA_REGISTRYINDEX); + } + else + { + // Store resulting function in the global environment + lua_xmove(L, global, 1); + lua_setglobal(global, name.value); + } return std::nullopt; } diff --git a/Ast/include/Luau/Allocator.h b/Ast/include/Luau/Allocator.h new file mode 100644 index 00000000..7fd951ae --- /dev/null +++ b/Ast/include/Luau/Allocator.h @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Location.h" +#include "Luau/DenseHash.h" +#include "Luau/Common.h" + +#include + +namespace Luau +{ + +class Allocator +{ +public: + Allocator(); + Allocator(Allocator&&); + + Allocator& operator=(Allocator&&) = delete; + + ~Allocator(); + + void* allocate(size_t size); + + template + T* alloc(Args&&... args) + { + static_assert(std::is_trivially_destructible::value, "Objects allocated with this allocator will never have their destructors run!"); + + T* t = static_cast(allocate(sizeof(T))); + new (t) T(std::forward(args)...); + return t; + } + +private: + struct Page + { + Page* next; + + char data[8192]; + }; + + Page* root; + size_t offset; +}; + +} diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 7845cca2..736f24a2 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -316,16 +316,18 @@ public: enum QuoteStyle { - Quoted, + QuotedSimple, + QuotedRaw, Unquoted }; - AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle = Quoted); + AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle); void visit(AstVisitor* visitor) override; + bool isQuoted() const; AstArray value; - QuoteStyle quoteStyle = Quoted; + QuoteStyle quoteStyle; }; class AstExprLocal : public AstExpr @@ -876,13 +878,14 @@ class AstStatTypeFunction : public AstStat public: LUAU_RTTI(AstStatTypeFunction); - AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body); + AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body, bool exported); void visit(AstVisitor* visitor) override; AstName name; Location nameLocation; AstExprFunction* body; + bool exported; }; class AstStatDeclareGlobal : public AstStat diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index f6ac28ad..6c8f21c1 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Allocator.h" #include "Luau/Ast.h" #include "Luau/Location.h" #include "Luau/DenseHash.h" @@ -11,40 +12,6 @@ namespace Luau { -class Allocator -{ -public: - Allocator(); - Allocator(Allocator&&); - - Allocator& operator=(Allocator&&) = delete; - - ~Allocator(); - - void* allocate(size_t size); - - template - T* alloc(Args&&... args) - { - static_assert(std::is_trivially_destructible::value, "Objects allocated with this allocator will never have their destructors run!"); - - T* t = static_cast(allocate(sizeof(T))); - new (t) T(std::forward(args)...); - return t; - } - -private: - struct Page - { - Page* next; - - char data[8192]; - }; - - Page* root; - size_t offset; -}; - struct Lexeme { enum Type diff --git a/Ast/src/Allocator.cpp b/Ast/src/Allocator.cpp new file mode 100644 index 00000000..f8a99db4 --- /dev/null +++ b/Ast/src/Allocator.cpp @@ -0,0 +1,66 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Allocator.h" + +namespace Luau +{ + +Allocator::Allocator() + : root(static_cast(operator new(sizeof(Page)))) + , offset(0) +{ + root->next = nullptr; +} + +Allocator::Allocator(Allocator&& rhs) + : root(rhs.root) + , offset(rhs.offset) +{ + rhs.root = nullptr; + rhs.offset = 0; +} + +Allocator::~Allocator() +{ + Page* page = root; + + while (page) + { + Page* next = page->next; + + operator delete(page); + + page = next; + } +} + +void* Allocator::allocate(size_t size) +{ + constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double); + + if (root) + { + uintptr_t data = reinterpret_cast(root->data); + uintptr_t result = (data + offset + align - 1) & ~(align - 1); + if (result + size <= data + sizeof(root->data)) + { + offset = result - data + size; + return reinterpret_cast(result); + } + } + + // allocate new page + size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data); + void* pageData = operator new(offsetof(Page, data) + pageSize); + + Page* page = static_cast(pageData); + + page->next = root; + + root = page; + offset = size; + + return page->data; +} + +} diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index a72aca86..a06fcb09 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -92,6 +92,11 @@ void AstExprConstantString::visit(AstVisitor* visitor) visitor->visit(this); } +bool AstExprConstantString::isQuoted() const +{ + return quoteStyle == QuoteStyle::QuotedSimple || quoteStyle == QuoteStyle::QuotedRaw; +} + AstExprLocal::AstExprLocal(const Location& location, AstLocal* local, bool upvalue) : AstExpr(ClassIndex(), location) , local(local) @@ -760,11 +765,18 @@ void AstStatTypeAlias::visit(AstVisitor* visitor) } } -AstStatTypeFunction::AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body) +AstStatTypeFunction::AstStatTypeFunction( + const Location& location, + const AstName& name, + const Location& nameLocation, + AstExprFunction* body, + bool exported +) : AstStat(ClassIndex(), location) , name(name) , nameLocation(nameLocation) , body(body) + , exported(exported) { } diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 54540215..4fb9c936 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Lexer.h" +#include "Luau/Allocator.h" #include "Luau/Common.h" #include "Luau/Confusables.h" #include "Luau/StringUtils.h" @@ -10,64 +11,6 @@ namespace Luau { -Allocator::Allocator() - : root(static_cast(operator new(sizeof(Page)))) - , offset(0) -{ - root->next = nullptr; -} - -Allocator::Allocator(Allocator&& rhs) - : root(rhs.root) - , offset(rhs.offset) -{ - rhs.root = nullptr; - rhs.offset = 0; -} - -Allocator::~Allocator() -{ - Page* page = root; - - while (page) - { - Page* next = page->next; - - operator delete(page); - - page = next; - } -} - -void* Allocator::allocate(size_t size) -{ - constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double); - - if (root) - { - uintptr_t data = reinterpret_cast(root->data); - uintptr_t result = (data + offset + align - 1) & ~(align - 1); - if (result + size <= data + sizeof(root->data)) - { - offset = result - data + size; - return reinterpret_cast(result); - } - } - - // allocate new page - size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data); - void* pageData = operator new(offsetof(Page, data) + pageSize); - - Page* page = static_cast(pageData); - - page->next = root; - - root = page; - offset = size; - - return page->data; -} - Lexeme::Lexeme(const Location& location, Type type) : type(type) , location(location) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 1ca028f2..02d17c1d 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -21,6 +21,7 @@ LUAU_FASTFLAGVARIABLE(LuauSolverV2) LUAU_FASTFLAGVARIABLE(LuauNativeAttribute) LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2) +LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunParseExport) LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing) LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck) @@ -943,8 +944,11 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported) Lexeme matchFn = lexer.current(); nextLexeme(); - if (exported) - report(start, "Type function cannot be exported"); + if (!FFlag::LuauUserDefinedTypeFunParseExport) + { + if (exported) + report(start, "Type function cannot be exported"); + } // parse the name of the type function std::optional fnName = parseNameOpt("type function name"); @@ -962,7 +966,7 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported) matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; - return allocator.alloc(Location(start, body->location), fnName->name, fnName->location, body); + return allocator.alloc(Location(start, body->location), fnName->name, fnName->location, body, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -3012,8 +3016,23 @@ std::optional> Parser::parseCharArray() AstExpr* Parser::parseString() { Location location = lexer.current().location; + + AstExprConstantString::QuoteStyle style; + switch (lexer.current().type) + { + case Lexeme::QuotedString: + case Lexeme::InterpStringSimple: + style = AstExprConstantString::QuotedSimple; + break; + case Lexeme::RawString: + style = AstExprConstantString::QuotedRaw; + break; + default: + LUAU_ASSERT(false && "Invalid string type"); + } + if (std::optional> value = parseCharArray()) - return allocator.alloc(location, *value); + return allocator.alloc(location, *value, style); else return reportExprError(location, {}, "String literal contains malformed escape sequence"); } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index be1f23f0..80ede2d0 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Config.h" #include "Luau/ModuleResolver.h" #include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" @@ -224,7 +225,14 @@ struct CliConfigResolver : Luau::ConfigResolver if (std::optional contents = readFile(configPath)) { - std::optional error = Luau::parseConfig(*contents, result); + Luau::ConfigOptions::AliasOptions aliasOpts; + aliasOpts.configLocation = configPath; + aliasOpts.overwriteAliases = true; + + Luau::ConfigOptions opts; + opts.aliasOptions = std::move(aliasOpts); + + std::optional error = Luau::parseConfig(*contents, result, opts); if (error) configErrors.push_back({configPath, *error}); } diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index e9f40a09..4906d55a 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -181,6 +181,16 @@ std::string resolvePath(std::string_view path, std::string_view baseFilePath) return resolvedPath; } +bool hasFileExtension(std::string_view name, const std::vector& extensions) +{ + for (const std::string& extension : extensions) + { + if (name.size() >= extension.size() && name.substr(name.size() - extension.size()) == extension) + return true; + } + return false; +} + std::optional readFile(const std::string& name) { #ifdef _WIN32 diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index dce94ace..f723c765 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -15,6 +15,8 @@ std::string resolvePath(std::string_view relativePath, std::string_view baseFile std::optional readFile(const std::string& name); std::optional readStdin(); +bool hasFileExtension(std::string_view name, const std::vector& extensions); + bool isAbsolutePath(std::string_view path); bool isFile(const std::string& path); bool isDirectory(const std::string& path); diff --git a/CLI/Require.cpp b/CLI/Require.cpp index 9a00597a..2c45d0ac 100644 --- a/CLI/Require.cpp +++ b/CLI/Require.cpp @@ -3,6 +3,7 @@ #include "FileUtils.h" #include "Luau/Common.h" +#include "Luau/Config.h" #include #include @@ -83,6 +84,9 @@ RequireResolver::ModuleStatus RequireResolver::findModuleImpl() absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix } + if (hasFileExtension(absolutePath, {".luau", ".lua"}) && isFile(absolutePath)) + luaL_argerrorL(L, 1, "error requiring module: consider removing the file extension"); + return ModuleStatus::NotFound; } @@ -235,14 +239,15 @@ std::optional RequireResolver::getAlias(std::string alias) return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; } ); - while (!config.aliases.count(alias) && !isConfigFullyResolved) + while (!config.aliases.contains(alias) && !isConfigFullyResolved) { parseNextConfig(); } - if (!config.aliases.count(alias) && isConfigFullyResolved) + if (!config.aliases.contains(alias) && isConfigFullyResolved) return std::nullopt; // could not find alias - return resolvePath(config.aliases[alias], joinPaths(lastSearchedDir, Luau::kConfigName)); + const Luau::Config::AliasInfo& aliasInfo = config.aliases[alias]; + return resolvePath(aliasInfo.value, aliasInfo.configLocation); } void RequireResolver::parseNextConfig() @@ -275,9 +280,16 @@ void RequireResolver::parseConfigInDirectory(const std::string& directory) { std::string configPath = joinPaths(directory, Luau::kConfigName); + Luau::ConfigOptions::AliasOptions aliasOpts; + aliasOpts.configLocation = configPath; + aliasOpts.overwriteAliases = false; + + Luau::ConfigOptions opts; + opts.aliasOptions = std::move(aliasOpts); + if (std::optional contents = readFile(configPath)) { - std::optional error = Luau::parseConfig(*contents, config); + std::optional error = Luau::parseConfig(*contents, config, opts); if (error) luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str()); } diff --git a/Common/include/Luau/Variant.h b/Common/include/Luau/Variant.h index 88722257..14eb8c4e 100644 --- a/Common/include/Luau/Variant.h +++ b/Common/include/Luau/Variant.h @@ -19,7 +19,7 @@ class Variant static_assert(std::disjunction_v...> == false, "variant does not allow references as an alternative type"); static_assert(std::disjunction_v...> == false, "variant does not allow arrays as an alternative type"); -private: +public: template static constexpr int getTypeId() { @@ -35,6 +35,7 @@ private: return -1; } +private: template struct First { diff --git a/Config/include/Luau/Config.h b/Config/include/Luau/Config.h index 3866547b..d6016229 100644 --- a/Config/include/Luau/Config.h +++ b/Config/include/Luau/Config.h @@ -1,12 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include "Luau/LinterConfig.h" #include "Luau/ParseOptions.h" +#include #include #include -#include +#include #include namespace Luau @@ -19,6 +21,10 @@ constexpr const char* kConfigName = ".luaurc"; struct Config { Config(); + Config(const Config& other) noexcept; + Config& operator=(const Config& other) noexcept; + Config(Config&& other) noexcept = default; + Config& operator=(Config&& other) noexcept = default; Mode mode = Mode::Nonstrict; @@ -32,7 +38,19 @@ struct Config std::vector globals; - std::unordered_map aliases; + struct AliasInfo + { + std::string value; + std::string_view configLocation; + }; + + DenseHashMap aliases{""}; + + void setAlias(std::string alias, const std::string& value, const std::string configLocation); + +private: + // Prevents making unnecessary copies of the same config location string. + DenseHashMap> configLocationCache{""}; }; struct ConfigResolver @@ -60,6 +78,18 @@ std::optional parseLintRuleString( bool isValidAlias(const std::string& alias); -std::optional parseConfig(const std::string& contents, Config& config, bool compat = false); +struct ConfigOptions +{ + bool compat = false; + + struct AliasOptions + { + std::string configLocation; + bool overwriteAliases; + }; + std::optional aliasOptions = std::nullopt; +}; + +std::optional parseConfig(const std::string& contents, Config& config, const ConfigOptions& options = ConfigOptions{}); } // namespace Luau diff --git a/Config/src/Config.cpp b/Config/src/Config.cpp index cf7d4b22..3760fd9e 100644 --- a/Config/src/Config.cpp +++ b/Config/src/Config.cpp @@ -4,7 +4,8 @@ #include "Luau/Lexer.h" #include "Luau/StringUtils.h" #include -#include +#include +#include namespace Luau { @@ -16,6 +17,50 @@ Config::Config() enabledLint.setDefaults(); } +Config::Config(const Config& other) noexcept + : mode(other.mode) + , parseOptions(other.parseOptions) + , enabledLint(other.enabledLint) + , fatalLint(other.fatalLint) + , lintErrors(other.lintErrors) + , typeErrors(other.typeErrors) + , globals(other.globals) +{ + for (const auto& [alias, aliasInfo] : other.aliases) + { + std::string configLocation = std::string(aliasInfo.configLocation); + + if (!configLocationCache.contains(configLocation)) + configLocationCache[configLocation] = std::make_unique(configLocation); + + AliasInfo newAliasInfo; + newAliasInfo.value = aliasInfo.value; + newAliasInfo.configLocation = *configLocationCache[configLocation]; + aliases[alias] = std::move(newAliasInfo); + } +} + +Config& Config::operator=(const Config& other) noexcept +{ + if (this != &other) + { + Config copy(other); + std::swap(*this, copy); + } + return *this; +} + +void Config::setAlias(std::string alias, const std::string& value, const std::string configLocation) +{ + AliasInfo& info = aliases[alias]; + info.value = value; + + if (!configLocationCache.contains(configLocation)) + configLocationCache[configLocation] = std::make_unique(configLocation); + + info.configLocation = *configLocationCache[configLocation]; +} + static Error parseBoolean(bool& result, const std::string& value) { if (value == "true") @@ -136,7 +181,12 @@ bool isValidAlias(const std::string& alias) return true; } -Error parseAlias(std::unordered_map& aliases, std::string aliasKey, const std::string& aliasValue) +static Error parseAlias( + Config& config, + std::string aliasKey, + const std::string& aliasValue, + const std::optional& aliasOptions +) { if (!isValidAlias(aliasKey)) return Error{"Invalid alias " + aliasKey}; @@ -150,8 +200,12 @@ Error parseAlias(std::unordered_map& aliases, std::str return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; } ); - if (!aliases.count(aliasKey)) - aliases[std::move(aliasKey)] = aliasValue; + + if (!aliasOptions) + return Error("Cannot parse aliases without alias options"); + + if (aliasOptions->overwriteAliases || !config.aliases.contains(aliasKey)) + config.setAlias(std::move(aliasKey), aliasValue, aliasOptions->configLocation); return std::nullopt; } @@ -285,16 +339,16 @@ static Error parseJson(const std::string& contents, Action action) return {}; } -Error parseConfig(const std::string& contents, Config& config, bool compat) +Error parseConfig(const std::string& contents, Config& config, const ConfigOptions& options) { return parseJson( contents, [&](const std::vector& keys, const std::string& value) -> Error { if (keys.size() == 1 && keys[0] == "languageMode") - return parseModeString(config.mode, value, compat); + return parseModeString(config.mode, value, options.compat); else if (keys.size() == 2 && keys[0] == "lint") - return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, compat); + return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, options.compat); else if (keys.size() == 1 && keys[0] == "lintErrors") return parseBoolean(config.lintErrors, value); else if (keys.size() == 1 && keys[0] == "typeErrors") @@ -305,9 +359,9 @@ Error parseConfig(const std::string& contents, Config& config, bool compat) return std::nullopt; } else if (keys.size() == 2 && keys[0] == "aliases") - return parseAlias(config.aliases, keys[1], value); - else if (compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode") - return parseModeString(config.mode, value, compat); + return parseAlias(config, keys[1], value, options.aliasOptions); + else if (options.compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode") + return parseModeString(config.mode, value, options.compat); else { std::vector keysv(keys.begin(), keys.end()); diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index 480aa07d..924da974 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -23,6 +23,13 @@ struct Analysis final using D = typename N::Data; + Analysis() = default; + + Analysis(N a) + : analysis(std::move(a)) + { + } + template static D fnMake(const N& analysis, const EGraph& egraph, const L& enode) { @@ -59,6 +66,15 @@ struct EClass final template struct EGraph final { + using EClassT = EClass; + + EGraph() = default; + + explicit EGraph(N analysis) + : analysis(std::move(analysis)) + { + } + Id find(Id id) const { return unionfind.find(id); @@ -85,33 +101,59 @@ struct EGraph final return id; } - void merge(Id id1, Id id2) + // Returns true if the two IDs were not previously merged. + bool merge(Id id1, Id id2) { id1 = find(id1); id2 = find(id2); if (id1 == id2) - return; + return false; - unionfind.merge(id1, id2); + const Id mergedId = unionfind.merge(id1, id2); - EClass& eclass1 = get(id1); - EClass eclass2 = std::move(get(id2)); + // Ensure that id1 is the Id that we keep, and id2 is the id that we drop. + if (mergedId == id2) + std::swap(id1, id2); + + EClassT& eclass1 = get(id1); + EClassT eclass2 = std::move(get(id2)); classes.erase(id2); - worklist.reserve(worklist.size() + eclass2.parents.size()); - for (auto [enode, id] : eclass2.parents) - worklist.push_back({std::move(enode), id}); + eclass1.nodes.insert(eclass1.nodes.end(), eclass2.nodes.begin(), eclass2.nodes.end()); + eclass1.parents.insert(eclass1.parents.end(), eclass2.parents.begin(), eclass2.parents.end()); + + std::sort( + eclass1.nodes.begin(), + eclass1.nodes.end(), + [](const L& left, const L& right) + { + return left.index() < right.index(); + } + ); + + worklist.reserve(worklist.size() + eclass1.parents.size()); + for (const auto& [eclass, id] : eclass1.parents) + worklist.push_back(id); analysis.join(eclass1.data, eclass2.data); + + return true; } void rebuild() { + std::unordered_set seen; + while (!worklist.empty()) { - auto [enode, id] = worklist.back(); + Id id = worklist.back(); worklist.pop_back(); - repair(get(find(id))); + + const bool isFresh = seen.insert(id).second; + if (!isFresh) + continue; + + repair(find(id)); } } @@ -120,16 +162,21 @@ struct EGraph final return classes.size(); } - EClass& operator[](Id id) + EClassT& operator[](Id id) { return get(find(id)); } - const EClass& operator[](Id id) const + const EClassT& operator[](Id id) const { return const_cast(this)->get(find(id)); } + const std::unordered_map& getAllClasses() const + { + return classes; + } + private: Analysis analysis; @@ -139,19 +186,19 @@ private: /// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same /// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the /// e-class 𝑀[find(𝑎)]. - std::unordered_map> classes; + std::unordered_map classes; /// The hashcons 𝐻 is a map from e-nodes to e-class ids. std::unordered_map hashcons; - std::vector> worklist; + std::vector worklist; private: void canonicalize(L& enode) { // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). - for (Id& id : enode.operands()) + for (Id& id : enode.mutableOperands()) id = find(id); } @@ -171,7 +218,7 @@ private: classes.insert_or_assign( id, - EClass{ + EClassT{ id, {enode}, analysis.make(*this, enode), @@ -182,7 +229,7 @@ private: for (Id operand : enode.operands()) get(operand).parents.push_back({enode, id}); - worklist.emplace_back(enode, id); + worklist.emplace_back(id); hashcons.insert_or_assign(enode, id); return id; @@ -190,12 +237,13 @@ private: // Looks up for an eclass from a given non-canonicalized `id`. // For a canonicalized eclass, use `get(find(id))` or `egraph[id]`. - EClass& get(Id id) + EClassT& get(Id id) { + LUAU_ASSERT(classes.count(id)); return classes.at(id); } - void repair(EClass& eclass) + void repair(Id id) { // In the egg paper, the `repair` function makes use of two loops over the `eclass.parents` // by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id. @@ -204,26 +252,54 @@ private: // Here, we unify the two loops. I think it's equivalent? // After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent. - std::unordered_map map; - for (auto& [enode, id] : eclass.parents) + std::unordered_map newParents; + + // The eclass can be deallocated if it is merged into another eclass, so + // we take what we need from it and avoid retaining a pointer. + std::vector> parents = get(id).parents; + for (auto& pair : parents) { + L& enode = pair.first; + Id id = pair.second; + // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. hashcons.erase(enode); canonicalize(enode); hashcons.insert_or_assign(enode, find(id)); - if (auto it = map.find(enode); it != map.end()) + if (auto it = newParents.find(enode); it != newParents.end()) merge(id, it->second); - map.insert_or_assign(enode, find(id)); + newParents.insert_or_assign(enode, find(id)); } - eclass.parents.clear(); - for (auto it = map.begin(); it != map.end();) + // We reacquire the pointer because the prior loop potentially merges + // the eclass into another, which might move it around in memory. + EClassT* eclass = &get(find(id)); + + eclass->parents.clear(); + + for (const auto& [node, id] : newParents) + eclass->parents.emplace_back(std::move(node), std::move(id)); + + std::unordered_set newNodes; + for (L node : eclass->nodes) { - auto node = map.extract(it++); - eclass.parents.emplace_back(std::move(node.key()), node.mapped()); + canonicalize(node); + newNodes.insert(std::move(node)); } + + eclass->nodes.assign(newNodes.begin(), newNodes.end()); + + // FIXME: Extract into sortByTag() + std::sort( + eclass->nodes.begin(), + eclass->nodes.end(), + [](const L& left, const L& right) + { + return left.index() < right.index(); + } + ); } }; diff --git a/EqSat/include/Luau/Id.h b/EqSat/include/Luau/Id.h index c56a6ab6..7069f23c 100644 --- a/EqSat/include/Luau/Id.h +++ b/EqSat/include/Luau/Id.h @@ -2,6 +2,7 @@ #pragma once #include +#include #include namespace Luau::EqSat @@ -9,15 +10,17 @@ namespace Luau::EqSat struct Id final { - explicit Id(size_t id); + explicit Id(uint32_t id); - explicit operator size_t() const; + explicit operator uint32_t() const; bool operator==(Id rhs) const; bool operator!=(Id rhs) const; + bool operator<(Id rhs) const; + private: - size_t id; + uint32_t id; }; } // namespace Luau::EqSat diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h index 8855d851..56fc7202 100644 --- a/EqSat/include/Luau/Language.h +++ b/EqSat/include/Luau/Language.h @@ -6,9 +6,19 @@ #include "Luau/Slice.h" #include "Luau/Variant.h" +#include #include #include +#include #include +#include + +#define LUAU_EQSAT_UNIT(name) \ + struct name : ::Luau::EqSat::Unit \ + { \ + static constexpr const char* tag = #name; \ + using Unit::Unit; \ + } #define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom \ @@ -31,21 +41,57 @@ using NodeVector::NodeVector; \ } -#define LUAU_EQSAT_FIELD(name) \ - struct name : public ::Luau::EqSat::Field \ - { \ - } - -#define LUAU_EQSAT_NODE_FIELDS(name, ...) \ - struct name : public ::Luau::EqSat::NodeFields \ +#define LUAU_EQSAT_NODE_SET(name) \ + struct name : public ::Luau::EqSat::NodeSet> \ { \ static constexpr const char* tag = #name; \ - using NodeFields::NodeFields; \ + using NodeSet::NodeSet; \ + } + +#define LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(name, t) \ + struct name : public ::Luau::EqSat::NodeAtomAndVector> \ + { \ + static constexpr const char* tag = #name; \ + using NodeAtomAndVector::NodeAtomAndVector; \ } namespace Luau::EqSat { +template +struct Unit +{ + Slice mutableOperands() + { + return {}; + } + + Slice operands() const + { + return {}; + } + + bool operator==(const Unit& rhs) const + { + return true; + } + + bool operator!=(const Unit& rhs) const + { + return false; + } + + struct Hash + { + size_t operator()(const Unit& value) const + { + // chosen by fair dice roll. + // guaranteed to be random. + return 4; + } + }; +}; + template struct Atom { @@ -60,7 +106,7 @@ struct Atom } public: - Slice operands() + Slice mutableOperands() { return {}; } @@ -92,6 +138,62 @@ private: T _value; }; +template +struct NodeAtomAndVector +{ + template + NodeAtomAndVector(const X& value, Args&&... args) + : _value(value) + , vector{std::forward(args)...} + { + } + + Id operator[](size_t i) const + { + return vector[i]; + } + +public: + const X& value() const + { + return _value; + } + + Slice mutableOperands() + { + return Slice{vector.data(), vector.size()}; + } + + Slice operands() const + { + return Slice{vector.data(), vector.size()}; + } + + bool operator==(const NodeAtomAndVector& rhs) const + { + return _value == rhs._value && vector == rhs.vector; + } + + bool operator!=(const NodeAtomAndVector& rhs) const + { + return !(*this == rhs); + } + + struct Hash + { + size_t operator()(const NodeAtomAndVector& value) const + { + size_t result = languageHash(value._value); + hashCombine(result, languageHash(value.vector)); + return result; + } + }; + +private: + X _value; + T vector; +}; + template struct NodeVector { @@ -107,7 +209,7 @@ struct NodeVector } public: - Slice operands() + Slice mutableOperands() { return Slice{vector.data(), vector.size()}; } @@ -139,90 +241,61 @@ private: T vector; }; -/// Empty base class just for static_asserts. -struct FieldBase +template +struct NodeSet { - FieldBase() = delete; - - FieldBase(FieldBase&&) = delete; - FieldBase& operator=(FieldBase&&) = delete; - - FieldBase(const FieldBase&) = delete; - FieldBase& operator=(const FieldBase&) = delete; -}; - -template -struct Field : FieldBase -{ -}; - -template -struct NodeFields -{ - static_assert(std::conjunction...>::value); - - template - static constexpr int getIndex() + template + NodeSet(Args&&... args) + : vector{std::forward(args)...} { - constexpr int N = sizeof...(Fields); - constexpr bool is[N] = {std::is_same_v, Fields>...}; + std::sort(begin(vector), end(vector)); + auto it = std::unique(begin(vector), end(vector)); + vector.erase(it, end(vector)); + } - for (int i = 0; i < N; ++i) - if (is[i]) - return i; - - return -1; + Id operator[](size_t i) const + { + return vector[i]; } public: - template - NodeFields(Args&&... args) - : array{std::forward(args)...} + Slice mutableOperands() { - } - - Slice operands() - { - return Slice{array}; + return Slice{vector.data(), vector.size()}; } Slice operands() const { - return Slice{array.data(), array.size()}; + return Slice{vector.data(), vector.size()}; } - template - Id field() const + bool operator==(const NodeSet& rhs) const { - static_assert(std::disjunction_v, Fields>...>); - return array[getIndex()]; + return vector == rhs.vector; } - bool operator==(const NodeFields& rhs) const - { - return array == rhs.array; - } - - bool operator!=(const NodeFields& rhs) const + bool operator!=(const NodeSet& rhs) const { return !(*this == rhs); } struct Hash { - size_t operator()(const NodeFields& value) const + size_t operator()(const NodeSet& value) const { - return languageHash(value.array); + return languageHash(value.vector); } }; -private: - std::array array; +protected: + T vector; }; template struct Language final { + using VariantTy = Luau::Variant; + template using WithinDomain = std::disjunction, Ts>...>; @@ -237,14 +310,14 @@ struct Language final return v.index(); } - /// You should never call this function with the intention of mutating the `Id`. - /// Reading is ok, but you should also never assume that these `Id`s are stable. - Slice operands() noexcept + /// This should only be used in canonicalization! + /// Always prefer operands() + Slice mutableOperands() noexcept { return visit( [](auto&& v) -> Slice { - return v.operands(); + return v.mutableOperands(); }, v ); @@ -306,7 +379,7 @@ public: }; private: - Variant v; + VariantTy v; }; } // namespace Luau::EqSat diff --git a/EqSat/include/Luau/LanguageHash.h b/EqSat/include/Luau/LanguageHash.h index 506f352b..cfc33b83 100644 --- a/EqSat/include/Luau/LanguageHash.h +++ b/EqSat/include/Luau/LanguageHash.h @@ -3,6 +3,7 @@ #include #include +#include #include namespace Luau::EqSat diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index 559ee119..22a61628 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -14,7 +14,9 @@ struct UnionFind final Id makeSet(); Id find(Id id) const; Id find(Id id); - void merge(Id a, Id b); + + // Merge aSet with bSet and return the canonicalized Id into the merged set. + Id merge(Id aSet, Id bSet); private: std::vector parents; diff --git a/EqSat/src/Id.cpp b/EqSat/src/Id.cpp index 960249ba..eae6a974 100644 --- a/EqSat/src/Id.cpp +++ b/EqSat/src/Id.cpp @@ -1,15 +1,16 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Id.h" +#include "Luau/Common.h" namespace Luau::EqSat { -Id::Id(size_t id) +Id::Id(uint32_t id) : id(id) { } -Id::operator size_t() const +Id::operator uint32_t() const { return id; } @@ -24,9 +25,14 @@ bool Id::operator!=(Id rhs) const return id != rhs.id; } +bool Id::operator<(Id rhs) const +{ + return id < rhs.id; +} + } // namespace Luau::EqSat size_t std::hash::operator()(Luau::EqSat::Id id) const { - return std::hash()(size_t(id)); + return std::hash()(uint32_t(id)); } diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 619c3f47..6a952999 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -3,12 +3,16 @@ #include "Luau/Common.h" +#include + namespace Luau::EqSat { Id UnionFind::makeSet() { - Id id{parents.size()}; + LUAU_ASSERT(parents.size() < std::numeric_limits::max()); + + Id id{uint32_t(parents.size())}; parents.push_back(id); ranks.push_back(0); @@ -25,42 +29,44 @@ Id UnionFind::find(Id id) Id set = canonicalize(id); // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. - while (id != parents[size_t(id)]) + while (id != parents[uint32_t(id)]) { // Note: we don't update the ranks here since a rank // represents the upper bound on the maximum depth of a tree - Id parent = parents[size_t(id)]; - parents[size_t(id)] = set; + Id parent = parents[uint32_t(id)]; + parents[uint32_t(id)] = set; id = parent; } return set; } -void UnionFind::merge(Id a, Id b) +Id UnionFind::merge(Id a, Id b) { Id aSet = find(a); Id bSet = find(b); if (aSet == bSet) - return; + return aSet; // Ensure that the rank of set A is greater than the rank of set B - if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) + if (ranks[uint32_t(aSet)] > ranks[uint32_t(bSet)]) std::swap(aSet, bSet); - parents[size_t(bSet)] = aSet; + parents[uint32_t(bSet)] = aSet; - if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) - ranks[size_t(aSet)]++; + if (ranks[uint32_t(aSet)] == ranks[uint32_t(bSet)]) + ranks[uint32_t(aSet)]++; + + return aSet; } Id UnionFind::canonicalize(Id id) const { - LUAU_ASSERT(size_t(id) < parents.size()); + LUAU_ASSERT(uint32_t(id) < parents.size()); // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. - while (id != parents[size_t(id)]) - id = parents[size_t(id)]; + while (id != parents[uint32_t(id)]) + id = parents[uint32_t(id)]; return id; } diff --git a/Sources.cmake b/Sources.cmake index 4b99e867..1299b119 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -14,6 +14,7 @@ endif() # Luau.Ast Sources target_sources(Luau.Ast PRIVATE + Ast/include/Luau/Allocator.h Ast/include/Luau/Ast.h Ast/include/Luau/Confusables.h Ast/include/Luau/Lexer.h @@ -24,6 +25,7 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/StringUtils.h Ast/include/Luau/TimeTrace.h + Ast/src/Allocator.cpp Ast/src/Ast.cpp Ast/src/Confusables.cpp Ast/src/Lexer.cpp @@ -168,6 +170,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstQuery.h Analysis/include/Luau/Autocomplete.h + Analysis/include/Luau/AutocompleteTypes.h Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Cancellation.h Analysis/include/Luau/Clone.h @@ -181,6 +184,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Differ.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h + Analysis/include/Luau/EqSatSimplification.h Analysis/include/Luau/FileResolver.h Analysis/include/Luau/FragmentAutocomplete.h Analysis/include/Luau/Frontend.h @@ -245,6 +249,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/AstJsonEncoder.cpp Analysis/src/AstQuery.cpp Analysis/src/Autocomplete.cpp + Analysis/src/AutocompleteCore.cpp Analysis/src/BuiltinDefinitions.cpp Analysis/src/Clone.cpp Analysis/src/Constraint.cpp @@ -256,6 +261,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Differ.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp + Analysis/src/EqSatSimplification.cpp Analysis/src/FragmentAutocomplete.cpp Analysis/src/Frontend.cpp Analysis/src/Generalization.cpp @@ -417,7 +423,7 @@ endif() if(TARGET Luau.UnitTest) # Luau.UnitTest Sources target_sources(Luau.UnitTest PRIVATE - tests/AnyTypeSummary.test.cpp + tests/AnyTypeSummary.test.cpp tests/AssemblyBuilderA64.test.cpp tests/AssemblyBuilderX64.test.cpp tests/AstJsonEncoder.test.cpp @@ -444,6 +450,7 @@ if(TARGET Luau.UnitTest) tests/EqSat.language.test.cpp tests/EqSat.propositional.test.cpp tests/EqSat.slice.test.cpp + tests/EqSatSimplification.test.cpp tests/Error.test.cpp tests/Fixture.cpp tests/Fixture.h diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index d382a924..052d8c82 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -39,7 +39,7 @@ const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Ri "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; -const char* luau_ident = "$Luau: Copyright (C) 2019-2023 Roblox Corporation $\n" +const char* luau_ident = "$Luau: Copyright (C) 2019-2024 Roblox Corporation $\n" "$URL: luau.org $\n"; #define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base)) diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index e170e9bc..e6e67020 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -67,7 +67,7 @@ TEST_CASE("encode_constants") charString.data = const_cast("a\x1d\0\\\"b"); charString.size = 6; - AstExprConstantString needsEscaping{Location(), charString}; + AstExprConstantString needsEscaping{Location(), charString, AstExprConstantString::QuotedSimple}; CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil)); CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b)); @@ -83,7 +83,7 @@ TEST_CASE("basic_escaping") { std::string s = "hello \"world\""; AstArray theString{s.data(), s.size()}; - AstExprConstantString str{Location(), theString}; + AstExprConstantString str{Location(), theString, AstExprConstantString::QuotedSimple}; std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})"; CHECK_EQ(expected, toJson(&str)); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index de4049a9..0424e3df 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -151,40 +151,6 @@ struct ACBuiltinsFixture : ACFixtureImpl { }; -#define LUAU_CHECK_HAS_KEY(map, key) \ - do \ - { \ - auto&& _m = (map); \ - auto&& _k = (key); \ - const size_t count = _m.count(_k); \ - CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \ - if (!count) \ - { \ - MESSAGE("Keys: (count " << _m.size() << ")"); \ - for (const auto& [k, v] : _m) \ - { \ - MESSAGE("\tkey: " << k); \ - } \ - } \ - } while (false) - -#define LUAU_CHECK_HAS_NO_KEY(map, key) \ - do \ - { \ - auto&& _m = (map); \ - auto&& _k = (key); \ - const size_t count = _m.count(_k); \ - CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \ - if (count) \ - { \ - MESSAGE("Keys: (count " << _m.size() << ")"); \ - for (const auto& [k, v] : _m) \ - { \ - MESSAGE("\tkey: " << k); \ - } \ - } \ - } while (false) - TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") diff --git a/tests/Config.test.cpp b/tests/Config.test.cpp index 70d6d6d7..690c4c37 100644 --- a/tests/Config.test.cpp +++ b/tests/Config.test.cpp @@ -58,7 +58,11 @@ TEST_CASE("report_a_syntax_error") TEST_CASE("noinfer_is_still_allowed") { Config config; - auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, true); + + ConfigOptions opts; + opts.compat = true; + + auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, opts); REQUIRE(!err); CHECK_EQ(int(Luau::Mode::NoCheck), int(config.mode)); @@ -147,6 +151,10 @@ TEST_CASE("extra_globals") TEST_CASE("lint_rules_compat") { Config config; + + ConfigOptions opts; + opts.compat = true; + auto err = parseConfig( R"( {"lint": { @@ -156,7 +164,7 @@ TEST_CASE("lint_rules_compat") }} )", config, - true + opts ); REQUIRE(!err); diff --git a/tests/ConstraintGeneratorFixture.cpp b/tests/ConstraintGeneratorFixture.cpp index 1b84d4c9..ef91fdf7 100644 --- a/tests/ConstraintGeneratorFixture.cpp +++ b/tests/ConstraintGeneratorFixture.cpp @@ -10,6 +10,7 @@ namespace Luau ConstraintGeneratorFixture::ConstraintGeneratorFixture() : Fixture() , mainModule(new Module) + , simplifier(newSimplifier(NotNull{&arena}, builtinTypes)) , forceTheFlag{FFlag::LuauSolverV2, true} { mainModule->name = "MainModule"; @@ -25,6 +26,7 @@ void ConstraintGeneratorFixture::generateConstraints(const std::string& code) cg = std::make_unique( mainModule, NotNull{&normalizer}, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, NotNull(&moduleResolver), builtinTypes, @@ -44,8 +46,19 @@ void ConstraintGeneratorFixture::solve(const std::string& code) { generateConstraints(code); ConstraintSolver cs{ - NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, NotNull{dfg.get()}, {} + NotNull{&normalizer}, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, + NotNull{rootScope}, + constraints, + "MainModule", + NotNull(&moduleResolver), + {}, + &logger, + NotNull{dfg.get()}, + {} }; + cs.run(); } diff --git a/tests/ConstraintGeneratorFixture.h b/tests/ConstraintGeneratorFixture.h index 782747c7..800bf873 100644 --- a/tests/ConstraintGeneratorFixture.h +++ b/tests/ConstraintGeneratorFixture.h @@ -4,8 +4,9 @@ #include "Luau/ConstraintGenerator.h" #include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" -#include "Luau/TypeArena.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Module.h" +#include "Luau/TypeArena.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -20,6 +21,7 @@ struct ConstraintGeneratorFixture : Fixture DcrLogger logger; UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + SimplifierPtr simplifier; TypeCheckLimits limits; TypeFunctionRuntime typeFunctionRuntime{NotNull{&ice}, NotNull{&limits}}; diff --git a/tests/EqSat.language.test.cpp b/tests/EqSat.language.test.cpp index 282d4ad2..fd1bde57 100644 --- a/tests/EqSat.language.test.cpp +++ b/tests/EqSat.language.test.cpp @@ -11,9 +11,7 @@ LUAU_EQSAT_ATOM(I32, int); LUAU_EQSAT_ATOM(Bool, bool); LUAU_EQSAT_ATOM(Str, std::string); -LUAU_EQSAT_FIELD(Left); -LUAU_EQSAT_FIELD(Right); -LUAU_EQSAT_NODE_FIELDS(Add, Left, Right); +LUAU_EQSAT_NODE_ARRAY(Add, 2); using namespace Luau; @@ -117,8 +115,8 @@ TEST_CASE("node_field") Add add{left, right}; - EqSat::Id left2 = add.field(); - EqSat::Id right2 = add.field(); + EqSat::Id left2 = add.operands()[0]; + EqSat::Id right2 = add.operands()[1]; CHECK(left == left2); CHECK(left != right2); @@ -135,10 +133,10 @@ TEST_CASE("language_operands") const Add* add = v2.get(); REQUIRE(add); - EqSat::Slice actual = v2.operands(); + EqSat::Slice actual = v2.operands(); CHECK(actual.size() == 2); - CHECK(actual[0] == add->field()); - CHECK(actual[1] == add->field()); + CHECK(actual[0] == add->operands()[0]); + CHECK(actual[1] == add->operands()[1]); } TEST_SUITE_END(); diff --git a/tests/EqSatSimplification.test.cpp b/tests/EqSatSimplification.test.cpp new file mode 100644 index 00000000..aaaec456 --- /dev/null +++ b/tests/EqSatSimplification.test.cpp @@ -0,0 +1,728 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "Luau/EqSatSimplification.h" + +using namespace Luau; + +struct ESFixture : Fixture +{ + ScopedFastFlag newSolverOnly{FFlag::LuauSolverV2, true}; + + TypeArena arena_; + const NotNull arena{&arena_}; + + SimplifierPtr simplifier; + + TypeId parentClass; + TypeId childClass; + TypeId anotherChild; + TypeId unrelatedClass; + + TypeId genericT = arena_.addType(GenericType{"T"}); + TypeId genericU = arena_.addType(GenericType{"U"}); + + TypeId numberToString = arena_.addType(FunctionType{ + arena_.addTypePack({builtinTypes->numberType}), + arena_.addTypePack({builtinTypes->stringType}) + }); + + TypeId stringToNumber = arena_.addType(FunctionType{ + arena_.addTypePack({builtinTypes->stringType}), + arena_.addTypePack({builtinTypes->numberType}) + }); + + ESFixture() + : simplifier(newSimplifier(arena, builtinTypes)) + { + createSomeClasses(&frontend); + + ScopePtr moduleScope = frontend.globals.globalScope; + + parentClass = moduleScope->linearSearchForBinding("Parent")->typeId; + childClass = moduleScope->linearSearchForBinding("Child")->typeId; + anotherChild = moduleScope->linearSearchForBinding("AnotherChild")->typeId; + unrelatedClass = moduleScope->linearSearchForBinding("Unrelated")->typeId; + } + + std::optional simplifyStr(TypeId ty) + { + auto res = eqSatSimplify(NotNull{simplifier.get()}, ty); + LUAU_ASSERT(res); + return toString(res->result); + } + + TypeId tbl(TableType::Props props) + { + return arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, TableState::Sealed}); + } +}; + +TEST_SUITE_BEGIN("EqSatSimplification"); + +TEST_CASE_FIXTURE(ESFixture, "primitive") +{ + CHECK("number" == simplifyStr(builtinTypes->numberType)); +} + +TEST_CASE_FIXTURE(ESFixture, "number | number") +{ + TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->numberType}}); + + CHECK("number" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "number | string") +{ + CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = number | t1") +{ + TypeId ty = arena->freshType(nullptr); + asMutable(ty)->ty.emplace(std::vector{builtinTypes->numberType, ty}); + + CHECK("number" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "number | string | number") +{ + TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType, builtinTypes->numberType}}); + + CHECK("number | string" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "string | (number | string) | number") +{ + TypeId u1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}); + TypeId u2 = arena->addType(UnionType{{builtinTypes->stringType, u1, builtinTypes->numberType}}); + + CHECK("number | string" == simplifyStr(u2)); +} + +TEST_CASE_FIXTURE(ESFixture, "string | any") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->anyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "any | string") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "any | never") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string | unknown") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->unknownType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "unknown | string") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "unknown | never") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string | never") +{ + CHECK("string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string | never | number") +{ + CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType, builtinTypes->numberType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & string") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & number") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->numberType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & unknown") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->unknownType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "never & string") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->neverType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & (unknown | never)") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->stringType, + arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}}) + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "true | false") +{ + CHECK("boolean" == simplifyStr(arena->addType(UnionType{{builtinTypes->trueType, builtinTypes->falseType}}))); +} + +/* + * Intuitively, if we have a type like + * + * x where x = A & B & (C | D | x) + * + * We know that x is certainly not larger than A & B. + * We also know that the union (C | D | x) can be rewritten `(C | D | (A & B & (C | D | x))) + * This tells us that the union part is not smaller than A & B. + * We can therefore discard the union entirely and simplify this type to A & B + */ +TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (number | t1)") +{ + TypeId intersectionTy = arena->addType(BlockedType{}); + TypeId unionTy = arena->addType(UnionType{{builtinTypes->numberType, intersectionTy}}); + + asMutable(intersectionTy)->ty.emplace(std::vector{builtinTypes->stringType, unionTy}); + + CHECK("string" == simplifyStr(intersectionTy)); +} + +TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (unknown | t1)") +{ + TypeId intersectionTy = arena->addType(BlockedType{}); + TypeId unionTy = arena->addType(UnionType{{builtinTypes->unknownType, intersectionTy}}); + + asMutable(intersectionTy)->ty.emplace(std::vector{builtinTypes->stringType, unionTy}); + + CHECK("string" == simplifyStr(intersectionTy)); +} + +TEST_CASE_FIXTURE(ESFixture, "error | unknown") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->errorType, builtinTypes->unknownType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "\"hello\" | string") +{ + CHECK("string" == simplifyStr(arena->addType(UnionType{{ + arena->addType(SingletonType{StringSingleton{"hello"}}), builtinTypes->stringType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "\"hello\" | \"world\" | \"hello\"") +{ + CHECK("\"hello\" | \"world\"" == simplifyStr(arena->addType(UnionType{{ + arena->addType(SingletonType{StringSingleton{"hello"}}), + arena->addType(SingletonType{StringSingleton{"world"}}), + arena->addType(SingletonType{StringSingleton{"hello"}}), + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "nil | boolean | number | string | thread | function | table | class | buffer") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{ + builtinTypes->nilType, + builtinTypes->booleanType, + builtinTypes->numberType, + builtinTypes->stringType, + builtinTypes->threadType, + builtinTypes->functionType, + builtinTypes->tableType, + builtinTypes->classType, + builtinTypes->bufferType, + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Parent & number") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ + parentClass, builtinTypes->numberType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & Parent") +{ + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{ + childClass, parentClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & Unrelated") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ + childClass, unrelatedClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child | Parent") +{ + CHECK("Parent" == simplifyStr(arena->addType(UnionType{{ + childClass, parentClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "class | Child") +{ + CHECK("class" == simplifyStr(arena->addType(UnionType{{ + builtinTypes->classType, childClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Parent | class | Child") +{ + CHECK("class" == simplifyStr(arena->addType(UnionType{{ + parentClass, builtinTypes->classType, childClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Parent | Unrelated") +{ + CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{ + parentClass, unrelatedClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "never | Parent | Unrelated") +{ + CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{ + builtinTypes->neverType, parentClass, unrelatedClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "never | Parent | (number & string) | Unrelated") +{ + CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{ + builtinTypes->neverType, parentClass, + arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}), + unrelatedClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "T & U") +{ + CHECK("T & U" == simplifyStr(arena->addType(IntersectionType{{ + genericT, genericU + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "boolean & true") +{ + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->booleanType, builtinTypes->trueType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "boolean & (true | number | string | thread | function | table | class | buffer)") +{ + TypeId truthy = arena->addType(UnionType{{ + builtinTypes->trueType, + builtinTypes->numberType, + builtinTypes->stringType, + builtinTypes->threadType, + builtinTypes->functionType, + builtinTypes->tableType, + builtinTypes->classType, + builtinTypes->bufferType, + }}); + + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->booleanType, truthy + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "boolean & ~(false?)") +{ + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->booleanType, builtinTypes->truthyType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "false & ~(false?)") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->falseType, builtinTypes->truthyType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (number) -> string") +{ + CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, numberToString}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (number) -> string") +{ + CHECK("(number) -> string" == simplifyStr(arena->addType(UnionType{{numberToString, numberToString}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & function") +{ + CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->functionType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & boolean") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->booleanType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & string") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & ~function") +{ + TypeId notFunction = arena->addType(NegationType{builtinTypes->functionType}); + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, notFunction}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string | function") +{ + CHECK("function" == simplifyStr(arena->addType(UnionType{{numberToString, builtinTypes->functionType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (string) -> number") +{ + CHECK("((number) -> string) & ((string) -> number)" == simplifyStr(arena->addType(IntersectionType{{numberToString, stringToNumber}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (string) -> number") +{ + CHECK("((number) -> string) | ((string) -> number)" == simplifyStr(arena->addType(UnionType{{numberToString, stringToNumber}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "add") +{ + CHECK("number" == simplifyStr(arena->addType( + TypeFunctionInstanceType{builtinTypeFunctions().addFunc, { + builtinTypes->numberType, builtinTypes->numberType + }} + ))); +} + +TEST_CASE_FIXTURE(ESFixture, "union") +{ + CHECK("number" == simplifyStr(arena->addType( + TypeFunctionInstanceType{builtinTypeFunctions().unionFunc, { + builtinTypes->numberType, builtinTypes->numberType + }} + ))); +} + +TEST_CASE_FIXTURE(ESFixture, "never & ~string") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->neverType, + arena->addType(NegationType{builtinTypes->stringType}) + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "blocked & never") +{ + const TypeId blocked = arena->addType(BlockedType{}); + + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{blocked, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "blocked & ~number & function") +{ + const TypeId blocked = arena->addType(BlockedType{}); + const TypeId notNumber = arena->addType(NegationType{builtinTypes->numberType}); + + const TypeId ty = arena->addType(IntersectionType{{blocked, notNumber, builtinTypes->functionType}}); + + std::string expected = toString(blocked) + " & function"; + + CHECK(expected == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "(number | boolean | string | nil | table) & (false | nil)") +{ + const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->stringType, builtinTypes->nilType, builtinTypes->tableType}}); + + CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number | boolean | nil) & (false | nil)") +{ + const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->nilType}}); + + CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(boolean | nil) & (false | nil)") +{ + const TypeId t1 = arena->addType(UnionType{{builtinTypes->booleanType, builtinTypes->nilType}}); + + CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}}))); +} + +// (('a & false) | ('a & nil)) | number + +// Child & ~Parent +// ~Parent & Child +// ~Child & Parent +// Parent & ~Child +// ~Child & ~Parent +// ~Parent & ~Child + +TEST_CASE_FIXTURE(ESFixture, "free & string & number") +{ + Scope scope{builtinTypes->anyTypePack}; + const TypeId freeTy = arena->addType(FreeType{&scope}); + + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{freeTy, builtinTypes->numberType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(blocked & number) | (blocked & number)") +{ + const TypeId blocked = arena->addType(BlockedType{}); + const TypeId u = arena->addType(IntersectionType{{blocked, builtinTypes->numberType}}); + const TypeId ty = arena->addType(UnionType{{u, u}}); + + const std::string blockedStr = toString(blocked); + + CHECK(blockedStr + " & number" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "{} & unknown") +{ + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{ + tbl({}), + builtinTypes->unknownType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{} & table") +{ + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{ + tbl({}), + builtinTypes->tableType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{} & ~(false?)") +{ + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{ + tbl({}), + builtinTypes->truthyType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: number}") +{ + const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}}); + const TypeId hasX = tbl({{"x", builtinTypes->numberType}}); + + const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}}); + auto res = eqSatSimplify(NotNull{simplifier.get()}, ty); + + CHECK("{ x: number }" == toString(res->result)); + + // Also assert that we don't allocate a fresh TableType in this case. + CHECK(follow(res->result) == hasX); +} + +TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: ~(false?)}") +{ + const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}}); + const TypeId hasX = tbl({{"x", builtinTypes->truthyType}}); + + const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}}); + auto res = eqSatSimplify(NotNull{simplifier.get()}, ty); + + CHECK("{ x: number }" == toString(res->result)); +} + +TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) }") +{ + // {x: number?}? + const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}}); + + // {x: ~(false?)} + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + + const TypeId ty = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}}); + + CHECK("{ x: number }" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "never | (({ x: number? }?) & { x: ~(false?) })") +{ + // {x: number?}? + const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}}); + + // {x: ~(false?)} + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + + // ({x: number?}?) & {x: ~(false?)} + const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}}); + + const TypeId ty = arena->addType(UnionType{{builtinTypes->neverType, intersectionTy}}); + + CHECK("{ x: number }" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "({ x: number? }?) & { x: ~(false?) } & ~(false?)") +{ + // {x: number?}? + const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}}); + + // {x: ~(false?)} + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + + // ({x: number?}?) & {x: ~(false?)} & ~(false?) + const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}}); + + CHECK("{ x: number }" == simplifyStr(intersectionTy)); +} + +#if 0 +// TODO +TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) } & ~(false?)) | number") +{ + // ({ x: number? }?) & { x: ~(false?) } & ~(false?) + const TypeId xWithOptionalNumber = tbl({{"x", builtinTypes->optionalNumberType}}); + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}}); + const TypeId ty = arena->addType(UnionType{{intersectionTy, builtinTypes->numberType}}); + + CHECK("{ x: number } | number" == simplifyStr(ty)); +} +#endif + +TEST_CASE_FIXTURE(ESFixture, "number & no-refine") +{ + CHECK("number" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->noRefineType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{ x: number } & ~boolean") +{ + const TypeId tblTy = tbl(TableType::Props{{"x", builtinTypes->numberType}}); + + const TypeId ty = arena->addType(IntersectionType{{ + tblTy, + arena->addType(NegationType{builtinTypes->booleanType}) + }}); + + CHECK("{ x: number }" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "(nil & string)?") +{ + const TypeId nilAndString = arena->addType(IntersectionType{{builtinTypes->nilType, builtinTypes->stringType}}); + const TypeId ty = arena->addType(UnionType{{nilAndString, builtinTypes->nilType}}); + + CHECK("nil" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "string & \"hi\"") +{ + const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}}); + + CHECK("\"hi\"" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, hi}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & (\"hi\" | \"bye\")") +{ + const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}}); + const TypeId bye = arena->addType(SingletonType{StringSingleton{"bye"}}); + + CHECK("\"bye\" | \"hi\"" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->stringType, + arena->addType(UnionType{{hi, bye}}) + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & ~Child") +{ + const TypeId ty = arena->addType(IntersectionType{{ + arena->addType(UnionType{{childClass, unrelatedClass}}), + arena->addType(NegationType{childClass}) + }}); + + CHECK("Unrelated" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "string & ~Child") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->stringType, + arena->addType(NegationType{childClass}) + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & Child") +{ + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{ + arena->addType(UnionType{{childClass, unrelatedClass}}), + childClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(Child | AnotherChild) & ~Child") +{ + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{ + arena->addType(UnionType{{childClass, anotherChild}}), + childClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: never }") +{ + const TypeId ty = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->neverType}}); + + CHECK("never" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: number? } & { x: string }") +{ + const TypeId leftTable = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->optionalNumberType}}); + const TypeId rightTable = tbl({{"x", builtinTypes->stringType}}); + + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{leftTable, rightTable}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & add") +{ + const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}}); + const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{ + builtinTypeFunctions().addFunc, + {u, parentClass}, + {} + }); + + const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}}); + + CHECK("Child & add" == simplifyStr(intersection)); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & intersect") +{ + const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}}); + const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{ + builtinTypeFunctions().intersectFunc, + {u, parentClass}, + {} + }); + + const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}}); + + CHECK("Child" == simplifyStr(intersection)); +} + +// {someKey: ~any} +// +// Maybe something we could do here is to try to reduce the key, get the +// class->node mapping, and skip the extraction process if the class corresponds +// to TNever. + +// t1 where t1 = add, number> + +TEST_SUITE_END(); diff --git a/tests/Fixture.h b/tests/Fixture.h index 0db208d9..39222a25 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -293,3 +293,37 @@ using DifferFixtureWithBuiltins = DifferFixtureGeneric; } while (false) #define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result) + +#define LUAU_CHECK_HAS_KEY(map, key) \ + do \ + { \ + auto&& _m = (map); \ + auto&& _k = (key); \ + const size_t count = _m.count(_k); \ + CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \ + if (!count) \ + { \ + MESSAGE("Keys: (count " << _m.size() << ")"); \ + for (const auto& [k, v] : _m) \ + { \ + MESSAGE("\tkey: " << k); \ + } \ + } \ + } while (false) + +#define LUAU_CHECK_HAS_NO_KEY(map, key) \ + do \ + { \ + auto&& _m = (map); \ + auto&& _k = (key); \ + const size_t count = _m.count(_k); \ + CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \ + if (count) \ + { \ + MESSAGE("Keys: (count " << _m.size() << ")"); \ + for (const auto& [k, v] : _m) \ + { \ + MESSAGE("\tkey: " << k); \ + } \ + } \ + } while (false) diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index de2e9832..81e42f87 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -4,19 +4,37 @@ #include "Fixture.h" #include "Luau/Ast.h" #include "Luau/AstQuery.h" +#include "Luau/Autocomplete.h" +#include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" #include "Luau/Frontend.h" +#include "Luau/AutocompleteTypes.h" using namespace Luau; LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauStoreDFGOnModule2); +LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) +static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) +{ + return std::nullopt; +} struct FragmentAutocompleteFixture : Fixture { - ScopedFastFlag sffs[3] = {{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}, {FFlag::LuauStoreDFGOnModule2, true}}; + ScopedFastFlag sffs[4] = { + {FFlag::LuauAllowFragmentParsing, true}, + {FFlag::LuauSolverV2, true}, + {FFlag::LuauStoreDFGOnModule2, true}, + {FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete, true} + }; + FragmentAutocompleteFixture() + { + addGlobalBinding(frontend.globals, "table", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globals, "math", Binding{builtinTypes->anyType}); + } FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos) { ParseResult p = tryParse(source); // We don't care about parsing incomplete asts @@ -26,7 +44,6 @@ struct FragmentAutocompleteFixture : Fixture CheckResult checkBase(const std::string& document) { - ScopedFastFlag sff{FFlag::LuauSolverV2, true}; FrontendOptions opts; opts.retainFullTypeGraphs = true; return this->frontend.check("MainModule", opts); @@ -48,6 +65,16 @@ struct FragmentAutocompleteFixture : Fixture options.runLintChecks = false; return Luau::typecheckFragment(frontend, "MainModule", cursorPos, options, document); } + + FragmentAutocompleteResult autocompleteFragment(const std::string& document, Position cursorPos) + { + FrontendOptions options; + options.retainFullTypeGraphs = true; + // Don't strictly need this in the new solver + options.forAutocomplete = true; + options.runLintChecks = false; + return Luau::fragmentAutocomplete(frontend, document, "MainModule", cursorPos, options, nullCallback); + } }; TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); @@ -172,6 +199,13 @@ TEST_SUITE_END(); TEST_SUITE_BEGIN("FragmentAutocompleteParserTests"); +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer") +{ + check("local a ="); + auto fragment = parseFragment("local a =", Position(0, 10)); + CHECK_EQ("local a =", fragment.fragmentToParse); +} + TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") { auto res = check(R"( @@ -278,6 +312,33 @@ local y = 5 CHECK_EQ("y", std::string(rhs->name.value)); } +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope") +{ + + check(R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )"); + + auto fragment = parseFragment( + R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )", + Position{6, 0} + ); + + + + CHECK_EQ("function abc()\n local myInnerLocal = 1\n\n end\n", fragment.fragmentToParse); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); @@ -302,7 +363,7 @@ local z = x + y Position{3, 15} ); - auto opt = linearSearchForBinding(fragment.freshScope, "z"); + auto opt = linearSearchForBinding(fragment.freshScope.get(), "z"); REQUIRE(opt); CHECK_EQ("number", toString(*opt)); } @@ -326,9 +387,222 @@ local y = 5 Position{2, 11} ); - auto correct = linearSearchForBinding(fragment.freshScope, "z"); + auto correct = linearSearchForBinding(fragment.freshScope.get(), "z"); REQUIRE(correct); CHECK_EQ("number", toString(*correct)); } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("FragmentAutocompleteTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_simple_property_access") +{ + auto res = check( + R"( +local tbl = { abc = 1234} +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = autocompleteFragment( + R"( +local tbl = { abc = 1234} +tbl. +)", + Position{2, 5} + ); + + LUAU_ASSERT(fragment.freshScope); + + CHECK_EQ(1, fragment.acResults.entryMap.size()); + CHECK(fragment.acResults.entryMap.count("abc")); + CHECK_EQ(AutocompleteContext::Property, fragment.acResults.context); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_nested_property_access") +{ + auto res = check( + R"( +local tbl = { abc = { def = 1234, egh = false } } +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = autocompleteFragment( + R"( +local tbl = { abc = { def = 1234, egh = false } } +tbl.abc. +)", + Position{2, 8} + ); + + LUAU_ASSERT(fragment.freshScope); + + CHECK_EQ(2, fragment.acResults.entryMap.size()); + CHECK(fragment.acResults.entryMap.count("def")); + CHECK(fragment.acResults.entryMap.count("egh")); + CHECK_EQ(fragment.acResults.context, AutocompleteContext::Property); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "inline_autocomplete_picks_the_right_scope") +{ + auto res = check( + R"( +type Table = { a: number, b: number } +do + type Table = { x: string, y: string } +end +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = autocompleteFragment( + R"( +type Table = { a: number, b: number } +do + type Table = { x: string, y: string } + local a : T +end +)", + Position{4, 15} + ); + + LUAU_ASSERT(fragment.freshScope); + + REQUIRE(fragment.acResults.entryMap.count("Table")); + REQUIRE(fragment.acResults.entryMap["Table"].type); + const TableType* tv = get(follow(*fragment.acResults.entryMap["Table"].type)); + REQUIRE(tv); + CHECK(tv->props.count("x")); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nested_recursive_function") +{ + auto res = check(R"( +function foo() +end +)"); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = autocompleteFragment( + R"( +function foo() +end +)", + Position{2, 0} + ); + + CHECK(fragment.acResults.entryMap.count("foo")); + CHECK_EQ(AutocompleteContext::Statement, fragment.acResults.context); +} + + +// Start compatibility tests! + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "empty_program") +{ + check(""); + + auto frag = autocompleteFragment(" ", Position{0, 1}); + auto ac = frag.acResults; + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer") +{ + check("local a ="); + auto frag = autocompleteFragment("local a =", Position{0, 9}); + auto ac = frag.acResults; + + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "leave_numbers_alone") +{ + check("local a = 3."); + + auto frag = autocompleteFragment("local a = 3.", Position{0, 12}); + auto ac = frag.acResults; + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::Unknown); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "user_defined_globals") +{ + check("local myLocal = 4; "); + + auto frag = autocompleteFragment("local myLocal = 4; ", Position{0, 18}); + auto ac = frag.acResults; + + CHECK(ac.entryMap.count("myLocal")); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "dont_suggest_local_before_its_definition") +{ + check(R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )"); + + // autocomplete after abc but before myInnerLocal + auto fragment = autocompleteFragment( + R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end +)", + Position{3, 0} + ); + auto ac = fragment.acResults; + CHECK(ac.entryMap.count("myLocal")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal"); + + // autocomplete after my inner local + fragment = autocompleteFragment( + R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )", + Position{4, 0} + ); + ac = fragment.acResults; + CHECK(ac.entryMap.count("myLocal")); + CHECK(ac.entryMap.count("myInnerLocal")); + + fragment = autocompleteFragment( + R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )", + Position{6, 0} + ); + + ac = fragment.acResults; + CHECK(ac.entryMap.count("myLocal")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal"); +} + +TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 0ab402b5..69330057 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -18,6 +18,7 @@ LUAU_FASTINT(LuauParseErrorLimit) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport) namespace { @@ -2377,10 +2378,15 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") { ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag sff2{FFlag::LuauUserDefinedTypeFunParseExport, true}; AstStat* stat = parse(R"( type function foo() - return + return types.number + end + + export type function bar() + return types.string end )"); @@ -2417,7 +2423,6 @@ TEST_CASE_FIXTURE(Fixture, "invalid_user_defined_type_functions") { ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - matchParseError("export type function foo() end", "Type function cannot be exported"); matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'"); matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'"); } diff --git a/tests/RequireByString.test.cpp b/tests/RequireByString.test.cpp index 641323c2..f9bc3afb 100644 --- a/tests/RequireByString.test.cpp +++ b/tests/RequireByString.test.cpp @@ -424,6 +424,13 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireUnprefixedPath") assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"}); } +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithExtension") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency.luau"; + runProtectedRequire(path); + assertOutputContainsAll({"false", "error requiring module: consider removing the file extension"}); +} + TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias") { std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/alias_requirer"; diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 422315f9..dedf8824 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -964,6 +964,7 @@ TEST_CASE_FIXTURE(Fixture, "correct_stringification_user_defined_type_functions" std::vector{builtinTypes->numberType}, // Type Function Arguments {}, {AstName{"woohoo"}}, // Type Function Name + {}, }; Type tv{tftt}; diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp index 29d7e8a7..eca633a8 100644 --- a/tests/TypeFunction.user.test.cpp +++ b/tests/TypeFunction.user.test.cpp @@ -16,6 +16,8 @@ LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite) LUAU_FASTFLAG(LuauUserTypeFunFixMetatable) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionResetState) LUAU_FASTFLAG(LuauUserTypeFunNonstrict) +LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal) +LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport) TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); @@ -1298,4 +1300,92 @@ local a: foo<> = "a" LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "implicit_export") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; + + fileResolver.source["game/A"] = R"( +type function concat(a, b) + return types.singleton(a:value() .. b:value()) +end +export type Concat = concat +local a: concat<'first', 'second'> +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")"); + + CheckResult bResult = check(R"( +local Test = require(game.A); +local b: Test.Concat<'third', 'fourth'> + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + CHECK(toString(requireType("b")) == R"("thirdfourth")"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "local_scope") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; + + CheckResult result = check(R"( +type function foo() + return "hi" +end +local function test() + type function bar() + return types.singleton(foo()) + end + + return ("" :: any) :: bar<> +end +local a = test() + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(toString(requireType("a")) == R"("hi")"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "explicit_export") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; + ScopedFastFlag luauUserDefinedTypeFunParseExport{FFlag::LuauUserDefinedTypeFunParseExport, true}; + + fileResolver.source["game/A"] = R"( +export type function concat(a, b) + return types.singleton(a:value() .. b:value()) +end +local a: concat<'first', 'second'> +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")"); + + CheckResult bResult = check(R"( +local Test = require(game.A); +local b: Test.concat<'third', 'fourth'> + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + CHECK(toString(requireType("b")) == R"("thirdfourth")"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index ad4f9a85..3686f2d4 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -20,6 +20,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) +LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions) TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -681,6 +682,11 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") { + // CLI-114134: this code *probably* wants the egraph in order + // to work properly. The new solver either falls over or + // forces so many constraints as to be unreliable. + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( function bottomupmerge(comp, a, b, left, mid, right) local i, j = left, mid @@ -743,6 +749,11 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") { + // CLI-114134: this code *probably* wants the egraph in order + // to work properly. The new solver either falls over or + // forces so many constraints as to be unreliable. + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( function bottomupmerge(comp, a, b, left, mid, right) local i, j = left, mid @@ -2554,8 +2565,17 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type") { - if (!FFlag::LuauSolverV2) - return; + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontRefCountTypesInTypeFunctions, true} + }; + + // CLI-114134: This test: + // a) Has a kind of weird result (suggesting `number | false` is not great); + // b) Is force solving some constraints. + // We end up with a weird recursive type that, if you roughly look at it, is + // clearly `number`. Hopefully the egraph will be able to unfold this. + CheckResult result = check(R"( function fib(n) return n < 2 and 1 or fib(n-1) + fib(n-2) @@ -2565,9 +2585,7 @@ end LUAU_REQUIRE_ERRORS(result); auto err = get(result.errors.back()); LUAU_ASSERT(err); - CHECK("number" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("number" == toString(err->recommendedArgs[0].second)); + CHECK("false | number" == toString(err->recommendedReturn)); } TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type") @@ -2862,6 +2880,8 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun") TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") { + ScopedFastFlag _{FFlag::LuauDontRefCountTypesInTypeFunctions, true}; + CheckResult result = check(R"( function foo(player) local success,result = player:thing() @@ -2889,7 +2909,7 @@ TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") auto tm2 = get(result.errors[1]); REQUIRE(tm2); CHECK(toString(tm2->wantedTp) == "string"); - CHECK(toString(tm2->givenTp) == "buffer | class | function | number | string | table | thread | true"); + CHECK(toString(tm2->givenTp) == "(buffer | class | function | number | string | table | thread | true) & unknown"); } else { diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index fd8e06a7..80dddc67 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -24,6 +24,7 @@ LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues) +LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions) using namespace Luau; @@ -1730,4 +1731,36 @@ TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue") )")); } +TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontRefCountTypesInTypeFunctions, true} + }; + + LUAU_CHECK_NO_ERRORS(check(R"( + --!strict + local function foo(a : string?) + local b = a or "" + return b:upper() + end + )")); +} + +TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontRefCountTypesInTypeFunctions, true} + }; + + LUAU_CHECK_NO_ERRORS(check(R"( + --!strict + local function wtf(name: string?) + local message + message = "invalid alternate fiber: " .. (name or "UNNAMED alternate") + end + )")); +} + TEST_SUITE_END();