// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Ast.h" #include "Luau/Constraint.h" #include "Luau/ControlFlow.h" #include "Luau/DataFlowGraph.h" #include "Luau/EqSatSimplification.h" #include "Luau/InsertionOrderedMap.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/Normalize.h" #include "Luau/NotNull.h" #include "Luau/Refinement.h" #include "Luau/Symbol.h" #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Variant.h" #include #include #include namespace Luau { struct Scope; using ScopePtr = std::shared_ptr; struct DcrLogger; struct TypeFunctionRuntime; struct Inference { TypeId ty = nullptr; RefinementId refinement = nullptr; Inference() = default; explicit Inference(TypeId ty, RefinementId refinement = nullptr) : ty(ty) , refinement(refinement) { } }; struct InferencePack { TypePackId tp = nullptr; std::vector refinements; InferencePack() = default; explicit InferencePack(TypePackId tp, const std::vector& refinements = {}) : tp(tp) , refinements(refinements) { } }; struct ConstraintGenerator { // A list of all the scopes in the module. This vector holds ownership of the // scope pointers; the scopes themselves borrow pointers to other scopes to // define the scope hierarchy. std::vector> scopes; ModulePtr module; NotNull builtinTypes; const NotNull arena; // The root scope of the module we're generating constraints for. // This is null when the CG is initially constructed. Scope* rootScope; TypeContext typeContext = TypeContext::Default; struct InferredBinding { Scope* scope; Location location; TypeIds types; }; // Some locals have multiple type states. We wish for Scope::bindings to // map each local name onto the union of every type that the local can have // over its lifetime, so we use this map to accumulate the set of types it // might have. // // See the functions recordInferredBinding and fillInInferredBindings. DenseHashMap inferredBindings{{}}; // Constraints that go straight to the solver. std::vector constraints; // Constraints that do not go to the solver right away. Other constraints // will enqueue them during solving. std::vector unqueuedConstraints; // The private scope of type aliases for which the type parameters belong to. DenseHashMap astTypeAliasDefiningScopes{nullptr}; NotNull dfg; RefinementArena refinementArena; int recursionCount = 0; // It is pretty uncommon for constraint generation to itself produce errors, but it can happen. std::vector errors; // Needed to be able to enable error-suppression preservation for immediate refinements. NotNull normalizer; NotNull simplifier; // Needed to register all available type functions for execution at later stages. NotNull typeFunctionRuntime; // Needed to resolve modules to make 'require' import types properly. NotNull moduleResolver; // Occasionally constraint generation needs to produce an ICE. const NotNull ice; ScopePtr globalScope; std::function prepareModuleScope; std::vector requireCycles; DenseHashMap localTypes{nullptr}; DcrLogger* logger; ConstraintGenerator( ModulePtr module, NotNull normalizer, NotNull simplifier, NotNull typeFunctionRuntime, NotNull moduleResolver, NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, std::function prepareModuleScope, DcrLogger* logger, NotNull dfg, std::vector requireCycles ); /** * The entry point to the ConstraintGenerator. This will construct a set * of scopes, constraints, and free types that can be solved later. * @param block the root block to generate constraints for. */ void visitModuleRoot(AstStatBlock* block); void visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block); private: std::vector> interiorTypes; /** * Fabricates a new free type belonging to a given scope. * @param scope the scope the free type belongs to. */ TypeId freshType(const ScopePtr& scope); /** * Fabricates a new free type pack belonging to a given scope. * @param scope the scope the free type pack belongs to. */ TypePackId freshTypePack(const ScopePtr& scope); /** * Allocate a new TypePack with the given head and tail. * * Avoids allocating 0-length type packs: * * If the head is non-empty, allocate and return a type pack with the given * head and tail. * If the head is empty and tail is non-empty, return *tail. * If both the head and tail are empty, return an empty type pack. */ TypePackId addTypePack(std::vector head, std::optional tail); /** * Fabricates a scope that is a child of another scope. * @param node the lexical node that the scope belongs to. * @param parent the parent scope of the new scope. Must not be null. */ ScopePtr childScope(AstNode* node, const ScopePtr& parent); std::optional lookup(const ScopePtr& scope, Location location, DefId def, bool prototype = true); /** * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. * @param cv the constraint variant to add. * @return the pointer to the inserted constraint */ NotNull addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); /** * Adds a constraint to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param c the constraint to add. * @return the pointer to the inserted constraint */ NotNull addConstraint(const ScopePtr& scope, std::unique_ptr c); struct RefinementPartition { // Types that we want to intersect against the type of the expression. std::vector discriminantTypes; // Sometimes the type we're discriminating against is implicitly nil. bool shouldAppendNilType = false; }; using RefinementContext = InsertionOrderedMap; void unionRefinements( const ScopePtr& scope, Location location, const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector* constraints ); void computeRefinement( const ScopePtr& scope, Location location, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector* constraints ); void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement); LUAU_NOINLINE void checkAliases(const ScopePtr& scope, AstStatBlock* block); ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); ControlFlow visitBlockWithoutChildScope_DEPRECATED(const ScopePtr& scope, AstStatBlock* block); ControlFlow visit(const ScopePtr& scope, AstStat* stat); ControlFlow visit(const ScopePtr& scope, AstStatBlock* block); ControlFlow visit(const ScopePtr& scope, AstStatLocal* local); ControlFlow visit(const ScopePtr& scope, AstStatFor* for_); ControlFlow visit(const ScopePtr& scope, AstStatForIn* forIn); ControlFlow visit(const ScopePtr& scope, AstStatWhile* while_); ControlFlow visit(const ScopePtr& scope, AstStatRepeat* repeat); ControlFlow visit(const ScopePtr& scope, AstStatLocalFunction* function); ControlFlow visit(const ScopePtr& scope, AstStatFunction* function); ControlFlow visit(const ScopePtr& scope, AstStatReturn* ret); ControlFlow visit(const ScopePtr& scope, AstStatAssign* assign); ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign); ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement); ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias); ControlFlow visit(const ScopePtr& scope, AstStatTypeFunction* function); ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); ControlFlow visit(const ScopePtr& scope, AstStatError* error); InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes = {}); InferencePack checkPack( const ScopePtr& scope, AstExpr* expr, const std::vector>& expectedTypes = {}, bool generalize = true ); InferencePack checkPack(const ScopePtr& scope, AstExprCall* call); /** * Checks an expression that is expected to evaluate to one type. * @param scope the scope the expression is contained within. * @param expr the expression to check. * @param expectedType the type of the expression that is expected from its * surrounding context. Used to implement bidirectional type checking. * @param generalize If true, generalize any lambdas that are encountered. * @return the type of the expression. */ Inference check( const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}, bool forceSingleton = false, bool generalize = true ); Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprLocal* local); Inference check(const ScopePtr& scope, AstExprGlobal* global); Inference checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation); Inference check(const ScopePtr& scope, AstExprIndexName* indexName); Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional expectedType, bool generalize); Inference check(const ScopePtr& scope, AstExprUnary* unary); Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); Inference check(const ScopePtr& scope, AstExprInterpString* interpString); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType); void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType); void visitLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId rhsType); void visitLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId rhsType); void visitLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId rhsType); struct FunctionSignature { // The type of the function. TypeId signature; // The scope that encompasses the function's signature. May be nullptr // if there was no need for a signature scope (the function has no // generics). ScopePtr signatureScope; // The scope that encompasses the function's body. Is a child scope of // signatureScope, if present. ScopePtr bodyScope; }; FunctionSignature checkFunctionSignature( const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType = {}, std::optional originalName = {} ); /** * Checks the body of a function expression. * @param scope the interior scope of the body of the function. * @param fn the function expression to check. */ void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn); // Specializations of 'resolveType' below TypeId resolveReferenceType(const ScopePtr& scope, AstType* ty, AstTypeReference* ref, bool inTypeArguments, bool replaceErrorWithFresh); TypeId resolveTableType(const ScopePtr& scope, AstType* ty, AstTypeTable* tab, bool inTypeArguments, bool replaceErrorWithFresh); TypeId resolveFunctionType(const ScopePtr& scope, AstType* ty, AstTypeFunction* fn, bool inTypeArguments, bool replaceErrorWithFresh); /** * Resolves a type from its AST annotation. * @param scope the scope that the type annotation appears within. * @param ty the AST annotation to resolve. * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type of the AST annotation. **/ TypeId resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh = false); /** * Resolves a type pack from its AST annotation. * @param scope the scope that the type annotation appears within. * @param tp the AST annotation to resolve. * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type pack of the AST annotation. **/ TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp, bool inTypeArguments, bool replaceErrorWithFresh = false); /** * Resolves a type pack from its AST annotation. * @param scope the scope that the type annotation appears within. * @param list the AST annotation to resolve. * @param inTypeArguments whether we are resolving a type that's contained within type arguments, `<...>`. * @return the type pack of the AST annotation. **/ TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments, bool replaceErrorWithFresh = false); /** * Creates generic types given a list of AST definitions, resolving default * types as required. * @param scope the scope that the generics should belong to. * @param generics the AST generics to create types for. * @param useCache whether to use the generic type cache for the given * scope. * @param addTypes whether to add the types to the scope's * privateTypeBindings map. **/ std::vector> createGenerics( const ScopePtr& scope, AstArray generics, bool useCache = false, bool addTypes = true ); /** * Creates generic type packs given a list of AST definitions, resolving * default type packs as required. * @param scope the scope that the generic packs should belong to. * @param generics the AST generics to create type packs for. * @param useCache whether to use the generic type pack cache for the given * scope. * @param addTypes whether to add the types to the scope's * privateTypePackBindings map. **/ std::vector> createGenericPacks( const ScopePtr& scope, AstArray packs, bool useCache = false, bool addTypes = true ); Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); void reportError(Location location, TypeErrorData err); void reportCodeTooComplex(Location location); // make a union type function of these two types TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); // make an intersect type function of these two types TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); void prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program); /** Scan the program for global definitions. * * ConstraintGenerator needs to differentiate between globals and accesses to undefined symbols. Doing this "for * real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an * initial scan of the AST and note what globals are defined. */ void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); bool recordPropertyAssignment(TypeId ty); // Record the fact that a particular local has a particular type in at least // one of its states. void recordInferredBinding(AstLocal* local, TypeId ty); void fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block); /** Given a function type annotation, return a vector describing the expected types of the calls to the function * For example, calling a function with annotation ((number) -> string & ((string) -> number)) * yields a vector of size 1, with value: [number | string] */ std::vector> getExpectedCallTypesForFunctionOverloads(const TypeId fnType); TypeId createTypeFunctionInstance( const TypeFunction& function, std::vector typeArguments, std::vector packArguments, const ScopePtr& scope, Location location ); TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right); }; /** Borrow a vector of pointers from a vector of owning pointers to constraints. */ std::vector> borrowConstraints(const std::vector& constraints); } // namespace Luau