Sync to upstream/release/637 (#1354)

# What's Changed?

- Code refactoring with a new clang-format
- More bug fixes / test case fixes in the new solver

## New Solver

- More precise telemetry collection of `any` types
- Simplification of two completely disjoint tables combines them into a
single table that inherits all properties / indexers
- Refining a `never & <anything>` does not produce type family types nor
constraints
- Silence "inference failed to complete" error when it is the only error
reported

---
### Internal Contributors

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Dibri Nsofor <dnsofor@roblox.com>
Co-authored-by: Jeremy Yoo <jyoo@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

---------

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Alexander McCord <amccord@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Vighnesh <vvijay@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: David Cope <dcope@roblox.com>
Co-authored-by: Lily Brown <lbrown@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
This commit is contained in:
Junseo Yoo 2024-08-02 07:30:04 -07:00 committed by GitHub
parent 58f8c24ddb
commit ce8495a69e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
218 changed files with 11074 additions and 5108 deletions

View File

@ -16,10 +16,14 @@ SortIncludes: false
IndentWidth: 4 IndentWidth: 4
TabWidth: 4 TabWidth: 4
ObjCBlockIndentWidth: 4 ObjCBlockIndentWidth: 4
AlignAfterOpenBracket: DontAlign
UseTab: Never UseTab: Never
PointerAlignment: Left PointerAlignment: Left
SpaceAfterTemplateKeyword: false SpaceAfterTemplateKeyword: false
AlignEscapedNewlines: DontAlign AlignEscapedNewlines: DontAlign
AlwaysBreakTemplateDeclarations: Yes AlwaysBreakTemplateDeclarations: Yes
MaxEmptyLinesToKeep: 10 MaxEmptyLinesToKeep: 10
AllowAllParametersOfDeclarationOnNextLine: false
AlignAfterOpenBracket: BlockIndent
BinPackArguments: false
BinPackParameters: false
PenaltyReturnTypeOnItsOwnLine: 10000

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/AstQuery.h"
#include "Luau/Config.h" #include "Luau/Config.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
@ -36,6 +37,7 @@ struct AnyTypeSummary
{ {
TypeArena arena; TypeArena arena;
AstStatBlock* rootSrc = nullptr;
DenseHashSet<TypeId> seenTypeFamilyInstances{nullptr}; DenseHashSet<TypeId> seenTypeFamilyInstances{nullptr};
int recursionCount = 0; int recursionCount = 0;
@ -47,31 +49,28 @@ struct AnyTypeSummary
AnyTypeSummary(); AnyTypeSummary();
void traverse(Module* module, AstStat* src, NotNull<BuiltinTypes> builtinTypes); void traverse(const Module* module, AstStat* src, NotNull<BuiltinTypes> builtinTypes);
std::pair<bool, TypeId> checkForAnyCast(Scope* scope, AstExprTypeAssertion* expr); std::pair<bool, TypeId> checkForAnyCast(const Scope* scope, AstExprTypeAssertion* expr);
// Todo: errors resolved by anys
void reportError(Location location, TypeErrorData err);
bool containsAny(TypePackId typ); bool containsAny(TypePackId typ);
bool containsAny(TypeId typ); bool containsAny(TypeId typ);
bool isAnyCast(Scope* scope, AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes); bool isAnyCast(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool isAnyCall(Scope* scope, AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes); bool isAnyCall(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasVariadicAnys(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes); bool hasVariadicAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasArgAnys(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes); bool hasArgAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasAnyReturns(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes); bool hasAnyReturns(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
TypeId checkForFamilyInhabitance(TypeId instance, Location location); TypeId checkForFamilyInhabitance(const TypeId instance, Location location);
TypeId lookupType(AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes); TypeId lookupType(const AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
TypePackId reconstructTypePack(AstArray<AstExpr*> exprs, Module* module, NotNull<BuiltinTypes> builtinTypes); TypePackId reconstructTypePack(const AstArray<AstExpr*> exprs, const Module* module, NotNull<BuiltinTypes> builtinTypes);
DenseHashSet<TypeId> seenTypeFunctionInstances{nullptr}; DenseHashSet<TypeId> seenTypeFunctionInstances{nullptr};
TypeId lookupAnnotation(AstType* annotation, Module* module, NotNull<BuiltinTypes> builtintypes); TypeId lookupAnnotation(AstType* annotation, const Module* module, NotNull<BuiltinTypes> builtintypes);
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation, Module* module); std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation, const Module* module);
TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location); TypeId checkForTypeFunctionInhabitance(const TypeId instance, const Location location);
enum Pattern: uint64_t enum Pattern: uint64_t
{ {
@ -91,11 +90,25 @@ struct AnyTypeSummary
Pattern code; Pattern code;
std::string node; std::string node;
TelemetryTypePair type; TelemetryTypePair type;
std::string debug;
explicit TypeInfo(Pattern code, std::string node, TelemetryTypePair type); explicit TypeInfo(Pattern code, std::string node, TelemetryTypePair type);
}; };
struct FindReturnAncestry final : public AstVisitor
{
AstNode* currNode{nullptr};
AstNode* stat{nullptr};
Position rootEnd;
bool found = false;
explicit FindReturnAncestry(AstNode* stat, Position rootEnd);
bool visit(AstType* node) override;
bool visit(AstNode* node) override;
bool visit(AstStatFunction* node) override;
bool visit(AstStatLocalFunction* node) override;
};
std::vector<TypeInfo> typeInfo; std::vector<TypeInfo> typeInfo;
/** /**
@ -103,29 +116,32 @@ struct AnyTypeSummary
* @param node the lexical node that the scope belongs to. * @param node the lexical node that the scope belongs to.
* @param parent the parent scope of the new scope. Must not be null. * @param parent the parent scope of the new scope. Must not be null.
*/ */
Scope* childScope(AstNode* node, const Scope* parent); const Scope* childScope(const AstNode* node, const Scope* parent);
Scope* findInnerMostScope(Location location, Module* module); std::optional<AstExpr*> matchRequire(const AstExprCall& call);
AstNode* getNode(AstStatBlock* root, AstNode* node);
const Scope* findInnerMostScope(const Location location, const Module* module);
const AstNode* findAstAncestryAtLocation(const AstStatBlock* root, AstNode* node);
void visit(Scope* scope, AstStat* stat, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStat* stat, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatBlock* block, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatBlock* block, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatIf* ifStatement, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatIf* ifStatement, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatWhile* while_, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatWhile* while_, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatRepeat* repeat, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatRepeat* repeat, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatReturn* ret, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatReturn* ret, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatLocal* local, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatFor* for_, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatFor* for_, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatForIn* forIn, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatForIn* forIn, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatAssign* assign, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatCompoundAssign* assign, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatCompoundAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatFunction* function, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatLocalFunction* function, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatLocalFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatTypeAlias* alias, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatTypeAlias* alias, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatDeclareGlobal* declareGlobal, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatDeclareGlobal* declareGlobal, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatDeclareClass* declareClass, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatDeclareClass* declareClass, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatDeclareFunction* declareFunction, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatDeclareFunction* declareFunction, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(Scope* scope, AstStatError* error, Module* module, NotNull<BuiltinTypes> builtinTypes); void visit(const Scope* scope, AstStatError* error, const Module* module, NotNull<BuiltinTypes> builtinTypes);
}; };
} // namespace Luau } // namespace Luau

View File

@ -19,10 +19,22 @@ using ScopePtr = std::shared_ptr<Scope>;
// A substitution which replaces free types by any // A substitution which replaces free types by any
struct Anyification : Substitution struct Anyification : Substitution
{ {
Anyification(TypeArena* arena, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType, Anyification(
TypePackId anyTypePack); TypeArena* arena,
Anyification(TypeArena* arena, const ScopePtr& scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler, TypeId anyType, NotNull<Scope> scope,
TypePackId anyTypePack); NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter* iceHandler,
TypeId anyType,
TypePackId anyTypePack
);
Anyification(
TypeArena* arena,
const ScopePtr& scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter* iceHandler,
TypeId anyType,
TypePackId anyTypePack
);
NotNull<Scope> scope; NotNull<Scope> scope;
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter* iceHandler; InternalErrorReporter* iceHandler;

View File

@ -25,21 +25,42 @@ TypeId makeOption(NotNull<BuiltinTypes> builtinTypes, TypeArena& arena, TypeId t
/** Small utility function for building up type definitions from C++. /** Small utility function for building up type definitions from C++.
*/ */
TypeId makeFunction( // Monomorphic TypeId makeFunction( // Monomorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes, TypeArena& arena,
bool checked = false); std::optional<TypeId> selfType,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<TypeId> retTypes,
bool checked = false
);
TypeId makeFunction( // Polymorphic TypeId makeFunction( // Polymorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, std::initializer_list<TypePackId> genericPacks, TypeArena& arena,
std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes, bool checked = false); std::optional<TypeId> selfType,
std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<TypeId> retTypes,
bool checked = false
);
TypeId makeFunction( // Monomorphic TypeId makeFunction( // Monomorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, TypeArena& arena,
std::initializer_list<TypeId> retTypes, bool checked = false); std::optional<TypeId> selfType,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes,
bool checked = false
);
TypeId makeFunction( // Polymorphic TypeId makeFunction( // Polymorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, std::initializer_list<TypePackId> genericPacks, TypeArena& arena,
std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes, std::optional<TypeId> selfType,
bool checked = false); std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes,
bool checked = false
);
void attachMagicFunction(TypeId ty, MagicFunction fn); void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn);

View File

@ -256,9 +256,24 @@ struct ReducePackConstraint
TypePackId tp; TypePackId tp;
}; };
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, IterableConstraint, NameConstraint, using ConstraintV = Variant<
TypeAliasExpansionConstraint, FunctionCallConstraint, FunctionCheckConstraint, PrimitiveTypeConstraint, HasPropConstraint, HasIndexerConstraint, SubtypeConstraint,
AssignPropConstraint, AssignIndexConstraint, UnpackConstraint, ReduceConstraint, ReducePackConstraint, EqualityConstraint>; PackSubtypeConstraint,
GeneralizationConstraint,
IterableConstraint,
NameConstraint,
TypeAliasExpansionConstraint,
FunctionCallConstraint,
FunctionCheckConstraint,
PrimitiveTypeConstraint,
HasPropConstraint,
HasIndexerConstraint,
AssignPropConstraint,
AssignIndexConstraint,
UnpackConstraint,
ReduceConstraint,
ReducePackConstraint,
EqualityConstraint>;
struct Constraint struct Constraint
{ {

View File

@ -122,9 +122,18 @@ struct ConstraintGenerator
DcrLogger* logger; DcrLogger* logger;
ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver, NotNull<BuiltinTypes> builtinTypes, ConstraintGenerator(
NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, ModulePtr module,
DcrLogger* logger, NotNull<DataFlowGraph> dfg, std::vector<RequireCycle> requireCycles); NotNull<Normalizer> normalizer,
NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
DcrLogger* logger,
NotNull<DataFlowGraph> dfg,
std::vector<RequireCycle> requireCycles
);
/** /**
* The entry point to the ConstraintGenerator. This will construct a set * The entry point to the ConstraintGenerator. This will construct a set
@ -195,10 +204,23 @@ private:
}; };
using RefinementContext = InsertionOrderedMap<DefId, RefinementPartition>; using RefinementContext = InsertionOrderedMap<DefId, RefinementPartition>;
void unionRefinements(const ScopePtr& scope, Location location, const RefinementContext& lhs, const RefinementContext& rhs, void unionRefinements(
RefinementContext& dest, std::vector<ConstraintV>* constraints); const ScopePtr& scope,
void computeRefinement(const ScopePtr& scope, Location location, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, Location location,
std::vector<ConstraintV>* constraints); const RefinementContext& lhs,
const RefinementContext& rhs,
RefinementContext& dest,
std::vector<ConstraintV>* constraints
);
void computeRefinement(
const ScopePtr& scope,
Location location,
RefinementId refinement,
RefinementContext* refis,
bool sense,
bool eq,
std::vector<ConstraintV>* constraints
);
void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement); void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement);
ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
@ -217,6 +239,7 @@ private:
ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign); ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign);
ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement); ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement);
ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias); ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias);
ControlFlow visit(const ScopePtr& scope, AstStatTypeFunction* function);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
@ -224,7 +247,11 @@ private:
InferencePack checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<std::optional<TypeId>>& expectedTypes = {}); InferencePack checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<std::optional<TypeId>>& expectedTypes = {});
InferencePack checkPack( InferencePack checkPack(
const ScopePtr& scope, AstExpr* expr, const std::vector<std::optional<TypeId>>& expectedTypes = {}, bool generalize = true); const ScopePtr& scope,
AstExpr* expr,
const std::vector<std::optional<TypeId>>& expectedTypes = {},
bool generalize = true
);
InferencePack checkPack(const ScopePtr& scope, AstExprCall* call); InferencePack checkPack(const ScopePtr& scope, AstExprCall* call);
@ -238,7 +265,12 @@ private:
* @return the type of the expression. * @return the type of the expression.
*/ */
Inference check( Inference check(
const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}, bool forceSingleton = false, bool generalize = true); const ScopePtr& scope,
AstExpr* expr,
std::optional<TypeId> expectedType = {},
bool forceSingleton = false,
bool generalize = true
);
Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
@ -276,7 +308,11 @@ private:
}; };
FunctionSignature checkFunctionSignature( FunctionSignature checkFunctionSignature(
const ScopePtr& parent, AstExprFunction* fn, std::optional<TypeId> expectedType = {}, std::optional<Location> originalName = {}); const ScopePtr& parent,
AstExprFunction* fn,
std::optional<TypeId> expectedType = {},
std::optional<Location> originalName = {}
);
/** /**
* Checks the body of a function expression. * Checks the body of a function expression.
@ -323,7 +359,11 @@ private:
* privateTypeBindings map. * privateTypeBindings map.
**/ **/
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics( std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache = false, bool addTypes = true); const ScopePtr& scope,
AstArray<AstGenericType> generics,
bool useCache = false,
bool addTypes = true
);
/** /**
* Creates generic type packs given a list of AST definitions, resolving * Creates generic type packs given a list of AST definitions, resolving
@ -336,7 +376,11 @@ private:
* privateTypePackBindings map. * privateTypePackBindings map.
**/ **/
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks( std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(
const ScopePtr& scope, AstArray<AstGenericTypePack> packs, bool useCache = false, bool addTypes = true); const ScopePtr& scope,
AstArray<AstGenericTypePack> packs,
bool useCache = false,
bool addTypes = true
);
Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack);
@ -371,7 +415,12 @@ private:
std::vector<std::optional<TypeId>> getExpectedCallTypesForFunctionOverloads(const TypeId fnType); std::vector<std::optional<TypeId>> getExpectedCallTypesForFunctionOverloads(const TypeId fnType);
TypeId createTypeFunctionInstance( TypeId createTypeFunctionInstance(
const TypeFunction& function, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments, const ScopePtr& scope, Location location); const TypeFunction& function,
std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments,
const ScopePtr& scope,
Location location
);
}; };
/** Borrow a vector of pointers from a vector of owning pointers to constraints. /** Borrow a vector of pointers from a vector of owning pointers to constraints.

View File

@ -109,9 +109,16 @@ struct ConstraintSolver
DenseHashMap<TypeId, const Constraint*> typeFunctionsToFinalize{nullptr}; DenseHashMap<TypeId, const Constraint*> typeFunctionsToFinalize{nullptr};
explicit ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints, explicit ConstraintSolver(
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger, NotNull<Normalizer> normalizer,
TypeCheckLimits limits); NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName,
NotNull<ModuleResolver> moduleResolver,
std::vector<RequireCycle> requireCycles,
DcrLogger* logger,
TypeCheckLimits limits
);
// Randomize the order in which to dispatch constraints // Randomize the order in which to dispatch constraints
void randomize(unsigned seed); void randomize(unsigned seed);
@ -170,7 +177,13 @@ public:
bool tryDispatchHasIndexer( bool tryDispatchHasIndexer(
int& recursionDepth, NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set<TypeId>& seen); int& recursionDepth,
NotNull<const Constraint> constraint,
TypeId subjectType,
TypeId indexType,
TypeId resultType,
Set<TypeId>& seen
);
bool tryDispatch(const HasIndexerConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const HasIndexerConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const AssignPropConstraint& c, NotNull<const Constraint> constraint);
@ -187,10 +200,23 @@ public:
// for a, ... in next_function, t, ... do // for a, ... in next_function, t, ... do
bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType, std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
const std::string& propName, ValueContext context, bool inConditional = false, bool suppressSimplification = false); NotNull<const Constraint> constraint,
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType, TypeId subjectType,
const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification, DenseHashSet<TypeId>& seen); const std::string& propName,
ValueContext context,
bool inConditional = false,
bool suppressSimplification = false
);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
NotNull<const Constraint> constraint,
TypeId subjectType,
const std::string& propName,
ValueContext context,
bool inConditional,
bool suppressSimplification,
DenseHashSet<TypeId>& seen
);
/** /**
* Generate constraints to unpack the types of srcTypes and assign each * Generate constraints to unpack the types of srcTypes and assign each

View File

@ -162,6 +162,7 @@ private:
ControlFlow visit(DfgScope* scope, AstStatFunction* f); ControlFlow visit(DfgScope* scope, AstStatFunction* f);
ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l); ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l);
ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t); ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t);
ControlFlow visit(DfgScope* scope, AstStatTypeFunction* f);
ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d); ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d); ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d); ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d);

View File

@ -126,7 +126,11 @@ struct DcrLogger
void captureInitialSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints); void captureInitialSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints);
StepSnapshot prepareStepSnapshot( StepSnapshot prepareStepSnapshot(
const Scope* rootScope, NotNull<const Constraint> current, bool force, const std::vector<NotNull<const Constraint>>& unsolvedConstraints); const Scope* rootScope,
NotNull<const Constraint> current,
bool force,
const std::vector<NotNull<const Constraint>>& unsolvedConstraints
);
void commitStepSnapshot(StepSnapshot snapshot); void commitStepSnapshot(StepSnapshot snapshot);
void captureFinalSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints); void captureFinalSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints);

View File

@ -62,7 +62,12 @@ struct DiffPathNodeLeaf
// TODO: Rename to anonymousIndex, for both union and Intersection // TODO: Rename to anonymousIndex, for both union and Intersection
std::optional<size_t> unionIndex; std::optional<size_t> unionIndex;
DiffPathNodeLeaf( DiffPathNodeLeaf(
std::optional<TypeId> ty, std::optional<Name> tableProperty, std::optional<int> minLength, bool isVariadic, std::optional<size_t> unionIndex) std::optional<TypeId> ty,
std::optional<Name> tableProperty,
std::optional<int> minLength,
bool isVariadic,
std::optional<size_t> unionIndex
)
: ty(ty) : ty(ty)
, tableProperty(tableProperty) , tableProperty(tableProperty)
, minLength(minLength) , minLength(minLength)
@ -159,7 +164,11 @@ struct DifferEnvironment
DenseHashMap<TypePackId, TypePackId> genericTpMatchedPairs; DenseHashMap<TypePackId, TypePackId> genericTpMatchedPairs;
DifferEnvironment( DifferEnvironment(
TypeId rootLeft, TypeId rootRight, std::optional<std::string> externalSymbolLeft, std::optional<std::string> externalSymbolRight) TypeId rootLeft,
TypeId rootRight,
std::optional<std::string> externalSymbolLeft,
std::optional<std::string> externalSymbolRight
)
: rootLeft(rootLeft) : rootLeft(rootLeft)
, rootRight(rootRight) , rootRight(rootRight)
, externalSymbolLeft(externalSymbolLeft) , externalSymbolLeft(externalSymbolLeft)

View File

@ -194,6 +194,11 @@ struct InternalError
bool operator==(const InternalError& rhs) const; bool operator==(const InternalError& rhs) const;
}; };
struct ConstraintSolvingIncompleteError
{
bool operator==(const ConstraintSolvingIncompleteError& rhs) const;
};
struct CannotCallNonFunction struct CannotCallNonFunction
{ {
TypeId ty; TypeId ty;
@ -443,15 +448,55 @@ struct UnexpectedTypePackInSubtyping
bool operator==(const UnexpectedTypePackInSubtyping& rhs) const; bool operator==(const UnexpectedTypePackInSubtyping& rhs) const;
}; };
using TypeErrorData = using TypeErrorData = Variant<
Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods, DuplicateTypeDefinition, TypeMismatch,
CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire, IncorrectGenericParameterCount, SyntaxError, UnknownSymbol,
CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError, CannotCallNonFunction, ExtraInformation, UnknownProperty,
DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning, DuplicateGenericParameter, CannotAssignToNever, NotATable,
CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, TypesAreUnrelated, CannotExtendTable,
NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe, UninhabitedTypeFunction, UninhabitedTypePackFunction, OnlyTablesCanHaveMethods,
WhereClauseNeeded, PackWhereClauseNeeded, CheckedFunctionCallError, NonStrictFunctionDefinitionError, PropertyAccessViolation, DuplicateTypeDefinition,
CheckedFunctionIncorrectArgs, UnexpectedTypeInSubtyping, UnexpectedTypePackInSubtyping, ExplicitFunctionAnnotationRecommended>; CountMismatch,
FunctionDoesNotTakeSelf,
FunctionRequiresSelf,
OccursCheckFailed,
UnknownRequire,
IncorrectGenericParameterCount,
SyntaxError,
CodeTooComplex,
UnificationTooComplex,
UnknownPropButFoundLikeProp,
GenericError,
InternalError,
ConstraintSolvingIncompleteError,
CannotCallNonFunction,
ExtraInformation,
DeprecatedApiUsed,
ModuleHasCyclicDependency,
IllegalRequire,
FunctionExitsWithoutReturning,
DuplicateGenericParameter,
CannotAssignToNever,
CannotInferBinaryOperation,
MissingProperties,
SwappedGenericTypeParameter,
OptionalValueAccess,
MissingUnionProperty,
TypesAreUnrelated,
NormalizationTooComplex,
TypePackMismatch,
DynamicPropertyLookupOnClassesUnsafe,
UninhabitedTypeFunction,
UninhabitedTypePackFunction,
WhereClauseNeeded,
PackWhereClauseNeeded,
CheckedFunctionCallError,
NonStrictFunctionDefinitionError,
PropertyAccessViolation,
CheckedFunctionIncorrectArgs,
UnexpectedTypeInSubtyping,
UnexpectedTypePackInSubtyping,
ExplicitFunctionAnnotationRecommended>;
struct TypeErrorSummary struct TypeErrorSummary
{ {

View File

@ -185,30 +185,55 @@ struct Frontend
void registerBuiltinDefinition(const std::string& name, std::function<void(Frontend&, GlobalTypes&, ScopePtr)>); void registerBuiltinDefinition(const std::string& name, std::function<void(Frontend&, GlobalTypes&, ScopePtr)>);
void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName);
LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, LoadDefinitionFileResult loadDefinitionFile(
bool captureComments, bool typeCheckForAutocomplete = false); GlobalTypes& globals,
ScopePtr targetScope,
std::string_view source,
const std::string& packageName,
bool captureComments,
bool typeCheckForAutocomplete = false
);
// Batch module checking. Queue modules and check them together, retrieve results with 'getCheckResult' // Batch module checking. Queue modules and check them together, retrieve results with 'getCheckResult'
// If provided, 'executeTask' function is allowed to call the 'task' function on any thread and return without waiting for 'task' to complete // If provided, 'executeTask' function is allowed to call the 'task' function on any thread and return without waiting for 'task' to complete
void queueModuleCheck(const std::vector<ModuleName>& names); void queueModuleCheck(const std::vector<ModuleName>& names);
void queueModuleCheck(const ModuleName& name); void queueModuleCheck(const ModuleName& name);
std::vector<ModuleName> checkQueuedModules(std::optional<FrontendOptions> optionOverride = {}, std::vector<ModuleName> checkQueuedModules(
std::function<void(std::function<void()> task)> executeTask = {}, std::function<bool(size_t done, size_t total)> progress = {}); std::optional<FrontendOptions> optionOverride = {},
std::function<void(std::function<void()> task)> executeTask = {},
std::function<bool(size_t done, size_t total)> progress = {}
);
std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false);
private: private:
ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector<RequireCycle> requireCycles, std::optional<ScopePtr> environmentScope, ModulePtr check(
bool forAutocomplete, bool recordJsonLog, TypeCheckLimits typeCheckLimits); const SourceModule& sourceModule,
Mode mode,
std::vector<RequireCycle> requireCycles,
std::optional<ScopePtr> environmentScope,
bool forAutocomplete,
bool recordJsonLog,
TypeCheckLimits typeCheckLimits
);
std::pair<SourceNode*, SourceModule*> getSourceNode(const ModuleName& name); std::pair<SourceNode*, SourceModule*> getSourceNode(const ModuleName& name);
SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions);
bool parseGraph( bool parseGraph(
std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete, std::function<bool(const ModuleName&)> canSkip = {}); std::vector<ModuleName>& buildQueue,
const ModuleName& root,
bool forAutocomplete,
std::function<bool(const ModuleName&)> canSkip = {}
);
void addBuildQueueItems(std::vector<BuildQueueItem>& items, std::vector<ModuleName>& buildQueue, bool cycleDetected, void addBuildQueueItems(
DenseHashSet<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions); std::vector<BuildQueueItem>& items,
std::vector<ModuleName>& buildQueue,
bool cycleDetected,
DenseHashSet<Luau::ModuleName>& seen,
const FrontendOptions& frontendOptions
);
void checkBuildQueueItem(BuildQueueItem& item); void checkBuildQueueItem(BuildQueueItem& item);
void checkBuildQueueItems(std::vector<BuildQueueItem>& items); void checkBuildQueueItems(std::vector<BuildQueueItem>& items);
void recordItemResult(const BuildQueueItem& item); void recordItemResult(const BuildQueueItem& item);
@ -248,14 +273,34 @@ public:
std::vector<ModuleName> moduleQueue; std::vector<ModuleName> moduleQueue;
}; };
ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes, ModulePtr check(
NotNull<InternalErrorReporter> iceHandler, NotNull<ModuleResolver> moduleResolver, NotNull<FileResolver> fileResolver, const SourceModule& sourceModule,
const ScopePtr& globalScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, FrontendOptions options, Mode mode,
TypeCheckLimits limits); const std::vector<RequireCycle>& requireCycles,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler,
NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver,
const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options,
TypeCheckLimits limits
);
ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes, ModulePtr check(
NotNull<InternalErrorReporter> iceHandler, NotNull<ModuleResolver> moduleResolver, NotNull<FileResolver> fileResolver, const SourceModule& sourceModule,
const ScopePtr& globalScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, FrontendOptions options, Mode mode,
TypeCheckLimits limits, bool recordJsonLog, std::function<void(const ModuleName&, std::string)> writeJsonLog); const std::vector<RequireCycle>& requireCycles,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler,
NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver,
const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options,
TypeCheckLimits limits,
bool recordJsonLog,
std::function<void(const ModuleName&, std::string)> writeJsonLog
);
} // namespace Luau } // namespace Luau

View File

@ -8,6 +8,12 @@
namespace Luau namespace Luau
{ {
std::optional<TypeId> generalize(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, std::optional<TypeId> generalize(
NotNull<DenseHashSet<TypeId>> bakedTypes, TypeId ty, /* avoid sealing tables*/ bool avoidSealingTables = false); NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> bakedTypes,
TypeId ty,
/* avoid sealing tables*/ bool avoidSealingTables = false
);
} }

View File

@ -17,8 +17,15 @@ struct TypeCheckLimits;
// A substitution which replaces generic types in a given set by free types. // A substitution which replaces generic types in a given set by free types.
struct ReplaceGenerics : Substitution struct ReplaceGenerics : Substitution
{ {
ReplaceGenerics(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope, ReplaceGenerics(
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks) const TxnLog* log,
TypeArena* arena,
NotNull<BuiltinTypes> builtinTypes,
TypeLevel level,
Scope* scope,
const std::vector<TypeId>& generics,
const std::vector<TypePackId>& genericPacks
)
: Substitution(log, arena) : Substitution(log, arena)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, level(level) , level(level)
@ -28,8 +35,15 @@ struct ReplaceGenerics : Substitution
{ {
} }
void resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope, void resetState(
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks); const TxnLog* log,
TypeArena* arena,
NotNull<BuiltinTypes> builtinTypes,
TypeLevel level,
Scope* scope,
const std::vector<TypeId>& generics,
const std::vector<TypePackId>& genericPacks
);
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
@ -141,6 +155,11 @@ struct GenericTypeFinder : TypeOnceVisitor
* limits to be exceeded. * limits to be exceeded.
*/ */
std::optional<TypeId> instantiate( std::optional<TypeId> instantiate(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, TypeId ty); NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
TypeId ty
);
} // namespace Luau } // namespace Luau

View File

@ -75,8 +75,16 @@ struct Instantiation2 : Substitution
}; };
std::optional<TypeId> instantiate2( std::optional<TypeId> instantiate2(
TypeArena* arena, DenseHashMap<TypeId, TypeId> genericSubstitutions, DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions, TypeId ty); TypeArena* arena,
std::optional<TypePackId> instantiate2(TypeArena* arena, DenseHashMap<TypeId, TypeId> genericSubstitutions, DenseHashMap<TypeId, TypeId> genericSubstitutions,
DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions, TypePackId tp); DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions,
TypeId ty
);
std::optional<TypePackId> instantiate2(
TypeArena* arena,
DenseHashMap<TypeId, TypeId> genericSubstitutions,
DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions,
TypePackId tp
);
} // namespace Luau } // namespace Luau

View File

@ -25,8 +25,14 @@ struct LintResult
std::vector<LintWarning> warnings; std::vector<LintWarning> warnings;
}; };
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, std::vector<LintWarning> lint(
const std::vector<HotComment>& hotcomments, const LintOptions& options); AstStat* root,
const AstNameTable& names,
const ScopePtr& env,
const Module* module,
const std::vector<HotComment>& hotcomments,
const LintOptions& options
);
std::vector<AstName> getDeprecatedGlobals(const AstNameTable& names); std::vector<AstName> getDeprecatedGlobals(const AstNameTable& names);

View File

@ -12,8 +12,15 @@ struct BuiltinTypes;
struct UnifierSharedState; struct UnifierSharedState;
struct TypeCheckLimits; struct TypeCheckLimits;
void checkNonStrict(NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState, void checkNonStrict(
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, const SourceModule& sourceModule, Module* module); NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg,
NotNull<TypeCheckLimits> limits,
const SourceModule& sourceModule,
Module* module
);
} // namespace Luau } // namespace Luau

View File

@ -31,8 +31,15 @@ struct OverloadResolver
OverloadIsNonviable, // Arguments were incompatible with the overloads parameters but were otherwise compatible by arity OverloadIsNonviable, // Arguments were incompatible with the overloads parameters but were otherwise compatible by arity
}; };
OverloadResolver(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<Normalizer> normalizer, NotNull<Scope> scope, OverloadResolver(
NotNull<InternalErrorReporter> reporter, NotNull<TypeCheckLimits> limits, Location callLocation); NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits,
Location callLocation
);
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
@ -58,11 +65,21 @@ private:
std::optional<ErrorVec> testIsSubtype(const Location& location, TypeId subTy, TypeId superTy); std::optional<ErrorVec> testIsSubtype(const Location& location, TypeId subTy, TypeId superTy);
std::optional<ErrorVec> testIsSubtype(const Location& location, TypePackId subTy, TypePackId superTy); std::optional<ErrorVec> testIsSubtype(const Location& location, TypePackId subTy, TypePackId superTy);
std::pair<Analysis, ErrorVec> checkOverload( std::pair<Analysis, ErrorVec> checkOverload(
TypeId fnTy, const TypePack* args, AstExpr* fnLoc, const std::vector<AstExpr*>* argExprs, bool callMetamethodOk = true); TypeId fnTy,
const TypePack* args,
AstExpr* fnLoc,
const std::vector<AstExpr*>* argExprs,
bool callMetamethodOk = true
);
static bool isLiteral(AstExpr* expr); static bool isLiteral(AstExpr* expr);
LUAU_NOINLINE LUAU_NOINLINE
std::pair<Analysis, ErrorVec> checkOverload_( std::pair<Analysis, ErrorVec> checkOverload_(
TypeId fnTy, const FunctionType* fn, const TypePack* args, AstExpr* fnExpr, const std::vector<AstExpr*>* argExprs); TypeId fnTy,
const FunctionType* fn,
const TypePack* args,
AstExpr* fnExpr,
const std::vector<AstExpr*>* argExprs
);
size_t indexof(Analysis analysis); size_t indexof(Analysis analysis);
void add(Analysis analysis, TypeId ty, ErrorVec&& errors); void add(Analysis analysis, TypeId ty, ErrorVec&& errors);
}; };
@ -88,8 +105,16 @@ struct SolveResult
// Helper utility, presently used for binary operator type functions. // Helper utility, presently used for binary operator type functions.
// //
// Given a function and a set of arguments, select a suitable overload. // Given a function and a set of arguments, select a suitable overload.
SolveResult solveFunctionCall(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Normalizer> normalizer, SolveResult solveFunctionCall(
NotNull<InternalErrorReporter> iceReporter, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, const Location& location, TypeId fn, NotNull<TypeArena> arena,
TypePackId argsPack); NotNull<BuiltinTypes> builtinTypes,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
const Location& location,
TypeId fn,
TypePackId argsPack
);
} // namespace Luau } // namespace Luau

View File

@ -140,8 +140,13 @@ struct Subtyping
SeenSet seenTypes{{}}; SeenSet seenTypes{{}};
Subtyping(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> typeArena, NotNull<Normalizer> normalizer, Subtyping(
NotNull<InternalErrorReporter> iceReporter, NotNull<Scope> scope); NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter,
NotNull<Scope> scope
);
Subtyping(const Subtyping&) = delete; Subtyping(const Subtyping&) = delete;
Subtyping& operator=(const Subtyping&) = delete; Subtyping& operator=(const Subtyping&) = delete;
@ -209,13 +214,19 @@ private:
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const Property& subProperty, const Property& superProperty, const std::string& name); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const Property& subProperty, const Property& superProperty, const std::string& name);
SubtypingResult isCovariantWith( SubtypingResult isCovariantWith(
SubtypingEnvironment& env, const std::shared_ptr<const NormalizedType>& subNorm, const std::shared_ptr<const NormalizedType>& superNorm); SubtypingEnvironment& env,
const std::shared_ptr<const NormalizedType>& subNorm,
const std::shared_ptr<const NormalizedType>& superNorm
);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const NormalizedClassType& superClass); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const NormalizedClassType& superClass);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const TypeIds& superTables); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const TypeIds& superTables);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const NormalizedStringType& superString); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const NormalizedStringType& superString);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const TypeIds& superTables); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const TypeIds& superTables);
SubtypingResult isCovariantWith( SubtypingResult isCovariantWith(
SubtypingEnvironment& env, const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction); SubtypingEnvironment& env,
const NormalizedFunctionType& subFunction,
const NormalizedFunctionType& superFunction
);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeIds& subTypes, const TypeIds& superTypes); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeIds& subTypes, const TypeIds& superTypes);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic);

View File

@ -14,7 +14,15 @@ struct BuiltinTypes;
struct Unifier2; struct Unifier2;
class AstExpr; class AstExpr;
TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes, NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes, TypeId matchLiteralType(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<Unifier2> unifier, TypeId expectedType, TypeId exprType, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
const AstExpr* expr, std::vector<TypeId>& toBlock); NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes,
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Unifier2> unifier,
TypeId expectedType,
TypeId exprType,
const AstExpr* expr,
std::vector<TypeId>& toBlock
);
} // namespace Luau } // namespace Luau

View File

@ -276,8 +276,8 @@ struct WithPredicate
} }
}; };
using MagicFunction = std::function<std::optional<WithPredicate<TypePackId>>( using MagicFunction = std::function<std::optional<
struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>; WithPredicate<TypePackId>>(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>;
struct MagicFunctionCallContext struct MagicFunctionCallContext
{ {
@ -305,19 +305,46 @@ struct FunctionType
FunctionType(TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false); FunctionType(TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
// Global polymorphic function // Global polymorphic function
FunctionType(std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes, FunctionType(
std::optional<FunctionDefinition> defn = {}, bool hasSelf = false); std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn = {},
bool hasSelf = false
);
// Local monomorphic function // Local monomorphic function
FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false); FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
FunctionType( FunctionType(
TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false); TypeLevel level,
Scope* scope,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn = {},
bool hasSelf = false
);
// Local polymorphic function // Local polymorphic function
FunctionType(TypeLevel level, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes, FunctionType(
std::optional<FunctionDefinition> defn = {}, bool hasSelf = false); TypeLevel level,
FunctionType(TypeLevel level, Scope* scope, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, std::vector<TypeId> generics,
TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false); std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn = {},
bool hasSelf = false
);
FunctionType(
TypeLevel level,
Scope* scope,
std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn = {},
bool hasSelf = false
);
std::optional<FunctionDefinition> definition; std::optional<FunctionDefinition> definition;
/// These should all be generic /// These should all be generic
@ -398,9 +425,15 @@ struct Property
// DEPRECATED // DEPRECATED
// TODO: Kill all constructors in favor of `Property::rw(TypeId read, TypeId write)` and friends. // TODO: Kill all constructors in favor of `Property::rw(TypeId read, TypeId write)` and friends.
Property(); Property();
Property(TypeId readTy, bool deprecated = false, const std::string& deprecatedSuggestion = "", std::optional<Location> location = std::nullopt, Property(
const Tags& tags = {}, const std::optional<std::string>& documentationSymbol = std::nullopt, TypeId readTy,
std::optional<Location> typeLocation = std::nullopt); bool deprecated = false,
const std::string& deprecatedSuggestion = "",
std::optional<Location> location = std::nullopt,
const Tags& tags = {},
const std::optional<std::string>& documentationSymbol = std::nullopt,
std::optional<Location> typeLocation = std::nullopt
);
// DEPRECATED: Should only be called in non-RWP! We assert that the `readTy` is not nullopt. // DEPRECATED: Should only be called in non-RWP! We assert that the `readTy` is not nullopt.
// TODO: Kill once we don't have non-RWP. // TODO: Kill once we don't have non-RWP.
@ -502,8 +535,16 @@ struct ClassType
std::optional<Location> definitionLocation; std::optional<Location> definitionLocation;
std::optional<TableIndexer> indexer; std::optional<TableIndexer> indexer;
ClassType(Name name, Props props, std::optional<TypeId> parent, std::optional<TypeId> metatable, Tags tags, ClassType(
std::shared_ptr<ClassUserData> userData, ModuleName definitionModuleName, std::optional<Location> definitionLocation) Name name,
Props props,
std::optional<TypeId> parent,
std::optional<TypeId> metatable,
Tags tags,
std::shared_ptr<ClassUserData> userData,
ModuleName definitionModuleName,
std::optional<Location> definitionLocation
)
: name(name) : name(name)
, props(props) , props(props)
, parent(parent) , parent(parent)
@ -515,9 +556,17 @@ struct ClassType
{ {
} }
ClassType(Name name, Props props, std::optional<TypeId> parent, std::optional<TypeId> metatable, Tags tags, ClassType(
std::shared_ptr<ClassUserData> userData, ModuleName definitionModuleName, std::optional<Location> definitionLocation, Name name,
std::optional<TableIndexer> indexer) Props props,
std::optional<TypeId> parent,
std::optional<TypeId> metatable,
Tags tags,
std::shared_ptr<ClassUserData> userData,
ModuleName definitionModuleName,
std::optional<Location> definitionLocation,
std::optional<TableIndexer> indexer
)
: name(name) : name(name)
, props(props) , props(props)
, parent(parent) , parent(parent)
@ -661,9 +710,26 @@ struct NegationType
using ErrorType = Unifiable::Error; using ErrorType = Unifiable::Error;
using TypeVariant = using TypeVariant = Unifiable::Variant<
Unifiable::Variant<TypeId, FreeType, GenericType, PrimitiveType, SingletonType, BlockedType, PendingExpansionType, FunctionType, TableType, TypeId,
MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType, TypeFunctionInstanceType>; FreeType,
GenericType,
PrimitiveType,
SingletonType,
BlockedType,
PendingExpansionType,
FunctionType,
TableType,
MetatableType,
ClassType,
AnyType,
UnionType,
IntersectionType,
LazyType,
UnknownType,
NeverType,
NegationType,
TypeFunctionInstanceType>;
struct Type final struct Type final
{ {

View File

@ -14,7 +14,13 @@ struct UnifierSharedState;
struct SourceModule; struct SourceModule;
struct Module; struct Module;
void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> sharedState, NotNull<TypeCheckLimits> limits, DcrLogger* logger, void check(
const SourceModule& sourceModule, Module* module); NotNull<BuiltinTypes> builtinTypes,
NotNull<UnifierSharedState> sharedState,
NotNull<TypeCheckLimits> limits,
DcrLogger* logger,
const SourceModule& sourceModule,
Module* module
);
} // namespace Luau } // namespace Luau

View File

@ -44,8 +44,14 @@ struct TypeFunctionContext
{ {
} }
TypeFunctionContext(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtins, NotNull<Scope> scope, NotNull<Normalizer> normalizer, TypeFunctionContext(
NotNull<InternalErrorReporter> ice, NotNull<TypeCheckLimits> limits) NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtins,
NotNull<Scope> scope,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> ice,
NotNull<TypeCheckLimits> limits
)
: arena(arena) : arena(arena)
, builtins(builtins) , builtins(builtins)
, scope(scope) , scope(scope)

View File

@ -62,7 +62,11 @@ struct HashBoolNamePair
struct TypeChecker struct TypeChecker
{ {
explicit TypeChecker( explicit TypeChecker(
const ScopePtr& globalScope, ModuleResolver* resolver, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler); const ScopePtr& globalScope,
ModuleResolver* resolver,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter* iceHandler
);
TypeChecker(const TypeChecker&) = delete; TypeChecker(const TypeChecker&) = delete;
TypeChecker& operator=(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete;
@ -85,6 +89,7 @@ struct TypeChecker
ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function);
ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function);
ControlFlow check(const ScopePtr& scope, const AstStatTypeAlias& typealias); ControlFlow check(const ScopePtr& scope, const AstStatTypeAlias& typealias);
ControlFlow check(const ScopePtr& scope, const AstStatTypeFunction& typefunction);
ControlFlow check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); ControlFlow check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
ControlFlow check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); ControlFlow check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
@ -96,7 +101,11 @@ struct TypeChecker
void checkBlockTypeAliases(const ScopePtr& scope, std::vector<AstStat*>& sorted); void checkBlockTypeAliases(const ScopePtr& scope, std::vector<AstStat*>& sorted);
WithPredicate<TypeId> checkExpr( WithPredicate<TypeId> checkExpr(
const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType = std::nullopt, bool forceSingleton = false); const ScopePtr& scope,
const AstExpr& expr,
std::optional<TypeId> expectedType = std::nullopt,
bool forceSingleton = false
);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprLocal& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprLocal& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprGlobal& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprVarargs& expr);
@ -107,17 +116,31 @@ struct TypeChecker
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType = std::nullopt); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprUnary& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprUnary& expr);
TypeId checkRelationalOperation( TypeId checkRelationalOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); const ScopePtr& scope,
const AstExprBinary& expr,
TypeId lhsType,
TypeId rhsType,
const PredicateVec& predicates = {}
);
TypeId checkBinaryOperation( TypeId checkBinaryOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); const ScopePtr& scope,
const AstExprBinary& expr,
TypeId lhsType,
TypeId rhsType,
const PredicateVec& predicates = {}
);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional<TypeId> expectedType = std::nullopt); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType = std::nullopt); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprInterpString& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprInterpString& expr);
TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector<std::pair<TypeId, TypeId>>& fieldTypes, TypeId checkExprTable(
std::optional<TypeId> expectedType); const ScopePtr& scope,
const AstExprTable& expr,
const std::vector<std::pair<TypeId, TypeId>>& fieldTypes,
std::optional<TypeId> expectedType
);
// Returns the type of the lvalue. // Returns the type of the lvalue.
TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx); TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx);
@ -130,34 +153,79 @@ struct TypeChecker
TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr, ValueContext ctx); TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr, ValueContext ctx);
TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level);
std::pair<TypeId, ScopePtr> checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::pair<TypeId, ScopePtr> checkFunctionSignature(
std::optional<Location> originalNameLoc, std::optional<TypeId> selfType, std::optional<TypeId> expectedType); const ScopePtr& scope,
int subLevel,
const AstExprFunction& expr,
std::optional<Location> originalNameLoc,
std::optional<TypeId> selfType,
std::optional<TypeId> expectedType
);
void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function);
void checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId paramPack, TypePackId argPack, void checkArgumentList(
const std::vector<Location>& argLocations); const ScopePtr& scope,
const AstExpr& funName,
Unifier& state,
TypePackId paramPack,
TypePackId argPack,
const std::vector<Location>& argLocations
);
WithPredicate<TypePackId> checkExprPack(const ScopePtr& scope, const AstExpr& expr); WithPredicate<TypePackId> checkExprPack(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr); WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr); WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr);
WithPredicate<TypePackId> checkExprPackHelper2( WithPredicate<TypePackId> checkExprPackHelper2(
const ScopePtr& scope, const AstExprCall& expr, TypeId selfType, TypeId actualFunctionType, TypeId functionType, TypePackId retPack); const ScopePtr& scope,
const AstExprCall& expr,
TypeId selfType,
TypeId actualFunctionType,
TypeId functionType,
TypePackId retPack
);
std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall); std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall);
std::unique_ptr<WithPredicate<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, std::unique_ptr<WithPredicate<TypePackId>> checkCallOverload(
TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult, const ScopePtr& scope,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors); const AstExprCall& expr,
bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations, TypeId fn,
const std::vector<OverloadErrorEntry>& errors); TypePackId retPack,
void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount, TypePack* args,
std::vector<OverloadErrorEntry>& errors); const std::vector<Location>* argLocations,
const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount,
std::vector<TypeId>& overloadsThatDont,
std::vector<OverloadErrorEntry>& errors
);
bool handleSelfCallMismatch(
const ScopePtr& scope,
const AstExprCall& expr,
TypePack* args,
const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors
);
void reportOverloadResolutionError(
const ScopePtr& scope,
const AstExprCall& expr,
TypePackId retPack,
TypePackId argPack,
const std::vector<Location>& argLocations,
const std::vector<TypeId>& overloads,
const std::vector<TypeId>& overloadsThatMatchArgCount,
std::vector<OverloadErrorEntry>& errors
);
WithPredicate<TypePackId> checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs, WithPredicate<TypePackId> checkExprList(
bool substituteFreeForNil = false, const std::vector<bool>& lhsAnnotations = {}, const ScopePtr& scope,
const std::vector<std::optional<TypeId>>& expectedTypes = {}); const Location& location,
const AstArray<AstExpr*>& exprs,
bool substituteFreeForNil = false,
const std::vector<bool>& lhsAnnotations = {},
const std::vector<std::optional<TypeId>>& expectedTypes = {}
);
static std::optional<AstExpr*> matchRequire(const AstExprCall& call); static std::optional<AstExpr*> matchRequire(const AstExprCall& call);
TypeId checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location); TypeId checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location);
@ -175,8 +243,13 @@ struct TypeChecker
*/ */
bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options); bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options);
bool unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, bool unify(
CountMismatch::Context ctx = CountMismatch::Context::Arg); TypePackId subTy,
TypePackId superTy,
const ScopePtr& scope,
const Location& location,
CountMismatch::Context ctx = CountMismatch::Context::Arg
);
/** Attempt to unify the types. /** Attempt to unify the types.
* If this fails, and the subTy type can be instantiated, do so and try unification again. * If this fails, and the subTy type can be instantiated, do so and try unification again.
@ -313,12 +386,23 @@ private:
TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation); TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams, TypeId instantiateTypeFun(
const std::vector<TypePackId>& typePackParams, const Location& location); const ScopePtr& scope,
const TypeFun& tf,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams,
const Location& location
);
// Note: `scope` must be a fresh scope. // Note: `scope` must be a fresh scope.
GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, GenericTypeDefinitions createGenericTypes(
const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericTypePack>& genericPackNames, bool useCache = false); const ScopePtr& scope,
std::optional<TypeLevel> levelOpt,
const AstNode& node,
const AstArray<AstGenericType>& genericNames,
const AstArray<AstGenericTypePack>& genericPackNames,
bool useCache = false
);
public: public:
void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);

View File

@ -56,14 +56,35 @@ struct InConditionalContext
using ScopePtr = std::shared_ptr<struct Scope>; using ScopePtr = std::shared_ptr<struct Scope>;
std::optional<Property> findTableProperty( std::optional<Property> findTableProperty(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId ty,
const std::string& name,
Location location
);
std::optional<TypeId> findMetatableEntry( std::optional<TypeId> findMetatableEntry(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId type,
const std::string& entry,
Location location
);
std::optional<TypeId> findTablePropertyRespectingMeta( std::optional<TypeId> findTablePropertyRespectingMeta(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId ty,
const std::string& name,
Location location
);
std::optional<TypeId> findTablePropertyRespectingMeta( std::optional<TypeId> findTablePropertyRespectingMeta(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, ValueContext context, Location location); NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId ty,
const std::string& name,
ValueContext context,
Location location
);
bool occursCheck(TypeId needle, TypeId haystack); bool occursCheck(TypeId needle, TypeId haystack);
@ -73,7 +94,12 @@ std::pair<size_t, std::optional<size_t>> getParameterExtents(const TxnLog* log,
// Extend the provided pack to at least `length` types. // Extend the provided pack to at least `length` types.
// Returns a temporary TypePack that contains those types plus a tail. // Returns a temporary TypePack that contains those types plus a tail.
TypePack extendTypePack( TypePack extendTypePack(
TypeArena& arena, NotNull<BuiltinTypes> builtinTypes, TypePackId pack, size_t length, std::vector<std::optional<TypeId>> overrides = {}); TypeArena& arena,
NotNull<BuiltinTypes> builtinTypes,
TypePackId pack,
size_t length,
std::vector<std::optional<TypeId>> overrides = {}
);
/** /**
* Reduces a union by decomposing to the any/error type if it appears in the * Reduces a union by decomposing to the any/error type if it appears in the

View File

@ -106,11 +106,21 @@ struct Unifier
* Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt.
*/ */
void tryUnify( void tryUnify(
TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); TypeId subTy,
TypeId superTy,
bool isFunctionCall = false,
bool isIntersection = false,
const LiteralProperties* aliasableMap = nullptr
);
private: private:
void tryUnify_( void tryUnify_(
TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); TypeId subTy,
TypeId superTy,
bool isFunctionCall = false,
bool isIntersection = false,
const LiteralProperties* aliasableMap = nullptr
);
void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy); void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy);
// Traverse the two types provided and block on any BlockedTypes we find. // Traverse the two types provided and block on any BlockedTypes we find.
@ -120,8 +130,14 @@ private:
void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall); void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall);
void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv);
void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall);
void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, void tryUnifyNormalizedTypes(
std::optional<TypeError> error = std::nullopt); TypeId subTy,
TypeId superTy,
const NormalizedType& subNorm,
const NormalizedType& superNorm,
std::string reason,
std::optional<TypeError> error = std::nullopt
);
void tryUnifyPrimitives(TypeId subTy, TypeId superTy); void tryUnifyPrimitives(TypeId subTy, TypeId superTy);
void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifySingletons(TypeId subTy, TypeId superTy);
void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false);

View File

@ -52,8 +52,13 @@ struct Unifier2
DenseHashSet<const void*>* uninhabitedTypeFunctions; DenseHashSet<const void*>* uninhabitedTypeFunctions;
Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<InternalErrorReporter> ice); Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<InternalErrorReporter> ice);
Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<InternalErrorReporter> ice, Unifier2(
DenseHashSet<const void*>* uninhabitedTypeFunctions); NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> ice,
DenseHashSet<const void*>* uninhabitedTypeFunctions
);
/** Attempt to commit the subtype relation subTy <: superTy to the type /** Attempt to commit the subtype relation subTy <: superTy to the type
* graph. * graph.

View File

@ -46,33 +46,12 @@ LUAU_FASTFLAG(DebugLuauMagicTypes);
namespace Luau namespace Luau
{ {
// TODO: instead of pair just type for solver? generated type void AnyTypeSummary::traverse(const Module* module, AstStat* src, NotNull<BuiltinTypes> builtinTypes)
// TODO: see lookupAnnotation in typechecker2. is cleaner than resolvetype
// or delay containsAny() check and do not return pair.
// quick flag in typeid saying was annotation or inferred, would be solid
std::optional<TypeOrPack> getInferredType(AstExpr* expr, Module* module)
{ {
std::optional<TypeOrPack> inferredType; visit(findInnerMostScope(src->location, module), src, module, builtinTypes);
if (module->astTypePacks.contains(expr))
{
inferredType = *module->astTypePacks.find(expr);
}
else if (module->astTypes.contains(expr))
{
inferredType = *module->astTypes.find(expr);
} }
return inferredType; void AnyTypeSummary::visit(const Scope* scope, AstStat* stat, const Module* module, NotNull<BuiltinTypes> builtinTypes)
}
void AnyTypeSummary::traverse(Module* module, AstStat* src, NotNull<BuiltinTypes> builtinTypes)
{
Scope* scope = findInnerMostScope(src->location, module);
visit(scope, src, module, builtinTypes);
}
void AnyTypeSummary::visit(Scope* scope, AstStat* stat, Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
RecursionLimiter limiter{&recursionCount, FInt::LuauAnySummaryRecursionLimit}; RecursionLimiter limiter{&recursionCount, FInt::LuauAnySummaryRecursionLimit};
@ -114,7 +93,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStat* stat, Module* module, NotNull<
return visit(scope, s, module, builtinTypes); return visit(scope, s, module, builtinTypes);
} }
void AnyTypeSummary::visit(Scope* scope, AstStatBlock* block, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatBlock* block, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
RecursionCounter counter{&recursionCount}; RecursionCounter counter{&recursionCount};
@ -125,37 +104,38 @@ void AnyTypeSummary::visit(Scope* scope, AstStatBlock* block, Module* module, No
visit(scope, stat, module, builtinTypes); visit(scope, stat, module, builtinTypes);
} }
void AnyTypeSummary::visit(Scope* scope, AstStatIf* ifStatement, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatIf* ifStatement, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
if (ifStatement->thenbody) if (ifStatement->thenbody)
{ {
Scope* thenScope = findInnerMostScope(ifStatement->thenbody->location, module); const Scope* thenScope = findInnerMostScope(ifStatement->thenbody->location, module);
visit(thenScope, ifStatement->thenbody, module, builtinTypes); visit(thenScope, ifStatement->thenbody, module, builtinTypes);
} }
if (ifStatement->elsebody) if (ifStatement->elsebody)
{ {
Scope* elseScope = findInnerMostScope(ifStatement->elsebody->location, module); const Scope* elseScope = findInnerMostScope(ifStatement->elsebody->location, module);
visit(elseScope, ifStatement->elsebody, module, builtinTypes); visit(elseScope, ifStatement->elsebody, module, builtinTypes);
} }
} }
void AnyTypeSummary::visit(Scope* scope, AstStatWhile* while_, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatWhile* while_, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
Scope* whileScope = findInnerMostScope(while_->location, module); const Scope* whileScope = findInnerMostScope(while_->location, module);
visit(whileScope, while_->body, module, builtinTypes); visit(whileScope, while_->body, module, builtinTypes);
} }
void AnyTypeSummary::visit(Scope* scope, AstStatRepeat* repeat, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatRepeat* repeat, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
Scope* repeatScope = findInnerMostScope(repeat->location, module); const Scope* repeatScope = findInnerMostScope(repeat->location, module);
visit(repeatScope, repeat->body, module, builtinTypes); visit(repeatScope, repeat->body, module, builtinTypes);
} }
void AnyTypeSummary::visit(Scope* scope, AstStatReturn* ret, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
// Scope* outScope = findOuterScope(ret->location, module); const Scope* retScope = findInnerMostScope(ret->location, module);
Scope* retScope = findInnerMostScope(ret->location, module);
auto ctxNode = getNode(rootSrc, ret);
for (auto val : ret->list) for (auto val : ret->list)
{ {
@ -163,7 +143,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStatReturn* ret, Module* module, Not
{ {
TelemetryTypePair types; TelemetryTypePair types;
types.inferredType = toString(lookupType(val, module, builtinTypes)); types.inferredType = toString(lookupType(val, module, builtinTypes));
TypeInfo ti{Pattern::FuncApp, toString(ret), types}; TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
@ -174,19 +154,19 @@ void AnyTypeSummary::visit(Scope* scope, AstStatReturn* ret, Module* module, Not
TelemetryTypePair types; TelemetryTypePair types;
types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes)); types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes));
auto inf = getInferredType(cast->expr, module); types.inferredType = toString(lookupType(cast->expr, module, builtinTypes));
if (inf)
types.inferredType = toString(*inf);
TypeInfo ti{Pattern::Casts, toString(ret), types}; TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
} }
} }
void AnyTypeSummary::visit(Scope* scope, AstStatLocal* local, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
auto ctxNode = getNode(rootSrc, local);
TypePackId values = reconstructTypePack(local->values, module, builtinTypes); TypePackId values = reconstructTypePack(local->values, module, builtinTypes);
auto [head, tail] = flatten(values); auto [head, tail] = flatten(values);
@ -203,18 +183,30 @@ void AnyTypeSummary::visit(Scope* scope, AstStatLocal* local, Module* module, No
TelemetryTypePair types; TelemetryTypePair types;
types.annotatedType = toString(annot); types.annotatedType = toString(annot);
types.inferredType = toString(lookupType(local->values.data[posn], module, builtinTypes));
auto inf = getInferredType(local->values.data[posn], module); TypeInfo ti{Pattern::VarAnnot, toString(ctxNode), types};
if (inf)
types.inferredType = toString(*inf);
TypeInfo ti{Pattern::VarAnnot, toString(local), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
const AstExprTypeAssertion* maybeRequire = local->values.data[posn]->as<AstExprTypeAssertion>();
if (!maybeRequire)
continue;
if (isAnyCast(scope, local->values.data[posn], module, builtinTypes))
{
TelemetryTypePair types;
types.inferredType = toString(head[std::min(local->values.size - 1, posn)]);
TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti);
}
} }
else else
{ {
if (std::min(local->values.size - 1, posn) < head.size()) if (std::min(local->values.size - 1, posn) < head.size())
{ {
if (loc->annotation) if (loc->annotation)
@ -227,7 +219,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStatLocal* local, Module* module, No
types.annotatedType = toString(annot); types.annotatedType = toString(annot);
types.inferredType = toString(head[std::min(local->values.size - 1, posn)]); types.inferredType = toString(head[std::min(local->values.size - 1, posn)]);
TypeInfo ti{Pattern::VarAnnot, toString(local), types}; TypeInfo ti{Pattern::VarAnnot, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
@ -242,7 +234,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStatLocal* local, Module* module, No
types.inferredType = toString(*tail); types.inferredType = toString(*tail);
TypeInfo ti{Pattern::VarAny, toString(local), types}; TypeInfo ti{Pattern::VarAny, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
@ -253,20 +245,22 @@ void AnyTypeSummary::visit(Scope* scope, AstStatLocal* local, Module* module, No
} }
} }
void AnyTypeSummary::visit(Scope* scope, AstStatFor* for_, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatFor* for_, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
Scope* forScope = findInnerMostScope(for_->location, module); const Scope* forScope = findInnerMostScope(for_->location, module);
visit(forScope, for_->body, module, builtinTypes); visit(forScope, for_->body, module, builtinTypes);
} }
void AnyTypeSummary::visit(Scope* scope, AstStatForIn* forIn, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatForIn* forIn, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
Scope* loopScope = findInnerMostScope(forIn->location, module); const Scope* loopScope = findInnerMostScope(forIn->location, module);
visit(loopScope, forIn->body, module, builtinTypes); visit(loopScope, forIn->body, module, builtinTypes);
} }
void AnyTypeSummary::visit(Scope* scope, AstStatAssign* assign, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
auto ctxNode = getNode(rootSrc, assign);
TypePackId values = reconstructTypePack(assign->values, module, builtinTypes); TypePackId values = reconstructTypePack(assign->values, module, builtinTypes);
auto [head, tail] = flatten(values); auto [head, tail] = flatten(values);
@ -290,7 +284,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStatAssign* assign, Module* module,
else else
types.inferredType = toString(builtinTypes->nilType); types.inferredType = toString(builtinTypes->nilType);
TypeInfo ti{Pattern::Assign, toString(assign), types}; TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
++posn; ++posn;
@ -302,11 +296,9 @@ void AnyTypeSummary::visit(Scope* scope, AstStatAssign* assign, Module* module,
{ {
TelemetryTypePair types; TelemetryTypePair types;
auto inf = getInferredType(val, module); types.inferredType = toString(lookupType(val, module, builtinTypes));
if (inf)
types.inferredType = toString(*inf);
TypeInfo ti{Pattern::FuncApp, toString(assign), types}; TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
@ -317,11 +309,9 @@ void AnyTypeSummary::visit(Scope* scope, AstStatAssign* assign, Module* module,
TelemetryTypePair types; TelemetryTypePair types;
types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes)); types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes));
auto inf = getInferredType(val, module); types.inferredType = toString(lookupType(val, module, builtinTypes));
if (inf)
types.inferredType = toString(*inf);
TypeInfo ti{Pattern::Casts, toString(assign), types}; TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
@ -335,14 +325,16 @@ void AnyTypeSummary::visit(Scope* scope, AstStatAssign* assign, Module* module,
types.inferredType = toString(*tail); types.inferredType = toString(*tail);
TypeInfo ti{Pattern::Assign, toString(assign), types}; TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
} }
void AnyTypeSummary::visit(Scope* scope, AstStatCompoundAssign* assign, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatCompoundAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
auto ctxNode = getNode(rootSrc, assign);
TelemetryTypePair types; TelemetryTypePair types;
types.inferredType = toString(lookupType(assign->value, module, builtinTypes)); types.inferredType = toString(lookupType(assign->value, module, builtinTypes));
@ -352,7 +344,7 @@ void AnyTypeSummary::visit(Scope* scope, AstStatCompoundAssign* assign, Module*
{ {
if (containsAny(*module->astTypes.find(assign->var))) if (containsAny(*module->astTypes.find(assign->var)))
{ {
TypeInfo ti{Pattern::Assign, toString(assign), types}; TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
@ -360,14 +352,14 @@ void AnyTypeSummary::visit(Scope* scope, AstStatCompoundAssign* assign, Module*
{ {
if (containsAny(*module->astTypePacks.find(assign->var))) if (containsAny(*module->astTypePacks.find(assign->var)))
{ {
TypeInfo ti{Pattern::Assign, toString(assign), types}; TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
if (isAnyCall(scope, assign->value, module, builtinTypes)) if (isAnyCall(scope, assign->value, module, builtinTypes))
{ {
TypeInfo ti{Pattern::FuncApp, toString(assign), types}; TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
@ -376,17 +368,15 @@ void AnyTypeSummary::visit(Scope* scope, AstStatCompoundAssign* assign, Module*
if (auto cast = assign->value->as<AstExprTypeAssertion>()) if (auto cast = assign->value->as<AstExprTypeAssertion>())
{ {
types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes)); types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes));
auto inf = getInferredType(cast->expr, module); types.inferredType = toString(lookupType(cast->expr, module, builtinTypes));
if (inf)
types.inferredType = toString(*inf);
TypeInfo ti{Pattern::Casts, toString(assign), types}; TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
} }
void AnyTypeSummary::visit(Scope* scope, AstStatFunction* function, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
TelemetryTypePair types; TelemetryTypePair types;
types.inferredType = toString(lookupType(function->func, module, builtinTypes)); types.inferredType = toString(lookupType(function->func, module, builtinTypes));
@ -413,25 +403,27 @@ void AnyTypeSummary::visit(Scope* scope, AstStatFunction* function, Module* modu
visit(scope, function->func->body, module, builtinTypes); visit(scope, function->func->body, module, builtinTypes);
} }
void AnyTypeSummary::visit(Scope* scope, AstStatLocalFunction* function, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatLocalFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
TelemetryTypePair types; TelemetryTypePair types;
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
if (hasVariadicAnys(scope, function->func, module, builtinTypes)) if (hasVariadicAnys(scope, function->func, module, builtinTypes))
{ {
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
TypeInfo ti{Pattern::VarAny, toString(function), types}; TypeInfo ti{Pattern::VarAny, toString(function), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
if (hasArgAnys(scope, function->func, module, builtinTypes)) if (hasArgAnys(scope, function->func, module, builtinTypes))
{ {
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
TypeInfo ti{Pattern::FuncArg, toString(function), types}; TypeInfo ti{Pattern::FuncArg, toString(function), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
if (hasAnyReturns(scope, function->func, module, builtinTypes)) if (hasAnyReturns(scope, function->func, module, builtinTypes))
{ {
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
TypeInfo ti{Pattern::FuncRet, toString(function), types}; TypeInfo ti{Pattern::FuncRet, toString(function), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
@ -440,8 +432,9 @@ void AnyTypeSummary::visit(Scope* scope, AstStatLocalFunction* function, Module*
visit(scope, function->func->body, module, builtinTypes); visit(scope, function->func->body, module, builtinTypes);
} }
void AnyTypeSummary::visit(Scope* scope, AstStatTypeAlias* alias, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatTypeAlias* alias, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
auto ctxNode = getNode(rootSrc, alias);
auto annot = lookupAnnotation(alias->type, module, builtinTypes); auto annot = lookupAnnotation(alias->type, module, builtinTypes);
if (containsAny(annot)) if (containsAny(annot))
@ -450,33 +443,34 @@ void AnyTypeSummary::visit(Scope* scope, AstStatTypeAlias* alias, Module* module
TelemetryTypePair types; TelemetryTypePair types;
types.annotatedType = toString(annot); types.annotatedType = toString(annot);
TypeInfo ti{Pattern::Alias, toString(alias), types}; TypeInfo ti{Pattern::Alias, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
void AnyTypeSummary::visit(Scope* scope, AstStatExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes) void AnyTypeSummary::visit(const Scope* scope, AstStatExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
auto ctxNode = getNode(rootSrc, expr);
if (isAnyCall(scope, expr->expr, module, builtinTypes)) if (isAnyCall(scope, expr->expr, module, builtinTypes))
{ {
TelemetryTypePair types; TelemetryTypePair types;
types.inferredType = toString(lookupType(expr->expr, module, builtinTypes)); types.inferredType = toString(lookupType(expr->expr, module, builtinTypes));
TypeInfo ti{Pattern::FuncApp, toString(expr), types}; TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti); typeInfo.push_back(ti);
} }
} }
void AnyTypeSummary::visit(Scope* scope, AstStatDeclareGlobal* declareGlobal, Module* module, NotNull<BuiltinTypes> builtinTypes) {} void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareGlobal* declareGlobal, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(Scope* scope, AstStatDeclareClass* declareClass, Module* module, NotNull<BuiltinTypes> builtinTypes) {} void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareClass* declareClass, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(Scope* scope, AstStatDeclareFunction* declareFunction, Module* module, NotNull<BuiltinTypes> builtinTypes) {} void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareFunction* declareFunction, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(Scope* scope, AstStatError* error, Module* module, NotNull<BuiltinTypes> builtinTypes) {} void AnyTypeSummary::visit(const Scope* scope, AstStatError* error, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
TypeId AnyTypeSummary::checkForFamilyInhabitance(TypeId instance, Location location) TypeId AnyTypeSummary::checkForFamilyInhabitance(const TypeId instance, const Location location)
{ {
if (seenTypeFamilyInstances.find(instance)) if (seenTypeFamilyInstances.find(instance))
return instance; return instance;
@ -485,13 +479,13 @@ TypeId AnyTypeSummary::checkForFamilyInhabitance(TypeId instance, Location locat
return instance; return instance;
} }
TypeId AnyTypeSummary::lookupType(AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes) TypeId AnyTypeSummary::lookupType(const AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
TypeId* ty = module->astTypes.find(expr); const TypeId* ty = module->astTypes.find(expr);
if (ty) if (ty)
return checkForFamilyInhabitance(follow(*ty), expr->location); return checkForFamilyInhabitance(follow(*ty), expr->location);
TypePackId* tp = module->astTypePacks.find(expr); const TypePackId* tp = module->astTypePacks.find(expr);
if (tp) if (tp)
{ {
if (auto fst = first(*tp, /*ignoreHiddenVariadics*/ false)) if (auto fst = first(*tp, /*ignoreHiddenVariadics*/ false))
@ -503,7 +497,7 @@ TypeId AnyTypeSummary::lookupType(AstExpr* expr, Module* module, NotNull<Builtin
return builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType();
} }
TypePackId AnyTypeSummary::reconstructTypePack(AstArray<AstExpr*> exprs, Module* module, NotNull<BuiltinTypes> builtinTypes) TypePackId AnyTypeSummary::reconstructTypePack(AstArray<AstExpr*> exprs, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
if (exprs.size == 0) if (exprs.size == 0)
return arena.addTypePack(TypePack{{}, std::nullopt}); return arena.addTypePack(TypePack{{}, std::nullopt});
@ -515,14 +509,14 @@ TypePackId AnyTypeSummary::reconstructTypePack(AstArray<AstExpr*> exprs, Module*
head.push_back(lookupType(exprs.data[i], module, builtinTypes)); head.push_back(lookupType(exprs.data[i], module, builtinTypes));
} }
TypePackId* tail = module->astTypePacks.find(exprs.data[exprs.size - 1]); const TypePackId* tail = module->astTypePacks.find(exprs.data[exprs.size - 1]);
if (tail) if (tail)
return arena.addTypePack(TypePack{std::move(head), follow(*tail)}); return arena.addTypePack(TypePack{std::move(head), follow(*tail)});
else else
return arena.addTypePack(TypePack{std::move(head), builtinTypes->errorRecoveryTypePack()}); return arena.addTypePack(TypePack{std::move(head), builtinTypes->errorRecoveryTypePack()});
} }
bool AnyTypeSummary::isAnyCall(Scope* scope, AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes) bool AnyTypeSummary::isAnyCall(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
if (auto call = expr->as<AstExprCall>()) if (auto call = expr->as<AstExprCall>())
{ {
@ -537,7 +531,7 @@ bool AnyTypeSummary::isAnyCall(Scope* scope, AstExpr* expr, Module* module, NotN
return false; return false;
} }
bool AnyTypeSummary::hasVariadicAnys(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes) bool AnyTypeSummary::hasVariadicAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
if (expr->vararg && expr->varargAnnotation) if (expr->vararg && expr->varargAnnotation)
{ {
@ -550,7 +544,7 @@ bool AnyTypeSummary::hasVariadicAnys(Scope* scope, AstExprFunction* expr, Module
return false; return false;
} }
bool AnyTypeSummary::hasArgAnys(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes) bool AnyTypeSummary::hasArgAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
if (expr->args.size > 0) if (expr->args.size > 0)
{ {
@ -569,7 +563,7 @@ bool AnyTypeSummary::hasArgAnys(Scope* scope, AstExprFunction* expr, Module* mod
return false; return false;
} }
bool AnyTypeSummary::hasAnyReturns(Scope* scope, AstExprFunction* expr, Module* module, NotNull<BuiltinTypes> builtinTypes) bool AnyTypeSummary::hasAnyReturns(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
if (!expr->returnAnnotation) if (!expr->returnAnnotation)
{ {
@ -596,7 +590,7 @@ bool AnyTypeSummary::hasAnyReturns(Scope* scope, AstExprFunction* expr, Module*
return false; return false;
} }
bool AnyTypeSummary::isAnyCast(Scope* scope, AstExpr* expr, Module* module, NotNull<BuiltinTypes> builtinTypes) bool AnyTypeSummary::isAnyCast(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{ {
if (auto cast = expr->as<AstExprTypeAssertion>()) if (auto cast = expr->as<AstExprTypeAssertion>())
{ {
@ -609,7 +603,7 @@ bool AnyTypeSummary::isAnyCast(Scope* scope, AstExpr* expr, Module* module, NotN
return false; return false;
} }
TypeId AnyTypeSummary::lookupAnnotation(AstType* annotation, Module* module, NotNull<BuiltinTypes> builtintypes) TypeId AnyTypeSummary::lookupAnnotation(AstType* annotation, const Module* module, NotNull<BuiltinTypes> builtintypes)
{ {
if (FFlag::DebugLuauMagicTypes) if (FFlag::DebugLuauMagicTypes)
{ {
@ -623,14 +617,14 @@ TypeId AnyTypeSummary::lookupAnnotation(AstType* annotation, Module* module, Not
} }
} }
TypeId* ty = module->astResolvedTypes.find(annotation); const TypeId* ty = module->astResolvedTypes.find(annotation);
if (ty) if (ty)
return checkForTypeFunctionInhabitance(follow(*ty), annotation->location); return checkForTypeFunctionInhabitance(follow(*ty), annotation->location);
else else
return checkForTypeFunctionInhabitance(builtintypes->errorRecoveryType(), annotation->location); return checkForTypeFunctionInhabitance(builtintypes->errorRecoveryType(), annotation->location);
} }
TypeId AnyTypeSummary::checkForTypeFunctionInhabitance(TypeId instance, Location location) TypeId AnyTypeSummary::checkForTypeFunctionInhabitance(const TypeId instance, const Location location)
{ {
if (seenTypeFunctionInstances.find(instance)) if (seenTypeFunctionInstances.find(instance))
return instance; return instance;
@ -639,9 +633,9 @@ TypeId AnyTypeSummary::checkForTypeFunctionInhabitance(TypeId instance, Location
return instance; return instance;
} }
std::optional<TypePackId> AnyTypeSummary::lookupPackAnnotation(AstTypePack* annotation, Module* module) std::optional<TypePackId> AnyTypeSummary::lookupPackAnnotation(AstTypePack* annotation, const Module* module)
{ {
TypePackId* tp = module->astResolvedTypePacks.find(annotation); const TypePackId* tp = module->astResolvedTypePacks.find(annotation);
if (tp != nullptr) if (tp != nullptr)
return {follow(*tp)}; return {follow(*tp)};
return {}; return {};
@ -786,9 +780,9 @@ bool AnyTypeSummary::containsAny(TypePackId typ)
return found; return found;
} }
Scope* AnyTypeSummary::findInnerMostScope(Location location, Module* module) const Scope* AnyTypeSummary::findInnerMostScope(const Location location, const Module* module)
{ {
Scope* bestScope = module->getModuleScope().get(); const Scope* bestScope = module->getModuleScope().get();
bool didNarrow = false; bool didNarrow = false;
do do
@ -808,6 +802,69 @@ Scope* AnyTypeSummary::findInnerMostScope(Location location, Module* module)
return bestScope; return bestScope;
} }
std::optional<AstExpr*> AnyTypeSummary::matchRequire(const AstExprCall& call)
{
const char* require = "require";
if (call.args.size != 1)
return std::nullopt;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != require)
return std::nullopt;
if (call.args.size != 1)
return std::nullopt;
return call.args.data[0];
}
AstNode* AnyTypeSummary::getNode(AstStatBlock* root, AstNode* node)
{
FindReturnAncestry finder(node, root->location.end);
root->visit(&finder);
if (!finder.currNode)
finder.currNode = node;
LUAU_ASSERT(finder.found && finder.currNode);
return finder.currNode;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstStatLocalFunction* node)
{
currNode = node;
return !found;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstStatFunction* node)
{
currNode = node;
return !found;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstType* node)
{
return !found;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstNode* node)
{
if (node == stat)
{
found = true;
}
if (node->location.end == rootEnd && stat->location.end >= rootEnd)
{
currNode = node;
found = true;
}
return !found;
}
AnyTypeSummary::TypeInfo::TypeInfo(Pattern code, std::string node, TelemetryTypePair type) AnyTypeSummary::TypeInfo::TypeInfo(Pattern code, std::string node, TelemetryTypePair type)
: code(code) : code(code)
, node(node) , node(node)
@ -815,6 +872,12 @@ AnyTypeSummary::TypeInfo::TypeInfo(Pattern code, std::string node, TelemetryType
{ {
} }
AnyTypeSummary::FindReturnAncestry::FindReturnAncestry(AstNode* stat, Position rootEnd)
: stat(stat)
, rootEnd(rootEnd)
{
}
AnyTypeSummary::AnyTypeSummary() {} AnyTypeSummary::AnyTypeSummary() {}
} // namespace Luau } // namespace Luau

View File

@ -9,8 +9,14 @@
namespace Luau namespace Luau
{ {
Anyification::Anyification(TypeArena* arena, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler, Anyification::Anyification(
TypeId anyType, TypePackId anyTypePack) TypeArena* arena,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter* iceHandler,
TypeId anyType,
TypePackId anyTypePack
)
: Substitution(TxnLog::empty(), arena) : Substitution(TxnLog::empty(), arena)
, scope(scope) , scope(scope)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
@ -20,8 +26,14 @@ Anyification::Anyification(TypeArena* arena, NotNull<Scope> scope, NotNull<Built
{ {
} }
Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler, Anyification::Anyification(
TypeId anyType, TypePackId anyTypePack) TypeArena* arena,
const ScopePtr& scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter* iceHandler,
TypeId anyType,
TypePackId anyTypePack
)
: Anyification(arena, NotNull{scope.get()}, builtinTypes, iceHandler, anyType, anyTypePack) : Anyification(arena, NotNull{scope.get()}, builtinTypes, iceHandler, anyType, anyTypePack)
{ {
} }

View File

@ -273,9 +273,14 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprGroup* node) void write(class AstExprGroup* node)
{ {
writeNode(node, "AstExprGroup", [&]() { writeNode(
node,
"AstExprGroup",
[&]()
{
write("expr", node->expr); write("expr", node->expr);
}); }
);
} }
void write(class AstExprConstantNil* node) void write(class AstExprConstantNil* node)
@ -285,37 +290,62 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprConstantBool* node) void write(class AstExprConstantBool* node)
{ {
writeNode(node, "AstExprConstantBool", [&]() { writeNode(
node,
"AstExprConstantBool",
[&]()
{
write("value", node->value); write("value", node->value);
}); }
);
} }
void write(class AstExprConstantNumber* node) void write(class AstExprConstantNumber* node)
{ {
writeNode(node, "AstExprConstantNumber", [&]() { writeNode(
node,
"AstExprConstantNumber",
[&]()
{
write("value", node->value); write("value", node->value);
}); }
);
} }
void write(class AstExprConstantString* node) void write(class AstExprConstantString* node)
{ {
writeNode(node, "AstExprConstantString", [&]() { writeNode(
node,
"AstExprConstantString",
[&]()
{
write("value", node->value); write("value", node->value);
}); }
);
} }
void write(class AstExprLocal* node) void write(class AstExprLocal* node)
{ {
writeNode(node, "AstExprLocal", [&]() { writeNode(
node,
"AstExprLocal",
[&]()
{
write("local", node->local); write("local", node->local);
}); }
);
} }
void write(class AstExprGlobal* node) void write(class AstExprGlobal* node)
{ {
writeNode(node, "AstExprGlobal", [&]() { writeNode(
node,
"AstExprGlobal",
[&]()
{
write("global", node->name); write("global", node->name);
}); }
);
} }
void write(class AstExprVarargs* node) void write(class AstExprVarargs* node)
@ -349,35 +379,54 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprCall* node) void write(class AstExprCall* node)
{ {
writeNode(node, "AstExprCall", [&]() { writeNode(
node,
"AstExprCall",
[&]()
{
PROP(func); PROP(func);
PROP(args); PROP(args);
PROP(self); PROP(self);
PROP(argLocation); PROP(argLocation);
}); }
);
} }
void write(class AstExprIndexName* node) void write(class AstExprIndexName* node)
{ {
writeNode(node, "AstExprIndexName", [&]() { writeNode(
node,
"AstExprIndexName",
[&]()
{
PROP(expr); PROP(expr);
PROP(index); PROP(index);
PROP(indexLocation); PROP(indexLocation);
PROP(op); PROP(op);
}); }
);
} }
void write(class AstExprIndexExpr* node) void write(class AstExprIndexExpr* node)
{ {
writeNode(node, "AstExprIndexExpr", [&]() { writeNode(
node,
"AstExprIndexExpr",
[&]()
{
PROP(expr); PROP(expr);
PROP(index); PROP(index);
}); }
);
} }
void write(class AstExprFunction* node) void write(class AstExprFunction* node)
{ {
writeNode(node, "AstExprFunction", [&]() { writeNode(
node,
"AstExprFunction",
[&]()
{
PROP(generics); PROP(generics);
PROP(genericPacks); PROP(genericPacks);
if (node->self) if (node->self)
@ -393,7 +442,8 @@ struct AstJsonEncoder : public AstVisitor
PROP(body); PROP(body);
PROP(functionDepth); PROP(functionDepth);
PROP(debugname); PROP(debugname);
}); }
);
} }
void write(const std::optional<AstTypeList>& typeList) void write(const std::optional<AstTypeList>& typeList)
@ -475,28 +525,43 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprIfElse* node) void write(class AstExprIfElse* node)
{ {
writeNode(node, "AstExprIfElse", [&]() { writeNode(
node,
"AstExprIfElse",
[&]()
{
PROP(condition); PROP(condition);
PROP(hasThen); PROP(hasThen);
PROP(trueExpr); PROP(trueExpr);
PROP(hasElse); PROP(hasElse);
PROP(falseExpr); PROP(falseExpr);
}); }
);
} }
void write(class AstExprInterpString* node) void write(class AstExprInterpString* node)
{ {
writeNode(node, "AstExprInterpString", [&]() { writeNode(
node,
"AstExprInterpString",
[&]()
{
PROP(strings); PROP(strings);
PROP(expressions); PROP(expressions);
}); }
);
} }
void write(class AstExprTable* node) void write(class AstExprTable* node)
{ {
writeNode(node, "AstExprTable", [&]() { writeNode(
node,
"AstExprTable",
[&]()
{
PROP(items); PROP(items);
}); }
);
} }
void write(AstExprUnary::Op op) void write(AstExprUnary::Op op)
@ -514,10 +579,15 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprUnary* node) void write(class AstExprUnary* node)
{ {
writeNode(node, "AstExprUnary", [&]() { writeNode(
node,
"AstExprUnary",
[&]()
{
PROP(op); PROP(op);
PROP(expr); PROP(expr);
}); }
);
} }
void write(AstExprBinary::Op op) void write(AstExprBinary::Op op)
@ -563,32 +633,51 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstExprBinary* node) void write(class AstExprBinary* node)
{ {
writeNode(node, "AstExprBinary", [&]() { writeNode(
node,
"AstExprBinary",
[&]()
{
PROP(op); PROP(op);
PROP(left); PROP(left);
PROP(right); PROP(right);
}); }
);
} }
void write(class AstExprTypeAssertion* node) void write(class AstExprTypeAssertion* node)
{ {
writeNode(node, "AstExprTypeAssertion", [&]() { writeNode(
node,
"AstExprTypeAssertion",
[&]()
{
PROP(expr); PROP(expr);
PROP(annotation); PROP(annotation);
}); }
);
} }
void write(class AstExprError* node) void write(class AstExprError* node)
{ {
writeNode(node, "AstExprError", [&]() { writeNode(
node,
"AstExprError",
[&]()
{
PROP(expressions); PROP(expressions);
PROP(messageIndex); PROP(messageIndex);
}); }
);
} }
void write(class AstStatBlock* node) void write(class AstStatBlock* node)
{ {
writeNode(node, "AstStatBlock", [&]() { writeNode(
node,
"AstStatBlock",
[&]()
{
writeRaw(",\"hasEnd\":"); writeRaw(",\"hasEnd\":");
write(node->hasEnd); write(node->hasEnd);
writeRaw(",\"body\":["); writeRaw(",\"body\":[");
@ -603,35 +692,51 @@ struct AstJsonEncoder : public AstVisitor
write(stat); write(stat);
} }
writeRaw("]"); writeRaw("]");
}); }
);
} }
void write(class AstStatIf* node) void write(class AstStatIf* node)
{ {
writeNode(node, "AstStatIf", [&]() { writeNode(
node,
"AstStatIf",
[&]()
{
PROP(condition); PROP(condition);
PROP(thenbody); PROP(thenbody);
if (node->elsebody) if (node->elsebody)
PROP(elsebody); PROP(elsebody);
write("hasThen", node->thenLocation.has_value()); write("hasThen", node->thenLocation.has_value());
}); }
);
} }
void write(class AstStatWhile* node) void write(class AstStatWhile* node)
{ {
writeNode(node, "AstStatWhile", [&]() { writeNode(
node,
"AstStatWhile",
[&]()
{
PROP(condition); PROP(condition);
PROP(body); PROP(body);
PROP(hasDo); PROP(hasDo);
}); }
);
} }
void write(class AstStatRepeat* node) void write(class AstStatRepeat* node)
{ {
writeNode(node, "AstStatRepeat", [&]() { writeNode(
node,
"AstStatRepeat",
[&]()
{
PROP(condition); PROP(condition);
PROP(body); PROP(body);
}); }
);
} }
void write(class AstStatBreak* node) void write(class AstStatBreak* node)
@ -646,29 +751,48 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstStatReturn* node) void write(class AstStatReturn* node)
{ {
writeNode(node, "AstStatReturn", [&]() { writeNode(
node,
"AstStatReturn",
[&]()
{
PROP(list); PROP(list);
}); }
);
} }
void write(class AstStatExpr* node) void write(class AstStatExpr* node)
{ {
writeNode(node, "AstStatExpr", [&]() { writeNode(
node,
"AstStatExpr",
[&]()
{
PROP(expr); PROP(expr);
}); }
);
} }
void write(class AstStatLocal* node) void write(class AstStatLocal* node)
{ {
writeNode(node, "AstStatLocal", [&]() { writeNode(
node,
"AstStatLocal",
[&]()
{
PROP(vars); PROP(vars);
PROP(values); PROP(values);
}); }
);
} }
void write(class AstStatFor* node) void write(class AstStatFor* node)
{ {
writeNode(node, "AstStatFor", [&]() { writeNode(
node,
"AstStatFor",
[&]()
{
PROP(var); PROP(var);
PROP(from); PROP(from);
PROP(to); PROP(to);
@ -676,67 +800,102 @@ struct AstJsonEncoder : public AstVisitor
PROP(step); PROP(step);
PROP(body); PROP(body);
PROP(hasDo); PROP(hasDo);
}); }
);
} }
void write(class AstStatForIn* node) void write(class AstStatForIn* node)
{ {
writeNode(node, "AstStatForIn", [&]() { writeNode(
node,
"AstStatForIn",
[&]()
{
PROP(vars); PROP(vars);
PROP(values); PROP(values);
PROP(body); PROP(body);
PROP(hasIn); PROP(hasIn);
PROP(hasDo); PROP(hasDo);
}); }
);
} }
void write(class AstStatAssign* node) void write(class AstStatAssign* node)
{ {
writeNode(node, "AstStatAssign", [&]() { writeNode(
node,
"AstStatAssign",
[&]()
{
PROP(vars); PROP(vars);
PROP(values); PROP(values);
}); }
);
} }
void write(class AstStatCompoundAssign* node) void write(class AstStatCompoundAssign* node)
{ {
writeNode(node, "AstStatCompoundAssign", [&]() { writeNode(
node,
"AstStatCompoundAssign",
[&]()
{
PROP(op); PROP(op);
PROP(var); PROP(var);
PROP(value); PROP(value);
}); }
);
} }
void write(class AstStatFunction* node) void write(class AstStatFunction* node)
{ {
writeNode(node, "AstStatFunction", [&]() { writeNode(
node,
"AstStatFunction",
[&]()
{
PROP(name); PROP(name);
PROP(func); PROP(func);
}); }
);
} }
void write(class AstStatLocalFunction* node) void write(class AstStatLocalFunction* node)
{ {
writeNode(node, "AstStatLocalFunction", [&]() { writeNode(
node,
"AstStatLocalFunction",
[&]()
{
PROP(name); PROP(name);
PROP(func); PROP(func);
}); }
);
} }
void write(class AstStatTypeAlias* node) void write(class AstStatTypeAlias* node)
{ {
writeNode(node, "AstStatTypeAlias", [&]() { writeNode(
node,
"AstStatTypeAlias",
[&]()
{
PROP(name); PROP(name);
PROP(generics); PROP(generics);
PROP(genericPacks); PROP(genericPacks);
PROP(type); PROP(type);
PROP(exported); PROP(exported);
}); }
);
} }
void write(class AstStatDeclareFunction* node) void write(class AstStatDeclareFunction* node)
{ {
writeNode(node, "AstStatDeclareFunction", [&]() { writeNode(
node,
"AstStatDeclareFunction",
[&]()
{
// TODO: attributes // TODO: attributes
PROP(name); PROP(name);
@ -755,19 +914,25 @@ struct AstJsonEncoder : public AstVisitor
PROP(retTypes); PROP(retTypes);
PROP(generics); PROP(generics);
PROP(genericPacks); PROP(genericPacks);
}); }
);
} }
void write(class AstStatDeclareGlobal* node) void write(class AstStatDeclareGlobal* node)
{ {
writeNode(node, "AstStatDeclareGlobal", [&]() { writeNode(
node,
"AstStatDeclareGlobal",
[&]()
{
PROP(name); PROP(name);
if (FFlag::LuauDeclarationExtraPropData) if (FFlag::LuauDeclarationExtraPropData)
PROP(nameLocation); PROP(nameLocation);
PROP(type); PROP(type);
}); }
);
} }
void write(const AstDeclaredClassProp& prop) void write(const AstDeclaredClassProp& prop)
@ -791,21 +956,31 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstStatDeclareClass* node) void write(class AstStatDeclareClass* node)
{ {
writeNode(node, "AstStatDeclareClass", [&]() { writeNode(
node,
"AstStatDeclareClass",
[&]()
{
PROP(name); PROP(name);
if (node->superName) if (node->superName)
write("superName", *node->superName); write("superName", *node->superName);
PROP(props); PROP(props);
PROP(indexer); PROP(indexer);
}); }
);
} }
void write(class AstStatError* node) void write(class AstStatError* node)
{ {
writeNode(node, "AstStatError", [&]() { writeNode(
node,
"AstStatError",
[&]()
{
PROP(expressions); PROP(expressions);
PROP(statements); PROP(statements);
}); }
);
} }
void write(struct AstTypeOrPack node) void write(struct AstTypeOrPack node)
@ -818,7 +993,11 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstTypeReference* node) void write(class AstTypeReference* node)
{ {
writeNode(node, "AstTypeReference", [&]() { writeNode(
node,
"AstTypeReference",
[&]()
{
if (node->prefix) if (node->prefix)
PROP(prefix); PROP(prefix);
if (node->prefixLocation) if (node->prefixLocation)
@ -826,7 +1005,8 @@ struct AstJsonEncoder : public AstVisitor
PROP(name); PROP(name);
PROP(nameLocation); PROP(nameLocation);
PROP(parameters); PROP(parameters);
}); }
);
} }
void write(const AstTableProp& prop) void write(const AstTableProp& prop)
@ -845,10 +1025,15 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstTypeTable* node) void write(class AstTypeTable* node)
{ {
writeNode(node, "AstTypeTable", [&]() { writeNode(
node,
"AstTypeTable",
[&]()
{
PROP(props); PROP(props);
PROP(indexer); PROP(indexer);
}); }
);
} }
void write(struct AstTableIndexer* indexer) void write(struct AstTableIndexer* indexer)
@ -871,78 +1056,128 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstTypeFunction* node) void write(class AstTypeFunction* node)
{ {
writeNode(node, "AstTypeFunction", [&]() { writeNode(
node,
"AstTypeFunction",
[&]()
{
PROP(generics); PROP(generics);
PROP(genericPacks); PROP(genericPacks);
PROP(argTypes); PROP(argTypes);
PROP(argNames); PROP(argNames);
PROP(returnTypes); PROP(returnTypes);
}); }
);
} }
void write(class AstTypeTypeof* node) void write(class AstTypeTypeof* node)
{ {
writeNode(node, "AstTypeTypeof", [&]() { writeNode(
node,
"AstTypeTypeof",
[&]()
{
PROP(expr); PROP(expr);
}); }
);
} }
void write(class AstTypeUnion* node) void write(class AstTypeUnion* node)
{ {
writeNode(node, "AstTypeUnion", [&]() { writeNode(
node,
"AstTypeUnion",
[&]()
{
PROP(types); PROP(types);
}); }
);
} }
void write(class AstTypeIntersection* node) void write(class AstTypeIntersection* node)
{ {
writeNode(node, "AstTypeIntersection", [&]() { writeNode(
node,
"AstTypeIntersection",
[&]()
{
PROP(types); PROP(types);
}); }
);
} }
void write(class AstTypeError* node) void write(class AstTypeError* node)
{ {
writeNode(node, "AstTypeError", [&]() { writeNode(
node,
"AstTypeError",
[&]()
{
PROP(types); PROP(types);
PROP(messageIndex); PROP(messageIndex);
}); }
);
} }
void write(class AstTypePackExplicit* node) void write(class AstTypePackExplicit* node)
{ {
writeNode(node, "AstTypePackExplicit", [&]() { writeNode(
node,
"AstTypePackExplicit",
[&]()
{
PROP(typeList); PROP(typeList);
}); }
);
} }
void write(class AstTypePackVariadic* node) void write(class AstTypePackVariadic* node)
{ {
writeNode(node, "AstTypePackVariadic", [&]() { writeNode(
node,
"AstTypePackVariadic",
[&]()
{
PROP(variadicType); PROP(variadicType);
}); }
);
} }
void write(class AstTypePackGeneric* node) void write(class AstTypePackGeneric* node)
{ {
writeNode(node, "AstTypePackGeneric", [&]() { writeNode(
node,
"AstTypePackGeneric",
[&]()
{
PROP(genericName); PROP(genericName);
}); }
);
} }
bool visit(class AstTypeSingletonBool* node) override bool visit(class AstTypeSingletonBool* node) override
{ {
writeNode(node, "AstTypeSingletonBool", [&]() { writeNode(
node,
"AstTypeSingletonBool",
[&]()
{
write("value", node->value); write("value", node->value);
}); }
);
return false; return false;
} }
bool visit(class AstTypeSingletonString* node) override bool visit(class AstTypeSingletonString* node) override
{ {
writeNode(node, "AstTypeSingletonString", [&]() { writeNode(
node,
"AstTypeSingletonString",
[&]()
{
write("value", node->value); write("value", node->value);
}); }
);
return false; return false;
} }

View File

@ -331,9 +331,14 @@ static std::optional<AstStatLocal*> findBindingLocalStatement(const SourceModule
return std::nullopt; return std::nullopt;
std::vector<AstNode*> nodes = findAstAncestryOfPosition(source, binding.location.begin); std::vector<AstNode*> nodes = findAstAncestryOfPosition(source, binding.location.begin);
auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) { auto iter = std::find_if(
nodes.rbegin(),
nodes.rend(),
[](AstNode* node)
{
return node->is<AstStatLocal>(); return node->is<AstStatLocal>();
}); }
);
return iter != nodes.rend() ? std::make_optional((*iter)->as<AstStatLocal>()) : std::nullopt; return iter != nodes.rend() ? std::make_optional((*iter)->as<AstStatLocal>()) : std::nullopt;
} }
@ -472,7 +477,11 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos)
} }
static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol( static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol(
const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional<DocumentationSymbol> documentationSymbol) const Module& module,
const TypeId ty,
const AstExpr* parentExpr,
const std::optional<DocumentationSymbol> documentationSymbol
)
{ {
if (!documentationSymbol) if (!documentationSymbol)
return std::nullopt; return std::nullopt;

View File

@ -15,8 +15,8 @@
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
static const std::unordered_set<std::string> kStatementStartingKeywords = { static const std::unordered_set<std::string> kStatementStartingKeywords =
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
namespace Luau namespace Luau
{ {
@ -161,7 +161,13 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
} }
static TypeCorrectKind checkTypeCorrectKind( static TypeCorrectKind checkTypeCorrectKind(
const Module& module, TypeArena* typeArena, NotNull<BuiltinTypes> builtinTypes, AstNode* node, Position position, TypeId ty) const Module& module,
TypeArena* typeArena,
NotNull<BuiltinTypes> builtinTypes,
AstNode* node,
Position position,
TypeId ty
)
{ {
ty = follow(ty); ty = follow(ty);
@ -176,7 +182,8 @@ static TypeCorrectKind checkTypeCorrectKind(
TypeId expectedType = follow(*typeAtPosition); TypeId expectedType = follow(*typeAtPosition);
auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) { auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv)
{
if (std::optional<TypeId> firstRetTy = first(ftv->retTypes)) if (std::optional<TypeId> firstRetTy = first(ftv->retTypes))
return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes);
@ -209,9 +216,18 @@ enum class PropIndexType
Key, Key,
}; };
static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull<BuiltinTypes> builtinTypes, TypeId rootTy, TypeId ty, static void autocompleteProps(
PropIndexType indexType, const std::vector<AstNode*>& nodes, AutocompleteEntryMap& result, std::unordered_set<TypeId>& seen, const Module& module,
std::optional<const ClassType*> containingClass = std::nullopt) TypeArena* typeArena,
NotNull<BuiltinTypes> builtinTypes,
TypeId rootTy,
TypeId ty,
PropIndexType indexType,
const std::vector<AstNode*>& nodes,
AutocompleteEntryMap& result,
std::unordered_set<TypeId>& seen,
std::optional<const ClassType*> containingClass = std::nullopt
)
{ {
rootTy = follow(rootTy); rootTy = follow(rootTy);
ty = follow(ty); ty = follow(ty);
@ -220,13 +236,15 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
return; return;
seen.insert(ty); seen.insert(ty);
auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) { auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type)
{
if (indexType == PropIndexType::Key) if (indexType == PropIndexType::Key)
return false; return false;
bool calledWithSelf = indexType == PropIndexType::Colon; bool calledWithSelf = indexType == PropIndexType::Colon;
auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) { auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv)
{
// Strong match with definition is a success // Strong match with definition is a success
if (calledWithSelf == ftv->hasSelf) if (calledWithSelf == ftv->hasSelf)
return true; return true;
@ -265,7 +283,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
return calledWithSelf; return calledWithSelf;
}; };
auto fillProps = [&](const ClassType::Props& props) { auto fillProps = [&](const ClassType::Props& props)
{
for (const auto& [name, prop] : props) for (const auto& [name, prop] : props)
{ {
// We are walking up the class hierarchy, so if we encounter a property that we have // We are walking up the class hierarchy, so if we encounter a property that we have
@ -291,13 +310,26 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
ParenthesesRecommendation parens = ParenthesesRecommendation parens =
indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect);
result[name] = AutocompleteEntry{AutocompleteEntryKind::Property, type, prop.deprecated, isWrongIndexer(type), typeCorrect, result[name] = AutocompleteEntry{
containingClass, &prop, prop.documentationSymbol, {}, parens, {}, indexType == PropIndexType::Colon}; AutocompleteEntryKind::Property,
type,
prop.deprecated,
isWrongIndexer(type),
typeCorrect,
containingClass,
&prop,
prop.documentationSymbol,
{},
parens,
{},
indexType == PropIndexType::Colon
};
} }
} }
}; };
auto fillMetatableProps = [&](const TableType* mtable) { auto fillMetatableProps = [&](const TableType* mtable)
{
auto indexIt = mtable->props.find("__index"); auto indexIt = mtable->props.find("__index");
if (indexIt != mtable->props.end()) if (indexIt != mtable->props.end())
{ {
@ -409,7 +441,11 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
} }
static void autocompleteKeywords( static void autocompleteKeywords(
const SourceModule& sourceModule, const std::vector<AstNode*>& ancestry, Position position, AutocompleteEntryMap& result) const SourceModule& sourceModule,
const std::vector<AstNode*>& ancestry,
Position position,
AutocompleteEntryMap& result
)
{ {
LUAU_ASSERT(!ancestry.empty()); LUAU_ASSERT(!ancestry.empty());
@ -429,15 +465,28 @@ static void autocompleteKeywords(
} }
} }
static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull<BuiltinTypes> builtinTypes, TypeId ty, PropIndexType indexType, static void autocompleteProps(
const std::vector<AstNode*>& nodes, AutocompleteEntryMap& result) const Module& module,
TypeArena* typeArena,
NotNull<BuiltinTypes> builtinTypes,
TypeId ty,
PropIndexType indexType,
const std::vector<AstNode*>& nodes,
AutocompleteEntryMap& result
)
{ {
std::unordered_set<TypeId> seen; std::unordered_set<TypeId> seen;
autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen); autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen);
} }
AutocompleteEntryMap autocompleteProps(const Module& module, TypeArena* typeArena, NotNull<BuiltinTypes> builtinTypes, TypeId ty, AutocompleteEntryMap autocompleteProps(
PropIndexType indexType, const std::vector<AstNode*>& nodes) const Module& module,
TypeArena* typeArena,
NotNull<BuiltinTypes> builtinTypes,
TypeId ty,
PropIndexType indexType,
const std::vector<AstNode*>& nodes
)
{ {
AutocompleteEntryMap result; AutocompleteEntryMap result;
autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result); autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result);
@ -472,7 +521,8 @@ static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node
return; return;
} }
auto formatKey = [addQuotes](const std::string& key) { auto formatKey = [addQuotes](const std::string& key)
{
if (addQuotes) if (addQuotes)
return "\"" + escape(key) + "\""; return "\"" + escape(key) + "\"";
@ -705,9 +755,14 @@ static std::optional<bool> functionIsExpectedAt(const Module& module, AstNode* n
if (const IntersectionType* itv = get<IntersectionType>(expectedType)) if (const IntersectionType* itv = get<IntersectionType>(expectedType))
{ {
return std::all_of(begin(itv->parts), end(itv->parts), [](auto&& ty) { return std::all_of(
begin(itv->parts),
end(itv->parts),
[](auto&& ty)
{
return get<FunctionType>(Luau::follow(ty)) != nullptr; return get<FunctionType>(Luau::follow(ty)) != nullptr;
}); }
);
} }
if (const UnionType* utv = get<UnionType>(expectedType)) if (const UnionType* utv = get<UnionType>(expectedType))
@ -727,15 +782,31 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
for (const auto& [name, ty] : scope->exportedTypeBindings) for (const auto& [name, ty] : scope->exportedTypeBindings)
{ {
if (!result.count(name)) if (!result.count(name))
result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type, false, false, TypeCorrectKind::None, std::nullopt, result[name] = AutocompleteEntry{
std::nullopt, ty.type->documentationSymbol}; AutocompleteEntryKind::Type,
ty.type,
false,
false,
TypeCorrectKind::None,
std::nullopt,
std::nullopt,
ty.type->documentationSymbol
};
} }
for (const auto& [name, ty] : scope->privateTypeBindings) for (const auto& [name, ty] : scope->privateTypeBindings)
{ {
if (!result.count(name)) if (!result.count(name))
result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type, false, false, TypeCorrectKind::None, std::nullopt, result[name] = AutocompleteEntry{
std::nullopt, ty.type->documentationSymbol}; AutocompleteEntryKind::Type,
ty.type,
false,
false,
TypeCorrectKind::None,
std::nullopt,
std::nullopt,
ty.type->documentationSymbol
};
} }
for (const auto& [name, _] : scope->importedTypeBindings) for (const auto& [name, _] : scope->importedTypeBindings)
@ -825,7 +896,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
else if (AstExprFunction* node = parent->as<AstExprFunction>()) else if (AstExprFunction* node = parent->as<AstExprFunction>())
{ {
// For lookup inside expected function type if that's available // For lookup inside expected function type if that's available
auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* { auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType*
{
auto it = module.astExpectedTypes.find(expr); auto it = module.astExpectedTypes.find(expr);
if (!it) if (!it)
@ -1029,7 +1101,11 @@ static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding&
} }
static AutocompleteEntryMap autocompleteStatement( static AutocompleteEntryMap autocompleteStatement(
const SourceModule& sourceModule, const Module& module, const std::vector<AstNode*>& ancestry, Position position) const SourceModule& sourceModule,
const Module& module,
const std::vector<AstNode*>& ancestry,
Position position
)
{ {
// This is inefficient. :( // This is inefficient. :(
ScopePtr scope = findScopeAtPosition(module, position); ScopePtr scope = findScopeAtPosition(module, position);
@ -1051,8 +1127,18 @@ static AutocompleteEntryMap autocompleteStatement(
std::string n = toString(name); std::string n = toString(name);
if (!result.count(n)) if (!result.count(n))
result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, TypeCorrectKind::None, std::nullopt, result[n] = {
std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None)}; AutocompleteEntryKind::Binding,
binding.typeId,
binding.deprecated,
false,
TypeCorrectKind::None,
std::nullopt,
std::nullopt,
binding.documentationSymbol,
{},
getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None)
};
} }
scope = scope->parent; scope = scope->parent;
@ -1122,7 +1208,11 @@ static AutocompleteEntryMap autocompleteStatement(
// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) // Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`)
static bool autocompleteIfElseExpression( static bool autocompleteIfElseExpression(
const AstNode* node, const std::vector<AstNode*>& ancestry, const Position& position, AutocompleteEntryMap& outResult) const AstNode* node,
const std::vector<AstNode*>& ancestry,
const Position& position,
AutocompleteEntryMap& outResult
)
{ {
AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr;
if (!parent) if (!parent)
@ -1161,8 +1251,15 @@ static bool autocompleteIfElseExpression(
} }
} }
static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull<BuiltinTypes> builtinTypes, static AutocompleteContext autocompleteExpression(
TypeArena* typeArena, const std::vector<AstNode*>& ancestry, Position position, AutocompleteEntryMap& result) const SourceModule& sourceModule,
const Module& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
const std::vector<AstNode*>& ancestry,
Position position,
AutocompleteEntryMap& result
)
{ {
LUAU_ASSERT(!ancestry.empty()); LUAU_ASSERT(!ancestry.empty());
@ -1197,8 +1294,18 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu
{ {
TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId); TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId);
result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, result[n] = {
binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; AutocompleteEntryKind::Binding,
binding.typeId,
binding.deprecated,
false,
typeCorrect,
std::nullopt,
std::nullopt,
binding.documentationSymbol,
{},
getParenRecommendation(binding.typeId, ancestry, typeCorrect)
};
} }
} }
@ -1225,8 +1332,14 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu
return AutocompleteContext::Expression; return AutocompleteContext::Expression;
} }
static AutocompleteResult autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull<BuiltinTypes> builtinTypes, static AutocompleteResult autocompleteExpression(
TypeArena* typeArena, const std::vector<AstNode*>& ancestry, Position position) const SourceModule& sourceModule,
const Module& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
const std::vector<AstNode*>& ancestry,
Position position
)
{ {
AutocompleteEntryMap result; AutocompleteEntryMap result;
AutocompleteContext context = autocompleteExpression(sourceModule, module, builtinTypes, typeArena, ancestry, position, result); AutocompleteContext context = autocompleteExpression(sourceModule, module, builtinTypes, typeArena, ancestry, position, result);
@ -1312,8 +1425,13 @@ static std::optional<std::string> getStringContents(const AstNode* node)
} }
} }
static std::optional<AutocompleteEntryMap> autocompleteStringParams(const SourceModule& sourceModule, const ModulePtr& module, static std::optional<AutocompleteEntryMap> autocompleteStringParams(
const std::vector<AstNode*>& nodes, Position position, StringCompletionCallback callback) const SourceModule& sourceModule,
const ModulePtr& module,
const std::vector<AstNode*>& nodes,
Position position,
StringCompletionCallback callback
)
{ {
if (nodes.size() < 2) if (nodes.size() < 2)
{ {
@ -1354,7 +1472,8 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(const Source
std::optional<std::string> candidateString = getStringContents(nodes.back()); std::optional<std::string> candidateString = getStringContents(nodes.back());
auto performCallback = [&](const FunctionType* funcType) -> std::optional<AutocompleteEntryMap> { auto performCallback = [&](const FunctionType* funcType) -> std::optional<AutocompleteEntryMap>
{
for (const std::string& tag : funcType->tags) for (const std::string& tag : funcType->tags)
{ {
if (std::optional<AutocompleteEntryMap> ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) if (std::optional<AutocompleteEntryMap> ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString))
@ -1463,7 +1582,11 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func
} }
static std::optional<AutocompleteEntry> makeAnonymousAutofilled( static std::optional<AutocompleteEntry> makeAnonymousAutofilled(
const ModulePtr& module, Position position, const AstNode* node, const std::vector<AstNode*>& ancestry) const ModulePtr& module,
Position position,
const AstNode* node,
const std::vector<AstNode*>& ancestry
)
{ {
const AstExprCall* call = node->as<AstExprCall>(); const AstExprCall* call = node->as<AstExprCall>();
if (!call && ancestry.size() > 1) if (!call && ancestry.size() > 1)
@ -1530,8 +1653,15 @@ static std::optional<AutocompleteEntry> makeAnonymousAutofilled(
return std::make_optional(std::move(entry)); return std::make_optional(std::move(entry));
} }
static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull<BuiltinTypes> builtinTypes, static AutocompleteResult autocomplete(
TypeArena* typeArena, Scope* globalScope, Position position, StringCompletionCallback callback) const SourceModule& sourceModule,
const ModulePtr& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
Scope* globalScope,
Position position,
StringCompletionCallback callback
)
{ {
if (isWithinComment(sourceModule, position)) if (isWithinComment(sourceModule, position))
return {}; return {};
@ -1662,8 +1792,11 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
} }
else if (AstStatIf* statIf = node->as<AstStatIf>(); statIf && !statIf->elseLocation.has_value()) else if (AstStatIf* statIf = node->as<AstStatIf>(); statIf && !statIf->elseLocation.has_value())
{ {
return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, return {
ancestry, AutocompleteContext::Keyword}; {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}},
ancestry,
AutocompleteContext::Keyword
};
} }
else if (AstStatIf* statIf = parent->as<AstStatIf>(); statIf && node->is<AstStatBlock>()) else if (AstStatIf* statIf = parent->as<AstStatIf>(); statIf && node->is<AstStatBlock>())
{ {

View File

@ -29,15 +29,35 @@ namespace Luau
{ {
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect( static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate); TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable( static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate); TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert( static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate); TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionPack( static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate); TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire( static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate); TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
);
static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionSelect(MagicFunctionCallContext context);
@ -61,26 +81,51 @@ TypeId makeOption(NotNull<BuiltinTypes> builtinTypes, TypeArena& arena, TypeId t
} }
TypeId makeFunction( TypeId makeFunction(
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes, bool checked) TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<TypeId> retTypes,
bool checked
)
{ {
return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes, checked); return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes, checked);
} }
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, TypeId makeFunction(
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes, bool checked) TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<TypeId> retTypes,
bool checked
)
{ {
return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes, checked); return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes, checked);
} }
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, TypeId makeFunction(
std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes, bool checked) TypeArena& arena,
std::optional<TypeId> selfType,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes,
bool checked
)
{ {
return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes, checked); return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes, checked);
} }
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, TypeId makeFunction(
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, TypeArena& arena,
std::initializer_list<TypeId> retTypes, bool checked) std::optional<TypeId> selfType,
std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes,
bool checked
)
{ {
std::vector<TypeId> params; std::vector<TypeId> params;
if (selfType) if (selfType)
@ -219,7 +264,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
builtinTypeFunctions().addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()}); builtinTypeFunctions().addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()});
LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile( LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(
globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete); globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete
);
LUAU_ASSERT(loadResult.success); LUAU_ASSERT(loadResult.success);
TypeId genericK = arena.addType(GenericType{"K"}); TypeId genericK = arena.addType(GenericType{"K"});
@ -313,10 +359,12 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
// declare function assert<T>(value: T, errorMessage: string?): intersect<T, ~(false?)> // declare function assert<T>(value: T, errorMessage: string?): intersect<T, ~(false?)>
TypeId genericT = arena.addType(GenericType{"T"}); TypeId genericT = arena.addType(GenericType{"T"});
TypeId refinedTy = arena.addType(TypeFunctionInstanceType{ TypeId refinedTy = arena.addType(TypeFunctionInstanceType{
NotNull{&builtinTypeFunctions().intersectFunc}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}}); NotNull{&builtinTypeFunctions().intersectFunc}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}
});
TypeId assertTy = arena.addType(FunctionType{ TypeId assertTy = arena.addType(FunctionType{
{genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})}); {genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})
});
addGlobalBinding(globals, "assert", assertTy, "@luau"); addGlobalBinding(globals, "assert", assertTy, "@luau");
} }
@ -380,7 +428,11 @@ static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes,
} }
std::optional<WithPredicate<TypePackId>> magicFunctionFormat( std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{ {
auto [paramPack, _predicates] = withPredicate; auto [paramPack, _predicates] = withPredicate;
@ -529,7 +581,11 @@ static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch( static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{ {
auto [paramPack, _predicates] = withPredicate; auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack); const auto& [params, tail] = flatten(paramPack);
@ -594,7 +650,11 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch( static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{ {
auto [paramPack, _predicates] = withPredicate; auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack); const auto& [params, tail] = flatten(paramPack);
@ -666,7 +726,11 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionFind( static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{ {
auto [paramPack, _predicates] = withPredicate; auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack); const auto& [params, tail] = flatten(paramPack);
@ -804,9 +868,11 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
const TypeId replArgType = const TypeId replArgType = arena->addType(UnionType{
arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), {stringType,
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}}); arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}
});
const TypeId gsubFunc = const TypeId gsubFunc =
makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
const TypeId gmatchFunc = const TypeId gmatchFunc =
@ -815,14 +881,17 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
FunctionType matchFuncTy{ FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}; arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})
};
matchFuncTy.isCheckedFunction = true; matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy); const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch); attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), FunctionType findFuncTy{
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}; arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})
};
findFuncTy.isCheckedFunction = true; findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy); const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind); attachMagicFunction(findFunc, magicFunctionFind);
@ -857,10 +926,19 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
{"reverse", {stringToStringType}}, {"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}},
{"upper", {stringToStringType}}, {"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, {"split",
{makeFunction(
*arena,
stringType,
{},
{},
{optionalString},
{},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})}, {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})},
/* checked */ true)}}, /* checked */ true
{"pack", {arena->addType(FunctionType{ )}},
{"pack",
{arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}), arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack, oneStringPack,
})}}, })}},
@ -879,7 +957,11 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect( static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{ {
auto [paramPack, _predicates] = withPredicate; auto [paramPack, _predicates] = withPredicate;
@ -965,7 +1047,11 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context)
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable( static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{ {
auto [paramPack, _predicates] = withPredicate; auto [paramPack, _predicates] = withPredicate;
@ -1043,7 +1129,11 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert( static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{ {
auto [paramPack, predicates] = withPredicate; auto [paramPack, predicates] = withPredicate;
@ -1073,7 +1163,11 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionPack( static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{ {
auto [paramPack, _predicates] = withPredicate; auto [paramPack, _predicates] = withPredicate;
@ -1174,7 +1268,11 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire( static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker,
const ScopePtr& scope,
const AstExprCall& expr,
WithPredicate<TypePackId> withPredicate
)
{ {
TypeArena& arena = typechecker.currentModule->internalTypes; TypeArena& arena = typechecker.currentModule->internalTypes;

View File

@ -227,19 +227,23 @@ private:
void cloneChildren(TypeId ty) void cloneChildren(TypeId ty)
{ {
return visit( return visit(
[&](auto&& t) { [&](auto&& t)
{
return cloneChildren(&t); return cloneChildren(&t);
}, },
asMutable(ty)->ty); asMutable(ty)->ty
);
} }
void cloneChildren(TypePackId tp) void cloneChildren(TypePackId tp)
{ {
return visit( return visit(
[&](auto&& t) { [&](auto&& t)
{
return cloneChildren(&t); return cloneChildren(&t);
}, },
asMutable(tp)->ty); asMutable(tp)->ty
);
} }
void cloneChildren(Kind kind) void cloneChildren(Kind kind)

View File

@ -189,10 +189,18 @@ bool hasFreeType(TypeId ty)
} // namespace } // namespace
ConstraintGenerator::ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver, ConstraintGenerator::ConstraintGenerator(
NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, ModulePtr module,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, DcrLogger* logger, NotNull<DataFlowGraph> dfg, NotNull<Normalizer> normalizer,
std::vector<RequireCycle> requireCycles) NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
DcrLogger* logger,
NotNull<DataFlowGraph> dfg,
std::vector<RequireCycle> requireCycles
)
: module(module) : module(module)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, arena(normalizer->arena) , arena(normalizer->arena)
@ -240,9 +248,15 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
NotNull<Constraint> genConstraint = NotNull<Constraint> genConstraint =
addConstraint(scope, block->location, GeneralizationConstraint{result, moduleFnTy, std::move(interiorTypes.back())}); addConstraint(scope, block->location, GeneralizationConstraint{result, moduleFnTy, std::move(interiorTypes.back())});
getMutable<BlockedType>(result)->setOwner(genConstraint); getMutable<BlockedType>(result)->setOwner(genConstraint);
forEachConstraint(start, end, this, [genConstraint](const ConstraintPtr& c) { forEachConstraint(
start,
end,
this,
[genConstraint](const ConstraintPtr& c)
{
genConstraint->dependencies.push_back(NotNull{c.get()}); genConstraint->dependencies.push_back(NotNull{c.get()});
}); }
);
interiorTypes.pop_back(); interiorTypes.pop_back();
@ -354,10 +368,17 @@ NotNull<Constraint> ConstraintGenerator::addConstraint(const ScopePtr& scope, st
return NotNull{constraints.emplace_back(std::move(c)).get()}; return NotNull{constraints.emplace_back(std::move(c)).get()};
} }
void ConstraintGenerator::unionRefinements(const ScopePtr& scope, Location location, const RefinementContext& lhs, const RefinementContext& rhs, void ConstraintGenerator::unionRefinements(
RefinementContext& dest, std::vector<ConstraintV>* constraints) const ScopePtr& scope,
Location location,
const RefinementContext& lhs,
const RefinementContext& rhs,
RefinementContext& dest,
std::vector<ConstraintV>* constraints
)
{
const auto intersect = [&](const std::vector<TypeId>& types)
{ {
const auto intersect = [&](const std::vector<TypeId>& types) {
if (1 == types.size()) if (1 == types.size())
return types[0]; return types[0];
else if (2 == types.size()) else if (2 == types.size())
@ -386,8 +407,15 @@ void ConstraintGenerator::unionRefinements(const ScopePtr& scope, Location locat
} }
} }
void ConstraintGenerator::computeRefinement(const ScopePtr& scope, Location location, RefinementId refinement, RefinementContext* refis, bool sense, void ConstraintGenerator::computeRefinement(
bool eq, std::vector<ConstraintV>* constraints) const ScopePtr& scope,
Location location,
RefinementId refinement,
RefinementContext* refis,
bool sense,
bool eq,
std::vector<ConstraintV>* constraints
)
{ {
if (!refinement) if (!refinement)
return; return;
@ -555,8 +583,11 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat
switch (shouldSuppressErrors(normalizer, ty)) switch (shouldSuppressErrors(normalizer, ty))
{ {
case ErrorSuppression::DoNotSuppress: case ErrorSuppression::DoNotSuppress:
{
if (!get<NeverType>(follow(ty)))
ty = makeIntersect(scope, location, ty, dt); ty = makeIntersect(scope, location, ty, dt);
break; break;
}
case ErrorSuppression::Suppress: case ErrorSuppression::Suppress:
ty = makeIntersect(scope, location, ty, dt); ty = makeIntersect(scope, location, ty, dt);
ty = makeUnion(scope, location, ty, builtinTypes->errorType); ty = makeUnion(scope, location, ty, builtinTypes->errorType);
@ -688,6 +719,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStat* stat)
return visit(scope, f); return visit(scope, f);
else if (auto a = stat->as<AstStatTypeAlias>()) else if (auto a = stat->as<AstStatTypeAlias>())
return visit(scope, a); return visit(scope, a);
else if (auto f = stat->as<AstStatTypeFunction>())
return visit(scope, f);
else if (auto s = stat->as<AstStatDeclareGlobal>()) else if (auto s = stat->as<AstStatDeclareGlobal>())
return visit(scope, s); return visit(scope, s);
else if (auto s = stat->as<AstStatDeclareFunction>()) else if (auto s = stat->as<AstStatDeclareFunction>())
@ -792,11 +825,15 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat
auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack});
forEachConstraint(start, end, this, forEachConstraint(
start,
end,
this,
[&uc](const ConstraintPtr& runBefore) [&uc](const ConstraintPtr& runBefore)
{ {
uc->dependencies.push_back(NotNull{runBefore.get()}); uc->dependencies.push_back(NotNull{runBefore.get()});
}); }
);
for (TypeId t : valueTypes) for (TypeId t : valueTypes)
getMutable<BlockedType>(t)->setOwner(uc); getMutable<BlockedType>(t)->setOwner(uc);
@ -875,7 +912,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFor* for_)
if (for_->var->annotation) if (for_->var->annotation)
annotationTy = resolveType(scope, for_->var->annotation, /* inTypeArguments */ false); annotationTy = resolveType(scope, for_->var->annotation, /* inTypeArguments */ false);
auto inferNumber = [&](AstExpr* expr) { auto inferNumber = [&](AstExpr* expr)
{
if (!expr) if (!expr)
return; return;
@ -929,7 +967,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI
} }
auto iterable = addConstraint( auto iterable = addConstraint(
loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes}); loopScope, getLocation(forIn->values), IterableConstraint{iterator, variableTypes, forIn->values.data[0], &module->astForInNextTypes}
);
for (TypeId var : variableTypes) for (TypeId var : variableTypes)
{ {
@ -943,9 +982,15 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatForIn* forI
Checkpoint end = checkpoint(this); Checkpoint end = checkpoint(this);
// This iter constraint must dispatch first. // This iter constraint must dispatch first.
forEachConstraint(start, end, this, [&iterable](const ConstraintPtr& runLater) { forEachConstraint(
start,
end,
this,
[&iterable](const ConstraintPtr& runLater)
{
runLater->dependencies.push_back(iterable); runLater->dependencies.push_back(iterable);
}); }
);
return ControlFlow::None; return ControlFlow::None;
} }
@ -1011,7 +1056,12 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocalFuncti
std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature});
Constraint* previous = nullptr; Constraint* previous = nullptr;
forEachConstraint(start, end, this, [&c, &previous](const ConstraintPtr& constraint) { forEachConstraint(
start,
end,
this,
[&c, &previous](const ConstraintPtr& constraint)
{
c->dependencies.push_back(NotNull{constraint.get()}); c->dependencies.push_back(NotNull{constraint.get()});
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns) if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
@ -1021,7 +1071,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocalFuncti
previous = constraint.get(); previous = constraint.get();
} }
}); }
);
getMutable<BlockedType>(functionType)->setOwner(addConstraint(scope, std::move(c))); getMutable<BlockedType>(functionType)->setOwner(addConstraint(scope, std::move(c)));
module->astTypes[function->func] = functionType; module->astTypes[function->func] = functionType;
@ -1055,7 +1106,12 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f
getMutable<BlockedType>(generalizedType)->setOwner(c); getMutable<BlockedType>(generalizedType)->setOwner(c);
Constraint* previous = nullptr; Constraint* previous = nullptr;
forEachConstraint(start, end, this, [&c, &previous](const ConstraintPtr& constraint) { forEachConstraint(
start,
end,
this,
[&c, &previous](const ConstraintPtr& constraint)
{
c->dependencies.push_back(NotNull{constraint.get()}); c->dependencies.push_back(NotNull{constraint.get()});
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns) if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
@ -1065,7 +1121,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f
previous = constraint.get(); previous = constraint.get();
} }
}); }
);
} }
DefId def = dfg->getDef(function->name); DefId def = dfg->getDef(function->name);
@ -1211,7 +1268,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatCompoundAss
ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatIf* ifStatement) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatIf* ifStatement)
{ {
RefinementId refinement = [&]() { RefinementId refinement = [&]()
{
InConditionalContext flipper{&typeContext}; InConditionalContext flipper{&typeContext};
return check(scope, ifStatement->condition, std::nullopt).refinement; return check(scope, ifStatement->condition, std::nullopt).refinement;
}(); }();
@ -1293,18 +1351,26 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias*
for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false)) for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false))
typePackParams.push_back(tpParam.second.tp); typePackParams.push_back(tpParam.second.tp);
addConstraint(scope, alias->type->location, addConstraint(
scope,
alias->type->location,
NameConstraint{ NameConstraint{
ty, ty,
alias->name.value, alias->name.value,
/*synthetic=*/false, /*synthetic=*/false,
std::move(typeParams), std::move(typeParams),
std::move(typePackParams), std::move(typePackParams),
}); }
);
return ControlFlow::None; return ControlFlow::None;
} }
ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunction* function)
{
return ControlFlow::None;
}
ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareGlobal* global)
{ {
LUAU_ASSERT(global->type); LUAU_ASSERT(global->type);
@ -1350,8 +1416,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas
if (!get<ClassType>(follow(*superTy))) if (!get<ClassType>(follow(*superTy)))
{ {
reportError(declaredClass->location, reportError(
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)}); declaredClass->location,
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)}
);
return ControlFlow::None; return ControlFlow::None;
} }
@ -1579,7 +1647,11 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstArray<Ast
} }
InferencePack ConstraintGenerator::checkPack( InferencePack ConstraintGenerator::checkPack(
const ScopePtr& scope, AstExpr* expr, const std::vector<std::optional<TypeId>>& expectedTypes, bool generalize) const ScopePtr& scope,
AstExpr* expr,
const std::vector<std::optional<TypeId>>& expectedTypes,
bool generalize
)
{ {
RecursionCounter counter{&recursionCount}; RecursionCounter counter{&recursionCount};
@ -1661,7 +1733,6 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
std::vector<std::optional<TypeId>> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType); std::vector<std::optional<TypeId>> expectedTypesForCall = getExpectedCallTypesForFunctionOverloads(fnType);
module->astOriginalCallTypes[call->func] = fnType;
module->astOriginalCallTypes[call] = fnType; module->astOriginalCallTypes[call] = fnType;
Checkpoint argBeginCheckpoint = checkpoint(this); Checkpoint argBeginCheckpoint = checkpoint(this);
@ -1796,14 +1867,25 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
* 4. Solve the call * 4. Solve the call
*/ */
NotNull<Constraint> checkConstraint = addConstraint(scope, call->func->location, NotNull<Constraint> checkConstraint = addConstraint(
FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}}); scope,
call->func->location,
FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}}
);
forEachConstraint(funcBeginCheckpoint, funcEndCheckpoint, this, [checkConstraint](const ConstraintPtr& constraint) { forEachConstraint(
funcBeginCheckpoint,
funcEndCheckpoint,
this,
[checkConstraint](const ConstraintPtr& constraint)
{
checkConstraint->dependencies.emplace_back(constraint.get()); checkConstraint->dependencies.emplace_back(constraint.get());
}); }
);
NotNull<Constraint> callConstraint = addConstraint(scope, call->func->location, NotNull<Constraint> callConstraint = addConstraint(
scope,
call->func->location,
FunctionCallConstraint{ FunctionCallConstraint{
fnType, fnType,
argPack, argPack,
@ -1811,17 +1893,24 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall*
call, call,
std::move(discriminantTypes), std::move(discriminantTypes),
&module->astOverloadResolvedTypes, &module->astOverloadResolvedTypes,
}); }
);
getMutable<BlockedTypePack>(rets)->owner = callConstraint.get(); getMutable<BlockedTypePack>(rets)->owner = callConstraint.get();
callConstraint->dependencies.push_back(checkConstraint); callConstraint->dependencies.push_back(checkConstraint);
forEachConstraint(argBeginCheckpoint, argEndCheckpoint, this, [checkConstraint, callConstraint](const ConstraintPtr& constraint) { forEachConstraint(
argBeginCheckpoint,
argEndCheckpoint,
this,
[checkConstraint, callConstraint](const ConstraintPtr& constraint)
{
constraint->dependencies.emplace_back(checkConstraint); constraint->dependencies.emplace_back(checkConstraint);
callConstraint->dependencies.emplace_back(constraint.get()); callConstraint->dependencies.emplace_back(constraint.get());
}); }
);
return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; return InferencePack{rets, {refinementArena.variadic(returnRefinements)}};
} }
@ -1974,7 +2063,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* globa
} }
Inference ConstraintGenerator::checkIndexName( Inference ConstraintGenerator::checkIndexName(
const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation) const ScopePtr& scope,
const RefinementKey* key,
AstExpr* indexee,
const std::string& index,
Location indexLocation
)
{ {
TypeId obj = check(scope, indexee).ty; TypeId obj = check(scope, indexee).ty;
TypeId result = nullptr; TypeId result = nullptr;
@ -2005,7 +2099,8 @@ Inference ConstraintGenerator::checkIndexName(
result = arena->addType(BlockedType{}); result = arena->addType(BlockedType{});
auto c = addConstraint( auto c = addConstraint(
scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}
);
getMutable<BlockedType>(result)->setOwner(c); getMutable<BlockedType>(result)->setOwner(c);
} }
@ -2076,7 +2171,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* fun
interiorTypes.pop_back(); interiorTypes.pop_back();
Constraint* previous = nullptr; Constraint* previous = nullptr;
forEachConstraint(startCheckpoint, endCheckpoint, this, [gc, &previous](const ConstraintPtr& constraint) { forEachConstraint(
startCheckpoint,
endCheckpoint,
this,
[gc, &previous](const ConstraintPtr& constraint)
{
gc->dependencies.emplace_back(constraint.get()); gc->dependencies.emplace_back(constraint.get());
if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns) if (auto psc = get<PackSubtypeConstraint>(*constraint); psc && psc->returns)
@ -2086,7 +2186,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* fun
previous = constraint.get(); previous = constraint.get();
} }
}); }
);
if (generalize && hasFreeType(sig.signature)) if (generalize && hasFreeType(sig.signature))
{ {
@ -2187,9 +2288,13 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar
} }
case AstExprBinary::Op::CompareGe: case AstExprBinary::Op::CompareGe:
{ {
TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().ltFunc, TypeId resultType = createTypeFunctionInstance(
builtinTypeFunctions().ltFunc,
{rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)` {rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)`
{}, scope, binary->location); {},
scope,
binary->location
);
return Inference{resultType, std::move(refinement)}; return Inference{resultType, std::move(refinement)};
} }
case AstExprBinary::Op::CompareLe: case AstExprBinary::Op::CompareLe:
@ -2202,7 +2307,10 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar
TypeId resultType = createTypeFunctionInstance( TypeId resultType = createTypeFunctionInstance(
builtinTypeFunctions().leFunc, builtinTypeFunctions().leFunc,
{rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)` {rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)`
{}, scope, binary->location); {},
scope,
binary->location
);
return Inference{resultType, std::move(refinement)}; return Inference{resultType, std::move(refinement)};
} }
case AstExprBinary::Op::CompareEq: case AstExprBinary::Op::CompareEq:
@ -2234,7 +2342,8 @@ builtinTypeFunctions().leFunc,
Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType) Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType)
{ {
RefinementId refinement = [&]() { RefinementId refinement = [&]()
{
InConditionalContext flipper{&typeContext}; InConditionalContext flipper{&typeContext};
ScopePtr condScope = childScope(ifElse->condition, scope); ScopePtr condScope = childScope(ifElse->condition, scope);
return check(condScope, ifElse->condition).refinement; return check(condScope, ifElse->condition).refinement;
@ -2266,7 +2375,10 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprInterpString*
} }
std::tuple<TypeId, TypeId, RefinementId> ConstraintGenerator::checkBinary( std::tuple<TypeId, TypeId, RefinementId> ConstraintGenerator::checkBinary(
const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType) const ScopePtr& scope,
AstExprBinary* binary,
std::optional<TypeId> expectedType
)
{ {
if (binary->op == AstExprBinary::And) if (binary->op == AstExprBinary::And)
{ {
@ -2460,7 +2572,8 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexName* e
bool incremented = recordPropertyAssignment(lhsTy); bool incremented = recordPropertyAssignment(lhsTy);
auto apc = addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, expr->indexLocation, propTy, incremented}); auto apc =
addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, expr->index.value, rhsType, expr->indexLocation, propTy, incremented});
getMutable<BlockedType>(propTy)->setOwner(apc); getMutable<BlockedType>(propTy)->setOwner(apc);
} }
@ -2476,7 +2589,9 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprIndexExpr* e
bool incremented = recordPropertyAssignment(lhsTy); bool incremented = recordPropertyAssignment(lhsTy);
auto apc = addConstraint(scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, expr->index->location, propTy, incremented}); auto apc = addConstraint(
scope, expr->location, AssignPropConstraint{lhsTy, std::move(propName), rhsType, expr->index->location, propTy, incremented}
);
getMutable<BlockedType>(propTy)->setOwner(apc); getMutable<BlockedType>(propTy)->setOwner(apc);
return; return;
@ -2505,7 +2620,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr,
TypeIds indexKeyLowerBound; TypeIds indexKeyLowerBound;
TypeIds indexValueLowerBound; TypeIds indexValueLowerBound;
auto createIndexer = [&indexKeyLowerBound, &indexValueLowerBound](const Location& location, TypeId currentIndexType, TypeId currentResultType) { auto createIndexer = [&indexKeyLowerBound, &indexValueLowerBound](const Location& location, TypeId currentIndexType, TypeId currentResultType)
{
indexKeyLowerBound.insert(follow(currentIndexType)); indexKeyLowerBound.insert(follow(currentIndexType));
indexValueLowerBound.insert(follow(currentResultType)); indexValueLowerBound.insert(follow(currentResultType));
}; };
@ -2565,14 +2681,19 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr,
Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice}; Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice};
std::vector<TypeId> toBlock; std::vector<TypeId> toBlock;
matchLiteralType( matchLiteralType(
NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr, toBlock); NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr, toBlock
);
} }
return Inference{ty}; return Inference{ty};
} }
ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignature( ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignature(
const ScopePtr& parent, AstExprFunction* fn, std::optional<TypeId> expectedType, std::optional<Location> originalName) const ScopePtr& parent,
AstExprFunction* fn,
std::optional<TypeId> expectedType,
std::optional<Location> originalName
)
{ {
ScopePtr signatureScope = nullptr; ScopePtr signatureScope = nullptr;
ScopePtr bodyScope = nullptr; ScopePtr bodyScope = nullptr;
@ -3076,7 +3197,11 @@ TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, const Ast
} }
std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGenerator::createGenerics( std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGenerator::createGenerics(
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache, bool addTypes) const ScopePtr& scope,
AstArray<AstGenericType> generics,
bool useCache,
bool addTypes
)
{ {
std::vector<std::pair<Name, GenericTypeDefinition>> result; std::vector<std::pair<Name, GenericTypeDefinition>> result;
for (const auto& generic : generics) for (const auto& generic : generics)
@ -3106,7 +3231,11 @@ std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGenerator::createG
} }
std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGenerator::createGenericPacks( std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGenerator::createGenericPacks(
const ScopePtr& scope, AstArray<AstGenericTypePack> generics, bool useCache, bool addTypes) const ScopePtr& scope,
AstArray<AstGenericTypePack> generics,
bool useCache,
bool addTypes
)
{ {
std::vector<std::pair<Name, GenericTypePackDefinition>> result; std::vector<std::pair<Name, GenericTypePackDefinition>> result;
for (const auto& generic : generics) for (const auto& generic : generics)
@ -3323,7 +3452,8 @@ std::vector<std::optional<TypeId>> ConstraintGenerator::getExpectedCallTypesForF
// For a list of functions f_0 : e_0 -> r_0, ... f_n : e_n -> r_n, // For a list of functions f_0 : e_0 -> r_0, ... f_n : e_n -> r_n,
// emit a list of arguments that the function could take at each position // emit a list of arguments that the function could take at each position
// by unioning the arguments at each place // by unioning the arguments at each place
auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { auto assignOption = [this, &expectedTypes](size_t index, TypeId ty)
{
if (index == expectedTypes.size()) if (index == expectedTypes.size())
expectedTypes.push_back(ty); expectedTypes.push_back(ty);
else if (ty) else if (ty)
@ -3372,7 +3502,12 @@ std::vector<std::optional<TypeId>> ConstraintGenerator::getExpectedCallTypesForF
} }
TypeId ConstraintGenerator::createTypeFunctionInstance( TypeId ConstraintGenerator::createTypeFunctionInstance(
const TypeFunction& function, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments, const ScopePtr& scope, Location location) const TypeFunction& function,
std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments,
const ScopePtr& scope,
Location location
)
{ {
TypeId result = arena->addTypeFunction(function, typeArguments, packArguments); TypeId result = arena->addTypeFunction(function, typeArguments, packArguments);
addConstraint(scope, location, ReduceConstraint{result}); addConstraint(scope, location, ReduceConstraint{result});

View File

@ -89,8 +89,13 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const
return true; return true;
} }
static std::pair<std::vector<TypeId>, std::vector<TypePackId>> saturateArguments(TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, static std::pair<std::vector<TypeId>, std::vector<TypePackId>> saturateArguments(
const TypeFun& fn, const std::vector<TypeId>& rawTypeArguments, const std::vector<TypePackId>& rawPackArguments) TypeArena* arena,
NotNull<BuiltinTypes> builtinTypes,
const TypeFun& fn,
const std::vector<TypeId>& rawTypeArguments,
const std::vector<TypePackId>& rawPackArguments
)
{ {
std::vector<TypeId> saturatedTypeArguments; std::vector<TypeId> saturatedTypeArguments;
std::vector<TypeId> extraTypes; std::vector<TypeId> extraTypes;
@ -310,8 +315,16 @@ struct InstantiationQueuer : TypeOnceVisitor
} }
}; };
ConstraintSolver::ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints, ConstraintSolver::ConstraintSolver(
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger, TypeCheckLimits limits) NotNull<Normalizer> normalizer,
NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName,
NotNull<ModuleResolver> moduleResolver,
std::vector<RequireCycle> requireCycles,
DcrLogger* logger,
TypeCheckLimits limits
)
: arena(normalizer->arena) : arena(normalizer->arena)
, builtinTypes(normalizer->builtinTypes) , builtinTypes(normalizer->builtinTypes)
, normalizer(normalizer) , normalizer(normalizer)
@ -374,7 +387,8 @@ void ConstraintSolver::run()
if (FFlag::DebugLuauLogSolver) if (FFlag::DebugLuauLogSolver)
{ {
printf( printf(
"Starting solver for module %s (%s)\n", moduleResolver->getHumanReadableModuleName(currentModuleName).c_str(), currentModuleName.c_str()); "Starting solver for module %s (%s)\n", moduleResolver->getHumanReadableModuleName(currentModuleName).c_str(), currentModuleName.c_str()
);
dump(this, opts); dump(this, opts);
printf("Bindings:\n"); printf("Bindings:\n");
dumpBindings(rootScope, opts); dumpBindings(rootScope, opts);
@ -385,7 +399,8 @@ void ConstraintSolver::run()
logger->captureInitialSolverState(rootScope, unsolvedConstraints); logger->captureInitialSolverState(rootScope, unsolvedConstraints);
} }
auto runSolverPass = [&](bool force) { auto runSolverPass = [&](bool force)
{
bool progress = false; bool progress = false;
size_t i = 0; size_t i = 0;
@ -489,7 +504,7 @@ void ConstraintSolver::run()
} while (progress); } while (progress);
if (!unsolvedConstraints.empty()) if (!unsolvedConstraints.empty())
reportError(InternalError{"Type inference failed to complete, you may see some confusing types and type errors."}, Location{}); reportError(ConstraintSolvingIncompleteError{}, Location{});
// After we have run all the constraints, type functions should be generalized // After we have run all the constraints, type functions should be generalized
// At this point, we can try to perform one final simplification to suss out // At this point, we can try to perform one final simplification to suss out
@ -730,7 +745,8 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
* applies constraints to the types of the iterators. * applies constraints to the types of the iterators.
*/ */
auto block_ = [&](auto&& t) { auto block_ = [&](auto&& t)
{
if (force) if (force)
{ {
// If we haven't figured out the type of the iteratee by now, // If we haven't figured out the type of the iteratee by now,
@ -891,7 +907,8 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
return true; return true;
} }
auto bindResult = [this, &c, constraint](TypeId result) { auto bindResult = [this, &c, constraint](TypeId result)
{
LUAU_ASSERT(get<PendingExpansionType>(c.target)); LUAU_ASSERT(get<PendingExpansionType>(c.target));
shiftReferences(c.target, result); shiftReferences(c.target, result);
bind(constraint, c.target, result); bind(constraint, c.target, result);
@ -929,14 +946,27 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments);
bool sameTypes = std::equal(typeArguments.begin(), typeArguments.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& p) { bool sameTypes = std::equal(
typeArguments.begin(),
typeArguments.end(),
tf->typeParams.begin(),
tf->typeParams.end(),
[](auto&& itp, auto&& p)
{
return itp == p.ty; return itp == p.ty;
}); }
);
bool samePacks = bool samePacks = std::equal(
std::equal(packArguments.begin(), packArguments.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itp, auto&& p) { packArguments.begin(),
packArguments.end(),
tf->typePackParams.begin(),
tf->typePackParams.end(),
[](auto&& itp, auto&& p)
{
return itp == p.tp; return itp == p.tp;
}); }
);
// If we're instantiating the type with its generic saturatedTypeArguments we are // If we're instantiating the type with its generic saturatedTypeArguments we are
// performing the identity substitution. We can just short-circuit and bind // performing the identity substitution. We can just short-circuit and bind
@ -1023,9 +1053,14 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
//clang-format off //clang-format off
bool needsClone = follow(tf->type) == target || (tfTable != nullptr && tfTable == getTableType(target)) || bool needsClone = follow(tf->type) == target || (tfTable != nullptr && tfTable == getTableType(target)) ||
std::any_of(typeArguments.begin(), typeArguments.end(), [&](const auto& other) { std::any_of(
typeArguments.begin(),
typeArguments.end(),
[&](const auto& other)
{
return other == target; return other == target;
}); }
);
//clang-format on //clang-format on
// Only tables have the properties we're trying to set. // Only tables have the properties we're trying to set.
@ -1120,7 +1155,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
if (blocked) if (blocked)
return false; return false;
auto collapse = [](const auto* t) -> std::optional<TypeId> { auto collapse = [](const auto* t) -> std::optional<TypeId>
{
auto it = begin(t); auto it = begin(t);
auto endIt = end(t); auto endIt = end(t);
@ -1145,6 +1181,9 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
// We don't support magic __call metamethods. // We don't support magic __call metamethods.
if (std::optional<TypeId> callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location)) if (std::optional<TypeId> callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location))
{ {
if (isBlocked(*callMm))
return block(*callMm, constraint);
argsHead.insert(argsHead.begin(), fn); argsHead.insert(argsHead.begin(), fn);
if (argsTail && isBlocked(*argsTail)) if (argsTail && isBlocked(*argsTail))
@ -1195,7 +1234,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
} }
OverloadResolver resolver{ OverloadResolver resolver{
builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location}; builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location
};
auto [status, overload] = resolver.selectOverload(fn, argsPack); auto [status, overload] = resolver.selectOverload(fn, argsPack);
TypeId overloadToUse = fn; TypeId overloadToUse = fn;
if (status == OverloadResolver::Analysis::Ok) if (status == OverloadResolver::Analysis::Ok)
@ -1334,8 +1374,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
} }
} }
} }
else if (expr->is<AstExprConstantBool>() || expr->is<AstExprConstantString>() || expr->is<AstExprConstantNumber>() || else if (expr->is<AstExprConstantBool>() || expr->is<AstExprConstantString>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantNil>())
expr->is<AstExprConstantNil>())
{ {
Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}};
u2.unify(actualArgTy, expectedArgTy); u2.unify(actualArgTy, expectedArgTy);
@ -1421,7 +1460,13 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
} }
bool ConstraintSolver::tryDispatchHasIndexer( bool ConstraintSolver::tryDispatchHasIndexer(
int& recursionDepth, NotNull<const Constraint> constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set<TypeId>& seen) int& recursionDepth,
NotNull<const Constraint> constraint,
TypeId subjectType,
TypeId indexType,
TypeId resultType,
Set<TypeId>& seen
)
{ {
RecursionLimiter _rl{&recursionDepth, FInt::LuauSolverRecursionLimit}; RecursionLimiter _rl{&recursionDepth, FInt::LuauSolverRecursionLimit};
@ -1455,7 +1500,8 @@ bool ConstraintSolver::tryDispatchHasIndexer(
FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType}; FreeType freeResult{ft->scope, builtinTypes->neverType, builtinTypes->unknownType};
emplace<FreeType>(constraint, resultType, freeResult); emplace<FreeType>(constraint, resultType, freeResult);
TypeId upperBound = arena->addType(TableType{/* props */ {}, TableIndexer{indexType, resultType}, TypeLevel{}, TableState::Unsealed}); TypeId upperBound =
arena->addType(TableType{/* props */ {}, TableIndexer{indexType, resultType}, TypeLevel{}, ft->scope, TableState::Unsealed});
unify(constraint, subjectType, upperBound); unify(constraint, subjectType, upperBound);
@ -1777,7 +1823,8 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull<const
// Important: In every codepath through this function, the type `c.propType` // Important: In every codepath through this function, the type `c.propType`
// must be bound to something, even if it's just the errorType. // must be bound to something, even if it's just the errorType.
auto tableStuff = [&](TableType* lhsTable) -> std::optional<bool> { auto tableStuff = [&](TableType* lhsTable) -> std::optional<bool>
{
if (lhsTable->indexer) if (lhsTable->indexer)
{ {
unify(constraint, indexType, lhsTable->indexer->indexType); unify(constraint, indexType, lhsTable->indexer->indexType);
@ -2074,7 +2121,8 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
return true; return true;
} }
auto unpack = [&](TypeId ty) { auto unpack = [&](TypeId ty)
{
for (TypeId varTy : c.variables) for (TypeId varTy : c.variables)
{ {
LUAU_ASSERT(get<BlockedType>(varTy)); LUAU_ASSERT(get<BlockedType>(varTy));
@ -2200,7 +2248,12 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
} }
bool ConstraintSolver::tryDispatchIterableFunction( bool ConstraintSolver::tryDispatchIterableFunction(
TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force) TypeId nextTy,
TypeId tableTy,
const IterableConstraint& c,
NotNull<const Constraint> constraint,
bool force
)
{ {
const FunctionType* nextFn = get<FunctionType>(nextTy); const FunctionType* nextFn = get<FunctionType>(nextTy);
// If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place.
@ -2237,7 +2290,10 @@ bool ConstraintSolver::tryDispatchIterableFunction(
} }
NotNull<const Constraint> ConstraintSolver::unpackAndAssign( NotNull<const Constraint> ConstraintSolver::unpackAndAssign(
const std::vector<TypeId> destTypes, TypePackId srcTypes, NotNull<const Constraint> constraint) const std::vector<TypeId> destTypes,
TypePackId srcTypes,
NotNull<const Constraint> constraint
)
{ {
auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes}); auto c = pushConstraint(constraint->scope, constraint->location, UnpackConstraint{destTypes, srcTypes});
@ -2251,15 +2307,28 @@ NotNull<const Constraint> ConstraintSolver::unpackAndAssign(
return c; return c;
} }
std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType, std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(
const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification) NotNull<const Constraint> constraint,
TypeId subjectType,
const std::string& propName,
ValueContext context,
bool inConditional,
bool suppressSimplification
)
{ {
DenseHashSet<TypeId> seen{nullptr}; DenseHashSet<TypeId> seen{nullptr};
return lookupTableProp(constraint, subjectType, propName, context, inConditional, suppressSimplification, seen); return lookupTableProp(constraint, subjectType, propName, context, inConditional, suppressSimplification, seen);
} }
std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(NotNull<const Constraint> constraint, TypeId subjectType, std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTableProp(
const std::string& propName, ValueContext context, bool inConditional, bool suppressSimplification, DenseHashSet<TypeId>& seen) NotNull<const Constraint> constraint,
TypeId subjectType,
const std::string& propName,
ValueContext context,
bool inConditional,
bool suppressSimplification,
DenseHashSet<TypeId>& seen
)
{ {
if (seen.contains(subjectType)) if (seen.contains(subjectType))
return {}; return {};

View File

@ -204,7 +204,8 @@ void DataFlowGraphBuilder::joinBindings(DfgScope* p, const DfgScope& a, const Df
void DataFlowGraphBuilder::joinProps(DfgScope* result, const DfgScope& a, const DfgScope& b) void DataFlowGraphBuilder::joinProps(DfgScope* result, const DfgScope& a, const DfgScope& b)
{ {
auto phinodify = [this](DfgScope* scope, const auto& a, const auto& b, DefId parent) mutable { auto phinodify = [this](DfgScope* scope, const auto& a, const auto& b, DefId parent) mutable
{
auto& p = scope->props[parent]; auto& p = scope->props[parent];
for (const auto& [k, defA] : a) for (const auto& [k, defA] : a)
{ {
@ -373,6 +374,8 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s)
return visit(scope, l); return visit(scope, l);
else if (auto t = s->as<AstStatTypeAlias>()) else if (auto t = s->as<AstStatTypeAlias>())
return visit(scope, t); return visit(scope, t);
else if (auto f = s->as<AstStatTypeFunction>())
return visit(scope, f);
else if (auto d = s->as<AstStatDeclareGlobal>()) else if (auto d = s->as<AstStatDeclareGlobal>())
return visit(scope, d); return visit(scope, d);
else if (auto d = s->as<AstStatDeclareFunction>()) else if (auto d = s->as<AstStatDeclareFunction>())
@ -631,6 +634,14 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t)
return ControlFlow::None; return ControlFlow::None;
} }
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeFunction* f)
{
DfgScope* unreachable = childScope(scope);
visitExpr(unreachable, f->body);
return ControlFlow::None;
}
ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d) ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d)
{ {
DefId def = defArena->freshCell(); DefId def = defArena->freshCell();
@ -691,7 +702,8 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e)
return {NotNull{*def}, key ? *key : nullptr}; return {NotNull{*def}, key ? *key : nullptr};
} }
auto go = [&]() -> DataFlowResult { auto go = [&]() -> DataFlowResult
{
if (auto g = e->as<AstExprGroup>()) if (auto g = e->as<AstExprGroup>())
return visitExpr(scope, g); return visitExpr(scope, g);
else if (auto c = e->as<AstExprConstantNil>()) else if (auto c = e->as<AstExprConstantNil>())
@ -910,7 +922,8 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* er
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef, bool isCompoundAssignment) void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef, bool isCompoundAssignment)
{ {
auto go = [&]() { auto go = [&]()
{
if (auto l = e->as<AstExprLocal>()) if (auto l = e->as<AstExprLocal>())
return visitLValue(scope, l, incomingDef, isCompoundAssignment); return visitLValue(scope, l, incomingDef, isCompoundAssignment);
else if (auto g = e->as<AstExprGlobal>()) else if (auto g = e->as<AstExprGlobal>())

View File

@ -124,7 +124,8 @@ void write(JsonEmitter& emitter, const ConstraintBlock& block)
ObjectEmitter o = emitter.writeObject(); ObjectEmitter o = emitter.writeObject();
o.writePair("stringification", block.stringification); o.writePair("stringification", block.stringification);
auto go = [&o](auto&& t) { auto go = [&o](auto&& t)
{
using T = std::decay_t<decltype(t)>; using T = std::decay_t<decltype(t)>;
o.writePair("id", toPointerId(t)); o.writePair("id", toPointerId(t));
@ -350,8 +351,12 @@ void DcrLogger::popBlock(NotNull<const Constraint> block)
} }
} }
static void snapshotTypeStrings(const std::vector<ExprTypesAtLocation>& interestedExprs, static void snapshotTypeStrings(
const std::vector<AnnotationTypesAtLocation>& interestedAnnots, DenseHashMap<const void*, std::string>& map, ToStringOptions& opts) const std::vector<ExprTypesAtLocation>& interestedExprs,
const std::vector<AnnotationTypesAtLocation>& interestedAnnots,
DenseHashMap<const void*, std::string>& map,
ToStringOptions& opts
)
{ {
for (const ExprTypesAtLocation& tys : interestedExprs) for (const ExprTypesAtLocation& tys : interestedExprs)
{ {
@ -368,7 +373,10 @@ static void snapshotTypeStrings(const std::vector<ExprTypesAtLocation>& interest
} }
void DcrLogger::captureBoundaryState( void DcrLogger::captureBoundaryState(
BoundarySnapshot& target, const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints) BoundarySnapshot& target,
const Scope* rootScope,
const std::vector<NotNull<const Constraint>>& unsolvedConstraints
)
{ {
target.rootScope = snapshotScope(rootScope, opts); target.rootScope = snapshotScope(rootScope, opts);
target.unsolvedConstraints.clear(); target.unsolvedConstraints.clear();
@ -391,7 +399,11 @@ void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vec
} }
StepSnapshot DcrLogger::prepareStepSnapshot( StepSnapshot DcrLogger::prepareStepSnapshot(
const Scope* rootScope, NotNull<const Constraint> current, bool force, const std::vector<NotNull<const Constraint>>& unsolvedConstraints) const Scope* rootScope,
NotNull<const Constraint> current,
bool force,
const std::vector<NotNull<const Constraint>>& unsolvedConstraints
)
{ {
ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts); ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts);
DenseHashMap<const Constraint*, ConstraintSnapshot> constraints{nullptr}; DenseHashMap<const Constraint*, ConstraintSnapshot> constraints{nullptr};

View File

@ -286,15 +286,22 @@ struct FindSeteqCounterexampleResult
bool inLeft; bool inLeft;
}; };
static FindSeteqCounterexampleResult findSeteqCounterexample( static FindSeteqCounterexampleResult findSeteqCounterexample(
DifferEnvironment& env, const std::vector<TypeId>& left, const std::vector<TypeId>& right); DifferEnvironment& env,
const std::vector<TypeId>& left,
const std::vector<TypeId>& right
);
static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right); static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right);
/** /**
* The last argument gives context info on which complex type contained the TypePack. * The last argument gives context info on which complex type contained the TypePack.
*/ */
static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right); static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right);
static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, static DifferResult diffCanonicalTpShape(
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& left, const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& right); DifferEnvironment& env,
DiffError::Kind possibleNonNormalErrorKind,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& left,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& right
);
static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right); static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right);
static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right); static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right);
@ -324,8 +331,13 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right)
if (leftTable->props.find(field) == leftTable->props.end()) if (leftTable->props.find(field) == leftTable->props.end())
{ {
// right has a field the left doesn't // right has a field the left doesn't
return DifferResult{DiffError{DiffError::Kind::MissingTableProperty, DiffPathNodeLeaf::nullopts(), return DifferResult{DiffError{
DiffPathNodeLeaf::detailsTableProperty(value.type(), field), env.getDevFixFriendlyNameLeft(), env.getDevFixFriendlyNameRight()}}; DiffError::Kind::MissingTableProperty,
DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::detailsTableProperty(value.type(), field),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight()
}};
} }
} }
// left and right have the same set of keys // left and right have the same set of keys
@ -491,7 +503,10 @@ static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right)
} }
static FindSeteqCounterexampleResult findSeteqCounterexample( static FindSeteqCounterexampleResult findSeteqCounterexample(
DifferEnvironment& env, const std::vector<TypeId>& left, const std::vector<TypeId>& right) DifferEnvironment& env,
const std::vector<TypeId>& left,
const std::vector<TypeId>& right
)
{ {
std::unordered_set<size_t> unmatchedRightIdxes; std::unordered_set<size_t> unmatchedRightIdxes;
for (size_t i = 0; i < right.size(); i++) for (size_t i = 0; i < right.size(); i++)
@ -760,8 +775,12 @@ static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonN
return diffHandleFlattenedTail(env, possibleNonNormalErrorKind, *leftFlatTpi.second, *rightFlatTpi.second); return diffHandleFlattenedTail(env, possibleNonNormalErrorKind, *leftFlatTpi.second, *rightFlatTpi.second);
} }
static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, static DifferResult diffCanonicalTpShape(
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& left, const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& right) DifferEnvironment& env,
DiffError::Kind possibleNonNormalErrorKind,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& left,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& right
)
{ {
if (left.first.size() == right.first.size() && left.second.has_value() == right.second.has_value()) if (left.first.size() == right.first.size() && left.second.has_value() == right.second.has_value())
return DifferResult{}; return DifferResult{};

View File

@ -21,7 +21,12 @@ LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauImproveNonFunctionCallError, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauImproveNonFunctionCallError, false)
static std::string wrongNumberOfArgsString( static std::string wrongNumberOfArgsString(
size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) size_t expectedCount,
std::optional<size_t> maximumCount,
size_t actualCount,
const char* argPrefix = nullptr,
bool isVariadic = false
)
{ {
std::string s = "expects "; std::string s = "expects ";
@ -65,8 +70,21 @@ namespace Luau
{ {
// this list of binary operator type functions is used for better stringification of type functions errors // this list of binary operator type functions is used for better stringification of type functions errors
static const std::unordered_map<std::string, const char*> kBinaryOps{{"add", "+"}, {"sub", "-"}, {"mul", "*"}, {"div", "/"}, {"idiv", "//"}, static const std::unordered_map<std::string, const char*> kBinaryOps{
{"pow", "^"}, {"mod", "%"}, {"concat", ".."}, {"and", "and"}, {"or", "or"}, {"lt", "< or >="}, {"le", "<= or >"}, {"eq", "== or ~="}}; {"add", "+"},
{"sub", "-"},
{"mul", "*"},
{"div", "/"},
{"idiv", "//"},
{"pow", "^"},
{"mod", "%"},
{"concat", ".."},
{"and", "and"},
{"or", "or"},
{"lt", "< or >="},
{"le", "<= or >"},
{"eq", "== or ~="}
};
// this list of unary operator type functions is used for better stringification of type functions errors // this list of unary operator type functions is used for better stringification of type functions errors
static const std::unordered_map<std::string, const char*> kUnaryOps{{"unm", "-"}, {"len", "#"}, {"not", "not"}}; static const std::unordered_map<std::string, const char*> kUnaryOps{{"unm", "-"}, {"len", "#"}, {"not", "not"}};
@ -86,12 +104,15 @@ struct ErrorConverter
std::string result; std::string result;
auto quote = [&](std::string s) { auto quote = [&](std::string s)
{
return "'" + s + "'"; return "'" + s + "'";
}; };
auto constructErrorMessage = [&](std::string givenType, std::string wantedType, std::optional<std::string> givenModule, auto constructErrorMessage =
std::optional<std::string> wantedModule) -> std::string { [&](std::string givenType, std::string wantedType, std::optional<std::string> givenModule, std::optional<std::string> wantedModule
) -> std::string
{
std::string given = givenModule ? quote(givenType) + " from " + quote(*givenModule) : quote(givenType); std::string given = givenModule ? quote(givenType) + " from " + quote(*givenModule) : quote(givenType);
std::string wanted = wantedModule ? quote(wantedType) + " from " + quote(*wantedModule) : quote(wantedType); std::string wanted = wantedModule ? quote(wantedType) + " from " + quote(*wantedModule) : quote(wantedType);
size_t luauIndentTypeMismatchMaxTypeLength = size_t(FInt::LuauIndentTypeMismatchMaxTypeLength); size_t luauIndentTypeMismatchMaxTypeLength = size_t(FInt::LuauIndentTypeMismatchMaxTypeLength);
@ -351,6 +372,11 @@ struct ErrorConverter
return e.message; return e.message;
} }
std::string operator()(const Luau::ConstraintSolvingIncompleteError& e) const
{
return "Type inference failed to complete, you may see some confusing types and type errors.";
}
std::optional<TypeId> findCallMetamethod(TypeId type) const std::optional<TypeId> findCallMetamethod(TypeId type) const
{ {
type = follow(type); type = follow(type);
@ -987,6 +1013,11 @@ bool InternalError::operator==(const InternalError& rhs) const
return message == rhs.message; return message == rhs.message;
} }
bool ConstraintSolvingIncompleteError::operator==(const ConstraintSolvingIncompleteError& rhs) const
{
return true;
}
bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const
{ {
return ty == rhs.ty; return ty == rhs.ty;
@ -1177,11 +1208,13 @@ bool containsParseErrorName(const TypeError& error)
template<typename T> template<typename T>
void copyError(T& e, TypeArena& destArena, CloneState& cloneState) void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
{ {
auto clone = [&](auto&& ty) { auto clone = [&](auto&& ty)
{
return ::Luau::clone(ty, destArena, cloneState); return ::Luau::clone(ty, destArena, cloneState);
}; };
auto visitErrorData = [&](auto&& e) { auto visitErrorData = [&](auto&& e)
{
copyError(e, destArena, cloneState); copyError(e, destArena, cloneState);
}; };
@ -1256,6 +1289,9 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
else if constexpr (std::is_same_v<T, InternalError>) else if constexpr (std::is_same_v<T, InternalError>)
{ {
} }
else if constexpr (std::is_same_v<T, ConstraintSolvingIncompleteError>)
{
}
else if constexpr (std::is_same_v<T, CannotCallNonFunction>) else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
{ {
e.ty = clone(e.ty); e.ty = clone(e.ty);
@ -1363,7 +1399,8 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena, NotNull<BuiltinTypes> bu
{ {
CloneState cloneState{builtinTypes}; CloneState cloneState{builtinTypes};
auto visitErrorData = [&](auto&& e) { auto visitErrorData = [&](auto&& e)
{
copyError(e, destArena, cloneState); copyError(e, destArena, cloneState);
}; };

View File

@ -176,8 +176,14 @@ static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, S
} }
} }
LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, LoadDefinitionFileResult Frontend::loadDefinitionFile(
const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete) GlobalTypes& globals,
ScopePtr targetScope,
std::string_view source,
const std::string& packageName,
bool captureComments,
bool typeCheckForAutocomplete
)
{ {
LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend");
@ -269,7 +275,10 @@ namespace
{ {
static ErrorVec accumulateErrors( static ErrorVec accumulateErrors(
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes, ModuleResolver& moduleResolver, const ModuleName& name) const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes,
ModuleResolver& moduleResolver,
const ModuleName& name
)
{ {
DenseHashSet<ModuleName> seen{{}}; DenseHashSet<ModuleName> seen{{}};
std::vector<ModuleName> queue{name}; std::vector<ModuleName> queue{name};
@ -301,9 +310,14 @@ static ErrorVec accumulateErrors(
Module& module = *modulePtr; Module& module = *modulePtr;
std::sort(module.errors.begin(), module.errors.end(), [](const TypeError& e1, const TypeError& e2) -> bool { std::sort(
module.errors.begin(),
module.errors.end(),
[](const TypeError& e1, const TypeError& e2) -> bool
{
return e1.location.begin > e2.location.begin; return e1.location.begin > e2.location.begin;
}); }
);
result.insert(result.end(), module.errors.begin(), module.errors.end()); result.insert(result.end(), module.errors.begin(), module.errors.end());
} }
@ -334,8 +348,12 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector<HotCom
// For each such path, record the full path and the location of the require in the starting module. // For each such path, record the full path and the location of the require in the starting module.
// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V)
// However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true) // However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true)
std::vector<RequireCycle> getRequireCycles(const FileResolver* resolver, std::vector<RequireCycle> getRequireCycles(
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes, const SourceNode* start, bool stopAtFirst = false) const FileResolver* resolver,
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes,
const SourceNode* start,
bool stopAtFirst = false
)
{ {
std::vector<RequireCycle> result; std::vector<RequireCycle> result;
@ -503,7 +521,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
{ {
item.module->ats.root = toString(sourceModule.root); item.module->ats.root = toString(sourceModule.root);
} }
item.module->ats.rootSrc = sourceModule.root;
item.module->ats.traverse(item.module.get(), sourceModule.root, NotNull{&builtinTypes_}); item.module->ats.traverse(item.module.get(), sourceModule.root, NotNull{&builtinTypes_});
} }
} }
@ -522,8 +540,11 @@ void Frontend::queueModuleCheck(const ModuleName& name)
moduleQueue.push_back(name); moduleQueue.push_back(name);
} }
std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptions> optionOverride, std::vector<ModuleName> Frontend::checkQueuedModules(
std::function<void(std::function<void()> task)> executeTask, std::function<bool(size_t done, size_t total)> progress) std::optional<FrontendOptions> optionOverride,
std::function<void(std::function<void()> task)> executeTask,
std::function<bool(size_t done, size_t total)> progress
)
{ {
FrontendOptions frontendOptions = optionOverride.value_or(options); FrontendOptions frontendOptions = optionOverride.value_or(options);
if (FFlag::DebugLuauDeferredConstraintResolution) if (FFlag::DebugLuauDeferredConstraintResolution)
@ -548,9 +569,15 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
} }
std::vector<ModuleName> queue; std::vector<ModuleName> queue;
bool cycleDetected = parseGraph(queue, name, frontendOptions.forAutocomplete, [&seen](const ModuleName& name) { bool cycleDetected = parseGraph(
queue,
name,
frontendOptions.forAutocomplete,
[&seen](const ModuleName& name)
{
return seen.contains(name); return seen.contains(name);
}); }
);
addBuildQueueItems(buildQueueItems, queue, cycleDetected, seen, frontendOptions); addBuildQueueItems(buildQueueItems, queue, cycleDetected, seen, frontendOptions);
} }
@ -570,7 +597,8 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
// Default task execution is single-threaded and immediate // Default task execution is single-threaded and immediate
if (!executeTask) if (!executeTask)
{ {
executeTask = [](std::function<void()> task) { executeTask = [](std::function<void()> task)
{
task(); task();
}; };
} }
@ -582,7 +610,8 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
size_t processing = 0; size_t processing = 0;
size_t remaining = buildQueueItems.size(); size_t remaining = buildQueueItems.size();
auto itemTask = [&](size_t i) { auto itemTask = [&](size_t i)
{
BuildQueueItem& item = buildQueueItems[i]; BuildQueueItem& item = buildQueueItems[i];
try try
@ -602,18 +631,23 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
cv.notify_one(); cv.notify_one();
}; };
auto sendItemTask = [&](size_t i) { auto sendItemTask = [&](size_t i)
{
BuildQueueItem& item = buildQueueItems[i]; BuildQueueItem& item = buildQueueItems[i];
item.processing = true; item.processing = true;
processing++; processing++;
executeTask([&itemTask, i]() { executeTask(
[&itemTask, i]()
{
itemTask(i); itemTask(i);
}); }
);
}; };
auto sendCycleItemTask = [&] { auto sendCycleItemTask = [&]
{
for (size_t i = 0; i < buildQueueItems.size(); i++) for (size_t i = 0; i < buildQueueItems.size(); i++)
{ {
BuildQueueItem& item = buildQueueItems[i]; BuildQueueItem& item = buildQueueItems[i];
@ -662,9 +696,13 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
std::unique_lock guard(mtx); std::unique_lock guard(mtx);
// If nothing is ready yet, wait // If nothing is ready yet, wait
cv.wait(guard, [&readyQueueItems] { cv.wait(
guard,
[&readyQueueItems]
{
return !readyQueueItems.empty(); return !readyQueueItems.empty();
}); }
);
// Handle checked items // Handle checked items
for (size_t i : readyQueueItems) for (size_t i : readyQueueItems)
@ -782,7 +820,11 @@ std::optional<CheckResult> Frontend::getCheckResult(const ModuleName& name, bool
} }
bool Frontend::parseGraph( bool Frontend::parseGraph(
std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete, std::function<bool(const ModuleName&)> canSkip) std::vector<ModuleName>& buildQueue,
const ModuleName& root,
bool forAutocomplete,
std::function<bool(const ModuleName&)> canSkip
)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend");
LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); LUAU_TIMETRACE_ARGUMENT("root", root.c_str());
@ -884,8 +926,13 @@ bool Frontend::parseGraph(
return cyclic; return cyclic;
} }
void Frontend::addBuildQueueItems(std::vector<BuildQueueItem>& items, std::vector<ModuleName>& buildQueue, bool cycleDetected, void Frontend::addBuildQueueItems(
DenseHashSet<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions) std::vector<BuildQueueItem>& items,
std::vector<ModuleName>& buildQueue,
bool cycleDetected,
DenseHashSet<Luau::ModuleName>& seen,
const FrontendOptions& frontendOptions
)
{ {
for (const ModuleName& moduleName : buildQueue) for (const ModuleName& moduleName : buildQueue)
{ {
@ -981,8 +1028,15 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
if (item.options.forAutocomplete) if (item.options.forAutocomplete)
{ {
// The autocomplete typecheck is always in strict mode with DM awareness to provide better type information for IDE features // The autocomplete typecheck is always in strict mode with DM awareness to provide better type information for IDE features
ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, ModulePtr moduleForAutocomplete = check(
/*recordJsonLog*/ false, typeCheckLimits); sourceModule,
Mode::Strict,
requireCycles,
environmentScope,
/*forAutocomplete*/ true,
/*recordJsonLog*/ false,
typeCheckLimits
);
double duration = getTimestamp() - timestamp; double duration = getTimestamp() - timestamp;
@ -1209,14 +1263,37 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons
return const_cast<Frontend*>(this)->getSourceModule(moduleName); return const_cast<Frontend*>(this)->getSourceModule(moduleName);
} }
ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes, ModulePtr check(
NotNull<InternalErrorReporter> iceHandler, NotNull<ModuleResolver> moduleResolver, NotNull<FileResolver> fileResolver, const SourceModule& sourceModule,
const ScopePtr& parentScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, FrontendOptions options, Mode mode,
TypeCheckLimits limits, std::function<void(const ModuleName&, std::string)> writeJsonLog) const std::vector<RequireCycle>& requireCycles,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler,
NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver,
const ScopePtr& parentScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options,
TypeCheckLimits limits,
std::function<void(const ModuleName&, std::string)> writeJsonLog
)
{ {
const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson; const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson;
return check(sourceModule, mode, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, parentScope, return check(
std::move(prepareModuleScope), options, limits, recordJsonLog, writeJsonLog); sourceModule,
mode,
requireCycles,
builtinTypes,
iceHandler,
moduleResolver,
fileResolver,
parentScope,
std::move(prepareModuleScope),
options,
limits,
recordJsonLog,
writeJsonLog
);
} }
struct InternalTypeFinder : TypeOnceVisitor struct InternalTypeFinder : TypeOnceVisitor
@ -1263,10 +1340,21 @@ struct InternalTypeFinder : TypeOnceVisitor
} }
}; };
ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes, ModulePtr check(
NotNull<InternalErrorReporter> iceHandler, NotNull<ModuleResolver> moduleResolver, NotNull<FileResolver> fileResolver, const SourceModule& sourceModule,
const ScopePtr& parentScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, FrontendOptions options, Mode mode,
TypeCheckLimits limits, bool recordJsonLog, std::function<void(const ModuleName&, std::string)> writeJsonLog) const std::vector<RequireCycle>& requireCycles,
NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler,
NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver,
const ScopePtr& parentScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options,
TypeCheckLimits limits,
bool recordJsonLog,
std::function<void(const ModuleName&, std::string)> writeJsonLog
)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::check", "Typechecking"); LUAU_TIMETRACE_SCOPE("Frontend::check", "Typechecking");
LUAU_TIMETRACE_ARGUMENT("module", sourceModule.name.c_str()); LUAU_TIMETRACE_ARGUMENT("module", sourceModule.name.c_str());
@ -1300,14 +1388,32 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}};
ConstraintGenerator cg{result, NotNull{&normalizer}, moduleResolver, builtinTypes, iceHandler, parentScope, std::move(prepareModuleScope), ConstraintGenerator cg{
logger.get(), NotNull{&dfg}, requireCycles}; result,
NotNull{&normalizer},
moduleResolver,
builtinTypes,
iceHandler,
parentScope,
std::move(prepareModuleScope),
logger.get(),
NotNull{&dfg},
requireCycles
};
cg.visitModuleRoot(sourceModule.root); cg.visitModuleRoot(sourceModule.root);
result->errors = std::move(cg.errors); result->errors = std::move(cg.errors);
ConstraintSolver cs{NotNull{&normalizer}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), result->name, moduleResolver, requireCycles, ConstraintSolver cs{
logger.get(), limits}; NotNull{&normalizer},
NotNull(cg.rootScope),
borrowConstraints(cg.constraints),
result->name,
moduleResolver,
requireCycles,
logger.get(),
limits
};
if (options.randomizeConstraintResolutionSeed) if (options.randomizeConstraintResolutionSeed)
cs.randomize(*options.randomizeConstraintResolutionSeed); cs.randomize(*options.randomizeConstraintResolutionSeed);
@ -1419,22 +1525,41 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<R
return result; return result;
} }
ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vector<RequireCycle> requireCycles, ModulePtr Frontend::check(
std::optional<ScopePtr> environmentScope, bool forAutocomplete, bool recordJsonLog, TypeCheckLimits typeCheckLimits) const SourceModule& sourceModule,
Mode mode,
std::vector<RequireCycle> requireCycles,
std::optional<ScopePtr> environmentScope,
bool forAutocomplete,
bool recordJsonLog,
TypeCheckLimits typeCheckLimits
)
{ {
if (FFlag::DebugLuauDeferredConstraintResolution) if (FFlag::DebugLuauDeferredConstraintResolution)
{ {
auto prepareModuleScopeWrap = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) { auto prepareModuleScopeWrap = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope)
{
if (prepareModuleScope) if (prepareModuleScope)
prepareModuleScope(name, scope, forAutocomplete); prepareModuleScope(name, scope, forAutocomplete);
}; };
try try
{ {
return Luau::check(sourceModule, mode, requireCycles, builtinTypes, NotNull{&iceHandler}, return Luau::check(
NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, sourceModule,
environmentScope ? *environmentScope : globals.globalScope, prepareModuleScopeWrap, options, typeCheckLimits, recordJsonLog, mode,
writeJsonLog); requireCycles,
builtinTypes,
NotNull{&iceHandler},
NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver},
NotNull{fileResolver},
environmentScope ? *environmentScope : globals.globalScope,
prepareModuleScopeWrap,
options,
typeCheckLimits,
recordJsonLog,
writeJsonLog
);
} }
catch (const InternalCompilerError& err) catch (const InternalCompilerError& err)
{ {
@ -1445,12 +1570,17 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect
} }
else else
{ {
TypeChecker typeChecker(forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope, TypeChecker typeChecker(
forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver, builtinTypes, &iceHandler); forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope,
forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver,
builtinTypes,
&iceHandler
);
if (prepareModuleScope) if (prepareModuleScope)
{ {
typeChecker.prepareModuleScope = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope) { typeChecker.prepareModuleScope = [this, forAutocomplete](const ModuleName& name, const ScopePtr& scope)
{
prepareModuleScope(name, scope, forAutocomplete); prepareModuleScope(name, scope, forAutocomplete);
}; };
} }

View File

@ -26,8 +26,14 @@ struct MutatingGeneralizer : TypeOnceVisitor
bool isWithinFunction = false; bool isWithinFunction = false;
bool avoidSealingTables = false; bool avoidSealingTables = false;
MutatingGeneralizer(NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<DenseHashSet<TypeId>> cachedTypes, MutatingGeneralizer(
DenseHashMap<const void*, size_t> positiveTypes, DenseHashMap<const void*, size_t> negativeTypes, bool avoidSealingTables) NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes,
DenseHashMap<const void*, size_t> positiveTypes,
DenseHashMap<const void*, size_t> negativeTypes,
bool avoidSealingTables
)
: TypeOnceVisitor(/* skipBoundTypes */ true) : TypeOnceVisitor(/* skipBoundTypes */ true)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, scope(scope) , scope(scope)
@ -867,8 +873,14 @@ struct TypeCacher : TypeOnceVisitor
} }
}; };
std::optional<TypeId> generalize(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, std::optional<TypeId> generalize(
NotNull<DenseHashSet<TypeId>> cachedTypes, TypeId ty, bool avoidSealingTables) NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes,
TypeId ty,
bool avoidSealingTables
)
{ {
ty = follow(ty); ty = follow(ty);

View File

@ -102,8 +102,15 @@ TypePackId Instantiation::clean(TypePackId tp)
return tp; return tp;
} }
void ReplaceGenerics::resetState(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope, void ReplaceGenerics::resetState(
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks) const TxnLog* log,
TypeArena* arena,
NotNull<BuiltinTypes> builtinTypes,
TypeLevel level,
Scope* scope,
const std::vector<TypeId>& generics,
const std::vector<TypePackId>& genericPacks
)
{ {
LUAU_ASSERT(FFlag::LuauReusableSubstitutions); LUAU_ASSERT(FFlag::LuauReusableSubstitutions);
@ -187,7 +194,12 @@ TypePackId ReplaceGenerics::clean(TypePackId tp)
} }
std::optional<TypeId> instantiate( std::optional<TypeId> instantiate(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, TypeId ty) NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
TypeId ty
)
{ {
ty = follow(ty); ty = follow(ty);

View File

@ -8,6 +8,23 @@ bool Instantiation2::ignoreChildren(TypeId ty)
{ {
if (get<ClassType>(ty)) if (get<ClassType>(ty))
return true; return true;
if (auto ftv = get<FunctionType>(ty))
{
if (ftv->hasNoFreeOrGenericTypes)
return false;
// If this function type quantifies over these generics, we don't want substitution to
// go any further into them because it's being shadowed in this case.
for (auto generic : ftv->generics)
if (genericSubstitutions.contains(generic))
return true;
for (auto generic : ftv->genericPacks)
if (genericPackSubstitutions.contains(generic))
return true;
}
return false; return false;
} }
@ -47,14 +64,22 @@ TypePackId Instantiation2::clean(TypePackId tp)
} }
std::optional<TypeId> instantiate2( std::optional<TypeId> instantiate2(
TypeArena* arena, DenseHashMap<TypeId, TypeId> genericSubstitutions, DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions, TypeId ty) TypeArena* arena,
DenseHashMap<TypeId, TypeId> genericSubstitutions,
DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions,
TypeId ty
)
{ {
Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)}; Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)};
return instantiation.substitute(ty); return instantiation.substitute(ty);
} }
std::optional<TypePackId> instantiate2( std::optional<TypePackId> instantiate2(
TypeArena* arena, DenseHashMap<TypeId, TypeId> genericSubstitutions, DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions, TypePackId tp) TypeArena* arena,
DenseHashMap<TypeId, TypeId> genericSubstitutions,
DenseHashMap<TypePackId, TypePackId> genericPackSubstitutions,
TypePackId tp
)
{ {
Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)}; Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)};
return instantiation.substitute(tp); return instantiation.substitute(tp);

View File

@ -114,6 +114,8 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "GenericError { " << err.message << " }"; stream << "GenericError { " << err.message << " }";
else if constexpr (std::is_same_v<T, InternalError>) else if constexpr (std::is_same_v<T, InternalError>)
stream << "InternalError { " << err.message << " }"; stream << "InternalError { " << err.message << " }";
else if constexpr (std::is_same_v<T, ConstraintSolvingIncompleteError>)
stream << "ConstraintSolvingIncompleteError {}";
else if constexpr (std::is_same_v<T, CannotCallNonFunction>) else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
stream << "CannotCallNonFunction { " << toString(err.ty) << " }"; stream << "CannotCallNonFunction { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, ExtraInformation>) else if constexpr (std::is_same_v<T, ExtraInformation>)
@ -259,7 +261,8 @@ std::ostream& operator<<(std::ostream& stream, const CannotAssignToNever::Reason
std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data) std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data)
{ {
auto cb = [&](const auto& e) { auto cb = [&](const auto& e)
{
return errorToString(stream, e); return errorToString(stream, e);
}; };
visit(cb, data); visit(cb, data);

View File

@ -275,8 +275,14 @@ private:
else if (g->deprecated) else if (g->deprecated)
{ {
if (const char* replacement = *g->deprecated; replacement && strlen(replacement)) if (const char* replacement = *g->deprecated; replacement && strlen(replacement))
emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", emitWarning(
gv->name.value, replacement); *context,
LintWarning::Code_DeprecatedGlobal,
gv->location,
"Global '%s' is deprecated, use '%s' instead",
gv->name.value,
replacement
);
else else
emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value);
} }
@ -291,18 +297,33 @@ private:
AstExprFunction* top = g.functionRef.back(); AstExprFunction* top = g.functionRef.back();
if (top->debugname.value) if (top->debugname.value)
emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, emitWarning(
"Global '%s' is only used in the enclosing function '%s'; consider changing it to local", g.firstRef->name.value, *context,
top->debugname.value); LintWarning::Code_GlobalUsedAsLocal,
g.firstRef->location,
"Global '%s' is only used in the enclosing function '%s'; consider changing it to local",
g.firstRef->name.value,
top->debugname.value
);
else else
emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, emitWarning(
*context,
LintWarning::Code_GlobalUsedAsLocal,
g.firstRef->location,
"Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local", "Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local",
g.firstRef->name.value, top->location.begin.line + 1); g.firstRef->name.value,
top->location.begin.line + 1
);
} }
else if (g.assigned && !g.readBeforeWritten && !g.definedInModuleScope && g.firstRef->name != context->placeholder) else if (g.assigned && !g.readBeforeWritten && !g.definedInModuleScope && g.firstRef->name != context->placeholder)
{ {
emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, emitWarning(
"Global '%s' is never read before being written. Consider changing it to local", g.firstRef->name.value); *context,
LintWarning::Code_GlobalUsedAsLocal,
g.firstRef->location,
"Global '%s' is never read before being written. Consider changing it to local",
g.firstRef->name.value
);
} }
} }
} }
@ -329,7 +350,8 @@ private:
if (node->name == context->placeholder) if (node->name == context->placeholder)
emitWarning( emitWarning(
*context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"); *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"
);
return true; return true;
} }
@ -338,7 +360,8 @@ private:
{ {
if (node->local->name == context->placeholder) if (node->local->name == context->placeholder)
emitWarning( emitWarning(
*context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"); *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"
);
return true; return true;
} }
@ -366,8 +389,13 @@ private:
} }
if (g.builtin) if (g.builtin)
emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, emitWarning(
"Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); *context,
LintWarning::Code_BuiltinGlobalWrite,
gv->location,
"Built-in global '%s' is overwritten here; consider using a local or changing the name",
gv->name.value
);
else else
g.assigned = true; g.assigned = true;
@ -396,8 +424,13 @@ private:
Global& g = globals[gv->name]; Global& g = globals[gv->name];
if (g.builtin) if (g.builtin)
emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, emitWarning(
"Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); *context,
LintWarning::Code_BuiltinGlobalWrite,
gv->location,
"Built-in global '%s' is overwritten here; consider using a local or changing the name",
gv->name.value
);
else else
{ {
g.assigned = true; g.assigned = true;
@ -565,8 +598,12 @@ private:
if (node->body.data[i - 1]->hasSemicolon) if (node->body.data[i - 1]->hasSemicolon)
continue; continue;
emitWarning(*context, LintWarning::Code_SameLineStatement, location, emitWarning(
"A new statement is on the same line; add semi-colon on previous statement to silence"); *context,
LintWarning::Code_SameLineStatement,
location,
"A new statement is on the same line; add semi-colon on previous statement to silence"
);
lastLine = location.begin.line; lastLine = location.begin.line;
} }
@ -613,7 +650,8 @@ private:
if (location.begin.column <= top.start.begin.column) if (location.begin.column <= top.start.begin.column)
{ {
emitWarning( emitWarning(
*context, LintWarning::Code_MultiLineStatement, location, "Statement spans multiple lines; use indentation to silence"); *context, LintWarning::Code_MultiLineStatement, location, "Statement spans multiple lines; use indentation to silence"
);
top.flagged = true; top.flagged = true;
} }
@ -727,8 +765,14 @@ private:
// don't warn on inter-function shadowing since it is much more fragile wrt refactoring // don't warn on inter-function shadowing since it is much more fragile wrt refactoring
if (shadow->functionDepth == local->functionDepth) if (shadow->functionDepth == local->functionDepth)
emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows previous declaration at line %d", emitWarning(
local->name.value, shadow->location.begin.line + 1); *context,
LintWarning::Code_LocalShadow,
local->location,
"Variable '%s' shadows previous declaration at line %d",
local->name.value,
shadow->location.begin.line + 1
);
} }
else if (Global* global = globals.find(local->name)) else if (Global* global = globals.find(local->name))
{ {
@ -736,8 +780,14 @@ private:
; // there are many builtins with common names like 'table'; some of them are deprecated as well ; // there are many builtins with common names like 'table'; some of them are deprecated as well
else if (global->firstRef) else if (global->firstRef)
{ {
emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows a global variable used at line %d", emitWarning(
local->name.value, global->firstRef->location.begin.line + 1); *context,
LintWarning::Code_LocalShadow,
local->location,
"Variable '%s' shadows a global variable used at line %d",
local->name.value,
global->firstRef->location.begin.line + 1
);
} }
else else
{ {
@ -752,14 +802,21 @@ private:
return; return;
if (info.function) if (info.function)
emitWarning(*context, LintWarning::Code_FunctionUnused, local->location, "Function '%s' is never used; prefix with '_' to silence", emitWarning(
local->name.value); *context,
LintWarning::Code_FunctionUnused,
local->location,
"Function '%s' is never used; prefix with '_' to silence",
local->name.value
);
else if (info.import) else if (info.import)
emitWarning(*context, LintWarning::Code_ImportUnused, local->location, "Import '%s' is never used; prefix with '_' to silence", emitWarning(
local->name.value); *context, LintWarning::Code_ImportUnused, local->location, "Import '%s' is never used; prefix with '_' to silence", local->name.value
);
else else
emitWarning(*context, LintWarning::Code_LocalUnused, local->location, "Variable '%s' is never used; prefix with '_' to silence", emitWarning(
local->name.value); *context, LintWarning::Code_LocalUnused, local->location, "Variable '%s' is never used; prefix with '_' to silence", local->name.value
);
} }
bool isRequireCall(AstExpr* expr) bool isRequireCall(AstExpr* expr)
@ -913,8 +970,13 @@ private:
for (auto& g : globals) for (auto& g : globals)
{ {
if (g.second.function && !g.second.used && g.first.value[0] != '_') if (g.second.function && !g.second.used && g.first.value[0] != '_')
emitWarning(*context, LintWarning::Code_FunctionUnused, g.second.location, "Function '%s' is never used; prefix with '_' to silence", emitWarning(
g.first.value); *context,
LintWarning::Code_FunctionUnused,
g.second.location,
"Function '%s' is never used; prefix with '_' to silence",
g.first.value
);
} }
} }
@ -1013,8 +1075,13 @@ private:
if (step == Error && si->is<AstStatExpr>() && next->is<AstStatReturn>() && i + 2 == stat->body.size) if (step == Error && si->is<AstStatExpr>() && next->is<AstStatReturn>() && i + 2 == stat->body.size)
return Error; return Error;
emitWarning(*context, LintWarning::Code_UnreachableCode, next->location, "Unreachable code (previous statement always %ss)", emitWarning(
getReason(step)); *context,
LintWarning::Code_UnreachableCode,
next->location,
"Unreachable code (previous statement always %ss)",
getReason(step)
);
return step; return step;
} }
} }
@ -1209,22 +1276,34 @@ private:
// for i=#t,1 do // for i=#t,1 do
if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 1.0) if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 1.0)
emitWarning( emitWarning(
*context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"); *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"
);
// for i=8,1 do // for i=8,1 do
else if (fc && tc && fc->value > tc->value) else if (fc && tc && fc->value > tc->value)
emitWarning( emitWarning(
*context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"); *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"
);
// for i=1,8.75 do // for i=1,8.75 do
else if (fc && tc && getLoopEnd(fc->value, tc->value) != tc->value) else if (fc && tc && getLoopEnd(fc->value, tc->value) != tc->value)
emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop ends at %g instead of %g; did you forget to specify step?", emitWarning(
getLoopEnd(fc->value, tc->value), tc->value); *context,
LintWarning::Code_ForRange,
rangeLocation,
"For loop ends at %g instead of %g; did you forget to specify step?",
getLoopEnd(fc->value, tc->value),
tc->value
);
// for i=0,#t do // for i=0,#t do
else if (fc && tu && fc->value == 0.0 && tu->op == AstExprUnary::Len) else if (fc && tu && fc->value == 0.0 && tu->op == AstExprUnary::Len)
emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop starts at 0, but arrays start at 1"); emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop starts at 0, but arrays start at 1");
// for i=#t,0 do // for i=#t,0 do
else if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 0.0) else if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 0.0)
emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, emitWarning(
"For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1"); *context,
LintWarning::Code_ForRange,
rangeLocation,
"For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1"
);
} }
return true; return true;
@ -1252,16 +1331,27 @@ private:
AstExpr* last = values.data[values.size - 1]; AstExpr* last = values.data[values.size - 1];
if (vars < values.size) if (vars < values.size)
emitWarning(*context, LintWarning::Code_UnbalancedAssignment, location, emitWarning(
"Assigning %d values to %d variables leaves some values unused", int(values.size), int(vars)); *context,
LintWarning::Code_UnbalancedAssignment,
location,
"Assigning %d values to %d variables leaves some values unused",
int(values.size),
int(vars)
);
else if (last->is<AstExprCall>() || last->is<AstExprVarargs>()) else if (last->is<AstExprCall>() || last->is<AstExprVarargs>())
; // we don't know how many values the last expression returns ; // we don't know how many values the last expression returns
else if (last->is<AstExprConstantNil>()) else if (last->is<AstExprConstantNil>())
; // last expression is nil which explicitly silences the nil-init warning ; // last expression is nil which explicitly silences the nil-init warning
else else
emitWarning(*context, LintWarning::Code_UnbalancedAssignment, location, emitWarning(
"Assigning %d values to %d variables initializes extra variables with nil; add 'nil' to value list to silence", int(values.size), *context,
int(vars)); LintWarning::Code_UnbalancedAssignment,
location,
"Assigning %d values to %d variables initializes extra variables with nil; add 'nil' to value list to silence",
int(values.size),
int(vars)
);
} }
} }
@ -1344,13 +1434,22 @@ private:
Location location = getEndLocation(bodyf); Location location = getEndLocation(bodyf);
if (node->debugname.value) if (node->debugname.value)
emitWarning(*context, LintWarning::Code_ImplicitReturn, location, emitWarning(
*context,
LintWarning::Code_ImplicitReturn,
location,
"Function '%s' can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence", "Function '%s' can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence",
node->debugname.value, vret->location.begin.line + 1); node->debugname.value,
vret->location.begin.line + 1
);
else else
emitWarning(*context, LintWarning::Code_ImplicitReturn, location, emitWarning(
*context,
LintWarning::Code_ImplicitReturn,
location,
"Function can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence", "Function can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence",
vret->location.begin.line + 1); vret->location.begin.line + 1
);
} }
return true; return true;
@ -1821,23 +1920,41 @@ private:
int& line = names[&expr->value]; int& line = names[&expr->value];
if (line) if (line)
emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, emitWarning(
"Table field '%.*s' is a duplicate; previously defined at line %d", int(expr->value.size), expr->value.data, line); *context,
LintWarning::Code_TableLiteral,
expr->location,
"Table field '%.*s' is a duplicate; previously defined at line %d",
int(expr->value.size),
expr->value.data,
line
);
else else
line = expr->location.begin.line + 1; line = expr->location.begin.line + 1;
} }
else if (AstExprConstantNumber* expr = item.key->as<AstExprConstantNumber>()) else if (AstExprConstantNumber* expr = item.key->as<AstExprConstantNumber>())
{ {
if (expr->value >= 1 && expr->value <= double(count) && double(int(expr->value)) == expr->value) if (expr->value >= 1 && expr->value <= double(count) && double(int(expr->value)) == expr->value)
emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, emitWarning(
"Table index %d is a duplicate; previously defined as a list entry", int(expr->value)); *context,
LintWarning::Code_TableLiteral,
expr->location,
"Table index %d is a duplicate; previously defined as a list entry",
int(expr->value)
);
else if (expr->value >= 0 && expr->value <= double(INT_MAX) && double(int(expr->value)) == expr->value) else if (expr->value >= 0 && expr->value <= double(INT_MAX) && double(int(expr->value)) == expr->value)
{ {
int& line = indices[int(expr->value)]; int& line = indices[int(expr->value)];
if (line) if (line)
emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, emitWarning(
"Table index %d is a duplicate; previously defined at line %d", int(expr->value), line); *context,
LintWarning::Code_TableLiteral,
expr->location,
"Table index %d is a duplicate; previously defined at line %d",
int(expr->value),
line
);
else else
line = expr->location.begin.line + 1; line = expr->location.begin.line + 1;
} }
@ -1875,18 +1992,41 @@ private:
if (int(rec->access) & int(item.access)) if (int(rec->access) & int(item.access))
{ {
if (rec->access == item.access) if (rec->access == item.access)
emitWarning(*context, LintWarning::Code_TableLiteral, item.location, emitWarning(
"Table type field '%s' is a duplicate; previously defined at line %d", item.name.value, rec->location.begin.line + 1); *context,
LintWarning::Code_TableLiteral,
item.location,
"Table type field '%s' is a duplicate; previously defined at line %d",
item.name.value,
rec->location.begin.line + 1
);
else if (rec->access == AstTableAccess::ReadWrite) else if (rec->access == AstTableAccess::ReadWrite)
emitWarning(*context, LintWarning::Code_TableLiteral, item.location, emitWarning(
"Table type field '%s' is already read-write; previously defined at line %d", item.name.value, *context,
rec->location.begin.line + 1); LintWarning::Code_TableLiteral,
item.location,
"Table type field '%s' is already read-write; previously defined at line %d",
item.name.value,
rec->location.begin.line + 1
);
else if (rec->access == AstTableAccess::Read) else if (rec->access == AstTableAccess::Read)
emitWarning(*context, LintWarning::Code_TableLiteral, rec->location, emitWarning(
"Table type field '%s' already has a read type defined at line %d", item.name.value, rec->location.begin.line + 1); *context,
LintWarning::Code_TableLiteral,
rec->location,
"Table type field '%s' already has a read type defined at line %d",
item.name.value,
rec->location.begin.line + 1
);
else if (rec->access == AstTableAccess::Write) else if (rec->access == AstTableAccess::Write)
emitWarning(*context, LintWarning::Code_TableLiteral, rec->location, emitWarning(
"Table type field '%s' already has a write type defined at line %d", item.name.value, rec->location.begin.line + 1); *context,
LintWarning::Code_TableLiteral,
rec->location,
"Table type field '%s' already has a write type defined at line %d",
item.name.value,
rec->location.begin.line + 1
);
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
} }
@ -1904,8 +2044,14 @@ private:
int& line = names[item.name]; int& line = names[item.name];
if (line) if (line)
emitWarning(*context, LintWarning::Code_TableLiteral, item.location, emitWarning(
"Table type field '%s' is a duplicate; previously defined at line %d", item.name.value, line); *context,
LintWarning::Code_TableLiteral,
item.location,
"Table type field '%s' is a duplicate; previously defined at line %d",
item.name.value,
line
);
else else
line = item.location.begin.line + 1; line = item.location.begin.line + 1;
} }
@ -1966,9 +2112,14 @@ private:
if (l.defined && !l.initialized && !l.assigned && l.firstUse) if (l.defined && !l.initialized && !l.assigned && l.firstUse)
{ {
emitWarning(*context, LintWarning::Code_UninitializedLocal, l.firstUse->location, emitWarning(
"Variable '%s' defined at line %d is never initialized or assigned; initialize with 'nil' to silence", local->name.value, *context,
local->location.begin.line + 1); LintWarning::Code_UninitializedLocal,
l.firstUse->location,
"Variable '%s' defined at line %d is never initialized or assigned; initialize with 'nil' to silence",
local->name.value,
local->location.begin.line + 1
);
} }
} }
} }
@ -2102,8 +2253,14 @@ private:
void report(const std::string& name, Location location, Location otherLocation) void report(const std::string& name, Location location, Location otherLocation)
{ {
emitWarning(*context, LintWarning::Code_DuplicateFunction, location, "Duplicate function definition: '%s' also defined on line %d", emitWarning(
name.c_str(), otherLocation.begin.line + 1); *context,
LintWarning::Code_DuplicateFunction,
location,
"Duplicate function definition: '%s' also defined on line %d",
name.c_str(),
otherLocation.begin.line + 1
);
} }
}; };
@ -2152,7 +2309,8 @@ private:
const char* suggestion = (fenv->name == "getfenv") ? "; consider using 'debug.info' instead" : ""; const char* suggestion = (fenv->name == "getfenv") ? "; consider using 'debug.info' instead" : "";
emitWarning( emitWarning(
*context, LintWarning::Code_DeprecatedApi, node->location, "Function '%s' is deprecated%s", fenv->name.value, suggestion); *context, LintWarning::Code_DeprecatedApi, node->location, "Function '%s' is deprecated%s", fenv->name.value, suggestion
);
} }
} }
} }
@ -2265,7 +2423,8 @@ private:
if (!tty->indexer && !tty->props.empty() && tty->state != TableState::Generic) if (!tty->indexer && !tty->props.empty() && tty->state != TableState::Generic)
emitWarning( emitWarning(
*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table without an array part is likely a bug", op); *context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table without an array part is likely a bug", op
);
else if (tty->indexer && isString(tty->indexer->indexType)) // note: to avoid complexity of subtype tests we just check if the key is a string else if (tty->indexer && isString(tty->indexer->indexType)) // note: to avoid complexity of subtype tests we just check if the key is a string
emitWarning(*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table with string keys is likely a bug", op); emitWarning(*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table with string keys is likely a bug", op);
} }
@ -2283,9 +2442,13 @@ private:
size_t ret = getReturnCount(follow(*funty)); size_t ret = getReturnCount(follow(*funty));
if (ret > 1) if (ret > 1)
emitWarning(*context, LintWarning::Code_TableOperations, tail->location, emitWarning(
*context,
LintWarning::Code_TableOperations,
tail->location,
"table.insert may change behavior if the call returns more than one result; consider adding parentheses around second " "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second "
"argument"); "argument"
);
} }
} }
} }
@ -2294,28 +2457,44 @@ private:
{ {
// table.insert(t, 0, ?) // table.insert(t, 0, ?)
if (isConstant(args[1], 0.0)) if (isConstant(args[1], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, emitWarning(
"table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?"); *context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?"
);
// table.insert(t, #t, ?) // table.insert(t, #t, ?)
if (isLength(args[1], args[0])) if (isLength(args[1], args[0]))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, emitWarning(
*context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.insert will insert the value before the last element, which is likely a bug; consider removing the second argument or " "table.insert will insert the value before the last element, which is likely a bug; consider removing the second argument or "
"wrap it in parentheses to silence"); "wrap it in parentheses to silence"
);
// table.insert(t, #t+1, ?) // table.insert(t, #t+1, ?)
if (AstExprBinary* add = args[1]->as<AstExprBinary>(); if (AstExprBinary* add = args[1]->as<AstExprBinary>();
add && add->op == AstExprBinary::Add && isLength(add->left, args[0]) && isConstant(add->right, 1.0)) add && add->op == AstExprBinary::Add && isLength(add->left, args[0]) && isConstant(add->right, 1.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, emitWarning(
"table.insert will append the value to the table; consider removing the second argument for efficiency"); *context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.insert will append the value to the table; consider removing the second argument for efficiency"
);
} }
if (func->index == "remove" && node->args.size >= 2) if (func->index == "remove" && node->args.size >= 2)
{ {
// table.remove(t, 0) // table.remove(t, 0)
if (isConstant(args[1], 0.0)) if (isConstant(args[1], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, emitWarning(
"table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?"); *context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?"
);
// note: it's tempting to check for table.remove(t, #t), which is equivalent to table.remove(t), but it's correct, occurs frequently, // note: it's tempting to check for table.remove(t, #t), which is equivalent to table.remove(t), but it's correct, occurs frequently,
// and also reads better. // and also reads better.
@ -2323,35 +2502,55 @@ private:
// table.remove(t, #t-1) // table.remove(t, #t-1)
if (AstExprBinary* sub = args[1]->as<AstExprBinary>(); if (AstExprBinary* sub = args[1]->as<AstExprBinary>();
sub && sub->op == AstExprBinary::Sub && isLength(sub->left, args[0]) && isConstant(sub->right, 1.0)) sub && sub->op == AstExprBinary::Sub && isLength(sub->left, args[0]) && isConstant(sub->right, 1.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, emitWarning(
*context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.remove will remove the value before the last element, which is likely a bug; consider removing the second argument or " "table.remove will remove the value before the last element, which is likely a bug; consider removing the second argument or "
"wrap it in parentheses to silence"); "wrap it in parentheses to silence"
);
} }
if (func->index == "move" && node->args.size >= 4) if (func->index == "move" && node->args.size >= 4)
{ {
// table.move(t, 0, _, _) // table.move(t, 0, _, _)
if (isConstant(args[1], 0.0)) if (isConstant(args[1], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, emitWarning(
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); *context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"
);
// table.move(t, _, _, 0) // table.move(t, _, _, 0)
else if (isConstant(args[3], 0.0)) else if (isConstant(args[3], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location, emitWarning(
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); *context,
LintWarning::Code_TableOperations,
args[3]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"
);
} }
if (func->index == "create" && node->args.size == 2) if (func->index == "create" && node->args.size == 2)
{ {
// table.create(n, {...}) // table.create(n, {...})
if (args[1]->is<AstExprTable>()) if (args[1]->is<AstExprTable>())
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, emitWarning(
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); *context,
LintWarning::Code_TableOperations,
args[1]->location,
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"
);
// table.create(n, {...} :: ?) // table.create(n, {...} :: ?)
if (AstExprTypeAssertion* as = args[1]->as<AstExprTypeAssertion>(); as && as->expr->is<AstExprTable>()) if (AstExprTypeAssertion* as = args[1]->as<AstExprTypeAssertion>(); as && as->expr->is<AstExprTable>())
emitWarning(*context, LintWarning::Code_TableOperations, as->expr->location, emitWarning(
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); *context,
LintWarning::Code_TableOperations,
as->expr->location,
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"
);
} }
} }
@ -2543,11 +2742,21 @@ private:
if (similar(conditions[j], conditions[i])) if (similar(conditions[j], conditions[i]))
{ {
if (conditions[i]->location.begin.line == conditions[j]->location.begin.line) if (conditions[i]->location.begin.line == conditions[j]->location.begin.line)
emitWarning(*context, LintWarning::Code_DuplicateCondition, conditions[i]->location, emitWarning(
"Condition has already been checked on column %d", conditions[j]->location.begin.column + 1); *context,
LintWarning::Code_DuplicateCondition,
conditions[i]->location,
"Condition has already been checked on column %d",
conditions[j]->location.begin.column + 1
);
else else
emitWarning(*context, LintWarning::Code_DuplicateCondition, conditions[i]->location, emitWarning(
"Condition has already been checked on line %d", conditions[j]->location.begin.line + 1); *context,
LintWarning::Code_DuplicateCondition,
conditions[i]->location,
"Condition has already been checked on line %d",
conditions[j]->location.begin.line + 1
);
break; break;
} }
} }
@ -2592,11 +2801,23 @@ private:
if (local->shadow && locals[local->shadow] == node && !ignoreDuplicate(local)) if (local->shadow && locals[local->shadow] == node && !ignoreDuplicate(local))
{ {
if (local->shadow->location.begin.line == local->location.begin.line) if (local->shadow->location.begin.line == local->location.begin.line)
emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Variable '%s' already defined on column %d", emitWarning(
local->name.value, local->shadow->location.begin.column + 1); *context,
LintWarning::Code_DuplicateLocal,
local->location,
"Variable '%s' already defined on column %d",
local->name.value,
local->shadow->location.begin.column + 1
);
else else
emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Variable '%s' already defined on line %d", emitWarning(
local->name.value, local->shadow->location.begin.line + 1); *context,
LintWarning::Code_DuplicateLocal,
local->location,
"Variable '%s' already defined on line %d",
local->name.value,
local->shadow->location.begin.line + 1
);
} }
} }
@ -2620,11 +2841,23 @@ private:
if (local->shadow == node->self) if (local->shadow == node->self)
emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter 'self' already defined implicitly"); emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter 'self' already defined implicitly");
else if (local->shadow->location.begin.line == local->location.begin.line) else if (local->shadow->location.begin.line == local->location.begin.line)
emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter '%s' already defined on column %d", emitWarning(
local->name.value, local->shadow->location.begin.column + 1); *context,
LintWarning::Code_DuplicateLocal,
local->location,
"Function parameter '%s' already defined on column %d",
local->name.value,
local->shadow->location.begin.column + 1
);
else else
emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter '%s' already defined on line %d", emitWarning(
local->name.value, local->shadow->location.begin.line + 1); *context,
LintWarning::Code_DuplicateLocal,
local->location,
"Function parameter '%s' already defined on line %d",
local->name.value,
local->shadow->location.begin.line + 1
);
} }
} }
@ -2668,10 +2901,14 @@ private:
alt = "false"; alt = "false";
if (alt) if (alt)
emitWarning(*context, LintWarning::Code_MisleadingAndOr, node->location, emitWarning(
*context,
LintWarning::Code_MisleadingAndOr,
node->location,
"The and-or expression always evaluates to the second alternative because the first alternative is %s; consider using if-then-else " "The and-or expression always evaluates to the second alternative because the first alternative is %s; consider using if-then-else "
"expression instead", "expression instead",
alt); alt
);
return true; return true;
} }
@ -2699,16 +2936,28 @@ private:
case ConstantNumberParseResult::Malformed: case ConstantNumberParseResult::Malformed:
break; break;
case ConstantNumberParseResult::Imprecise: case ConstantNumberParseResult::Imprecise:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, emitWarning(
"Number literal exceeded available precision and was truncated to closest representable number"); *context,
LintWarning::Code_IntegerParsing,
node->location,
"Number literal exceeded available precision and was truncated to closest representable number"
);
break; break;
case ConstantNumberParseResult::BinOverflow: case ConstantNumberParseResult::BinOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, emitWarning(
"Binary number literal exceeded available precision and was truncated to 2^64"); *context,
LintWarning::Code_IntegerParsing,
node->location,
"Binary number literal exceeded available precision and was truncated to 2^64"
);
break; break;
case ConstantNumberParseResult::HexOverflow: case ConstantNumberParseResult::HexOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, emitWarning(
"Hexadecimal number literal exceeded available precision and was truncated to 2^64"); *context,
LintWarning::Code_IntegerParsing,
node->location,
"Hexadecimal number literal exceeded available precision and was truncated to 2^64"
);
break; break;
} }
@ -2759,12 +3008,24 @@ private:
std::string op = toString(node->op); std::string op = toString(node->op);
if (isEquality(node->op)) if (isEquality(node->op))
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, emitWarning(
"not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence", op.c_str(), op.c_str(), *context,
node->op == AstExprBinary::CompareEq ? "~=" : "=="); LintWarning::Code_ComparisonPrecedence,
node->location,
"not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence",
op.c_str(),
op.c_str(),
node->op == AstExprBinary::CompareEq ? "~=" : "=="
);
else else
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, emitWarning(
"not X %s Y is equivalent to (not X) %s Y; add parentheses to silence", op.c_str(), op.c_str()); *context,
LintWarning::Code_ComparisonPrecedence,
node->location,
"not X %s Y is equivalent to (not X) %s Y; add parentheses to silence",
op.c_str(),
op.c_str()
);
} }
else if (AstExprBinary* left = node->left->as<AstExprBinary>(); left && isComparison(left->op)) else if (AstExprBinary* left = node->left->as<AstExprBinary>(); left && isComparison(left->op))
{ {
@ -2772,12 +3033,29 @@ private:
std::string rop = toString(node->op); std::string rop = toString(node->op);
if (isEquality(left->op) || isEquality(node->op)) if (isEquality(left->op) || isEquality(node->op))
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, emitWarning(
"X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str()); *context,
LintWarning::Code_ComparisonPrecedence,
node->location,
"X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence",
lop.c_str(),
rop.c_str(),
lop.c_str(),
rop.c_str()
);
else else
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, emitWarning(
"X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str(), *context,
lop.c_str(), rop.c_str()); LintWarning::Code_ComparisonPrecedence,
node->location,
"X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?",
lop.c_str(),
rop.c_str(),
lop.c_str(),
rop.c_str(),
lop.c_str(),
rop.c_str()
);
} }
return true; return true;
@ -2843,8 +3121,12 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
if (!hc.header) if (!hc.header)
{ {
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, emitWarning(
"Comment directive is ignored because it is placed after the first non-comment token"); context,
LintWarning::Code_CommentDirective,
hc.location,
"Comment directive is ignored because it is placed after the first non-comment token"
);
} }
else else
{ {
@ -2865,21 +3147,36 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
// skip Unknown // skip Unknown
if (const char* suggestion = fuzzyMatch(rule, kWarningNames + 1, LintWarning::Code__Count - 1)) if (const char* suggestion = fuzzyMatch(rule, kWarningNames + 1, LintWarning::Code__Count - 1))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, emitWarning(
"nolint directive refers to unknown lint rule '%s'; did you mean '%s'?", rule, suggestion); context,
LintWarning::Code_CommentDirective,
hc.location,
"nolint directive refers to unknown lint rule '%s'; did you mean '%s'?",
rule,
suggestion
);
else else
emitWarning( emitWarning(
context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule); context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule
);
} }
} }
else if (first == "nocheck" || first == "nonstrict" || first == "strict") else if (first == "nocheck" || first == "nonstrict" || first == "strict")
{ {
if (space != std::string::npos) if (space != std::string::npos)
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, emitWarning(
"Comment directive with the type checking mode has extra symbols at the end of the line"); context,
LintWarning::Code_CommentDirective,
hc.location,
"Comment directive with the type checking mode has extra symbols at the end of the line"
);
else if (seenMode) else if (seenMode)
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, emitWarning(
"Comment directive with the type checking mode has already been used"); context,
LintWarning::Code_CommentDirective,
hc.location,
"Comment directive with the type checking mode has already been used"
);
else else
seenMode = true; seenMode = true;
} }
@ -2894,15 +3191,21 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
const char* level = hc.content.c_str() + notspace; const char* level = hc.content.c_str() + notspace;
if (strcmp(level, "0") && strcmp(level, "1") && strcmp(level, "2")) if (strcmp(level, "0") && strcmp(level, "1") && strcmp(level, "2"))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, emitWarning(
"optimize directive uses unknown optimization level '%s', 0..2 expected", level); context,
LintWarning::Code_CommentDirective,
hc.location,
"optimize directive uses unknown optimization level '%s', 0..2 expected",
level
);
} }
} }
else if (first == "native") else if (first == "native")
{ {
if (space != std::string::npos) if (space != std::string::npos)
emitWarning( emitWarning(
context, LintWarning::Code_CommentDirective, hc.location, "native directive has extra symbols at the end of the line"); context, LintWarning::Code_CommentDirective, hc.location, "native directive has extra symbols at the end of the line"
);
} }
else else
{ {
@ -2916,11 +3219,19 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
}; };
if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments))) if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments)))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'; did you mean '%s'?", emitWarning(
int(first.size()), first.data(), suggestion); context,
LintWarning::Code_CommentDirective,
hc.location,
"Unknown comment directive '%.*s'; did you mean '%s'?",
int(first.size()),
first.data(),
suggestion
);
else else
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()), emitWarning(
first.data()); context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()), first.data()
);
} }
} }
} }
@ -2973,8 +3284,12 @@ private:
{ {
if (attribute->type == AstAttr::Type::Native) if (attribute->type == AstAttr::Type::Native)
{ {
emitWarning(*context, LintWarning::Code_RedundantNativeAttribute, attribute->location, emitWarning(
"native attribute on a function is redundant in a native module; consider removing it"); *context,
LintWarning::Code_RedundantNativeAttribute,
attribute->location,
"native attribute on a function is redundant in a native module; consider removing it"
);
} }
} }
@ -2982,8 +3297,14 @@ private:
} }
}; };
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, std::vector<LintWarning> lint(
const std::vector<HotComment>& hotcomments, const LintOptions& options) AstStat* root,
const AstNameTable& names,
const ScopePtr& env,
const Module* module,
const std::vector<HotComment>& hotcomments,
const LintOptions& options
)
{ {
LintContext context; LintContext context;
@ -3068,8 +3389,7 @@ std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const Sc
if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence)) if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence))
LintComparisonPrecedence::process(context); LintComparisonPrecedence::process(context);
if (FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute && if (FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute && context.warningEnabled(LintWarning::Code_RedundantNativeAttribute))
context.warningEnabled(LintWarning::Code_RedundantNativeAttribute))
{ {
if (hasNativeCommentDirective(hotcomments)) if (hasNativeCommentDirective(hotcomments))
LintRedundantNativeAttribute::process(context); LintRedundantNativeAttribute::process(context);

View File

@ -24,8 +24,8 @@ static bool contains(Position pos, Comment comment)
{ {
if (comment.location.contains(pos)) if (comment.location.contains(pos))
return true; return true;
else if (comment.type == Lexeme::BrokenComment && else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't
comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end // have an end
return true; return true;
else if (comment.type == Lexeme::Comment && comment.location.end == pos) else if (comment.type == Lexeme::Comment && comment.location.end == pos)
return true; return true;
@ -36,9 +36,14 @@ static bool contains(Position pos, Comment comment)
static bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos) static bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos)
{ {
auto iter = std::lower_bound( auto iter = std::lower_bound(
commentLocations.begin(), commentLocations.end(), Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) { commentLocations.begin(),
commentLocations.end(),
Comment{Lexeme::Comment, Location{pos, pos}},
[](const Comment& a, const Comment& b)
{
return a.location.end < b.location.end; return a.location.end < b.location.end;
}); }
);
if (iter == commentLocations.end()) if (iter == commentLocations.end())
return false; return false;

View File

@ -69,7 +69,11 @@ struct NonStrictContext
NonStrictContext& operator=(NonStrictContext&&) = default; NonStrictContext& operator=(NonStrictContext&&) = default;
static NonStrictContext disjunction( static NonStrictContext disjunction(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right) NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
const NonStrictContext& left,
const NonStrictContext& right
)
{ {
// disjunction implements union over the domain of keys // disjunction implements union over the domain of keys
// if the default value for a defId not in the map is `never` // if the default value for a defId not in the map is `never`
@ -94,7 +98,11 @@ struct NonStrictContext
} }
static NonStrictContext conjunction( static NonStrictContext conjunction(
NotNull<BuiltinTypes> builtins, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right) NotNull<BuiltinTypes> builtins,
NotNull<TypeArena> arena,
const NonStrictContext& left,
const NonStrictContext& right
)
{ {
NonStrictContext conj{}; NonStrictContext conj{};
@ -160,8 +168,15 @@ struct NonStrictTypeChecker
const NotNull<TypeCheckLimits> limits; const NotNull<TypeCheckLimits> limits;
NonStrictTypeChecker(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, const NotNull<InternalErrorReporter> ice, NonStrictTypeChecker(
NotNull<UnifierSharedState> unifierState, NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, Module* module) NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
const NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg,
NotNull<TypeCheckLimits> limits,
Module* module
)
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, ice(ice) , ice(ice)
, arena(arena) , arena(arena)
@ -213,7 +228,8 @@ struct NonStrictTypeChecker
return instance; return instance;
ErrorVec errors = ErrorVec errors =
reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true).errors; reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true)
.errors;
if (errors.empty()) if (errors.empty())
noTypeFunctionErrors.insert(instance); noTypeFunctionErrors.insert(instance);
@ -271,6 +287,8 @@ struct NonStrictTypeChecker
return visit(s); return visit(s);
else if (auto s = stat->as<AstStatTypeAlias>()) else if (auto s = stat->as<AstStatTypeAlias>())
return visit(s); return visit(s);
else if (auto f = stat->as<AstStatTypeFunction>())
return visit(f);
else if (auto s = stat->as<AstStatDeclareFunction>()) else if (auto s = stat->as<AstStatDeclareFunction>())
return visit(s); return visit(s);
else if (auto s = stat->as<AstStatDeclareGlobal>()) else if (auto s = stat->as<AstStatDeclareGlobal>())
@ -395,6 +413,12 @@ struct NonStrictTypeChecker
return {}; return {};
} }
NonStrictContext visit(AstStatTypeFunction* typeFunc)
{
reportError(GenericError{"This syntax is not supported"}, typeFunc->location);
return {};
}
NonStrictContext visit(AstStatDeclareFunction* declFn) NonStrictContext visit(AstStatDeclareFunction* declFn)
{ {
return {}; return {};
@ -726,8 +750,15 @@ private:
}; };
}; };
void checkNonStrict(NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState, void checkNonStrict(
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, const SourceModule& sourceModule, Module* module) NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice,
NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg,
NotNull<TypeCheckLimits> limits,
const SourceModule& sourceModule,
Module* module
)
{ {
LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking"); LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking");

View File

@ -159,10 +159,15 @@ size_t TypeIds::getHash() const
bool TypeIds::isNever() const bool TypeIds::isNever() const
{ {
return std::all_of(begin(), end(), [&](TypeId i) { return std::all_of(
begin(),
end(),
[&](TypeId i)
{
// If each typeid is never, then I guess typeid's is also never? // If each typeid is never, then I guess typeid's is also never?
return get<NeverType>(i) != nullptr; return get<NeverType>(i) != nullptr;
}); }
);
} }
bool TypeIds::operator==(const TypeIds& there) const bool TypeIds::operator==(const TypeIds& there) const
@ -371,10 +376,15 @@ bool NormalizedType::shouldSuppressErrors() const
bool NormalizedType::hasTopTable() const bool NormalizedType::hasTopTable() const
{ {
return hasTables() && std::any_of(tables.begin(), tables.end(), [&](TypeId ty) { return hasTables() && std::any_of(
tables.begin(),
tables.end(),
[&](TypeId ty)
{
auto primTy = get<PrimitiveType>(ty); auto primTy = get<PrimitiveType>(ty);
return primTy && primTy->type == PrimitiveType::Type::Table; return primTy && primTy->type == PrimitiveType::Type::Table;
}); }
);
} }
bool NormalizedType::hasTops() const bool NormalizedType::hasTops() const
@ -806,7 +816,8 @@ static bool areNormalizedClasses(const NormalizedClassType& tys)
if (isSubclass(ctv, octv)) if (isSubclass(ctv, octv))
{ {
auto iss = [ctv](TypeId t) { auto iss = [ctv](TypeId t)
{
const ClassType* c = get<ClassType>(t); const ClassType* c = get<ClassType>(t);
if (!c) if (!c)
return false; return false;
@ -970,7 +981,6 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector<TypeId>
NormalizedType norm{builtinTypes}; NormalizedType norm{builtinTypes};
norm.tops = builtinTypes->anyType; norm.tops = builtinTypes->anyType;
// Now we need to intersect the two types // Now we need to intersect the two types
Set<TypeId> seenSetTypes{nullptr};
for (auto ty : intersections) for (auto ty : intersections)
{ {
NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet); NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet);
@ -1417,8 +1427,9 @@ std::optional<TypePackId> Normalizer::unionOfTypePacks(TypePackId here, TypePack
itt++; itt++;
} }
auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, auto dealWithDifferentArities =
bool& thereSubHere) { [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere)
{
if (ith != end(here)) if (ith != end(here))
{ {
TypeId tty = builtinTypes->nilType; TypeId tty = builtinTypes->nilType;
@ -1803,8 +1814,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t
} }
else if (get<UnknownType>(here.tops)) else if (get<UnknownType>(here.tops))
return NormalizationResult::True; return NormalizationResult::True;
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there))
get<TypeFunctionInstanceType>(there))
{ {
if (tyvarIndex(there) <= ignoreSmallerTyvars) if (tyvarIndex(there) <= ignoreSmallerTyvars)
return NormalizationResult::True; return NormalizationResult::True;
@ -2379,8 +2389,9 @@ std::optional<TypePackId> Normalizer::intersectionOfTypePacks(TypePackId here, T
itt++; itt++;
} }
auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, auto dealWithDifferentArities =
bool& thereSubHere) { [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere)
{
if (ith != end(here)) if (ith != end(here))
{ {
TypeId tty = builtinTypes->nilType; TypeId tty = builtinTypes->nilType;
@ -2570,7 +2581,7 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
} }
} }
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenSet); NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy);
// Cleanup // Cleanup
if (fixCyclicTablesBlowingStack()) if (fixCyclicTablesBlowingStack())
@ -3088,8 +3099,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
} }
return NormalizationResult::True; return NormalizationResult::True;
} }
else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || else if (get<GenericType>(there) || get<FreeType>(there) || get<BlockedType>(there) || get<PendingExpansionType>(there) || get<TypeFunctionInstanceType>(there))
get<TypeFunctionInstanceType>(there))
{ {
NormalizedType thereNorm{builtinTypes}; NormalizedType thereNorm{builtinTypes};
NormalizedType topNorm{builtinTypes}; NormalizedType topNorm{builtinTypes};
@ -3441,7 +3451,12 @@ bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, Not
} }
bool isConsistentSubtype( bool isConsistentSubtype(
TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice) TypePackId subPack,
TypePackId superPack,
NotNull<Scope> scope,
NotNull<BuiltinTypes> builtinTypes,
InternalErrorReporter& ice
)
{ {
LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution);

View File

@ -13,8 +13,15 @@
namespace Luau namespace Luau
{ {
OverloadResolver::OverloadResolver(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<Normalizer> normalizer, NotNull<Scope> scope, OverloadResolver::OverloadResolver(
NotNull<InternalErrorReporter> reporter, NotNull<TypeCheckLimits> limits, Location callLocation) NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits,
Location callLocation
)
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, arena(arena) , arena(arena)
, normalizer(normalizer) , normalizer(normalizer)
@ -28,10 +35,15 @@ OverloadResolver::OverloadResolver(NotNull<BuiltinTypes> builtinTypes, NotNull<T
std::pair<OverloadResolver::Analysis, TypeId> OverloadResolver::selectOverload(TypeId ty, TypePackId argsPack) std::pair<OverloadResolver::Analysis, TypeId> OverloadResolver::selectOverload(TypeId ty, TypePackId argsPack)
{ {
auto tryOne = [&](TypeId f) { auto tryOne = [&](TypeId f)
{
if (auto ftv = get<FunctionType>(f)) if (auto ftv = get<FunctionType>(f))
{ {
Subtyping::Variance variance = subtyping.variance;
subtyping.variance = Subtyping::Variance::Contravariant;
SubtypingResult r = subtyping.isSubtype(argsPack, ftv->argTypes); SubtypingResult r = subtyping.isSubtype(argsPack, ftv->argTypes);
subtyping.variance = variance;
if (r.isSubtype) if (r.isSubtype)
return true; return true;
} }
@ -137,7 +149,12 @@ std::optional<ErrorVec> OverloadResolver::testIsSubtype(const Location& location
} }
std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload( std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload(
TypeId fnTy, const TypePack* args, AstExpr* fnLoc, const std::vector<AstExpr*>* argExprs, bool callMetamethodOk) TypeId fnTy,
const TypePack* args,
AstExpr* fnLoc,
const std::vector<AstExpr*>* argExprs,
bool callMetamethodOk
)
{ {
fnTy = follow(fnTy); fnTy = follow(fnTy);
@ -173,7 +190,12 @@ bool OverloadResolver::isLiteral(AstExpr* expr)
} }
std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload_( std::pair<OverloadResolver::Analysis, ErrorVec> OverloadResolver::checkOverload_(
TypeId fnTy, const FunctionType* fn, const TypePack* args, AstExpr* fnExpr, const std::vector<AstExpr*>* argExprs) TypeId fnTy,
const FunctionType* fn,
const TypePack* args,
AstExpr* fnExpr,
const std::vector<AstExpr*>* argExprs
)
{ {
FunctionGraphReductionResult result = FunctionGraphReductionResult result =
reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true); reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true);
@ -373,9 +395,17 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors)
// we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`. // we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`.
// this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed. // this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed.
std::optional<TypeId> selectOverload(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<Normalizer> normalizer, std::optional<TypeId> selectOverload(
NotNull<Scope> scope, NotNull<InternalErrorReporter> iceReporter, NotNull<TypeCheckLimits> limits, const Location& location, TypeId fn, NotNull<BuiltinTypes> builtinTypes,
TypePackId argsPack) NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
const Location& location,
TypeId fn,
TypePackId argsPack
)
{ {
OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location}; OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location};
auto [status, overload] = resolver.selectOverload(fn, argsPack); auto [status, overload] = resolver.selectOverload(fn, argsPack);
@ -389,9 +419,17 @@ std::optional<TypeId> selectOverload(NotNull<BuiltinTypes> builtinTypes, NotNull
return {}; return {};
} }
SolveResult solveFunctionCall(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Normalizer> normalizer, SolveResult solveFunctionCall(
NotNull<InternalErrorReporter> iceReporter, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, const Location& location, TypeId fn, NotNull<TypeArena> arena,
TypePackId argsPack) NotNull<BuiltinTypes> builtinTypes,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
const Location& location,
TypeId fn,
TypePackId argsPack
)
{ {
std::optional<TypeId> overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack); std::optional<TypeId> overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack);
if (!overloadToUse) if (!overloadToUse)

View File

@ -8,8 +8,6 @@
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_FASTFLAG(DebugLuauSharedSelf)
namespace Luau namespace Luau
{ {
@ -99,45 +97,6 @@ struct Quantifier final : TypeOnceVisitor
}; };
void quantify(TypeId ty, TypeLevel level) void quantify(TypeId ty, TypeLevel level)
{
if (FFlag::DebugLuauSharedSelf)
{
ty = follow(ty);
if (auto ttv = getTableType(ty); ttv && ttv->selfTy)
{
Quantifier selfQ{level};
selfQ.traverse(*ttv->selfTy);
Quantifier q{level};
q.traverse(ty);
for (const auto& [_, prop] : ttv->props)
{
auto ftv = getMutable<FunctionType>(follow(prop.type()));
if (!ftv || !ftv->hasSelf)
continue;
if (Luau::first(ftv->argTypes) == ttv->selfTy)
{
ftv->generics.insert(ftv->generics.end(), selfQ.generics.begin(), selfQ.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), selfQ.genericPacks.begin(), selfQ.genericPacks.end());
}
}
}
else if (auto ftv = getMutable<FunctionType>(ty))
{
Quantifier q{level};
q.traverse(ty);
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType)
ftv->hasNoFreeOrGenericTypes = true;
}
}
else
{ {
Quantifier q{level}; Quantifier q{level};
q.traverse(ty); q.traverse(ty);
@ -147,7 +106,6 @@ void quantify(TypeId ty, TypeLevel level)
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
} }
}
struct PureQuantifier : Substitution struct PureQuantifier : Substitution
{ {

View File

@ -238,12 +238,22 @@ Relation relateTables(TypeId left, TypeId right, SimplifierSeenSet& seen)
LUAU_ASSERT(1 == rightTable->props.size()); LUAU_ASSERT(1 == rightTable->props.size());
// Disjoint props have nothing in common // Disjoint props have nothing in common
// t1 with props p1's cannot appear in t2 and t2 with props p2's cannot appear in t1 // t1 with props p1's cannot appear in t2 and t2 with props p2's cannot appear in t1
bool foundPropFromLeftInRight = std::any_of(begin(leftTable->props), end(leftTable->props), [&](auto prop) { bool foundPropFromLeftInRight = std::any_of(
begin(leftTable->props),
end(leftTable->props),
[&](auto prop)
{
return rightTable->props.count(prop.first) > 0; return rightTable->props.count(prop.first) > 0;
}); }
bool foundPropFromRightInLeft = std::any_of(begin(rightTable->props), end(rightTable->props), [&](auto prop) { );
bool foundPropFromRightInLeft = std::any_of(
begin(rightTable->props),
end(rightTable->props),
[&](auto prop)
{
return leftTable->props.count(prop.first) > 0; return leftTable->props.count(prop.first) > 0;
}); }
);
if (!foundPropFromLeftInRight && !foundPropFromRightInLeft && leftTable->props.size() >= 1 && rightTable->props.size() >= 1) if (!foundPropFromLeftInRight && !foundPropFromRightInLeft && leftTable->props.size() >= 1 && rightTable->props.size() >= 1)
return Relation::Disjoint; return Relation::Disjoint;
@ -1112,8 +1122,13 @@ std::optional<TypeId> TypeSimplifier::basicIntersect(TypeId left, TypeId right)
{ {
case Relation::Disjoint: case Relation::Disjoint:
return builtinTypes->neverType; return builtinTypes->neverType;
case Relation::Superset:
case Relation::Coincident: case Relation::Coincident:
return right; return right;
case Relation::Subset:
if (1 == rt->props.size())
return left;
break;
default: default:
break; break;
} }
@ -1121,6 +1136,40 @@ std::optional<TypeId> TypeSimplifier::basicIntersect(TypeId left, TypeId right)
} }
else if (1 == rt->props.size()) else if (1 == rt->props.size())
return basicIntersect(right, left); return basicIntersect(right, left);
// If two tables have disjoint properties and indexers, we can combine them.
if (!lt->indexer && !rt->indexer && lt->state == TableState::Sealed && rt->state == TableState::Sealed)
{
if (rt->props.empty())
return left;
bool areDisjoint = true;
for (const auto& [name, leftProp]: lt->props)
{
if (rt->props.count(name))
{
areDisjoint = false;
break;
}
}
if (areDisjoint)
{
TableType::Props mergedProps = lt->props;
for (const auto& [name, rightProp]: rt->props)
mergedProps[name] = rightProp;
return arena->addType(TableType{
mergedProps,
std::nullopt,
TypeLevel{},
lt->scope,
TableState::Sealed
});
}
}
return std::nullopt;
} }
Relation relation = relate(left, right); Relation relation = relate(left, right);

View File

@ -18,7 +18,8 @@ namespace Luau
static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone)
{ {
auto go = [ty, &dest, alwaysClone](auto&& a) { auto go = [ty, &dest, alwaysClone](auto&& a)
{
using T = std::decay_t<decltype(a)>; using T = std::decay_t<decltype(a)>;
// The pointer identities of free and local types is very important. // The pointer identities of free and local types is very important.
@ -672,7 +673,8 @@ TypePackId Substitution::clone(TypePackId tp)
else if (const TypeFunctionInstanceTypePack* tfitp = get<TypeFunctionInstanceTypePack>(tp)) else if (const TypeFunctionInstanceTypePack* tfitp = get<TypeFunctionInstanceTypePack>(tp))
{ {
TypeFunctionInstanceTypePack clone{ TypeFunctionInstanceTypePack clone{
tfitp->function, std::vector<TypeId>(tfitp->typeArguments.size()), std::vector<TypePackId>(tfitp->packArguments.size())}; tfitp->function, std::vector<TypeId>(tfitp->typeArguments.size()), std::vector<TypePackId>(tfitp->packArguments.size())
};
clone.typeArguments.assign(tfitp->typeArguments.begin(), tfitp->typeArguments.end()); clone.typeArguments.assign(tfitp->typeArguments.begin(), tfitp->typeArguments.end());
clone.packArguments.assign(tfitp->packArguments.begin(), tfitp->packArguments.end()); clone.packArguments.assign(tfitp->packArguments.begin(), tfitp->packArguments.end());
return addTypePack(std::move(clone)); return addTypePack(std::move(clone));

View File

@ -91,7 +91,8 @@ static SubtypingReasonings mergeReasonings(const SubtypingReasonings& a, const S
else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant) else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant)
{ {
SubtypingReasoning inverseReasoning = SubtypingReasoning{ SubtypingReasoning inverseReasoning = SubtypingReasoning{
r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant}; r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant
};
if (b.contains(inverseReasoning)) if (b.contains(inverseReasoning))
result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant}); result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant});
else else
@ -106,7 +107,8 @@ static SubtypingReasonings mergeReasonings(const SubtypingReasonings& a, const S
else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant) else if (r.variance == SubtypingVariance::Covariant || r.variance == SubtypingVariance::Contravariant)
{ {
SubtypingReasoning inverseReasoning = SubtypingReasoning{ SubtypingReasoning inverseReasoning = SubtypingReasoning{
r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant}; r.subPath, r.superPath, r.variance == SubtypingVariance::Covariant ? SubtypingVariance::Contravariant : SubtypingVariance::Covariant
};
if (a.contains(inverseReasoning)) if (a.contains(inverseReasoning))
result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant}); result.insert(SubtypingReasoning{r.subPath, r.superPath, SubtypingVariance::Invariant});
else else
@ -267,7 +269,11 @@ struct ApplyMappedGenerics : Substitution
ApplyMappedGenerics( ApplyMappedGenerics(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, MappedGenerics& mappedGenerics, MappedGenericPacks& mappedGenericPacks) NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
MappedGenerics& mappedGenerics,
MappedGenericPacks& mappedGenericPacks
)
: Substitution(TxnLog::empty(), arena) : Substitution(TxnLog::empty(), arena)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, arena(arena) , arena(arena)
@ -323,8 +329,13 @@ std::optional<TypeId> SubtypingEnvironment::applyMappedGenerics(NotNull<BuiltinT
return amg.substitute(ty); return amg.substitute(ty);
} }
Subtyping::Subtyping(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> typeArena, NotNull<Normalizer> normalizer, Subtyping::Subtyping(
NotNull<InternalErrorReporter> iceReporter, NotNull<Scope> scope) NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> typeArena,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter,
NotNull<Scope> scope
)
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, arena(typeArena) , arena(typeArena)
, normalizer(normalizer) , normalizer(normalizer)
@ -1243,8 +1254,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Tabl
std::vector<SubtypingResult> results; std::vector<SubtypingResult> results;
if (auto subIter = subTable->props.find(name); subIter != subTable->props.end()) if (auto subIter = subTable->props.find(name); subIter != subTable->props.end())
results.push_back(isCovariantWith(env, subIter->second, superProp, name)); results.push_back(isCovariantWith(env, subIter->second, superProp, name));
else if (subTable->indexer)
if (subTable->indexer)
{ {
if (isCovariantWith(env, builtinTypes->stringType, subTable->indexer->indexType).isSubtype) if (isCovariantWith(env, builtinTypes->stringType, subTable->indexer->indexType).isSubtype)
{ {
@ -1317,7 +1327,12 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Clas
} }
SubtypingResult Subtyping::isCovariantWith( SubtypingResult Subtyping::isCovariantWith(
SubtypingEnvironment& env, TypeId subTy, const ClassType* subClass, TypeId superTy, const TableType* superTable) SubtypingEnvironment& env,
TypeId subTy,
const ClassType* subClass,
TypeId superTy,
const TableType* superTable
)
{ {
SubtypingResult result{true}; SubtypingResult result{true};
@ -1366,7 +1381,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Prim
{ {
if (auto stringTable = get<TableType>(it->second.type())) if (auto stringTable = get<TableType>(it->second.type()))
result.orElse( result.orElse(
isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build())); isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build())
);
} }
} }
} }
@ -1388,7 +1404,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Sing
{ {
if (auto stringTable = get<TableType>(it->second.type())) if (auto stringTable = get<TableType>(it->second.type()))
result.orElse( result.orElse(
isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build())); isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().readProp("__index").build())
);
} }
} }
} }
@ -1429,7 +1446,10 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Prop
} }
SubtypingResult Subtyping::isCovariantWith( SubtypingResult Subtyping::isCovariantWith(
SubtypingEnvironment& env, const std::shared_ptr<const NormalizedType>& subNorm, const std::shared_ptr<const NormalizedType>& superNorm) SubtypingEnvironment& env,
const std::shared_ptr<const NormalizedType>& subNorm,
const std::shared_ptr<const NormalizedType>& superNorm
)
{ {
if (!subNorm || !superNorm) if (!subNorm || !superNorm)
return {false, true}; return {false, true};
@ -1540,7 +1560,10 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Norm
} }
SubtypingResult Subtyping::isCovariantWith( SubtypingResult Subtyping::isCovariantWith(
SubtypingEnvironment& env, const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction) SubtypingEnvironment& env,
const NormalizedFunctionType& subFunction,
const NormalizedFunctionType& superFunction
)
{ {
if (subFunction.isNever()) if (subFunction.isNever())
return {true}; return {true};

View File

@ -13,8 +13,10 @@ namespace Luau
static bool isLiteral(const AstExpr* expr) static bool isLiteral(const AstExpr* expr)
{ {
return (expr->is<AstExprTable>() || expr->is<AstExprFunction>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantString>() || return (
expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNil>()); expr->is<AstExprTable>() || expr->is<AstExprFunction>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantString>() ||
expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNil>()
);
} }
// A fast approximation of subTy <: superTy // A fast approximation of subTy <: superTy
@ -108,9 +110,17 @@ static std::optional<TypeId> extractMatchingTableType(std::vector<TypeId>& table
return std::nullopt; return std::nullopt;
} }
TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes, NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes, TypeId matchLiteralType(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<Unifier2> unifier, TypeId expectedType, TypeId exprType, NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
const AstExpr* expr, std::vector<TypeId>& toBlock) NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes,
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Unifier2> unifier,
TypeId expectedType,
TypeId exprType,
const AstExpr* expr,
std::vector<TypeId>& toBlock
)
{ {
/* /*
* Table types that arise from literal table expressions have some * Table types that arise from literal table expressions have some
@ -208,7 +218,7 @@ TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
if (auto exprTable = expr->as<AstExprTable>()) if (auto exprTable = expr->as<AstExprTable>())
{ {
TableType* tableTy = getMutable<TableType>(exprType); TableType* const tableTy = getMutable<TableType>(exprType);
LUAU_ASSERT(tableTy); LUAU_ASSERT(tableTy);
const TableType* expectedTableTy = get<TableType>(expectedType); const TableType* expectedTableTy = get<TableType>(expectedType);
@ -260,8 +270,17 @@ TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
(*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType;
(*astExpectedTypes)[item.value] = expectedTableTy->indexer->indexResultType; (*astExpectedTypes)[item.value] = expectedTableTy->indexer->indexResultType;
TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, TypeId matchedType = matchLiteralType(
expectedTableTy->indexer->indexResultType, propTy, item.value, toBlock); astTypes,
astExpectedTypes,
builtinTypes,
arena,
unifier,
expectedTableTy->indexer->indexResultType,
propTy,
item.value,
toBlock
);
if (tableTy->indexer) if (tableTy->indexer)
unifier->unify(matchedType, tableTy->indexer->indexResultType); unifier->unify(matchedType, tableTy->indexer->indexResultType);
@ -334,8 +353,17 @@ TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
LUAU_ASSERT(propTy); LUAU_ASSERT(propTy);
unifier->unify(expectedTableTy->indexer->indexType, builtinTypes->numberType); unifier->unify(expectedTableTy->indexer->indexType, builtinTypes->numberType);
TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, TypeId matchedType = matchLiteralType(
expectedTableTy->indexer->indexResultType, *propTy, item.value, toBlock); astTypes,
astExpectedTypes,
builtinTypes,
arena,
unifier,
expectedTableTy->indexer->indexResultType,
*propTy,
item.value,
toBlock
);
// if the index result type is the prop type, we can replace it with the matched type here. // if the index result type is the prop type, we can replace it with the matched type here.
if (tableTy->indexer->indexResultType == *propTy) if (tableTy->indexer->indexResultType == *propTy)
@ -410,6 +438,15 @@ TypeId matchLiteralType(NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
if (exprProp.readTy || exprProp.writeTy) if (exprProp.readTy || exprProp.writeTy)
tableTy->props[*key] = std::move(exprProp); tableTy->props[*key] = std::move(exprProp);
} }
// If the expected table has an indexer, then the provided table can
// have one too.
// TODO: If the expected table also has an indexer, we might want to
// push the expected indexer's types into it.
if (expectedTableTy->indexer && !tableTy->indexer)
{
tableTy->indexer = expectedTableTy->indexer;
}
} }
return exprType; return exprType;

View File

@ -146,7 +146,8 @@ void StateDot::visitChildren(TypeId ty, int index)
startNode(index); startNode(index);
startNodeLabel(); startNodeLabel();
auto go = [&](auto&& t) { auto go = [&](auto&& t)
{
using T = std::decay_t<decltype(t)>; using T = std::decay_t<decltype(t)>;
if constexpr (std::is_same_v<T, BoundType>) if constexpr (std::is_same_v<T, BoundType>)

View File

@ -168,7 +168,8 @@ struct StringifierState
DenseHashMap<TypeId, std::string> cycleNames{{}}; DenseHashMap<TypeId, std::string> cycleNames{{}};
DenseHashMap<TypePackId, std::string> cycleTpNames{{}}; DenseHashMap<TypePackId, std::string> cycleTpNames{{}};
Set<void*> seen{{}}; Set<void*> seen{{}};
// `$$$` was chosen as the tombstone for `usedNames` since it is not a valid name syntactically and is relatively short for string comparison reasons. // `$$$` was chosen as the tombstone for `usedNames` since it is not a valid name syntactically and is relatively short for string comparison
// reasons.
DenseHashSet<std::string> usedNames{"$$$"}; DenseHashSet<std::string> usedNames{"$$$"};
size_t indentation = 0; size_t indentation = 0;
@ -356,10 +357,12 @@ struct TypeStringifier
} }
Luau::visit( Luau::visit(
[this, tv](auto&& t) { [this, tv](auto&& t)
{
return (*this)(tv, t); return (*this)(tv, t);
}, },
tv->ty); tv->ty
);
} }
void emitKey(const std::string& name) void emitKey(const std::string& name)
@ -1104,10 +1107,12 @@ struct TypePackStringifier
} }
Luau::visit( Luau::visit(
[this, tp](auto&& t) { [this, tp](auto&& t)
{
return (*this)(tp, t); return (*this)(tp, t);
}, },
tp->ty); tp->ty
);
} }
void operator()(TypePackId, const TypePack& tp) void operator()(TypePackId, const TypePack& tp)
@ -1272,8 +1277,13 @@ void TypeStringifier::stringify(TypePackId tpid, const std::vector<std::optional
tps.stringify(tpid); tps.stringify(tpid);
} }
static void assignCycleNames(const std::set<TypeId>& cycles, const std::set<TypePackId>& cycleTPs, DenseHashMap<TypeId, std::string>& cycleNames, static void assignCycleNames(
DenseHashMap<TypePackId, std::string>& cycleTpNames, bool exhaustive) const std::set<TypeId>& cycles,
const std::set<TypePackId>& cycleTPs,
DenseHashMap<TypeId, std::string>& cycleNames,
DenseHashMap<TypePackId, std::string>& cycleTpNames,
bool exhaustive
)
{ {
int nextIndex = 1; int nextIndex = 1;
@ -1285,9 +1295,14 @@ static void assignCycleNames(const std::set<TypeId>& cycles, const std::set<Type
if (auto ttv = get<TableType>(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) if (auto ttv = get<TableType>(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name))
{ {
// If we have a cycle type in type parameters, assign a cycle name for this named table // If we have a cycle type in type parameters, assign a cycle name for this named table
if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { if (std::find_if(
ttv->instantiatedTypeParams.begin(),
ttv->instantiatedTypeParams.end(),
[&](auto&& el)
{
return cycles.count(follow(el)); return cycles.count(follow(el));
}) != ttv->instantiatedTypeParams.end()) }
) != ttv->instantiatedTypeParams.end())
cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName;
continue; continue;
@ -1381,9 +1396,14 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts)
state.exhaustive = true; state.exhaustive = true;
std::vector<std::pair<TypeId, std::string>> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; std::vector<std::pair<TypeId, std::string>> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()};
std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { std::sort(
sortedCycleNames.begin(),
sortedCycleNames.end(),
[](const auto& a, const auto& b)
{
return a.second < b.second; return a.second < b.second;
}); }
);
bool semi = false; bool semi = false;
for (const auto& [cycleTy, name] : sortedCycleNames) for (const auto& [cycleTy, name] : sortedCycleNames)
@ -1394,18 +1414,25 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts)
state.emit(name); state.emit(name);
state.emit(" = "); state.emit(" = ");
Luau::visit( Luau::visit(
[&tvs, cycleTy = cycleTy](auto&& t) { [&tvs, cycleTy = cycleTy](auto&& t)
{
return tvs(cycleTy, t); return tvs(cycleTy, t);
}, },
cycleTy->ty); cycleTy->ty
);
semi = true; semi = true;
} }
std::vector<std::pair<TypePackId, std::string>> sortedCycleTpNames(state.cycleTpNames.begin(), state.cycleTpNames.end()); std::vector<std::pair<TypePackId, std::string>> sortedCycleTpNames(state.cycleTpNames.begin(), state.cycleTpNames.end());
std::sort(sortedCycleTpNames.begin(), sortedCycleTpNames.end(), [](const auto& a, const auto& b) { std::sort(
sortedCycleTpNames.begin(),
sortedCycleTpNames.end(),
[](const auto& a, const auto& b)
{
return a.second < b.second; return a.second < b.second;
}); }
);
TypePackStringifier tps{state}; TypePackStringifier tps{state};
@ -1417,10 +1444,12 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts)
state.emit(name); state.emit(name);
state.emit(" = "); state.emit(" = ");
Luau::visit( Luau::visit(
[&tps, cycleTy = cycleTp](auto&& t) { [&tps, cycleTy = cycleTp](auto&& t)
{
return tps(cycleTy, t); return tps(cycleTy, t);
}, },
cycleTp->ty); cycleTp->ty
);
semi = true; semi = true;
} }
@ -1474,9 +1503,14 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts)
state.exhaustive = true; state.exhaustive = true;
std::vector<std::pair<TypeId, std::string>> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; std::vector<std::pair<TypeId, std::string>> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()};
std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { std::sort(
sortedCycleNames.begin(),
sortedCycleNames.end(),
[](const auto& a, const auto& b)
{
return a.second < b.second; return a.second < b.second;
}); }
);
bool semi = false; bool semi = false;
for (const auto& [cycleTy, name] : sortedCycleNames) for (const auto& [cycleTy, name] : sortedCycleNames)
@ -1487,18 +1521,25 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts)
state.emit(name); state.emit(name);
state.emit(" = "); state.emit(" = ");
Luau::visit( Luau::visit(
[&tvs, cycleTy = cycleTy](auto t) { [&tvs, cycleTy = cycleTy](auto t)
{
return tvs(cycleTy, t); return tvs(cycleTy, t);
}, },
cycleTy->ty); cycleTy->ty
);
semi = true; semi = true;
} }
std::vector<std::pair<TypePackId, std::string>> sortedCycleTpNames{state.cycleTpNames.begin(), state.cycleTpNames.end()}; std::vector<std::pair<TypePackId, std::string>> sortedCycleTpNames{state.cycleTpNames.begin(), state.cycleTpNames.end()};
std::sort(sortedCycleTpNames.begin(), sortedCycleTpNames.end(), [](const auto& a, const auto& b) { std::sort(
sortedCycleTpNames.begin(),
sortedCycleTpNames.end(),
[](const auto& a, const auto& b)
{
return a.second < b.second; return a.second < b.second;
}); }
);
TypePackStringifier tps{tvs.state}; TypePackStringifier tps{tvs.state};
@ -1510,10 +1551,12 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts)
state.emit(name); state.emit(name);
state.emit(" = "); state.emit(" = ");
Luau::visit( Luau::visit(
[&tps, cycleTp = cycleTp](auto t) { [&tps, cycleTp = cycleTp](auto t)
{
return tps(cycleTp, t); return tps(cycleTp, t);
}, },
cycleTp->ty); cycleTp->ty
);
semi = true; semi = true;
} }
@ -1713,10 +1756,12 @@ std::string toStringVector(const std::vector<TypeId>& types, ToStringOptions& op
std::string toString(const Constraint& constraint, ToStringOptions& opts) std::string toString(const Constraint& constraint, ToStringOptions& opts)
{ {
auto go = [&opts](auto&& c) -> std::string { auto go = [&opts](auto&& c) -> std::string
{
using T = std::decay_t<decltype(c)>; using T = std::decay_t<decltype(c)>;
auto tos = [&opts](auto&& a) { auto tos = [&opts](auto&& a)
{
return toString(a, opts); return toString(a, opts);
}; };

View File

@ -28,8 +28,8 @@ bool isIdentifierChar(char c)
return isIdentifierStartChar(c) || isDigit(c); return isIdentifierStartChar(c) || isDigit(c);
} }
const std::vector<std::string> keywords = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", const std::vector<std::string> keywords = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in",
"not", "or", "repeat", "return", "then", "true", "until", "while"}; "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"};
} // namespace } // namespace
@ -844,6 +844,15 @@ struct Printer
visualizeTypeAnnotation(*a->type); visualizeTypeAnnotation(*a->type);
} }
} }
else if (const auto& t = program.as<AstStatTypeFunction>())
{
if (writeTypes)
{
writer.keyword("type function");
writer.identifier(t->name.value);
visualizeFunctionBody(*t->body);
}
}
else if (const auto& a = program.as<AstStatError>()) else if (const auto& a = program.as<AstStatError>())
{ {
writer.symbol("(error-stat"); writer.symbol("(error-stat");

View File

@ -469,7 +469,11 @@ std::optional<TypeLevel> TxnLog::getLevel(TypeId ty) const
TypeId TxnLog::follow(TypeId ty) const TypeId TxnLog::follow(TypeId ty) const
{ {
return Luau::follow(ty, this, [](const void* ctx, TypeId ty) -> TypeId { return Luau::follow(
ty,
this,
[](const void* ctx, TypeId ty) -> TypeId
{
const TxnLog* self = static_cast<const TxnLog*>(ctx); const TxnLog* self = static_cast<const TxnLog*>(ctx);
PendingType* state = self->pending(ty); PendingType* state = self->pending(ty);
@ -480,12 +484,17 @@ TypeId TxnLog::follow(TypeId ty) const
// that normally apply. This is safe because follow will only call get<> // that normally apply. This is safe because follow will only call get<>
// on the returned pointer. // on the returned pointer.
return const_cast<const Type*>(&state->pending); return const_cast<const Type*>(&state->pending);
}); }
);
} }
TypePackId TxnLog::follow(TypePackId tp) const TypePackId TxnLog::follow(TypePackId tp) const
{ {
return Luau::follow(tp, this, [](const void* ctx, TypePackId tp) -> TypePackId { return Luau::follow(
tp,
this,
[](const void* ctx, TypePackId tp) -> TypePackId
{
const TxnLog* self = static_cast<const TxnLog*>(ctx); const TxnLog* self = static_cast<const TxnLog*>(ctx);
PendingTypePack* state = self->pending(tp); PendingTypePack* state = self->pending(tp);
@ -496,7 +505,8 @@ TypePackId TxnLog::follow(TypePackId tp) const
// invariants that normally apply. This is safe because follow will // invariants that normally apply. This is safe because follow will
// only call get<> on the returned pointer. // only call get<> on the returned pointer.
return const_cast<const TypePackVar*>(&state->pending); return const_cast<const TypePackVar*>(&state->pending);
}); }
);
} }
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TxnLog::getChanges() const std::pair<std::vector<TypeId>, std::vector<TypePackId>> TxnLog::getChanges() const

View File

@ -58,9 +58,15 @@ TypeId follow(TypeId t)
TypeId follow(TypeId t, FollowOption followOption) TypeId follow(TypeId t, FollowOption followOption)
{ {
return follow(t, followOption, nullptr, [](const void*, TypeId t) -> TypeId { return follow(
t,
followOption,
nullptr,
[](const void*, TypeId t) -> TypeId
{
return t; return t;
}); }
);
} }
TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)) TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId))
@ -70,7 +76,8 @@ TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeI
TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId)) TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId))
{ {
auto advance = [followOption, context, mapper](TypeId ty) -> std::optional<TypeId> { auto advance = [followOption, context, mapper](TypeId ty) -> std::optional<TypeId>
{
TypeId mapped = mapper(context, ty); TypeId mapped = mapper(context, ty);
if (auto btv = get<Unifiable::Bound<TypeId>>(mapped)) if (auto btv = get<Unifiable::Bound<TypeId>>(mapped))
@ -259,7 +266,8 @@ bool isOverloadedFunction(TypeId ty)
if (!get<IntersectionType>(follow(ty))) if (!get<IntersectionType>(follow(ty)))
return false; return false;
auto isFunction = [](TypeId part) -> bool { auto isFunction = [](TypeId part) -> bool
{
return get<FunctionType>(part); return get<FunctionType>(part);
}; };
@ -567,7 +575,11 @@ void BlockedType::replaceOwner(Constraint* newOwner)
} }
PendingExpansionType::PendingExpansionType( PendingExpansionType::PendingExpansionType(
std::optional<AstName> prefix, AstName name, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments) std::optional<AstName> prefix,
AstName name,
std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments
)
: prefix(prefix) : prefix(prefix)
, name(name) , name(name)
, typeArguments(typeArguments) , typeArguments(typeArguments)
@ -596,7 +608,13 @@ FunctionType::FunctionType(TypeLevel level, TypePackId argTypes, TypePackId retT
} }
FunctionType::FunctionType( FunctionType::FunctionType(
TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn, bool hasSelf) TypeLevel level,
Scope* scope,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn,
bool hasSelf
)
: definition(std::move(defn)) : definition(std::move(defn))
, level(level) , level(level)
, scope(scope) , scope(scope)
@ -606,8 +624,14 @@ FunctionType::FunctionType(
{ {
} }
FunctionType::FunctionType(std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes, FunctionType::FunctionType(
std::optional<FunctionDefinition> defn, bool hasSelf) std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn,
bool hasSelf
)
: definition(std::move(defn)) : definition(std::move(defn))
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
@ -617,8 +641,15 @@ FunctionType::FunctionType(std::vector<TypeId> generics, std::vector<TypePackId>
{ {
} }
FunctionType::FunctionType(TypeLevel level, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, FunctionType::FunctionType(
TypePackId retTypes, std::optional<FunctionDefinition> defn, bool hasSelf) TypeLevel level,
std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn,
bool hasSelf
)
: definition(std::move(defn)) : definition(std::move(defn))
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
@ -629,8 +660,16 @@ FunctionType::FunctionType(TypeLevel level, std::vector<TypeId> generics, std::v
{ {
} }
FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, FunctionType::FunctionType(
TypePackId retTypes, std::optional<FunctionDefinition> defn, bool hasSelf) TypeLevel level,
Scope* scope,
std::vector<TypeId> generics,
std::vector<TypePackId> genericPacks,
TypePackId argTypes,
TypePackId retTypes,
std::optional<FunctionDefinition> defn,
bool hasSelf
)
: definition(std::move(defn)) : definition(std::move(defn))
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
@ -644,8 +683,15 @@ FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector<TypeId> ge
Property::Property() {} Property::Property() {}
Property::Property(TypeId readTy, bool deprecated, const std::string& deprecatedSuggestion, std::optional<Location> location, const Tags& tags, Property::Property(
const std::optional<std::string>& documentationSymbol, std::optional<Location> typeLocation) TypeId readTy,
bool deprecated,
const std::string& deprecatedSuggestion,
std::optional<Location> location,
const Tags& tags,
const std::optional<std::string>& documentationSymbol,
std::optional<Location> typeLocation
)
: deprecated(deprecated) : deprecated(deprecated)
, deprecatedSuggestion(deprecatedSuggestion) , deprecatedSuggestion(deprecatedSuggestion)
, location(location) , location(location)
@ -953,9 +999,15 @@ Type& Type::operator=(const Type& rhs)
return *this; return *this;
} }
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, TypeId makeFunction(
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, TypeArena& arena,
std::initializer_list<TypeId> retTypes); std::optional<TypeId> selfType,
std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes
);
TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes); // BuiltinDefinitions.cpp TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes); // BuiltinDefinitions.cpp

View File

@ -166,7 +166,8 @@ public:
} }
return allocator->alloc<AstTypeReference>( return allocator->alloc<AstTypeReference>(
Location(), std::nullopt, AstName(ttv.name->c_str()), std::nullopt, Location(), parameters.size != 0, parameters); Location(), std::nullopt, AstName(ttv.name->c_str()), std::nullopt, Location(), parameters.size != 0, parameters
);
} }
if (hasSeen(&ttv)) if (hasSeen(&ttv))
@ -319,7 +320,8 @@ public:
retTailAnnotation = rehydrate(*retTail); retTailAnnotation = rehydrate(*retTail);
return allocator->alloc<AstTypeFunction>( return allocator->alloc<AstTypeFunction>(
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}
);
} }
AstType* operator()(const Unifiable::Error&) AstType* operator()(const Unifiable::Error&)
{ {
@ -328,7 +330,8 @@ public:
AstType* operator()(const GenericType& gtv) AstType* operator()(const GenericType& gtv)
{ {
return allocator->alloc<AstTypeReference>( return allocator->alloc<AstTypeReference>(
Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)), std::nullopt, Location()); Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)), std::nullopt, Location()
);
} }
AstType* operator()(const Unifiable::Bound<TypeId>& bound) AstType* operator()(const Unifiable::Bound<TypeId>& bound)
{ {

View File

@ -442,8 +442,12 @@ struct TypeChecker2
return instance; return instance;
seenTypeFunctionInstances.insert(instance); seenTypeFunctionInstances.insert(instance);
ErrorVec errors = reduceTypeFunctions(instance, location, ErrorVec errors = reduceTypeFunctions(
TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true) instance,
location,
TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits},
true
)
.errors; .errors;
if (!isErrorSuppressing(location, instance)) if (!isErrorSuppressing(location, instance))
reportErrors(std::move(errors)); reportErrors(std::move(errors));
@ -488,7 +492,8 @@ struct TypeChecker2
{ {
TypeId argTy = lookupAnnotation(ref->parameters.data[0].type); TypeId argTy = lookupAnnotation(ref->parameters.data[0].type);
luauPrintLine(format( luauPrintLine(format(
"_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str())); "_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str()
));
return follow(argTy); return follow(argTy);
} }
} }
@ -597,6 +602,8 @@ struct TypeChecker2
return visit(s); return visit(s);
else if (auto s = stat->as<AstStatTypeAlias>()) else if (auto s = stat->as<AstStatTypeAlias>())
return visit(s); return visit(s);
else if (auto f = stat->as<AstStatTypeFunction>())
return visit(f);
else if (auto s = stat->as<AstStatDeclareFunction>()) else if (auto s = stat->as<AstStatDeclareFunction>())
return visit(s); return visit(s);
else if (auto s = stat->as<AstStatDeclareGlobal>()) else if (auto s = stat->as<AstStatDeclareGlobal>())
@ -728,7 +735,8 @@ struct TypeChecker2
local->values.data[local->values.size - 1]->is<AstExprCall>() ? CountMismatch::FunctionResult local->values.data[local->values.size - 1]->is<AstExprCall>() ? CountMismatch::FunctionResult
: CountMismatch::ExprListResult, : CountMismatch::ExprListResult,
}, },
errorLocation); errorLocation
);
} }
} }
} }
@ -744,7 +752,8 @@ struct TypeChecker2
testIsSubtype(builtinTypes->numberType, annotatedType, forStatement->var->location); testIsSubtype(builtinTypes->numberType, annotatedType, forStatement->var->location);
} }
auto checkNumber = [this](AstExpr* expr) { auto checkNumber = [this](AstExpr* expr)
{
if (!expr) if (!expr)
return; return;
@ -839,7 +848,8 @@ struct TypeChecker2
} }
TypeId iteratorTy = follow(iteratorTypes.head[0]); TypeId iteratorTy = follow(iteratorTypes.head[0]);
auto checkFunction = [this, &arena, &forInStatement, &variableTypes](const FunctionType* iterFtv, std::vector<TypeId> iterTys, bool isMm) { auto checkFunction = [this, &arena, &forInStatement, &variableTypes](const FunctionType* iterFtv, std::vector<TypeId> iterTys, bool isMm)
{
if (iterTys.size() < 1 || iterTys.size() > 3) if (iterTys.size() < 1 || iterTys.size() > 3)
{ {
if (isMm) if (isMm)
@ -856,7 +866,8 @@ struct TypeChecker2
{ {
if (isMm) if (isMm)
reportError( reportError(
GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values)); GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values)
);
else else
reportError(GenericError{"next() does not return enough values"}, forInStatement->values.data[0]->location); reportError(GenericError{"next() does not return enough values"}, forInStatement->values.data[0]->location);
} }
@ -1143,6 +1154,13 @@ struct TypeChecker2
visit(stat->type); visit(stat->type);
} }
void visit(AstStatTypeFunction* stat)
{
// TODO: add type checking for user-defined type functions
reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}});
}
void visit(AstTypeList types) void visit(AstTypeList types)
{ {
for (AstType* type : types.types) for (AstType* type : types.types)
@ -1349,11 +1367,6 @@ struct TypeChecker2
args.head.push_back(lookupType(indexExpr->expr)); args.head.push_back(lookupType(indexExpr->expr));
argExprs.push_back(indexExpr->expr); argExprs.push_back(indexExpr->expr);
} }
else if (findMetatableEntry(builtinTypes, module->errors, *originalCallTy, "__call", call->func->location))
{
args.head.insert(args.head.begin(), lookupType(call->func));
argExprs.push_back(call->func);
}
for (size_t i = 0; i < call->args.size; ++i) for (size_t i = 0; i < call->args.size; ++i)
{ {
@ -1698,12 +1711,17 @@ struct TypeChecker2
// together. For now, this will work. // together. For now, this will work.
reportError( reportError(
GenericError{format( GenericError{format(
"Parameter '%s' has been reduced to never. This function is not callable with any possible value.", arg->name.value)}, "Parameter '%s' has been reduced to never. This function is not callable with any possible value.", arg->name.value
arg->location); )},
arg->location
);
for (const auto& [site, component] : *contributors) for (const auto& [site, component] : *contributors)
reportError(ExtraInformation{format("Parameter '%s' is required to be a subtype of '%s' here.", arg->name.value, reportError(
toString(component).c_str())}, ExtraInformation{
site); format("Parameter '%s' is required to be a subtype of '%s' here.", arg->name.value, toString(component).c_str())
},
site
);
} }
} }
@ -1739,8 +1757,10 @@ struct TypeChecker2
{ {
TypeFunctionReductionGuessResult result = guesser.guessTypeFunctionReductionForFunctionExpr(*fn, inferredFtv, retTy); TypeFunctionReductionGuessResult result = guesser.guessTypeFunctionReductionForFunctionExpr(*fn, inferredFtv, retTy);
if (result.shouldRecommendAnnotation) if (result.shouldRecommendAnnotation)
reportError(ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, reportError(
fn->location); ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType},
fn->location
);
} }
} }
} }
@ -1881,9 +1901,12 @@ struct TypeChecker2
if ((get<BlockedType>(leftType) || get<FreeType>(leftType) || get<GenericType>(leftType)) && !isEquality && !isLogical) if ((get<BlockedType>(leftType) || get<FreeType>(leftType) || get<GenericType>(leftType)) && !isEquality && !isLogical)
{ {
auto name = getIdentifierOfBaseVar(expr->left); auto name = getIdentifierOfBaseVar(expr->left);
reportError(CannotInferBinaryOperation{expr->op, name, reportError(
isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation}, CannotInferBinaryOperation{
expr->location); expr->op, name, isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation
},
expr->location
);
return leftType; return leftType;
} }
@ -1897,7 +1920,8 @@ struct TypeChecker2
if (isEquality && !matches) if (isEquality && !matches)
{ {
auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional<TypeId> otherMt) { auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional<TypeId> otherMt)
{
for (TypeId option : utv) for (TypeId option : utv)
{ {
if (getMetatable(follow(option), builtinTypes) == otherMt) if (getMetatable(follow(option), builtinTypes) == otherMt)
@ -1929,9 +1953,15 @@ struct TypeChecker2
if (!matches && isComparison) if (!matches && isComparison)
{ {
reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", reportError(
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, GenericError{format(
expr->location); "Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(leftType).c_str(),
toString(rightType).c_str(),
toString(expr->op).c_str()
)},
expr->location
);
return builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType();
} }
@ -2034,17 +2064,29 @@ struct TypeChecker2
{ {
if (isComparison) if (isComparison)
{ {
reportError(GenericError{format( reportError(
GenericError{format(
"Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod", "Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod",
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)}, toString(leftType).c_str(),
expr->location); toString(rightType).c_str(),
toString(expr->op).c_str(),
it->second
)},
expr->location
);
} }
else else
{ {
reportError(GenericError{format( reportError(
GenericError{format(
"Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod", "Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod",
toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)}, toString(expr->op).c_str(),
expr->location); toString(leftType).c_str(),
toString(rightType).c_str(),
it->second
)},
expr->location
);
} }
return builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType();
@ -2053,15 +2095,27 @@ struct TypeChecker2
{ {
if (isComparison) if (isComparison)
{ {
reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable", reportError(
toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, GenericError{format(
expr->location); "Types '%s' and '%s' cannot be compared with %s because neither type has a metatable",
toString(leftType).c_str(),
toString(rightType).c_str(),
toString(expr->op).c_str()
)},
expr->location
);
} }
else else
{ {
reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable", reportError(
toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())}, GenericError{format(
expr->location); "Operator %s is not applicable for '%s' and '%s' because neither type has a metatable",
toString(expr->op).c_str(),
toString(leftType).c_str(),
toString(rightType).c_str()
)},
expr->location
);
} }
return builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType();
@ -2111,9 +2165,15 @@ struct TypeChecker2
return builtinTypes->booleanType; return builtinTypes->booleanType;
} }
reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), reportError(
toString(rightType).c_str(), toString(expr->op).c_str())}, GenericError{format(
expr->location); "Types '%s' and '%s' cannot be compared with relational operator %s",
toString(leftType).c_str(),
toString(rightType).c_str(),
toString(expr->op).c_str()
)},
expr->location
);
return builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType();
} }
@ -2297,13 +2357,23 @@ struct TypeChecker2
size_t typesRequired = alias->typeParams.size(); size_t typesRequired = alias->typeParams.size();
size_t packsRequired = alias->typePackParams.size(); size_t packsRequired = alias->typePackParams.size();
bool hasDefaultTypes = std::any_of(alias->typeParams.begin(), alias->typeParams.end(), [](auto&& el) { bool hasDefaultTypes = std::any_of(
alias->typeParams.begin(),
alias->typeParams.end(),
[](auto&& el)
{
return el.defaultValue.has_value(); return el.defaultValue.has_value();
}); }
);
bool hasDefaultPacks = std::any_of(alias->typePackParams.begin(), alias->typePackParams.end(), [](auto&& el) { bool hasDefaultPacks = std::any_of(
alias->typePackParams.begin(),
alias->typePackParams.end(),
[](auto&& el)
{
return el.defaultValue.has_value(); return el.defaultValue.has_value();
}); }
);
if (!ty->hasParameterList) if (!ty->hasParameterList)
{ {
@ -2385,13 +2455,15 @@ struct TypeChecker2
if (typesProvided != typesRequired || packsProvided != packsRequired) if (typesProvided != typesRequired || packsProvided != packsRequired)
{ {
reportError(IncorrectGenericParameterCount{ reportError(
IncorrectGenericParameterCount{
/* name */ ty->name.value, /* name */ ty->name.value,
/* typeFun */ *alias, /* typeFun */ *alias,
/* actualParameters */ typesProvided, /* actualParameters */ typesProvided,
/* actualPackParameters */ packsProvided, /* actualPackParameters */ packsProvided,
}, },
ty->location); ty->location
);
} }
} }
else else
@ -2403,7 +2475,8 @@ struct TypeChecker2
ty->name.value, ty->name.value,
SwappedGenericTypeParameter::Kind::Type, SwappedGenericTypeParameter::Kind::Type,
}, },
ty->location); ty->location
);
} }
else else
{ {
@ -2501,7 +2574,8 @@ struct TypeChecker2
tp->genericName.value, tp->genericName.value,
SwappedGenericTypeParameter::Kind::Pack, SwappedGenericTypeParameter::Kind::Pack,
}, },
tp->location); tp->location
);
} }
else else
{ {
@ -2715,8 +2789,14 @@ struct TypeChecker2
* contains the prop, and * contains the prop, and
* * A vector of types that do not contain the prop. * * A vector of types that do not contain the prop.
*/ */
PropertyTypes lookupProp(const NormalizedType* norm, const std::string& prop, ValueContext context, const Location& location, PropertyTypes lookupProp(
TypeId astIndexExprType, std::vector<TypeError>& errors) const NormalizedType* norm,
const std::string& prop,
ValueContext context,
const Location& location,
TypeId astIndexExprType,
std::vector<TypeError>& errors
)
{ {
std::vector<TypeId> typesOfProp; std::vector<TypeId> typesOfProp;
std::vector<TypeId> typesMissingTheProp; std::vector<TypeId> typesMissingTheProp;
@ -2724,7 +2804,8 @@ struct TypeChecker2
// this is `false` if we ever hit the resource limits during any of our uses of `fetch`. // this is `false` if we ever hit the resource limits during any of our uses of `fetch`.
bool normValid = true; bool normValid = true;
auto fetch = [&](TypeId ty) { auto fetch = [&](TypeId ty)
{
NormalizationResult result = normalizer.isInhabited(ty); NormalizationResult result = normalizer.isInhabited(ty);
if (result == NormalizationResult::HitLimits) if (result == NormalizationResult::HitLimits)
normValid = false; normValid = false;
@ -2875,8 +2956,15 @@ struct TypeChecker2
std::optional<TypeId> result; std::optional<TypeId> result;
}; };
PropertyType hasIndexTypeFromType(TypeId ty, const std::string& prop, ValueContext context, const Location& location, DenseHashSet<TypeId>& seen, PropertyType hasIndexTypeFromType(
TypeId astIndexExprType, std::vector<TypeError>& errors) TypeId ty,
const std::string& prop,
ValueContext context,
const Location& location,
DenseHashSet<TypeId>& seen,
TypeId astIndexExprType,
std::vector<TypeError>& errors
)
{ {
// If we have already encountered this type, we must assume that some // If we have already encountered this type, we must assume that some
// other codepath will do the right thing and signal false if the // other codepath will do the right thing and signal false if the
@ -2982,7 +3070,8 @@ struct TypeChecker2
std::string_view sv(utk->key); std::string_view sv(utk->key);
std::set<Name> candidates; std::set<Name> candidates;
auto accumulate = [&](const TableType::Props& props) { auto accumulate = [&](const TableType::Props& props)
{
for (const auto& [name, ty] : props) for (const auto& [name, ty] : props)
{ {
if (sv != name && equalsLower(sv, name)) if (sv != name && equalsLower(sv, name))
@ -3055,8 +3144,14 @@ struct TypeChecker2
} }
}; };
void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> unifierState, NotNull<TypeCheckLimits> limits, DcrLogger* logger, void check(
const SourceModule& sourceModule, Module* module) NotNull<BuiltinTypes> builtinTypes,
NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits,
DcrLogger* logger,
const SourceModule& sourceModule,
Module* module
)
{ {
LUAU_TIMETRACE_SCOPE("check", "Typechecking"); LUAU_TIMETRACE_SCOPE("check", "Typechecking");
@ -3064,6 +3159,12 @@ void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> unifi
typeChecker.visit(sourceModule.root); typeChecker.visit(sourceModule.root);
// if the only error we're producing is one about constraint solving being incomplete, we can silence it.
// this means we won't give this warning if types seem totally nonsensical, but there are no other errors.
// this is probably, on the whole, a good decision to not annoy users though.
if (module->errors.size() == 1 && get<ConstraintSolvingIncompleteError>(module->errors[0]))
module->errors.clear();
unfreeze(module->interfaceTypes); unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes, builtinTypes); copyErrors(module->errors, module->interfaceTypes, builtinTypes);
freeze(module->interfaceTypes); freeze(module->interfaceTypes);

View File

@ -112,8 +112,15 @@ struct TypeFunctionReducer
// Local to the constraint being reduced. // Local to the constraint being reduced.
Location location; Location location;
TypeFunctionReducer(VecDeque<TypeId> queuedTys, VecDeque<TypePackId> queuedTps, TypeOrTypePackIdSet shouldGuess, std::vector<TypeId> cyclicTypes, TypeFunctionReducer(
Location location, TypeFunctionContext ctx, bool force = false) VecDeque<TypeId> queuedTys,
VecDeque<TypePackId> queuedTps,
TypeOrTypePackIdSet shouldGuess,
std::vector<TypeId> cyclicTypes,
Location location,
TypeFunctionContext ctx,
bool force = false
)
: ctx(ctx) : ctx(ctx)
, queuedTys(std::move(queuedTys)) , queuedTys(std::move(queuedTys))
, queuedTps(std::move(queuedTps)) , queuedTps(std::move(queuedTps))
@ -218,8 +225,12 @@ struct TypeFunctionReducer
else if (!reduction.uninhabited && !force) else if (!reduction.uninhabited && !force)
{ {
if (FFlag::DebugLuauLogTypeFamilies) if (FFlag::DebugLuauLogTypeFamilies)
printf("%s is irreducible; blocked on %zu types, %zu packs\n", toString(subject, {true}).c_str(), reduction.blockedTypes.size(), printf(
reduction.blockedPacks.size()); "%s is irreducible; blocked on %zu types, %zu packs\n",
toString(subject, {true}).c_str(),
reduction.blockedTypes.size(),
reduction.blockedPacks.size()
);
for (TypeId b : reduction.blockedTypes) for (TypeId b : reduction.blockedTypes)
result.blockedTypes.insert(b); result.blockedTypes.insert(b);
@ -371,7 +382,8 @@ struct TypeFunctionReducer
if (tryGuessing(subject)) if (tryGuessing(subject))
return; return;
TypeFunctionReductionResult<TypePackId> result = tfit->function->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); TypeFunctionReductionResult<TypePackId> result =
tfit->function->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx});
handleTypeFunctionReduction(subject, result); handleTypeFunctionReduction(subject, result);
} }
} }
@ -385,8 +397,15 @@ struct TypeFunctionReducer
} }
}; };
static FunctionGraphReductionResult reduceFunctionsInternal(VecDeque<TypeId> queuedTys, VecDeque<TypePackId> queuedTps, TypeOrTypePackIdSet shouldGuess, static FunctionGraphReductionResult reduceFunctionsInternal(
std::vector<TypeId> cyclics, Location location, TypeFunctionContext ctx, bool force) VecDeque<TypeId> queuedTys,
VecDeque<TypePackId> queuedTps,
TypeOrTypePackIdSet shouldGuess,
std::vector<TypeId> cyclics,
Location location,
TypeFunctionContext ctx,
bool force
)
{ {
TypeFunctionReducer reducer{std::move(queuedTys), std::move(queuedTps), std::move(shouldGuess), std::move(cyclics), location, ctx, force}; TypeFunctionReducer reducer{std::move(queuedTys), std::move(queuedTps), std::move(shouldGuess), std::move(cyclics), location, ctx, force};
int iterationCount = 0; int iterationCount = 0;
@ -422,8 +441,15 @@ FunctionGraphReductionResult reduceTypeFunctions(TypeId entrypoint, Location loc
if (collector.tys.empty() && collector.tps.empty()) if (collector.tys.empty() && collector.tps.empty())
return {}; return {};
return reduceFunctionsInternal(std::move(collector.tys), std::move(collector.tps), std::move(collector.shouldGuess), return reduceFunctionsInternal(
std::move(collector.cyclicInstance), location, ctx, force); std::move(collector.tys),
std::move(collector.tps),
std::move(collector.shouldGuess),
std::move(collector.cyclicInstance),
location,
ctx,
force
);
} }
FunctionGraphReductionResult reduceTypeFunctions(TypePackId entrypoint, Location location, TypeFunctionContext ctx, bool force) FunctionGraphReductionResult reduceTypeFunctions(TypePackId entrypoint, Location location, TypeFunctionContext ctx, bool force)
@ -442,8 +468,15 @@ FunctionGraphReductionResult reduceTypeFunctions(TypePackId entrypoint, Location
if (collector.tys.empty() && collector.tps.empty()) if (collector.tys.empty() && collector.tps.empty())
return {}; return {};
return reduceFunctionsInternal(std::move(collector.tys), std::move(collector.tps), std::move(collector.shouldGuess), return reduceFunctionsInternal(
std::move(collector.cyclicInstance), location, ctx, force); std::move(collector.tys),
std::move(collector.tps),
std::move(collector.shouldGuess),
std::move(collector.cyclicInstance),
location,
ctx,
force
);
} }
bool isPending(TypeId ty, ConstraintSolver* solver) bool isPending(TypeId ty, ConstraintSolver* solver)
@ -452,8 +485,14 @@ bool isPending(TypeId ty, ConstraintSolver* solver)
} }
template<typename F, typename... Args> template<typename F, typename... Args>
static std::optional<TypeFunctionReductionResult<TypeId>> tryDistributeTypeFunctionApp(F f, TypeId instance, const std::vector<TypeId>& typeParams, static std::optional<TypeFunctionReductionResult<TypeId>> tryDistributeTypeFunctionApp(
const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx, Args&&... args) F f,
TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx,
Args&&... args
)
{ {
// op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d)
bool uninhabited = false; bool uninhabited = false;
@ -529,7 +568,11 @@ static std::optional<TypeFunctionReductionResult<TypeId>> tryDistributeTypeFunct
} }
TypeFunctionReductionResult<TypeId> notTypeFunction( TypeFunctionReductionResult<TypeId> notTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 1 || !packParams.empty()) if (typeParams.size() != 1 || !packParams.empty())
{ {
@ -553,7 +596,11 @@ TypeFunctionReductionResult<TypeId> notTypeFunction(
} }
TypeFunctionReductionResult<TypeId> lenTypeFunction( TypeFunctionReductionResult<TypeId> lenTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 1 || !packParams.empty()) if (typeParams.size() != 1 || !packParams.empty())
{ {
@ -645,7 +692,11 @@ TypeFunctionReductionResult<TypeId> lenTypeFunction(
} }
TypeFunctionReductionResult<TypeId> unmTypeFunction( TypeFunctionReductionResult<TypeId> unmTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 1 || !packParams.empty()) if (typeParams.size() != 1 || !packParams.empty())
{ {
@ -744,8 +795,13 @@ NotNull<Constraint> TypeFunctionContext::pushConstraint(ConstraintV&& c)
return newConstraint; return newConstraint;
} }
TypeFunctionReductionResult<TypeId> numericBinopTypeFunction(TypeId instance, const std::vector<TypeId>& typeParams, TypeFunctionReductionResult<TypeId> numericBinopTypeFunction(
const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx, const std::string metamethod) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx,
const std::string metamethod
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -848,7 +904,11 @@ TypeFunctionReductionResult<TypeId> numericBinopTypeFunction(TypeId instance, co
} }
TypeFunctionReductionResult<TypeId> addTypeFunction( TypeFunctionReductionResult<TypeId> addTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -860,7 +920,11 @@ TypeFunctionReductionResult<TypeId> addTypeFunction(
} }
TypeFunctionReductionResult<TypeId> subTypeFunction( TypeFunctionReductionResult<TypeId> subTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -872,7 +936,11 @@ TypeFunctionReductionResult<TypeId> subTypeFunction(
} }
TypeFunctionReductionResult<TypeId> mulTypeFunction( TypeFunctionReductionResult<TypeId> mulTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -884,7 +952,11 @@ TypeFunctionReductionResult<TypeId> mulTypeFunction(
} }
TypeFunctionReductionResult<TypeId> divTypeFunction( TypeFunctionReductionResult<TypeId> divTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -896,7 +968,11 @@ TypeFunctionReductionResult<TypeId> divTypeFunction(
} }
TypeFunctionReductionResult<TypeId> idivTypeFunction( TypeFunctionReductionResult<TypeId> idivTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -908,7 +984,11 @@ TypeFunctionReductionResult<TypeId> idivTypeFunction(
} }
TypeFunctionReductionResult<TypeId> powTypeFunction( TypeFunctionReductionResult<TypeId> powTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -920,7 +1000,11 @@ TypeFunctionReductionResult<TypeId> powTypeFunction(
} }
TypeFunctionReductionResult<TypeId> modTypeFunction( TypeFunctionReductionResult<TypeId> modTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -932,7 +1016,11 @@ TypeFunctionReductionResult<TypeId> modTypeFunction(
} }
TypeFunctionReductionResult<TypeId> concatTypeFunction( TypeFunctionReductionResult<TypeId> concatTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -1040,7 +1128,11 @@ TypeFunctionReductionResult<TypeId> concatTypeFunction(
} }
TypeFunctionReductionResult<TypeId> andTypeFunction( TypeFunctionReductionResult<TypeId> andTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -1091,7 +1183,11 @@ TypeFunctionReductionResult<TypeId> andTypeFunction(
} }
TypeFunctionReductionResult<TypeId> orTypeFunction( TypeFunctionReductionResult<TypeId> orTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -1141,8 +1237,13 @@ TypeFunctionReductionResult<TypeId> orTypeFunction(
return {overallResult.result, false, std::move(blockedTypes), {}}; return {overallResult.result, false, std::move(blockedTypes), {}};
} }
static TypeFunctionReductionResult<TypeId> comparisonTypeFunction(TypeId instance, const std::vector<TypeId>& typeParams, static TypeFunctionReductionResult<TypeId> comparisonTypeFunction(
const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx, const std::string metamethod) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx,
const std::string metamethod
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
@ -1281,7 +1382,11 @@ static TypeFunctionReductionResult<TypeId> comparisonTypeFunction(TypeId instanc
} }
TypeFunctionReductionResult<TypeId> ltTypeFunction( TypeFunctionReductionResult<TypeId> ltTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -1293,7 +1398,11 @@ TypeFunctionReductionResult<TypeId> ltTypeFunction(
} }
TypeFunctionReductionResult<TypeId> leTypeFunction( TypeFunctionReductionResult<TypeId> leTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -1305,7 +1414,11 @@ TypeFunctionReductionResult<TypeId> leTypeFunction(
} }
TypeFunctionReductionResult<TypeId> eqTypeFunction( TypeFunctionReductionResult<TypeId> eqTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -1436,7 +1549,11 @@ struct FindRefinementBlockers : TypeOnceVisitor
TypeFunctionReductionResult<TypeId> refineTypeFunction( TypeFunctionReductionResult<TypeId> refineTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -1521,7 +1638,11 @@ TypeFunctionReductionResult<TypeId> refineTypeFunction(
} }
TypeFunctionReductionResult<TypeId> singletonTypeFunction( TypeFunctionReductionResult<TypeId> singletonTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 1 || !packParams.empty()) if (typeParams.size() != 1 || !packParams.empty())
{ {
@ -1558,7 +1679,11 @@ TypeFunctionReductionResult<TypeId> singletonTypeFunction(
} }
TypeFunctionReductionResult<TypeId> unionTypeFunction( TypeFunctionReductionResult<TypeId> unionTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (!packParams.empty()) if (!packParams.empty())
{ {
@ -1619,7 +1744,11 @@ TypeFunctionReductionResult<TypeId> unionTypeFunction(
TypeFunctionReductionResult<TypeId> intersectTypeFunction( TypeFunctionReductionResult<TypeId> intersectTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (!packParams.empty()) if (!packParams.empty())
{ {
@ -1726,7 +1855,11 @@ bool computeKeysOf(TypeId ty, Set<std::string>& result, DenseHashSet<TypeId>& se
} }
TypeFunctionReductionResult<TypeId> keyofFunctionImpl( TypeFunctionReductionResult<TypeId> keyofFunctionImpl(
const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx, bool isRaw) const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx,
bool isRaw
)
{ {
if (typeParams.size() != 1 || !packParams.empty()) if (typeParams.size() != 1 || !packParams.empty())
{ {
@ -1843,7 +1976,11 @@ TypeFunctionReductionResult<TypeId> keyofFunctionImpl(
} }
TypeFunctionReductionResult<TypeId> keyofTypeFunction( TypeFunctionReductionResult<TypeId> keyofTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 1 || !packParams.empty()) if (typeParams.size() != 1 || !packParams.empty())
{ {
@ -1855,7 +1992,11 @@ TypeFunctionReductionResult<TypeId> keyofTypeFunction(
} }
TypeFunctionReductionResult<TypeId> rawkeyofTypeFunction( TypeFunctionReductionResult<TypeId> rawkeyofTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 1 || !packParams.empty()) if (typeParams.size() != 1 || !packParams.empty())
{ {
@ -1870,7 +2011,12 @@ TypeFunctionReductionResult<TypeId> rawkeyofTypeFunction(
If found, appends that property to `result` and returns true If found, appends that property to `result` and returns true
Else, returns false */ Else, returns false */
bool searchPropsAndIndexer( bool searchPropsAndIndexer(
TypeId ty, TableType::Props tblProps, std::optional<TableIndexer> tblIndexer, DenseHashSet<TypeId>& result, NotNull<TypeFunctionContext> ctx) TypeId ty,
TableType::Props tblProps,
std::optional<TableIndexer> tblIndexer,
DenseHashSet<TypeId>& result,
NotNull<TypeFunctionContext> ctx
)
{ {
ty = follow(ty); ty = follow(ty);
@ -1961,7 +2107,11 @@ bool tblIndexInto(TypeId indexer, TypeId indexee, DenseHashSet<TypeId>& result,
indexer refers to the type that is used to access indexee indexer refers to the type that is used to access indexee
Example: index<Person, "name"> => `Person` is the indexee and `"name"` is the indexer */ Example: index<Person, "name"> => `Person` is the indexee and `"name"` is the indexer */
TypeFunctionReductionResult<TypeId> indexFunctionImpl( TypeFunctionReductionResult<TypeId> indexFunctionImpl(
const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx, bool isRaw) const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx,
bool isRaw
)
{ {
TypeId indexeeTy = follow(typeParams.at(0)); TypeId indexeeTy = follow(typeParams.at(0));
std::shared_ptr<const NormalizedType> indexeeNormTy = ctx->normalizer->normalize(indexeeTy); std::shared_ptr<const NormalizedType> indexeeNormTy = ctx->normalizer->normalize(indexeeTy);
@ -2053,9 +2203,15 @@ TypeFunctionReductionResult<TypeId> indexFunctionImpl(
} }
// Call `follow()` on each element to resolve all Bound types before returning // Call `follow()` on each element to resolve all Bound types before returning
std::transform(properties.begin(), properties.end(), properties.begin(), [](TypeId ty) { std::transform(
properties.begin(),
properties.end(),
properties.begin(),
[](TypeId ty)
{
return follow(ty); return follow(ty);
}); }
);
// If the type being reduced to is a single type, no need to union // If the type being reduced to is a single type, no need to union
if (properties.size() == 1) if (properties.size() == 1)
@ -2065,7 +2221,11 @@ TypeFunctionReductionResult<TypeId> indexFunctionImpl(
} }
TypeFunctionReductionResult<TypeId> indexTypeFunction( TypeFunctionReductionResult<TypeId> indexTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -2077,7 +2237,11 @@ TypeFunctionReductionResult<TypeId> indexTypeFunction(
} }
TypeFunctionReductionResult<TypeId> rawgetTypeFunction( TypeFunctionReductionResult<TypeId> rawgetTypeFunction(
TypeId instance, const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFunctionContext> ctx) TypeId instance,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams,
NotNull<TypeFunctionContext> ctx
)
{ {
if (typeParams.size() != 2 || !packParams.empty()) if (typeParams.size() != 2 || !packParams.empty())
{ {
@ -2119,7 +2283,8 @@ BuiltinTypeFunctions::BuiltinTypeFunctions()
void BuiltinTypeFunctions::addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const void BuiltinTypeFunctions::addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const
{ {
// make a type function for a one-argument type function // make a type function for a one-argument type function
auto mkUnaryTypeFunction = [&](const TypeFunction* tf) { auto mkUnaryTypeFunction = [&](const TypeFunction* tf)
{
TypeId t = arena->addType(GenericType{"T"}); TypeId t = arena->addType(GenericType{"T"});
GenericTypeDefinition genericT{t}; GenericTypeDefinition genericT{t};
@ -2127,7 +2292,8 @@ void BuiltinTypeFunctions::addToScope(NotNull<TypeArena> arena, NotNull<Scope> s
}; };
// make a type function for a two-argument type function // make a type function for a two-argument type function
auto mkBinaryTypeFunction = [&](const TypeFunction* tf) { auto mkBinaryTypeFunction = [&](const TypeFunction* tf)
{
TypeId t = arena->addType(GenericType{"T"}); TypeId t = arena->addType(GenericType{"T"});
TypeId u = arena->addType(GenericType{"U"}); TypeId u = arena->addType(GenericType{"U"});
GenericTypeDefinition genericT{t}; GenericTypeDefinition genericT{t};

View File

@ -128,7 +128,10 @@ std::optional<TypePackId> TypeFunctionReductionGuesser::guess(TypePackId tp)
} }
TypeFunctionReductionGuessResult TypeFunctionReductionGuesser::guessTypeFunctionReductionForFunctionExpr( TypeFunctionReductionGuessResult TypeFunctionReductionGuesser::guessTypeFunctionReductionForFunctionExpr(
const AstExprFunction& expr, const FunctionType* ftv, TypeId retTy) const AstExprFunction& expr,
const FunctionType* ftv,
TypeId retTy
)
{ {
InstanceCollector2 collector; InstanceCollector2 collector;
collector.traverse(retTy); collector.traverse(retTy);
@ -204,8 +207,9 @@ std::optional<TypeId> TypeFunctionReductionGuesser::guessType(TypeId arg)
bool TypeFunctionReductionGuesser::isNumericBinopFunction(const TypeFunctionInstanceType& instance) bool TypeFunctionReductionGuesser::isNumericBinopFunction(const TypeFunctionInstanceType& instance)
{ {
return instance.function->name == "add" || instance.function->name == "sub" || instance.function->name == "mul" || instance.function->name == "div" || return instance.function->name == "add" || instance.function->name == "sub" || instance.function->name == "mul" ||
instance.function->name == "idiv" || instance.function->name == "pow" || instance.function->name == "mod"; instance.function->name == "div" || instance.function->name == "idiv" || instance.function->name == "pow" ||
instance.function->name == "mod";
} }
bool TypeFunctionReductionGuesser::isComparisonFunction(const TypeFunctionInstanceType& instance) bool TypeFunctionReductionGuesser::isComparisonFunction(const TypeFunctionInstanceType& instance)
@ -350,7 +354,8 @@ TypeFunctionInferenceResult TypeFunctionReductionGuesser::inferComparisonFunctio
TypeId lhsTy = follow(instance->typeArguments[0]); TypeId lhsTy = follow(instance->typeArguments[0]);
TypeId rhsTy = follow(instance->typeArguments[1]); TypeId rhsTy = follow(instance->typeArguments[1]);
auto comparisonInference = [&](TypeId op) -> TypeFunctionInferenceResult { auto comparisonInference = [&](TypeId op) -> TypeFunctionInferenceResult
{
return TypeFunctionInferenceResult{{op, op}, builtins->booleanType}; return TypeFunctionInferenceResult{{op, op}, builtins->booleanType};
}; };

View File

@ -31,9 +31,7 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300)
LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)
LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false)
LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false)
LUAU_FASTFLAGVARIABLE(LuauReusableSubstitutions, false) LUAU_FASTFLAGVARIABLE(LuauReusableSubstitutions, false)
@ -294,13 +292,6 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
currentModule->cancelled = true; currentModule->cancelled = true;
} }
if (FFlag::DebugLuauSharedSelf)
{
for (auto& [ty, scope] : deferredQuantification)
Luau::quantify(ty, scope->level);
deferredQuantification.clear();
}
if (get<FreeTypePack>(follow(moduleScope->returnType))) if (get<FreeTypePack>(follow(moduleScope->returnType)))
moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt});
else else
@ -379,6 +370,8 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program)
ice("Should not be calling two-argument check() on a function statement", program.location); ice("Should not be calling two-argument check() on a function statement", program.location);
else if (auto typealias = program.as<AstStatTypeAlias>()) else if (auto typealias = program.as<AstStatTypeAlias>())
return check(scope, *typealias); return check(scope, *typealias);
else if (auto typefunction = program.as<AstStatTypeFunction>())
return check(scope, *typefunction);
else if (auto global = program.as<AstStatDeclareGlobal>()) else if (auto global = program.as<AstStatDeclareGlobal>())
{ {
TypeId globalType = resolveType(scope, *global->type); TypeId globalType = resolveType(scope, *global->type);
@ -517,7 +510,8 @@ ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope,
std::unordered_map<AstStat*, std::pair<TypeId, ScopePtr>> functionDecls; std::unordered_map<AstStat*, std::pair<TypeId, ScopePtr>> functionDecls;
auto checkBody = [&](AstStat* stat) { auto checkBody = [&](AstStat* stat)
{
if (auto fun = stat->as<AstStatFunction>()) if (auto fun = stat->as<AstStatFunction>())
{ {
LUAU_ASSERT(functionDecls.count(stat)); LUAU_ASSERT(functionDecls.count(stat));
@ -581,32 +575,9 @@ ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope,
} }
else if (auto fun = (*protoIter)->as<AstStatFunction>()) else if (auto fun = (*protoIter)->as<AstStatFunction>())
{ {
std::optional<TypeId> selfType; std::optional<TypeId> selfType; // TODO clip
std::optional<TypeId> expectedType; std::optional<TypeId> expectedType;
if (FFlag::DebugLuauSharedSelf)
{
if (auto name = fun->name->as<AstExprIndexName>())
{
TypeId baseTy = checkExpr(scope, *name->expr).type;
tablify(baseTy);
if (!fun->func->self)
expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, /* addErrors= */ false);
else if (auto ttv = getMutableTableType(baseTy))
{
if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy)
{
ttv->selfTy = anyIfNonstrict(freshType(ttv->level));
deferredQuantification.push_back({baseTy, scope});
}
selfType = ttv->selfTy;
}
}
}
else
{
if (!fun->func->self) if (!fun->func->self)
{ {
if (auto name = fun->name->as<AstExprIndexName>()) if (auto name = fun->name->as<AstExprIndexName>())
@ -615,7 +586,6 @@ ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope,
expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false); expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false);
} }
} }
}
auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, selfType, expectedType); auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, selfType, expectedType);
auto [funTy, funScope] = pair; auto [funTy, funScope] = pair;
@ -1563,14 +1533,26 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty
// Additionally, we can't modify types that come from other modules // Additionally, we can't modify types that come from other modules
if (ttv->name || follow(ty)->owningArena != &currentModule->internalTypes) if (ttv->name || follow(ty)->owningArena != &currentModule->internalTypes)
{ {
bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), bool sameTys = std::equal(
binding->typeParams.end(), [](auto&& itp, auto&& tp) { ttv->instantiatedTypeParams.begin(),
ttv->instantiatedTypeParams.end(),
binding->typeParams.begin(),
binding->typeParams.end(),
[](auto&& itp, auto&& tp)
{
return itp == tp.ty; return itp == tp.ty;
}); }
bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), binding->typePackParams.begin(), );
binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) { bool sameTps = std::equal(
ttv->instantiatedTypePackParams.begin(),
ttv->instantiatedTypePackParams.end(),
binding->typePackParams.begin(),
binding->typePackParams.end(),
[](auto&& itpp, auto&& tpp)
{
return itpp == tpp.tp; return itpp == tpp.tp;
}); }
);
// Copy can be skipped if this is an identical alias // Copy can be skipped if this is an identical alias
if (!ttv->name || ttv->name != name || !sameTys || !sameTps) if (!ttv->name || ttv->name != name || !sameTys || !sameTps)
@ -1630,6 +1612,13 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& ty
return ControlFlow::None; return ControlFlow::None;
} }
ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeFunction& typefunction)
{
reportError(TypeError{typefunction.location, GenericError{"This syntax is not supported"}});
return ControlFlow::None;
}
void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel) void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel)
{ {
Name name = typealias.name.value; Name name = typealias.name.value;
@ -1704,8 +1693,10 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de
if (!get<ClassType>(follow(*superTy))) if (!get<ClassType>(follow(*superTy)))
{ {
reportError(declaredClass.location, reportError(
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); declaredClass.location,
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}
);
incorrectClassDefinitions.insert(&declaredClass); incorrectClassDefinitions.insert(&declaredClass);
return; return;
} }
@ -1852,15 +1843,27 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFuncti
std::vector<TypeId> genericTys; std::vector<TypeId> genericTys;
genericTys.reserve(generics.size()); genericTys.reserve(generics.size());
std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { std::transform(
generics.begin(),
generics.end(),
std::back_inserter(genericTys),
[](auto&& el)
{
return el.ty; return el.ty;
}); }
);
std::vector<TypePackId> genericTps; std::vector<TypePackId> genericTps;
genericTps.reserve(genericPacks.size()); genericTps.reserve(genericPacks.size());
std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { std::transform(
genericPacks.begin(),
genericPacks.end(),
std::back_inserter(genericTps),
[](auto&& el)
{
return el.tp; return el.tp;
}); }
);
TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId argPack = resolveTypePack(funScope, global.params);
TypePackId retPack = resolveTypePack(funScope, global.retTypes); TypePackId retPack = resolveTypePack(funScope, global.retTypes);
@ -2085,7 +2088,12 @@ std::optional<TypeId> TypeChecker::findMetatableEntry(TypeId type, std::string e
} }
std::optional<TypeId> TypeChecker::getIndexTypeFromType( std::optional<TypeId> TypeChecker::getIndexTypeFromType(
const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) const ScopePtr& scope,
TypeId type,
const Name& name,
const Location& location,
bool addErrors
)
{ {
size_t errorCount = currentModule->errors.size(); size_t errorCount = currentModule->errors.size();
@ -2098,7 +2106,12 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
} }
std::optional<TypeId> TypeChecker::getIndexTypeFromTypeImpl( std::optional<TypeId> TypeChecker::getIndexTypeFromTypeImpl(
const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) const ScopePtr& scope,
TypeId type,
const Name& name,
const Location& location,
bool addErrors
)
{ {
type = follow(type); type = follow(type);
@ -2297,7 +2310,11 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
} }
TypeId TypeChecker::checkExprTable( TypeId TypeChecker::checkExprTable(
const ScopePtr& scope, const AstExprTable& expr, const std::vector<std::pair<TypeId, TypeId>>& fieldTypes, std::optional<TypeId> expectedType) const ScopePtr& scope,
const AstExprTable& expr,
const std::vector<std::pair<TypeId, TypeId>>& fieldTypes,
std::optional<TypeId> expectedType
)
{ {
TableType::Props props; TableType::Props props;
std::optional<TableIndexer> indexer; std::optional<TableIndexer> indexer;
@ -2526,8 +2543,10 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
return WithPredicate{retType}; return WithPredicate{retType};
} }
reportError(expr.location, reportError(
GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); expr.location,
GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}
);
return WithPredicate{errorRecoveryType(scope)}; return WithPredicate{errorRecoveryType(scope)};
} }
@ -2674,7 +2693,8 @@ static std::optional<bool> areEqComparable(NotNull<TypeArena> arena, NotNull<Nor
a = follow(a); a = follow(a);
b = follow(b); b = follow(b);
auto isExempt = [](TypeId t) { auto isExempt = [](TypeId t)
{
return isNil(t) || get<FreeType>(t); return isNil(t) || get<FreeType>(t);
}; };
@ -2705,9 +2725,15 @@ static std::optional<bool> areEqComparable(NotNull<TypeArena> arena, NotNull<Nor
} }
TypeId TypeChecker::checkRelationalOperation( TypeId TypeChecker::checkRelationalOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates) const ScopePtr& scope,
const AstExprBinary& expr,
TypeId lhsType,
TypeId rhsType,
const PredicateVec& predicates
)
{
auto stripNil = [this](TypeId ty, bool isOrOp = false)
{ {
auto stripNil = [this](TypeId ty, bool isOrOp = false) {
ty = follow(ty); ty = follow(ty);
if (!isNonstrictMode() && !isOrOp) if (!isNonstrictMode() && !isOrOp)
return ty; return ty;
@ -2788,7 +2814,8 @@ TypeId TypeChecker::checkRelationalOperation(
if (!*eqTestResult) if (!*eqTestResult)
{ {
reportError( reportError(
expr.location, GenericError{format("Type %s cannot be compared with %s", toString(lhsType).c_str(), toString(rhsType).c_str())}); expr.location, GenericError{format("Type %s cannot be compared with %s", toString(lhsType).c_str(), toString(rhsType).c_str())}
);
return errorRecoveryType(booleanType); return errorRecoveryType(booleanType);
} }
} }
@ -2821,16 +2848,24 @@ TypeId TypeChecker::checkRelationalOperation(
// we need to be conservative in the old solver to deliver a reasonable developer experience. // we need to be conservative in the old solver to deliver a reasonable developer experience.
if (!isEquality && state.errors.empty() && isBoolean(leftType)) if (!isEquality && state.errors.empty() && isBoolean(leftType))
{ {
reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", reportError(
toString(leftType).c_str(), toString(expr.op).c_str())}); expr.location,
GenericError{
format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())
}
);
} }
} }
else else
{ {
if (!isEquality && state.errors.empty() && (get<UnionType>(leftType) || isBoolean(leftType))) if (!isEquality && state.errors.empty() && (get<UnionType>(leftType) || isBoolean(leftType)))
{ {
reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", reportError(
toString(leftType).c_str(), toString(expr.op).c_str())}); expr.location,
GenericError{
format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())
}
);
} }
} }
@ -2879,8 +2914,14 @@ TypeId TypeChecker::checkRelationalOperation(
if (!matches) if (!matches)
{ {
reportError( reportError(
expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", expr.location,
toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); GenericError{format(
"Types %s and %s cannot be compared with %s because they do not have the same metatable",
toString(lhsType).c_str(),
toString(rhsType).c_str(),
toString(expr.op).c_str()
)}
);
return errorRecoveryType(booleanType); return errorRecoveryType(booleanType);
} }
} }
@ -2911,7 +2952,8 @@ TypeId TypeChecker::checkRelationalOperation(
TypeId actualFunctionType = addType(FunctionType(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType}))); TypeId actualFunctionType = addType(FunctionType(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType})));
state.tryUnify( state.tryUnify(
instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true); instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true
);
state.log.commit(); state.log.commit();
@ -2921,7 +2963,8 @@ TypeId TypeChecker::checkRelationalOperation(
else if (needsMetamethod) else if (needsMetamethod)
{ {
reportError( reportError(
expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}
);
return errorRecoveryType(booleanType); return errorRecoveryType(booleanType);
} }
} }
@ -2935,8 +2978,12 @@ TypeId TypeChecker::checkRelationalOperation(
if (needsMetamethod) if (needsMetamethod)
{ {
reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", reportError(
toString(lhsType).c_str(), toString(expr.op).c_str())}); expr.location,
GenericError{
format("Type %s cannot be compared with %s because it has no metatable", toString(lhsType).c_str(), toString(expr.op).c_str())
}
);
return errorRecoveryType(booleanType); return errorRecoveryType(booleanType);
} }
@ -3006,7 +3053,12 @@ TypeId TypeChecker::checkRelationalOperation(
} }
TypeId TypeChecker::checkBinaryOperation( TypeId TypeChecker::checkBinaryOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates) const ScopePtr& scope,
const AstExprBinary& expr,
TypeId lhsType,
TypeId rhsType,
const PredicateVec& predicates
)
{ {
switch (expr.op) switch (expr.op)
{ {
@ -3057,7 +3109,8 @@ TypeId TypeChecker::checkBinaryOperation(
if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType))
{ {
auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId
{
TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypeId actualFunctionType = instantiate(scope, fnt, expr.location);
TypePackId arguments = addTypePack({lhst, rhst}); TypePackId arguments = addTypePack({lhst, rhst});
TypePackId retTypePack = freshTypePack(scope); TypePackId retTypePack = freshTypePack(scope);
@ -3104,8 +3157,15 @@ TypeId TypeChecker::checkBinaryOperation(
return checkMetatableCall(*fnt, rhsType, lhsType); return checkMetatableCall(*fnt, rhsType, lhsType);
} }
reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), reportError(
toString(lhsType).c_str(), toString(rhsType).c_str())}); expr.location,
GenericError{format(
"Binary operator '%s' not supported by types '%s' and '%s'",
toString(expr.op).c_str(),
toString(lhsType).c_str(),
toString(rhsType).c_str()
)}
);
return errorRecoveryType(scope); return errorRecoveryType(scope);
} }
@ -3537,7 +3597,8 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
// Primarily about detecting duplicates. // Primarily about detecting duplicates.
TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level) TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level)
{ {
auto freshTy = [&]() { auto freshTy = [&]()
{
return freshType(level); return freshType(level);
}; };
@ -3610,8 +3671,14 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T
// `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X` // `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X`
// to get type `(X) -> X`, then we quantify the free types to get the final // to get type `(X) -> X`, then we quantify the free types to get the final
// generic type `<a>(a) -> a`. // generic type `<a>(a) -> a`.
std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
std::optional<Location> originalName, std::optional<TypeId> selfType, std::optional<TypeId> expectedType) const ScopePtr& scope,
int subLevel,
const AstExprFunction& expr,
std::optional<Location> originalName,
std::optional<TypeId> selfType,
std::optional<TypeId> expectedType
)
{ {
ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel); ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel);
@ -3704,26 +3771,12 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(const ScopePtr&
funScope->returnType = retPack; funScope->returnType = retPack;
if (FFlag::DebugLuauSharedSelf)
{
if (expr.self) if (expr.self)
{ {
// TODO: generic self types: CLI-39906
TypeId selfTy = anyIfNonstrict(selfType ? *selfType : freshType(funScope));
funScope->bindings[expr.self] = {selfTy, expr.self->location};
argTypes.push_back(selfTy);
}
}
else
{
if (expr.self)
{
// TODO: generic self types: CLI-39906
TypeId selfType = anyIfNonstrict(freshType(funScope)); TypeId selfType = anyIfNonstrict(freshType(funScope));
funScope->bindings[expr.self] = {selfType, expr.self->location}; funScope->bindings[expr.self] = {selfType, expr.self->location};
argTypes.push_back(selfType); argTypes.push_back(selfType);
} }
}
// Prepare expected argument type iterators if we have an expected function type // Prepare expected argument type iterators if we have an expected function type
TypePackIterator expectedArgsCurr, expectedArgsEnd; TypePackIterator expectedArgsCurr, expectedArgsEnd;
@ -3911,8 +3964,14 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
} }
} }
void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId argPack, TypePackId paramPack, void TypeChecker::checkArgumentList(
const std::vector<Location>& argLocations) const ScopePtr& scope,
const AstExpr& funName,
Unifier& state,
TypePackId argPack,
TypePackId paramPack,
const std::vector<Location>& argLocations
)
{ {
/* Important terminology refresher: /* Important terminology refresher:
* A function requires parameters. * A function requires parameters.
@ -3924,7 +3983,8 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
size_t paramIndex = 0; size_t paramIndex = 0;
auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]() { auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]()
{
// For this case, we want the error span to cover every errant extra parameter // For this case, we want the error span to cover every errant extra parameter
Location location = state.location; Location location = state.location;
if (!argLocations.empty()) if (!argLocations.empty())
@ -3936,8 +3996,10 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
namePath = *path; namePath = *path;
auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack);
state.reportError(TypeError{location, state.reportError(TypeError{
CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}}); location,
CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}
});
}; };
while (true) while (true)
@ -4044,7 +4106,8 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
namePath = *path; namePath = *path;
state.reportError(TypeError{ state.reportError(TypeError{
funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}
});
return; return;
} }
++paramIter; ++paramIter;
@ -4188,7 +4251,8 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
// We break this function up into a lambda here to limit our stack footprint. // We break this function up into a lambda here to limit our stack footprint.
// The vectors used by this function aren't allocated until the lambda is actually called. // The vectors used by this function aren't allocated until the lambda is actually called.
auto the_rest = [&]() -> WithPredicate<TypePackId> { auto the_rest = [&]() -> WithPredicate<TypePackId>
{
// checkExpr will log the pre-instantiated type of the function. // checkExpr will log the pre-instantiated type of the function.
// That's not nearly as interesting as the instantiated type, which will include details about how // That's not nearly as interesting as the instantiated type, which will include details about how
// generic functions are being instantiated for this particular callsite. // generic functions are being instantiated for this particular callsite.
@ -4231,7 +4295,8 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
fn = follow(fn); fn = follow(fn);
if (auto ret = checkCallOverload( if (auto ret = checkCallOverload(
scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors
))
return *ret; return *ret;
} }
@ -4258,7 +4323,8 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
{ {
std::vector<std::optional<TypeId>> expectedTypes; std::vector<std::optional<TypeId>> expectedTypes;
auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { auto assignOption = [this, &expectedTypes](size_t index, TypeId ty)
{
if (index == expectedTypes.size()) if (index == expectedTypes.size())
{ {
expectedTypes.push_back(ty); expectedTypes.push_back(ty);
@ -4317,9 +4383,19 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
* If this was an optional, callers would have to pay the stack cost for the result. This is problematic * If this was an optional, callers would have to pay the stack cost for the result. This is problematic
* for functions that need to support recursion up to 600 levels deep. * for functions that need to support recursion up to 600 levels deep.
*/ */
std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(
TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult, const ScopePtr& scope,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors) const AstExprCall& expr,
TypeId fn,
TypePackId retPack,
TypePackId argPack,
TypePack* args,
const std::vector<Location>* argLocations,
const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount,
std::vector<TypeId>& overloadsThatDont,
std::vector<OverloadErrorEntry>& errors
)
{ {
LUAU_ASSERT(argLocations); LUAU_ASSERT(argLocations);
@ -4453,8 +4529,13 @@ std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const
return nullptr; return nullptr;
} }
bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations, bool TypeChecker::handleSelfCallMismatch(
const std::vector<OverloadErrorEntry>& errors) const ScopePtr& scope,
const AstExprCall& expr,
TypePack* args,
const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors
)
{ {
// No overloads succeeded: Scan for one that would have worked had the user // No overloads succeeded: Scan for one that would have worked had the user
// used a.b() rather than a:b() or vice versa. // used a.b() rather than a:b() or vice versa.
@ -4521,13 +4602,19 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal
return false; return false;
} }
void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, void TypeChecker::reportOverloadResolutionError(
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount, const ScopePtr& scope,
std::vector<OverloadErrorEntry>& errors) const AstExprCall& expr,
TypePackId retPack,
TypePackId argPack,
const std::vector<Location>& argLocations,
const std::vector<TypeId>& overloads,
const std::vector<TypeId>& overloadsThatMatchArgCount,
std::vector<OverloadErrorEntry>& errors
)
{ {
if (overloads.size() == 1) if (overloads.size() == 1)
{ {
if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls)
errors.front().log.commit(); errors.front().log.commit();
reportErrors(errors.front().errors); reportErrors(errors.front().errors);
@ -4551,13 +4638,17 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast
const FunctionType* ftv = get<FunctionType>(overload); const FunctionType* ftv = get<FunctionType>(overload);
auto error = std::find_if(errors.begin(), errors.end(), [ftv](const OverloadErrorEntry& e) { auto error = std::find_if(
errors.begin(),
errors.end(),
[ftv](const OverloadErrorEntry& e)
{
return ftv == e.fnTy; return ftv == e.fnTy;
}); }
);
LUAU_ASSERT(error != errors.end()); LUAU_ASSERT(error != errors.end());
if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls)
error->log.commit(); error->log.commit();
reportErrors(error->errors); reportErrors(error->errors);
@ -4601,14 +4692,21 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast
return; return;
} }
WithPredicate<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs, WithPredicate<TypePackId> TypeChecker::checkExprList(
bool substituteFreeForNil, const std::vector<bool>& instantiateGenerics, const std::vector<std::optional<TypeId>>& expectedTypes) const ScopePtr& scope,
const Location& location,
const AstArray<AstExpr*>& exprs,
bool substituteFreeForNil,
const std::vector<bool>& instantiateGenerics,
const std::vector<std::optional<TypeId>>& expectedTypes
)
{ {
bool uninhabitable = false; bool uninhabitable = false;
TypePackId pack = addTypePack(TypePack{}); TypePackId pack = addTypePack(TypePack{});
PredicateVec predicates; // At the moment we will be pushing all predicate sets into this. Do we need some way to split them up? PredicateVec predicates; // At the moment we will be pushing all predicate sets into this. Do we need some way to split them up?
auto insert = [&predicates](PredicateVec& vec) { auto insert = [&predicates](PredicateVec& vec)
{
for (Predicate& c : vec) for (Predicate& c : vec)
predicates.push_back(std::move(c)); predicates.push_back(std::move(c));
}; };
@ -4875,20 +4973,10 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location
{ {
ty = follow(ty); ty = follow(ty);
if (FFlag::DebugLuauSharedSelf)
{
if (auto ftv = get<FunctionType>(ty))
Luau::quantify(ty, scope->level);
else if (auto ttv = getTableType(ty); ttv && ttv->selfTy)
Luau::quantify(ty, scope->level);
}
else
{
const FunctionType* ftv = get<FunctionType>(ty); const FunctionType* ftv = get<FunctionType>(ty);
if (ftv) if (ftv)
Luau::quantify(ty, scope->level); Luau::quantify(ty, scope->level);
}
return ty; return ty;
} }
@ -5031,11 +5119,17 @@ LUAU_NOINLINE void TypeChecker::throwUserCancelError()
void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec)
{ {
// Remove errors with names that were generated by recovery from a parse error // Remove errors with names that were generated by recovery from a parse error
errVec.erase(std::remove_if(errVec.begin(), errVec.end(), errVec.erase(
[](auto& err) { std::remove_if(
errVec.begin(),
errVec.end(),
[](auto& err)
{
return containsParseErrorName(err); return containsParseErrorName(err);
}), }
errVec.end()); ),
errVec.end()
);
for (auto& err : errVec) for (auto& err : errVec)
{ {
@ -5049,7 +5143,8 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d
std::string_view sv(utk->key); std::string_view sv(utk->key);
std::set<Name> candidates; std::set<Name> candidates;
auto accumulate = [&](const TableType::Props& props) { auto accumulate = [&](const TableType::Props& props)
{
for (const auto& [name, ty] : props) for (const auto& [name, ty] : props)
{ {
if (sv != name && equalsLower(sv, name)) if (sv != name && equalsLower(sv, name))
@ -5103,7 +5198,11 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio
void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) void TypeChecker::merge(RefinementMap& l, const RefinementMap& r)
{ {
Luau::merge(l, r, [this](TypeId a, TypeId b) { Luau::merge(
l,
r,
[this](TypeId a, TypeId b)
{
// TODO: normalize(UnionType{{a, b}}) // TODO: normalize(UnionType{{a, b}})
std::unordered_set<TypeId> set; std::unordered_set<TypeId> set;
@ -5121,7 +5220,8 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r)
if (set.size() == 1) if (set.size() == 1)
return options[0]; return options[0];
return addType(UnionType{std::move(options)}); return addType(UnionType{std::move(options)});
}); }
);
} }
Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location) Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location)
@ -5172,7 +5272,8 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess)
TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense, TypeId emptySetTy) TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense, TypeId emptySetTy)
{ {
return [this, sense, emptySetTy](TypeId ty) -> std::optional<TypeId> { return [this, sense, emptySetTy](TypeId ty) -> std::optional<TypeId>
{
// any/error/free gets a special pass unconditionally because they can't be decided. // any/error/free gets a special pass unconditionally because they can't be decided.
if (get<AnyType>(ty) || get<ErrorType>(ty) || get<FreeType>(ty)) if (get<AnyType>(ty) || get<ErrorType>(ty) || get<FreeType>(ty))
return ty; return ty;
@ -5314,12 +5415,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
return tf->type; return tf->type;
bool parameterCountErrorReported = false; bool parameterCountErrorReported = false;
bool hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { bool hasDefaultTypes = std::any_of(
tf->typeParams.begin(),
tf->typeParams.end(),
[](auto&& el)
{
return el.defaultValue.has_value(); return el.defaultValue.has_value();
}); }
bool hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { );
bool hasDefaultPacks = std::any_of(
tf->typePackParams.begin(),
tf->typePackParams.end(),
[](auto&& el)
{
return el.defaultValue.has_value(); return el.defaultValue.has_value();
}); }
);
if (!lit->hasParameterList) if (!lit->hasParameterList)
{ {
@ -5442,7 +5553,8 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
{ {
if (!parameterCountErrorReported) if (!parameterCountErrorReported)
reportError( reportError(
TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}
);
// Pad the types out with error recovery types // Pad the types out with error recovery types
while (typeParams.size() < tf->typeParams.size()) while (typeParams.size() < tf->typeParams.size())
@ -5451,13 +5563,26 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
typePackParams.push_back(errorRecoveryTypePack(scope)); typePackParams.push_back(errorRecoveryTypePack(scope));
} }
bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { bool sameTys = std::equal(
typeParams.begin(),
typeParams.end(),
tf->typeParams.begin(),
tf->typeParams.end(),
[](auto&& itp, auto&& tp)
{
return itp == tp.ty; return itp == tp.ty;
}); }
);
bool sameTps = std::equal( bool sameTps = std::equal(
typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { typePackParams.begin(),
typePackParams.end(),
tf->typePackParams.begin(),
tf->typePackParams.end(),
[](auto&& itpp, auto&& tpp)
{
return itpp == tpp.tp; return itpp == tpp.tp;
}); }
);
// If the generic parameters and the type arguments are the same, we are about to // If the generic parameters and the type arguments are the same, we are about to
// perform an identity substitution, which we can just short-circuit. // perform an identity substitution, which we can just short-circuit.
@ -5512,15 +5637,27 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
std::vector<TypeId> genericTys; std::vector<TypeId> genericTys;
genericTys.reserve(generics.size()); genericTys.reserve(generics.size());
std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { std::transform(
generics.begin(),
generics.end(),
std::back_inserter(genericTys),
[](auto&& el)
{
return el.ty; return el.ty;
}); }
);
std::vector<TypePackId> genericTps; std::vector<TypePackId> genericTps;
genericTps.reserve(genericPacks.size()); genericTps.reserve(genericPacks.size());
std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { std::transform(
genericPacks.begin(),
genericPacks.end(),
std::back_inserter(genericTps),
[](auto&& el)
{
return el.tp; return el.tp;
}); }
);
TypeId fnType = addType(FunctionType{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes}); TypeId fnType = addType(FunctionType{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes});
@ -5641,8 +5778,13 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack
return result; return result;
} }
TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams, TypeId TypeChecker::instantiateTypeFun(
const std::vector<TypePackId>& typePackParams, const Location& location) const ScopePtr& scope,
const TypeFun& tf,
const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams,
const Location& location
)
{ {
if (tf.typeParams.empty() && tf.typePackParams.empty()) if (tf.typeParams.empty() && tf.typePackParams.empty())
return tf.type; return tf.type;
@ -5706,8 +5848,14 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
return instantiated; return instantiated;
} }
GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, GenericTypeDefinitions TypeChecker::createGenericTypes(
const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericTypePack>& genericPackNames, bool useCache) const ScopePtr& scope,
std::optional<TypeLevel> levelOpt,
const AstNode& node,
const AstArray<AstGenericType>& genericNames,
const AstArray<AstGenericTypePack>& genericPackNames,
bool useCache
)
{ {
LUAU_ASSERT(scope->parent); LUAU_ASSERT(scope->parent);
@ -5835,7 +5983,8 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const
} }
} }
auto intoType = [this](const std::unordered_set<TypeId>& s) -> std::optional<TypeId> { auto intoType = [this](const std::unordered_set<TypeId>& s) -> std::optional<TypeId>
{
if (s.empty()) if (s.empty())
return std::nullopt; return std::nullopt;
@ -6022,7 +6171,8 @@ void TypeChecker::resolve(const OrPredicate& orP, RefinementMap& refis, const Sc
void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense) void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense)
{ {
auto predicate = [&](TypeId option) -> std::optional<TypeId> { auto predicate = [&](TypeId option) -> std::optional<TypeId>
{
// This by itself is not truly enough to determine that A is stronger than B or vice versa. // This by itself is not truly enough to determine that A is stronger than B or vice versa.
bool optionIsSubtype = canUnify(option, isaP.ty, scope, isaP.location).empty(); bool optionIsSubtype = canUnify(option, isaP.ty, scope, isaP.location).empty();
bool targetIsSubtype = canUnify(isaP.ty, option, scope, isaP.location).empty(); bool targetIsSubtype = canUnify(isaP.ty, option, scope, isaP.location).empty();
@ -6085,8 +6235,10 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r
return; return;
} }
auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional<TypeId> mapsTo = std::nullopt) { auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional<TypeId> mapsTo = std::nullopt)
TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional<TypeId> { {
TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional<TypeId>
{
if (sense && get<UnknownType>(ty)) if (sense && get<UnknownType>(ty))
return mapsTo.value_or(ty); return mapsTo.value_or(ty);
@ -6117,22 +6269,31 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r
return refine(isBuffer, bufferType); return refine(isBuffer, bufferType);
else if (typeguardP.kind == "table") else if (typeguardP.kind == "table")
{ {
return refine([](TypeId ty) -> bool { return refine(
[](TypeId ty) -> bool
{
return isTableIntersection(ty) || get<TableType>(ty) || get<MetatableType>(ty); return isTableIntersection(ty) || get<TableType>(ty) || get<MetatableType>(ty);
}); }
);
} }
else if (typeguardP.kind == "function") else if (typeguardP.kind == "function")
{ {
return refine([](TypeId ty) -> bool { return refine(
[](TypeId ty) -> bool
{
return isOverloadedFunction(ty) || get<FunctionType>(ty); return isOverloadedFunction(ty) || get<FunctionType>(ty);
}); }
);
} }
else if (typeguardP.kind == "userdata") else if (typeguardP.kind == "userdata")
{ {
// For now, we don't really care about being accurate with userdata if the typeguard was using typeof. // For now, we don't really care about being accurate with userdata if the typeguard was using typeof.
return refine([](TypeId ty) -> bool { return refine(
[](TypeId ty) -> bool
{
return get<ClassType>(ty); return get<ClassType>(ty);
}); }
);
} }
if (!typeguardP.isTypeof) if (!typeguardP.isTypeof)
@ -6162,7 +6323,8 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r
void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense) void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense)
{ {
// This refinement will require success typing to do everything correctly. For now, we can get most of the way there. // This refinement will require success typing to do everything correctly. For now, we can get most of the way there.
auto options = [](TypeId ty) -> std::vector<TypeId> { auto options = [](TypeId ty) -> std::vector<TypeId>
{
if (auto utv = get<UnionType>(follow(ty))) if (auto utv = get<UnionType>(follow(ty)))
return std::vector<TypeId>(begin(utv), end(utv)); return std::vector<TypeId>(begin(utv), end(utv));
return {ty}; return {ty};
@ -6173,7 +6335,8 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc
if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable))
return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here.
auto predicate = [&](TypeId option) -> std::optional<TypeId> { auto predicate = [&](TypeId option) -> std::optional<TypeId>
{
if (!sense && isNil(eqP.type)) if (!sense && isNil(eqP.type))
return (isUndecidable(option) || !isNil(option)) ? std::optional<TypeId>(option) : std::nullopt; return (isUndecidable(option) || !isNil(option)) ? std::optional<TypeId>(option) : std::nullopt;

View File

@ -257,14 +257,20 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs)
TypePackId follow(TypePackId tp) TypePackId follow(TypePackId tp)
{ {
return follow(tp, nullptr, [](const void*, TypePackId t) { return follow(
tp,
nullptr,
[](const void*, TypePackId t)
{
return t; return t;
}); }
);
} }
TypePackId follow(TypePackId tp, const void* context, TypePackId (*mapper)(const void*, TypePackId)) TypePackId follow(TypePackId tp, const void* context, TypePackId (*mapper)(const void*, TypePackId))
{ {
auto advance = [context, mapper](TypePackId ty) -> std::optional<TypePackId> { auto advance = [context, mapper](TypePackId ty) -> std::optional<TypePackId>
{
TypePackId mapped = mapper(context, ty); TypePackId mapped = mapper(context, ty);
if (const Unifiable::Bound<TypePackId>* btv = get<Unifiable::Bound<TypePackId>>(mapped)) if (const Unifiable::Bound<TypePackId>* btv = get<Unifiable::Bound<TypePackId>>(mapped))

View File

@ -534,7 +534,8 @@ std::string toString(const TypePath::Path& path, bool prefixDot)
std::stringstream result; std::stringstream result;
bool first = true; bool first = true;
auto strComponent = [&](auto&& c) { auto strComponent = [&](auto&& c)
{
using T = std::decay_t<decltype(c)>; using T = std::decay_t<decltype(c)>;
if constexpr (std::is_same_v<T, TypePath::Property>) if constexpr (std::is_same_v<T, TypePath::Property>)
{ {
@ -626,7 +627,8 @@ std::string toString(const TypePath::Path& path, bool prefixDot)
static bool traverse(TraversalState& state, const Path& path) static bool traverse(TraversalState& state, const Path& path)
{ {
auto step = [&state](auto&& c) { auto step = [&state](auto&& c)
{
return state.traverse(c); return state.traverse(c);
}; };

View File

@ -24,7 +24,8 @@ bool occursCheck(TypeId needle, TypeId haystack)
LUAU_ASSERT(get<BlockedType>(needle) || get<PendingExpansionType>(needle)); LUAU_ASSERT(get<BlockedType>(needle) || get<PendingExpansionType>(needle));
haystack = follow(haystack); haystack = follow(haystack);
auto checkHaystack = [needle](TypeId haystack) { auto checkHaystack = [needle](TypeId haystack)
{
return occursCheck(needle, haystack); return occursCheck(needle, haystack);
}; };
@ -92,7 +93,12 @@ std::optional<Property> findTableProperty(NotNull<BuiltinTypes> builtinTypes, Er
} }
std::optional<TypeId> findMetatableEntry( std::optional<TypeId> findMetatableEntry(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId type,
const std::string& entry,
Location location
)
{ {
type = follow(type); type = follow(type);
@ -120,13 +126,24 @@ std::optional<TypeId> findMetatableEntry(
} }
std::optional<TypeId> findTablePropertyRespectingMeta( std::optional<TypeId> findTablePropertyRespectingMeta(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId ty,
const std::string& name,
Location location
)
{ {
return findTablePropertyRespectingMeta(builtinTypes, errors, ty, name, ValueContext::RValue, location); return findTablePropertyRespectingMeta(builtinTypes, errors, ty, name, ValueContext::RValue, location);
} }
std::optional<TypeId> findTablePropertyRespectingMeta( std::optional<TypeId> findTablePropertyRespectingMeta(
NotNull<BuiltinTypes> builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, ValueContext context, Location location) NotNull<BuiltinTypes> builtinTypes,
ErrorVec& errors,
TypeId ty,
const std::string& name,
ValueContext context,
Location location
)
{ {
if (get<AnyType>(ty)) if (get<AnyType>(ty))
return ty; return ty;
@ -217,7 +234,12 @@ std::pair<size_t, std::optional<size_t>> getParameterExtents(const TxnLog* log,
} }
TypePack extendTypePack( TypePack extendTypePack(
TypeArena& arena, NotNull<BuiltinTypes> builtinTypes, TypePackId pack, size_t length, std::vector<std::optional<TypeId>> overrides) TypeArena& arena,
NotNull<BuiltinTypes> builtinTypes,
TypePackId pack,
size_t length,
std::vector<std::optional<TypeId>> overrides
)
{ {
TypePack result; TypePack result;

View File

@ -19,7 +19,6 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit)
LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauErrorRecoveryType)
LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false)
LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false)
LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false) LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false)
@ -329,7 +328,8 @@ TypePackId Widen::operator()(TypePackId tp)
std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors) std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
{ {
auto isUnificationTooComplex = [](const TypeError& te) { auto isUnificationTooComplex = [](const TypeError& te)
{
return nullptr != get<UnificationTooComplex>(te); return nullptr != get<UnificationTooComplex>(te);
}; };
@ -342,7 +342,8 @@ std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
std::optional<TypeError> hasCountMismatch(const ErrorVec& errors) std::optional<TypeError> hasCountMismatch(const ErrorVec& errors)
{ {
auto isCountMismatch = [](const TypeError& te) { auto isCountMismatch = [](const TypeError& te)
{
return nullptr != get<CountMismatch>(te); return nullptr != get<CountMismatch>(te);
}; };
@ -771,47 +772,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ
} }
} }
if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls)
log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types});
else
{
// even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option.
auto tryBind = [this, subTy](TypeId superOption) {
superOption = log.follow(superOption);
// just skip if the superOption is not free-ish.
auto ttv = log.getMutable<TableType>(superOption);
if (!log.is<FreeType>(superOption) && (!ttv || ttv->state != TableState::Free))
return;
// If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype
// test is successful.
if (auto subUnion = get<UnionType>(subTy))
{
if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption))
return;
}
// Since we have already checked if S <: T, checking it again will not queue up the type for replacement.
// So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set.
if (log.haveSeen(subTy, superOption))
{
// TODO: would it be nice for TxnLog::replace to do this?
if (log.is<TableType>(superOption))
log.bindTable(superOption, subTy);
else
log.replace(superOption, *subTy);
}
};
if (auto superUnion = log.getMutable<UnionType>(superTy))
{
for (TypeId ty : superUnion)
tryBind(ty);
}
else
tryBind(superTy);
}
if (unificationTooComplex) if (unificationTooComplex)
reportError(*unificationTooComplex); reportError(*unificationTooComplex);
@ -954,7 +915,8 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
return reportError(location, NormalizationTooComplex{}); return reportError(location, NormalizationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
innerState.tryUnifyNormalizedTypes( innerState.tryUnifyNormalizedTypes(
subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption
);
else else
innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible");
@ -985,7 +947,8 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
failure = true; failure = true;
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
reportError( reportError(
location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption, mismatchContext()}); location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption, mismatchContext()}
);
else else
reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible", mismatchContext()}); reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible", mismatchContext()});
} }
@ -1151,7 +1114,13 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
} }
void Unifier::tryUnifyNormalizedTypes( void Unifier::tryUnifyNormalizedTypes(
TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional<TypeError> error) TypeId subTy,
TypeId superTy,
const NormalizedType& subNorm,
const NormalizedType& superNorm,
std::string reason,
std::optional<TypeError> error
)
{ {
if (get<AnyType>(superNorm.tops)) if (get<AnyType>(superNorm.tops))
return; return;
@ -1394,7 +1363,8 @@ bool Unifier::canCacheResult(TypeId subTy, TypeId superTy)
if (subTyInfo && *subTyInfo) if (subTyInfo && *subTyInfo)
return false; return false;
auto skipCacheFor = [this](TypeId ty) { auto skipCacheFor = [this](TypeId ty)
{
SkipCacheForType visitor{sharedState.skipCacheForType, types}; SkipCacheForType visitor{sharedState.skipCacheForType, types};
visitor.traverse(ty); visitor.traverse(ty);
@ -1674,7 +1644,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
superIter.scope = scope.get(); superIter.scope = scope.get();
subIter.scope = scope.get(); subIter.scope = scope.get();
auto mkFreshType = [this](Scope* scope, TypeLevel level) { auto mkFreshType = [this](Scope* scope, TypeLevel level)
{
if (FFlag::DebugLuauDeferredConstraintResolution) if (FFlag::DebugLuauDeferredConstraintResolution)
return freshType(NotNull{types}, builtinTypes, scope); return freshType(NotNull{types}, builtinTypes, scope);
else else
@ -1970,8 +1941,16 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty() && innerState.firstPackErrorPos) else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), reportError(
innerState.errors.front(), mismatchContext()}); location,
TypeMismatch{
superTy,
subTy,
format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front(),
mismatchContext()
}
);
else if (!innerState.errors.empty()) else if (!innerState.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()});
@ -1985,8 +1964,16 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes))
reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front(), mismatchContext()}); reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front(), mismatchContext()});
else if (!innerState.errors.empty() && innerState.firstPackErrorPos) else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), reportError(
innerState.errors.front(), mismatchContext()}); location,
TypeMismatch{
superTy,
subTy,
format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front(),
mismatchContext()
}
);
else if (!innerState.errors.empty()) else if (!innerState.errors.empty())
reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()});
} }
@ -2402,7 +2389,8 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
if (!superTable || superTable->state != TableState::Free) if (!superTable || superTable->state != TableState::Free)
return reportError(location, TypeMismatch{osuperTy, osubTy, mismatchContext()}); return reportError(location, TypeMismatch{osuperTy, osubTy, mismatchContext()});
auto fail = [&](std::optional<TypeError> e) { auto fail = [&](std::optional<TypeError> e)
{
std::string reason = "The former's metatable does not satisfy the requirements."; std::string reason = "The former's metatable does not satisfy the requirements.";
if (e) if (e)
reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e, mismatchContext()}); reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e, mismatchContext()});
@ -2497,7 +2485,8 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
reportError(*e); reportError(*e);
else if (!innerState.errors.empty()) else if (!innerState.errors.empty())
reportError( reportError(
location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}); location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}
);
log.concat(std::move(innerState.log)); log.concat(std::move(innerState.log));
failure |= innerState.failure; failure |= innerState.failure;
@ -2535,8 +2524,10 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
if (auto e = hasUnificationTooComplex(innerState.errors)) if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e); reportError(*e);
else if (!innerState.errors.empty()) else if (!innerState.errors.empty())
reportError(TypeError{location, reportError(TypeError{
TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}}); location,
TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}
});
else if (!missingProperty) else if (!missingProperty)
{ {
log.concat(std::move(innerState.log)); log.concat(std::move(innerState.log));
@ -2574,7 +2565,8 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
if (reversed) if (reversed)
std::swap(superTy, subTy); std::swap(superTy, subTy);
auto fail = [&]() { auto fail = [&]()
{
if (!reversed) if (!reversed)
reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); reportError(location, TypeMismatch{superTy, subTy, mismatchContext()});
else else
@ -2770,8 +2762,15 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
} }
} }
static void tryUnifyWithAny(std::vector<TypeId>& queue, Unifier& state, DenseHashSet<TypeId>& seen, DenseHashSet<TypePackId>& seenTypePacks, static void tryUnifyWithAny(
const TypeArena* typeArena, TypeId anyType, TypePackId anyTypePack) std::vector<TypeId>& queue,
Unifier& state,
DenseHashSet<TypeId>& seen,
DenseHashSet<TypePackId>& seenTypePacks,
const TypeArena* typeArena,
TypeId anyType,
TypePackId anyTypePack
)
{ {
while (!queue.empty()) while (!queue.empty())
{ {
@ -2927,7 +2926,8 @@ bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
bool occurrence = false; bool occurrence = false;
auto check = [&](TypeId tv) { auto check = [&](TypeId tv)
{
if (occursCheck(seen, needle, tv)) if (occursCheck(seen, needle, tv))
occurrence = true; occurrence = true;
}; };
@ -3064,8 +3064,10 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const s
if (auto e = hasUnificationTooComplex(innerErrors)) if (auto e = hasUnificationTooComplex(innerErrors))
reportError(*e); reportError(*e);
else if (!innerErrors.empty()) else if (!innerErrors.empty())
reportError(TypeError{location, reportError(TypeError{
TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front(), mismatchContext()}}); location,
TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front(), mismatchContext()}
});
} }
void Unifier::ice(const std::string& message, const Location& location) void Unifier::ice(const std::string& message, const Location& location)

View File

@ -33,7 +33,8 @@ static bool areCompatible(TypeId left, TypeId right)
const TableType* rightTable = p.second; const TableType* rightTable = p.second;
LUAU_ASSERT(rightTable); LUAU_ASSERT(rightTable);
const auto missingPropIsCompatible = [](const Property& leftProp, const TableType* rightTable) { const auto missingPropIsCompatible = [](const Property& leftProp, const TableType* rightTable)
{
// Two tables may be compatible even if their shapes aren't exactly the // Two tables may be compatible even if their shapes aren't exactly the
// same if the extra property is optional, free (and therefore // same if the extra property is optional, free (and therefore
// potentially optional), or if the right table has an indexer. Or if // potentially optional), or if the right table has an indexer. Or if
@ -96,8 +97,13 @@ Unifier2::Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes,
{ {
} }
Unifier2::Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<InternalErrorReporter> ice, Unifier2::Unifier2(
DenseHashSet<const void*>* uninhabitedTypeFunctions) NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> ice,
DenseHashSet<const void*>* uninhabitedTypeFunctions
)
: arena(arena) : arena(arena)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, scope(scope) , scope(scope)
@ -251,7 +257,8 @@ bool Unifier2::unifyFreeWithType(TypeId subTy, TypeId superTy)
FreeType* subFree = getMutable<FreeType>(subTy); FreeType* subFree = getMutable<FreeType>(subTy);
LUAU_ASSERT(subFree); LUAU_ASSERT(subFree);
auto doDefault = [&]() { auto doDefault = [&]()
{
subFree->upperBound = mkIntersection(subFree->upperBound, superTy); subFree->upperBound = mkIntersection(subFree->upperBound, superTy);
expandedFreeTypes[subTy].push_back(superTy); expandedFreeTypes[subTy].push_back(superTy);
return true; return true;
@ -841,7 +848,8 @@ OccursCheckResult Unifier2::occursCheck(DenseHashSet<TypeId>& seen, TypeId needl
OccursCheckResult occurrence = OccursCheckResult::Pass; OccursCheckResult occurrence = OccursCheckResult::Pass;
auto check = [&](TypeId ty) { auto check = [&](TypeId ty)
{
if (occursCheck(seen, needle, ty) == OccursCheckResult::Fail) if (occursCheck(seen, needle, ty) == OccursCheckResult::Fail)
occurrence = OccursCheckResult::Fail; occurrence = OccursCheckResult::Fail;
}; };

View File

@ -384,7 +384,13 @@ public:
LUAU_RTTI(AstExprIndexName) LUAU_RTTI(AstExprIndexName)
AstExprIndexName( AstExprIndexName(
const Location& location, AstExpr* expr, const AstName& index, const Location& indexLocation, const Position& opPosition, char op); const Location& location,
AstExpr* expr,
const AstName& index,
const Location& indexLocation,
const Position& opPosition,
char op
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
@ -413,11 +419,22 @@ class AstExprFunction : public AstExpr
public: public:
LUAU_RTTI(AstExprFunction) LUAU_RTTI(AstExprFunction)
AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics, AstExprFunction(
const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& location,
const Location& varargLocation, AstStatBlock* body, size_t functionDepth, const AstName& debugname, const AstArray<AstAttr*>& attributes,
const std::optional<AstTypeList>& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, const AstArray<AstGenericType>& generics,
const std::optional<Location>& argLocation = std::nullopt); const AstArray<AstGenericTypePack>& genericPacks,
AstLocal* self,
const AstArray<AstLocal*>& args,
bool vararg,
const Location& varargLocation,
AstStatBlock* body,
size_t functionDepth,
const AstName& debugname,
const std::optional<AstTypeList>& returnAnnotation = {},
AstTypePack* varargAnnotation = nullptr,
const std::optional<Location>& argLocation = std::nullopt
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
@ -603,8 +620,14 @@ class AstStatIf : public AstStat
public: public:
LUAU_RTTI(AstStatIf) LUAU_RTTI(AstStatIf)
AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, const std::optional<Location>& thenLocation, AstStatIf(
const std::optional<Location>& elseLocation); const Location& location,
AstExpr* condition,
AstStatBlock* thenbody,
AstStat* elsebody,
const std::optional<Location>& thenLocation,
const std::optional<Location>& elseLocation
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
@ -698,8 +721,12 @@ class AstStatLocal : public AstStat
public: public:
LUAU_RTTI(AstStatLocal) LUAU_RTTI(AstStatLocal)
AstStatLocal(const Location& location, const AstArray<AstLocal*>& vars, const AstArray<AstExpr*>& values, AstStatLocal(
const std::optional<Location>& equalsSignLocation); const Location& location,
const AstArray<AstLocal*>& vars,
const AstArray<AstExpr*>& values,
const std::optional<Location>& equalsSignLocation
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
@ -714,8 +741,16 @@ class AstStatFor : public AstStat
public: public:
LUAU_RTTI(AstStatFor) LUAU_RTTI(AstStatFor)
AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, AstStatFor(
const Location& doLocation); const Location& location,
AstLocal* var,
AstExpr* from,
AstExpr* to,
AstExpr* step,
AstStatBlock* body,
bool hasDo,
const Location& doLocation
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
@ -734,8 +769,16 @@ class AstStatForIn : public AstStat
public: public:
LUAU_RTTI(AstStatForIn) LUAU_RTTI(AstStatForIn)
AstStatForIn(const Location& location, const AstArray<AstLocal*>& vars, const AstArray<AstExpr*>& values, AstStatBlock* body, bool hasIn, AstStatForIn(
const Location& inLocation, bool hasDo, const Location& doLocation); const Location& location,
const AstArray<AstLocal*>& vars,
const AstArray<AstExpr*>& values,
AstStatBlock* body,
bool hasIn,
const Location& inLocation,
bool hasDo,
const Location& doLocation
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
@ -808,8 +851,15 @@ class AstStatTypeAlias : public AstStat
public: public:
LUAU_RTTI(AstStatTypeAlias) LUAU_RTTI(AstStatTypeAlias)
AstStatTypeAlias(const Location& location, const AstName& name, const Location& nameLocation, const AstArray<AstGenericType>& generics, AstStatTypeAlias(
const AstArray<AstGenericTypePack>& genericPacks, AstType* type, bool exported); const Location& location,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
AstType* type,
bool exported
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
@ -821,6 +871,20 @@ public:
bool exported; bool exported;
}; };
class AstStatTypeFunction : public AstStat
{
public:
LUAU_RTTI(AstStatTypeFunction);
AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body);
void visit(AstVisitor* visitor) override;
AstName name;
Location nameLocation;
AstExprFunction* body;
};
class AstStatDeclareGlobal : public AstStat class AstStatDeclareGlobal : public AstStat
{ {
public: public:
@ -840,13 +904,32 @@ class AstStatDeclareFunction : public AstStat
public: public:
LUAU_RTTI(AstStatDeclareFunction) LUAU_RTTI(AstStatDeclareFunction)
AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation, const AstArray<AstGenericType>& generics, AstStatDeclareFunction(
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& location,
const Location& varargLocation, const AstTypeList& retTypes); const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames,
bool vararg,
const Location& varargLocation,
const AstTypeList& retTypes
);
AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name, const Location& nameLocation, AstStatDeclareFunction(
const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const Location& location,
const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes); const AstArray<AstAttr*>& attributes,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames,
bool vararg,
const Location& varargLocation,
const AstTypeList& retTypes
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
@ -896,8 +979,13 @@ class AstStatDeclareClass : public AstStat
public: public:
LUAU_RTTI(AstStatDeclareClass) LUAU_RTTI(AstStatDeclareClass)
AstStatDeclareClass(const Location& location, const AstName& name, std::optional<AstName> superName, const AstArray<AstDeclaredClassProp>& props, AstStatDeclareClass(
AstTableIndexer* indexer = nullptr); const Location& location,
const AstName& name,
std::optional<AstName> superName,
const AstArray<AstDeclaredClassProp>& props,
AstTableIndexer* indexer = nullptr
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
@ -934,8 +1022,15 @@ class AstTypeReference : public AstType
public: public:
LUAU_RTTI(AstTypeReference) LUAU_RTTI(AstTypeReference)
AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, std::optional<Location> prefixLocation, AstTypeReference(
const Location& nameLocation, bool hasParameterList = false, const AstArray<AstTypeOrPack>& parameters = {}); const Location& location,
std::optional<AstName> prefix,
AstName name,
std::optional<Location> prefixLocation,
const Location& nameLocation,
bool hasParameterList = false,
const AstArray<AstTypeOrPack>& parameters = {}
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
@ -974,12 +1069,24 @@ class AstTypeFunction : public AstType
public: public:
LUAU_RTTI(AstTypeFunction) LUAU_RTTI(AstTypeFunction)
AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstTypeFunction(
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes); const Location& location,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes
);
AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics, AstTypeFunction(
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const Location& location,
const AstTypeList& returnTypes); const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;

View File

@ -55,7 +55,12 @@ class Parser
{ {
public: public:
static ParseResult parse( static ParseResult parse(
const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options = ParseOptions()); const char* buffer,
std::size_t bufferSize,
AstNameTable& names,
Allocator& allocator,
ParseOptions options = ParseOptions()
);
private: private:
struct Name; struct Name;
@ -140,6 +145,9 @@ private:
// type Name `=' Type // type Name `=' Type
AstStat* parseTypeAlias(const Location& start, bool exported); AstStat* parseTypeAlias(const Location& start, bool exported);
// type function Name ... end
AstStat* parseTypeFunction(const Location& start);
AstDeclaredClassProp parseDeclaredClassMethod(); AstDeclaredClassProp parseDeclaredClassMethod();
// `declare global' Name: Type | // `declare global' Name: Type |
@ -157,7 +165,12 @@ private:
// funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type]
// funcbody ::= funcbodyhead block end // funcbody ::= funcbodyhead block end
std::pair<AstExprFunction*, AstLocal*> parseFunctionBody( std::pair<AstExprFunction*, AstLocal*> parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes); bool hasself,
const Lexeme& matchFunction,
const AstName& debugname,
const Name* localName,
const AstArray<AstAttr*>& attributes
);
// explist ::= {exp `,'} exp // explist ::= {exp `,'} exp
void parseExprList(TempVector<AstExpr*>& result); void parseExprList(TempVector<AstExpr*>& result);
@ -191,9 +204,15 @@ private:
AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation); AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation);
AstTypeOrPack parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes); AstTypeOrPack parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes);
AstType* parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics, AstType* parseFunctionTypeTail(
AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames, const Lexeme& begin,
AstTypePack* varargAnnotation); const AstArray<AstAttr*>& attributes,
AstArray<AstGenericType> generics,
AstArray<AstGenericTypePack> genericPacks,
AstArray<AstType*> params,
AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation
);
AstType* parseTableType(bool inDeclarationContext = false); AstType* parseTableType(bool inDeclarationContext = false);
AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false); AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false);
@ -315,8 +334,13 @@ private:
void reportNameError(const char* context); void reportNameError(const char* context);
AstStatError* reportStatError(const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements, AstStatError* reportStatError(
const char* format, ...) LUAU_PRINTF_ATTR(5, 6); const Location& location,
const AstArray<AstExpr*>& expressions,
const AstArray<AstStat*>& statements,
const char* format,
...
) LUAU_PRINTF_ATTR(5, 6);
AstExprError* reportExprError(const Location& location, const AstArray<AstExpr*>& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); AstExprError* reportExprError(const Location& location, const AstArray<AstExpr*>& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5);
AstTypeError* reportTypeError(const Location& location, const AstArray<AstType*>& types, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); AstTypeError* reportTypeError(const Location& location, const AstArray<AstType*>& types, const char* format, ...) LUAU_PRINTF_ATTR(4, 5);
// `parseErrorLocation` is associated with the parser error // `parseErrorLocation` is associated with the parser error

View File

@ -141,7 +141,13 @@ void AstExprCall::visit(AstVisitor* visitor)
} }
AstExprIndexName::AstExprIndexName( AstExprIndexName::AstExprIndexName(
const Location& location, AstExpr* expr, const AstName& index, const Location& indexLocation, const Position& opPosition, char op) const Location& location,
AstExpr* expr,
const AstName& index,
const Location& indexLocation,
const Position& opPosition,
char op
)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, expr(expr) , expr(expr)
, index(index) , index(index)
@ -173,10 +179,22 @@ void AstExprIndexExpr::visit(AstVisitor* visitor)
} }
} }
AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics, AstExprFunction::AstExprFunction(
const AstArray<AstGenericTypePack>& genericPacks, AstLocal* self, const AstArray<AstLocal*>& args, bool vararg, const Location& varargLocation, const Location& location,
AstStatBlock* body, size_t functionDepth, const AstName& debugname, const std::optional<AstTypeList>& returnAnnotation, const AstArray<AstAttr*>& attributes,
AstTypePack* varargAnnotation, const std::optional<Location>& argLocation) const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
AstLocal* self,
const AstArray<AstLocal*>& args,
bool vararg,
const Location& varargLocation,
AstStatBlock* body,
size_t functionDepth,
const AstName& debugname,
const std::optional<AstTypeList>& returnAnnotation,
AstTypePack* varargAnnotation,
const std::optional<Location>& argLocation
)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, attributes(attributes) , attributes(attributes)
, generics(generics) , generics(generics)
@ -418,8 +436,14 @@ void AstStatBlock::visit(AstVisitor* visitor)
} }
} }
AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, AstStatIf::AstStatIf(
const std::optional<Location>& thenLocation, const std::optional<Location>& elseLocation) const Location& location,
AstExpr* condition,
AstStatBlock* thenbody,
AstStat* elsebody,
const std::optional<Location>& thenLocation,
const std::optional<Location>& elseLocation
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, condition(condition) , condition(condition)
, thenbody(thenbody) , thenbody(thenbody)
@ -524,7 +548,11 @@ void AstStatExpr::visit(AstVisitor* visitor)
} }
AstStatLocal::AstStatLocal( AstStatLocal::AstStatLocal(
const Location& location, const AstArray<AstLocal*>& vars, const AstArray<AstExpr*>& values, const std::optional<Location>& equalsSignLocation) const Location& location,
const AstArray<AstLocal*>& vars,
const AstArray<AstExpr*>& values,
const std::optional<Location>& equalsSignLocation
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, vars(vars) , vars(vars)
, values(values) , values(values)
@ -548,7 +576,15 @@ void AstStatLocal::visit(AstVisitor* visitor)
} }
AstStatFor::AstStatFor( AstStatFor::AstStatFor(
const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, const Location& doLocation) const Location& location,
AstLocal* var,
AstExpr* from,
AstExpr* to,
AstExpr* step,
AstStatBlock* body,
bool hasDo,
const Location& doLocation
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, var(var) , var(var)
, from(from) , from(from)
@ -577,8 +613,16 @@ void AstStatFor::visit(AstVisitor* visitor)
} }
} }
AstStatForIn::AstStatForIn(const Location& location, const AstArray<AstLocal*>& vars, const AstArray<AstExpr*>& values, AstStatBlock* body, AstStatForIn::AstStatForIn(
bool hasIn, const Location& inLocation, bool hasDo, const Location& doLocation) const Location& location,
const AstArray<AstLocal*>& vars,
const AstArray<AstExpr*>& values,
AstStatBlock* body,
bool hasIn,
const Location& inLocation,
bool hasDo,
const Location& doLocation
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, vars(vars) , vars(vars)
, values(values) , values(values)
@ -672,8 +716,15 @@ void AstStatLocalFunction::visit(AstVisitor* visitor)
func->visit(visitor); func->visit(visitor);
} }
AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const Location& nameLocation, AstStatTypeAlias::AstStatTypeAlias(
const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstType* type, bool exported) const Location& location,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
AstType* type,
bool exported
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, name(name) , name(name)
, nameLocation(nameLocation) , nameLocation(nameLocation)
@ -704,6 +755,20 @@ void AstStatTypeAlias::visit(AstVisitor* visitor)
} }
} }
AstStatTypeFunction::AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body)
: AstStat(ClassIndex(), location)
, name(name)
, nameLocation(nameLocation)
, body(body)
{
}
void AstStatTypeFunction::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
body->visit(visitor);
}
AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type) AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, const Location& nameLocation, AstType* type)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, name(name) , name(name)
@ -718,9 +783,18 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor)
type->visit(visitor); type->visit(visitor);
} }
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const Location& nameLocation, AstStatDeclareFunction::AstStatDeclareFunction(
const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const Location& location,
const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes) const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames,
bool vararg,
const Location& varargLocation,
const AstTypeList& retTypes
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, attributes() , attributes()
, name(name) , name(name)
@ -735,9 +809,19 @@ AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const A
{ {
} }
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstName& name, AstStatDeclareFunction::AstStatDeclareFunction(
const Location& nameLocation, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, const Location& location,
const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, bool vararg, const Location& varargLocation, const AstTypeList& retTypes) const AstArray<AstAttr*>& attributes,
const AstName& name,
const Location& nameLocation,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames,
bool vararg,
const Location& varargLocation,
const AstTypeList& retTypes
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, attributes(attributes) , attributes(attributes)
, name(name) , name(name)
@ -772,8 +856,13 @@ bool AstStatDeclareFunction::isCheckedFunction() const
return false; return false;
} }
AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional<AstName> superName, AstStatDeclareClass::AstStatDeclareClass(
const AstArray<AstDeclaredClassProp>& props, AstTableIndexer* indexer) const Location& location,
const AstName& name,
std::optional<AstName> superName,
const AstArray<AstDeclaredClassProp>& props,
AstTableIndexer* indexer
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, name(name) , name(name)
, superName(superName) , superName(superName)
@ -792,7 +881,11 @@ void AstStatDeclareClass::visit(AstVisitor* visitor)
} }
AstStatError::AstStatError( AstStatError::AstStatError(
const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements, unsigned messageIndex) const Location& location,
const AstArray<AstExpr*>& expressions,
const AstArray<AstStat*>& statements,
unsigned messageIndex
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, expressions(expressions) , expressions(expressions)
, statements(statements) , statements(statements)
@ -812,8 +905,15 @@ void AstStatError::visit(AstVisitor* visitor)
} }
} }
AstTypeReference::AstTypeReference(const Location& location, std::optional<AstName> prefix, AstName name, std::optional<Location> prefixLocation, AstTypeReference::AstTypeReference(
const Location& nameLocation, bool hasParameterList, const AstArray<AstTypeOrPack>& parameters) const Location& location,
std::optional<AstName> prefix,
AstName name,
std::optional<Location> prefixLocation,
const Location& nameLocation,
bool hasParameterList,
const AstArray<AstTypeOrPack>& parameters
)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, hasParameterList(hasParameterList) , hasParameterList(hasParameterList)
, prefix(prefix) , prefix(prefix)
@ -860,8 +960,14 @@ void AstTypeTable::visit(AstVisitor* visitor)
} }
} }
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks, AstTypeFunction::AstTypeFunction(
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes) const Location& location,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes
)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, attributes() , attributes()
, generics(generics) , generics(generics)
@ -873,9 +979,15 @@ AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGen
LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size);
} }
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstAttr*>& attributes, const AstArray<AstGenericType>& generics, AstTypeFunction::AstTypeFunction(
const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const Location& location,
const AstTypeList& returnTypes) const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics,
const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes
)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, attributes(attributes) , attributes(attributes)
, generics(generics) , generics(generics)

View File

@ -1808,9 +1808,15 @@ static const Confusable kConfusables[] =
const char* findConfusable(uint32_t codepoint) const char* findConfusable(uint32_t codepoint)
{ {
auto it = std::lower_bound(std::begin(kConfusables), std::end(kConfusables), codepoint, [](const Confusable& lhs, uint32_t rhs) { auto it = std::lower_bound(
std::begin(kConfusables),
std::end(kConfusables),
codepoint,
[](const Confusable& lhs, uint32_t rhs)
{
return lhs.codepoint < rhs; return lhs.codepoint < rhs;
}); }
);
return (it != std::end(kConfusables) && it->codepoint == codepoint) ? it->text : nullptr; return (it != std::end(kConfusables) && it->codepoint == codepoint) ? it->text : nullptr;
} }

View File

@ -92,8 +92,10 @@ Lexeme::Lexeme(const Location& location, Type type, const char* data, size_t siz
, length(unsigned(size)) , length(unsigned(size))
, data(data) , data(data)
{ {
LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || LUAU_ASSERT(
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment); type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd ||
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment
);
} }
Lexeme::Lexeme(const Location& location, Type type, const char* name) Lexeme::Lexeme(const Location& location, Type type, const char* name)
@ -107,14 +109,16 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name)
unsigned int Lexeme::getLength() const unsigned int Lexeme::getLength() const
{ {
LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || LUAU_ASSERT(
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment); type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd ||
type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment
);
return length; return length;
} }
static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in",
"repeat", "return", "then", "true", "until", "while"}; "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"};
std::string Lexeme::toString() const std::string Lexeme::toString() const
{ {

View File

@ -20,6 +20,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false)
LUAU_FASTFLAGVARIABLE(LuauDeclarationExtraPropData, false) LUAU_FASTFLAGVARIABLE(LuauDeclarationExtraPropData, false)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false)
namespace Luau namespace Luau
{ {
@ -785,9 +786,13 @@ AstStat* Parser::parseAttributeStat()
return parseDeclaration(expr->location, attributes); return parseDeclaration(expr->location, attributes);
} }
default: default:
return reportStatError(lexer.current().location, {}, {}, return reportStatError(
lexer.current().location,
{},
{},
"Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s instead", "Expected 'function', 'local function', 'declare function' or a function type declaration after attribute, but got %s instead",
lexer.current().toString().c_str()); lexer.current().toString().c_str()
);
} }
} }
@ -825,8 +830,13 @@ AstStat* Parser::parseLocal(const AstArray<AstAttr*>& attributes)
{ {
if (attributes.size != 0) if (attributes.size != 0)
{ {
return reportStatError(lexer.current().location, {}, {}, "Expected 'function' after local declaration with attribute, but got %s instead", return reportStatError(
lexer.current().toString().c_str()); lexer.current().location,
{},
{},
"Expected 'function' after local declaration with attribute, but got %s instead",
lexer.current().toString().c_str()
);
} }
matchRecoveryStopOnToken['=']++; matchRecoveryStopOnToken['=']++;
@ -880,6 +890,15 @@ AstStat* Parser::parseReturn()
// type Name [`<' varlist `>'] `=' Type // type Name [`<' varlist `>'] `=' Type
AstStat* Parser::parseTypeAlias(const Location& start, bool exported) AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
{ {
// parsing a type function
if (FFlag::LuauUserDefinedTypeFunctions)
{
if (lexer.current().type == Lexeme::ReservedFunction)
return parseTypeFunction(start);
}
// parsing a type alias
// note: `type` token is already parsed for us, so we just need to parse the rest // note: `type` token is already parsed for us, so we just need to parse the rest
std::optional<Name> name = parseNameOpt("type name"); std::optional<Name> name = parseNameOpt("type name");
@ -897,6 +916,26 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
return allocator.alloc<AstStatTypeAlias>(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported); return allocator.alloc<AstStatTypeAlias>(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported);
} }
// type function Name `(' arglist `)' `=' funcbody `end'
AstStat* Parser::parseTypeFunction(const Location& start)
{
Lexeme matchFn = lexer.current();
nextLexeme();
// parse the name of the type function
std::optional<Name> fnName = parseNameOpt("type function name");
if (!fnName)
fnName = Name(nameError, lexer.current().location);
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
AstExprFunction* body = parseFunctionBody(/* hasself */ false, matchFn, fnName->name, nullptr, AstArray<AstAttr*>({nullptr, 0})).first;
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
return allocator.alloc<AstStatTypeFunction>(Location(start, body->location), fnName->name, fnName->location, body);
}
AstDeclaredClassProp Parser::parseDeclaredClassMethod() AstDeclaredClassProp Parser::parseDeclaredClassMethod()
{ {
Location start; Location start;
@ -940,8 +979,12 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr)
{ {
return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, return AstDeclaredClassProp{
reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; fnName.name,
FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{},
reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"),
true
};
} }
// Skip the first index. // Skip the first index.
@ -959,10 +1002,16 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
report(start, "All declaration parameters aside from 'self' must be annotated"); report(start, "All declaration parameters aside from 'self' must be annotated");
AstType* fnType = allocator.alloc<AstTypeFunction>( AstType* fnType = allocator.alloc<AstTypeFunction>(
Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes
);
return AstDeclaredClassProp{fnName.name, FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{}, fnType, true, return AstDeclaredClassProp{
FFlag::LuauDeclarationExtraPropData ? Location(start, end) : Location{}}; fnName.name,
FFlag::LuauDeclarationExtraPropData ? fnName.location : Location{},
fnType,
true,
FFlag::LuauDeclarationExtraPropData ? Location(start, end) : Location{}
};
} }
AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes) AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*>& attributes)
@ -970,8 +1019,13 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
// `declare` token is already parsed at this point // `declare` token is already parsed at this point
if ((attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction)) if ((attributes.size != 0) && (lexer.current().type != Lexeme::ReservedFunction))
return reportStatError(lexer.current().location, {}, {}, "Expected a function type declaration after attribute, but got %s instead", return reportStatError(
lexer.current().toString().c_str()); lexer.current().location,
{},
{},
"Expected a function type declaration after attribute, but got %s instead",
lexer.current().toString().c_str()
);
if (lexer.current().type == Lexeme::ReservedFunction) if (lexer.current().type == Lexeme::ReservedFunction)
{ {
@ -1014,11 +1068,33 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated");
if (FFlag::LuauDeclarationExtraPropData) if (FFlag::LuauDeclarationExtraPropData)
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), attributes, globalName.name, globalName.location, generics, return allocator.alloc<AstStatDeclareFunction>(
genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), vararg, varargLocation, retTypes); Location(start, end),
attributes,
globalName.name,
globalName.location,
generics,
genericPacks,
AstTypeList{copy(vars), varargAnnotation},
copy(varNames),
vararg,
varargLocation,
retTypes
);
else else
return allocator.alloc<AstStatDeclareFunction>(Location(start, end), attributes, globalName.name, Location{}, generics, genericPacks, return allocator.alloc<AstStatDeclareFunction>(
AstTypeList{copy(vars), varargAnnotation}, copy(varNames), false, Location{}, retTypes); Location(start, end),
attributes,
globalName.name,
Location{},
generics,
genericPacks,
AstTypeList{copy(vars), varargAnnotation},
copy(varNames),
false,
Location{},
retTypes
);
} }
else if (AstName(lexer.current().name) == "class") else if (AstName(lexer.current().name) == "class")
{ {
@ -1064,7 +1140,8 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
if (chars && !containsNull) if (chars && !containsNull)
props.push_back(AstDeclaredClassProp{ props.push_back(AstDeclaredClassProp{
AstName(chars->data), Location(nameBegin, nameEnd), type, false, Location(begin.location, lexer.previousLocation())}); AstName(chars->data), Location(nameBegin, nameEnd), type, false, Location(begin.location, lexer.previousLocation())
});
else else
report(begin.location, "String literal contains malformed escape sequence or \\0"); report(begin.location, "String literal contains malformed escape sequence or \\0");
} }
@ -1107,8 +1184,8 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
Name propName = parseName("property name"); Name propName = parseName("property name");
expectAndConsume(':', "property type annotation"); expectAndConsume(':', "property type annotation");
AstType* propType = parseType(); AstType* propType = parseType();
props.push_back( props.push_back(AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())}
AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())}); );
} }
else else
{ {
@ -1130,7 +1207,8 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArray<AstAttr*
AstType* type = parseType(/* in declaration context */ true); AstType* type = parseType(/* in declaration context */ true);
return allocator.alloc<AstStatDeclareGlobal>( return allocator.alloc<AstStatDeclareGlobal>(
Location(start, type->location), globalName->name, FFlag::LuauDeclarationExtraPropData ? globalName->location : Location{}, type); Location(start, type->location), globalName->name, FFlag::LuauDeclarationExtraPropData ? globalName->location : Location{}, type
);
} }
else else
{ {
@ -1205,7 +1283,12 @@ std::pair<AstLocal*, AstArray<AstLocal*>> Parser::prepareFunctionArguments(const
// funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end
// parlist ::= bindinglist [`,' `...'] | `...' // parlist ::= bindinglist [`,' `...'] | `...'
std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody( std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName, const AstArray<AstAttr*>& attributes) bool hasself,
const Lexeme& matchFunction,
const AstName& debugname,
const Name* localName,
const AstArray<AstAttr*>& attributes
)
{ {
Location start = matchFunction.location; Location start = matchFunction.location;
@ -1257,9 +1340,25 @@ std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction);
body->hasEnd = hasEnd; body->hasEnd = hasEnd;
return {allocator.alloc<AstExprFunction>(Location(start, end), attributes, generics, genericPacks, self, vars, vararg, varargLocation, body, return {
functionStack.size(), debugname, typelist, varargAnnotation, argLocation), allocator.alloc<AstExprFunction>(
funLocal}; Location(start, end),
attributes,
generics,
genericPacks,
self,
vars,
vararg,
varargLocation,
body,
functionStack.size(),
debugname,
typelist,
varargAnnotation,
argLocation
),
funLocal
};
} }
// explist ::= {exp `,'} exp // explist ::= {exp `,'} exp
@ -1656,9 +1755,15 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray<AstAttr*>
return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; return {parseFunctionTypeTail(begin, attributes, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}};
} }
AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAttr*>& attributes, AstArray<AstGenericType> generics, AstType* Parser::parseFunctionTypeTail(
AstArray<AstGenericTypePack> genericPacks, AstArray<AstType*> params, AstArray<std::optional<AstArgumentName>> paramNames, const Lexeme& begin,
AstTypePack* varargAnnotation) const AstArray<AstAttr*>& attributes,
AstArray<AstGenericType> generics,
AstArray<AstGenericTypePack> genericPacks,
AstArray<AstType*> params,
AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation
)
{ {
incrementRecursionCounter("type annotation"); incrementRecursionCounter("type annotation");
@ -1683,7 +1788,8 @@ AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, const AstArray<AstAt
AstTypeList paramTypes = AstTypeList{params, varargAnnotation}; AstTypeList paramTypes = AstTypeList{params, varargAnnotation};
return allocator.alloc<AstTypeFunction>( return allocator.alloc<AstTypeFunction>(
Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList); Location(begin.location, endLocation), attributes, generics, genericPacks, paramTypes, paramNames, returnTypeList
);
} }
// Type ::= // Type ::=
@ -1760,8 +1866,11 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin)
if (isUnion && isIntersection) if (isUnion && isIntersection)
{ {
return reportTypeError(Location(begin, parts.back()->location), copy(parts), return reportTypeError(
"Mixing union and intersection types is not allowed; consider wrapping in parentheses."); Location(begin, parts.back()->location),
copy(parts),
"Mixing union and intersection types is not allowed; consider wrapping in parentheses."
);
} }
location.end = parts.back()->location.end; location.end = parts.back()->location.end;
@ -1922,7 +2031,8 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
Location end = lexer.previousLocation(); Location end = lexer.previousLocation();
return { return {
allocator.alloc<AstTypeReference>(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {}}; allocator.alloc<AstTypeReference>(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {}
};
} }
else if (lexer.current().type == '{') else if (lexer.current().type == '{')
{ {
@ -1936,10 +2046,15 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext)
{ {
nextLexeme(); nextLexeme();
return {reportTypeError(start, {}, return {
reportTypeError(
start,
{},
"Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> "
"...any'"), "...any'"
{}}; ),
{}
};
} }
else else
{ {
@ -2114,8 +2229,7 @@ std::optional<AstExprBinary::Op> Parser::checkBinaryConfusables(const BinaryOpPr
report(Location(start, next.location), "Unexpected '||'; did you mean 'or'?"); report(Location(start, next.location), "Unexpected '||'; did you mean 'or'?");
return AstExprBinary::Or; return AstExprBinary::Or;
} }
else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::CompareNe].left > limit)
binaryPriority[AstExprBinary::CompareNe].left > limit)
{ {
nextLexeme(); nextLexeme();
report(Location(start, next.location), "Unexpected '!='; did you mean '~='?"); report(Location(start, next.location), "Unexpected '!='; did you mean '~='?");
@ -2129,6 +2243,7 @@ std::optional<AstExprBinary::Op> Parser::checkBinaryConfusables(const BinaryOpPr
// where `binop' is any binary operator with a priority higher than `limit' // where `binop' is any binary operator with a priority higher than `limit'
AstExpr* Parser::parseExpr(unsigned int limit) AstExpr* Parser::parseExpr(unsigned int limit)
{ {
// clang-format off
static const BinaryOpPriority binaryPriority[] = { static const BinaryOpPriority binaryPriority[] = {
{6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `//' `%' {6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `//' `%'
{10, 9}, {5, 4}, // power and concat (right associative) {10, 9}, {5, 4}, // power and concat (right associative)
@ -2136,6 +2251,8 @@ AstExpr* Parser::parseExpr(unsigned int limit)
{3, 3}, {3, 3}, {3, 3}, {3, 3}, // order {3, 3}, {3, 3}, {3, 3}, {3, 3}, // order
{2, 2}, {1, 1} // logical (and/or) {2, 2}, {1, 1} // logical (and/or)
}; };
// clang-format on
static_assert(sizeof(binaryPriority) / sizeof(binaryPriority[0]) == size_t(AstExprBinary::Op__Count), "binaryPriority needs an entry per op"); static_assert(sizeof(binaryPriority) / sizeof(binaryPriority[0]) == size_t(AstExprBinary::Op__Count), "binaryPriority needs an entry per op");
unsigned int oldRecursionCount = recursionCounter; unsigned int oldRecursionCount = recursionCounter;
@ -2414,7 +2531,8 @@ AstExpr* Parser::parseSimpleExpr()
if (lexer.current().type != Lexeme::ReservedFunction) if (lexer.current().type != Lexeme::ReservedFunction)
{ {
return reportExprError( return reportExprError(
start, {}, "Expected 'function' declaration after attribute, but got %s instead", lexer.current().toString().c_str()); start, {}, "Expected 'function' declaration after attribute, but got %s instead", lexer.current().toString().c_str()
);
} }
} }
@ -2447,8 +2565,7 @@ AstExpr* Parser::parseSimpleExpr()
{ {
return parseNumber(); return parseNumber();
} }
else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple)
lexer.current().type == Lexeme::InterpStringSimple)
{ {
return parseString(); return parseString();
} }
@ -2548,15 +2665,22 @@ LUAU_NOINLINE AstExpr* Parser::reportFunctionArgsError(AstExpr* func, bool self)
} }
else else
{ {
return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}), return reportExprError(
"Expected '(', '{' or <string> when parsing function call, got %s", lexer.current().toString().c_str()); Location(func->location.begin, lexer.current().location.begin),
copy({func}),
"Expected '(', '{' or <string> when parsing function call, got %s",
lexer.current().toString().c_str()
);
} }
} }
LUAU_NOINLINE void Parser::reportAmbiguousCallError() LUAU_NOINLINE void Parser::reportAmbiguousCallError()
{ {
report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " report(
"new statement; use ';' to separate statements"); lexer.current().location,
"Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of "
"new statement; use ';' to separate statements"
);
} }
// tableconstructor ::= `{' [fieldlist] `}' // tableconstructor ::= `{' [fieldlist] `}'
@ -2868,8 +2992,10 @@ AstArray<AstTypeOrPack> Parser::parseTypeParams()
std::optional<AstArray<char>> Parser::parseCharArray() std::optional<AstArray<char>> Parser::parseCharArray()
{ {
LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || LUAU_ASSERT(
lexer.current().type == Lexeme::InterpStringSimple); lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString ||
lexer.current().type == Lexeme::InterpStringSimple
);
scratchData.assign(lexer.current().data, lexer.current().getLength()); scratchData.assign(lexer.current().data, lexer.current().getLength());
@ -2911,8 +3037,10 @@ AstExpr* Parser::parseInterpString()
do do
{ {
Lexeme currentLexeme = lexer.current(); Lexeme currentLexeme = lexer.current();
LUAU_ASSERT(currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || LUAU_ASSERT(
currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple); currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid ||
currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple
);
endLocation = currentLexeme.location; endLocation = currentLexeme.location;
@ -3013,7 +3141,8 @@ AstLocal* Parser::pushLocal(const Binding& binding)
AstLocal*& local = localMap[name.name]; AstLocal*& local = localMap[name.name];
local = allocator.alloc<AstLocal>( local = allocator.alloc<AstLocal>(
name.name, name.location, /* shadow= */ local, functionStack.size() - 1, functionStack.back().loopDepth, binding.annotation); name.name, name.location, /* shadow= */ local, functionStack.size() - 1, functionStack.back().loopDepth, binding.annotation
);
localStack.push_back(local); localStack.push_back(local);
@ -3146,11 +3275,25 @@ LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const Ma
std::string matchString = Lexeme(Location(Position(0, 0), 0), begin.type).toString(); std::string matchString = Lexeme(Location(Position(0, 0), 0), begin.type).toString();
if (lexer.current().location.begin.line == begin.position.line) if (lexer.current().location.begin.line == begin.position.line)
report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), matchString.c_str(), report(
begin.position.column + 1, lexer.current().toString().c_str(), extra ? extra : ""); lexer.current().location,
"Expected %s (to close %s at column %d), got %s%s",
typeString.c_str(),
matchString.c_str(),
begin.position.column + 1,
lexer.current().toString().c_str(),
extra ? extra : ""
);
else else
report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), matchString.c_str(), report(
begin.position.line + 1, lexer.current().toString().c_str(), extra ? extra : ""); lexer.current().location,
"Expected %s (to close %s at line %d), got %s%s",
typeString.c_str(),
matchString.c_str(),
begin.position.line + 1,
lexer.current().toString().c_str(),
extra ? extra : ""
);
} }
bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin) bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin)
@ -3287,7 +3430,12 @@ LUAU_NOINLINE void Parser::reportNameError(const char* context)
} }
AstStatError* Parser::reportStatError( AstStatError* Parser::reportStatError(
const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements, const char* format, ...) const Location& location,
const AstArray<AstExpr*>& expressions,
const AstArray<AstStat*>& statements,
const char* format,
...
)
{ {
va_list args; va_list args;
va_start(args, format); va_start(args, format);

View File

@ -141,7 +141,8 @@ size_t editDistance(std::string_view a, std::string_view b)
size_t maxDistance = a.size() + b.size(); size_t maxDistance = a.size() + b.size();
std::vector<size_t> distances((a.size() + 2) * (b.size() + 2), 0); std::vector<size_t> distances((a.size() + 2) * (b.size() + 2), 0);
auto getPos = [b](size_t x, size_t y) -> size_t { auto getPos = [b](size_t x, size_t y) -> size_t
{
return (x * (b.size() + 2)) + y; return (x * (b.size() + 2)) + y;
}; };

View File

@ -184,8 +184,14 @@ void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Ev
Token& token = context.tokens[ev.token]; Token& token = context.tokens[ev.token];
formatAppend(temp, R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)", token.name, token.category, formatAppend(
ev.data.microsec, threadId); temp,
R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)",
token.name,
token.category,
ev.data.microsec,
threadId
);
unfinishedEnter = true; unfinishedEnter = true;
} }
break; break;
@ -201,10 +207,13 @@ void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector<Ev
unfinishedEnter = false; unfinishedEnter = false;
} }
formatAppend(temp, formatAppend(
temp,
R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)" R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)"
"\n", "\n",
ev.data.microsec, threadId); ev.data.microsec,
threadId
);
break; break;
case EventType::ArgName: case EventType::ArgName:
LUAU_ASSERT(unfinishedEnter); LUAU_ASSERT(unfinishedEnter);

View File

@ -64,8 +64,13 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con
if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data)) if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data))
report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str());
else else
report(format, humanReadableName.c_str(), error.location, "TypeError", report(
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); format,
humanReadableName.c_str(),
error.location,
"TypeError",
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()
);
} }
static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning)
@ -235,9 +240,12 @@ struct TaskScheduler
{ {
for (unsigned i = 0; i < threadCount; i++) for (unsigned i = 0; i < threadCount; i++)
{ {
workers.emplace_back([this] { workers.emplace_back(
[this]
{
workerFunction(); workerFunction();
}); }
);
} }
} }
@ -254,9 +262,13 @@ struct TaskScheduler
{ {
std::unique_lock guard(mtx); std::unique_lock guard(mtx);
cv.wait(guard, [this] { cv.wait(
guard,
[this]
{
return !tasks.empty(); return !tasks.empty();
}); }
);
std::function<void()> task = tasks.front(); std::function<void()> task = tasks.front();
tasks.pop(); tasks.pop();
@ -351,7 +363,8 @@ int main(int argc, char** argv)
if (FFlag::DebugLuauLogSolverToJsonFile) if (FFlag::DebugLuauLogSolverToJsonFile)
{ {
frontend.writeJsonLog = [&basePath](const Luau::ModuleName& moduleName, std::string log) { frontend.writeJsonLog = [&basePath](const Luau::ModuleName& moduleName, std::string log)
{
std::string path = moduleName + ".log.json"; std::string path = moduleName + ".log.json";
size_t pos = moduleName.find_last_of('/'); size_t pos = moduleName.find_last_of('/');
if (pos != std::string::npos) if (pos != std::string::npos)
@ -390,9 +403,13 @@ int main(int argc, char** argv)
{ {
TaskScheduler scheduler(threadCount); TaskScheduler scheduler(threadCount);
checkedModules = frontend.checkQueuedModules(std::nullopt, [&](std::function<void()> f) { checkedModules = frontend.checkQueuedModules(
std::nullopt,
[&](std::function<void()> f)
{
scheduler.push(std::move(f)); scheduler.push(std::move(f));
}); }
);
} }
catch (const Luau::InternalCompilerError& ice) catch (const Luau::InternalCompilerError& ice)
{ {
@ -403,8 +420,13 @@ int main(int argc, char** argv)
Luau::TypeError error(location, moduleName, Luau::InternalError{ice.message}); Luau::TypeError error(location, moduleName, Luau::InternalError{ice.message});
report(format, humanReadableName.c_str(), location, "InternalCompilerError", report(
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); format,
humanReadableName.c_str(),
location,
"InternalCompilerError",
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()
);
return 1; return 1;
} }

View File

@ -231,7 +231,10 @@ static void serializeScriptSummary(const std::string& file, const std::vector<Fu
} }
static bool serializeSummaries( static bool serializeSummaries(
const std::vector<std::string>& files, const std::vector<std::vector<FunctionBytecodeSummary>>& scriptSummaries, const std::string& summaryFile) const std::vector<std::string>& files,
const std::vector<std::vector<FunctionBytecodeSummary>>& scriptSummaries,
const std::string& summaryFile
)
{ {
FILE* fp = fopen(summaryFile.c_str(), "w"); FILE* fp = fopen(summaryFile.c_str(), "w");

View File

@ -108,7 +108,11 @@ static void reportError(const char* name, const Luau::CompileError& error)
} }
static std::string getCodegenAssembly( static std::string getCodegenAssembly(
const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options, Luau::CodeGen::LoweringStats* stats) const char* name,
const std::string& bytecode,
Luau::CodeGen::AssemblyOptions options,
Luau::CodeGen::LoweringStats* stats
)
{ {
std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close); std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get(); lua_State* L = globalState.get();
@ -326,8 +330,10 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A
if (format == CompileFormat::Text) if (format == CompileFormat::Text)
{ {
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | bcb.setDumpFlags(
Luau::BytecodeBuilder::Dump_Remarks | Luau::BytecodeBuilder::Dump_Types); Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |
Luau::BytecodeBuilder::Dump_Remarks | Luau::BytecodeBuilder::Dump_Types
);
bcb.setDumpSource(*source); bcb.setDumpSource(*source);
} }
else if (format == CompileFormat::Remarks) else if (format == CompileFormat::Remarks)
@ -335,11 +341,12 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpSource(*source); bcb.setDumpSource(*source);
} }
else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || format == CompileFormat::CodegenVerbose)
format == CompileFormat::CodegenVerbose)
{ {
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | bcb.setDumpFlags(
Luau::BytecodeBuilder::Dump_Remarks); Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |
Luau::BytecodeBuilder::Dump_Remarks
);
bcb.setDumpSource(*source); bcb.setDumpSource(*source);
} }
@ -623,19 +630,37 @@ int main(int argc, char** argv)
if (compileFormat == CompileFormat::Null) if (compileFormat == CompileFormat::Null)
{ {
printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), int(stats.bytecode / 1024), printf(
stats.readTime, stats.parseTime, stats.compileTime); "Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n",
int(stats.lines / 1000),
int(stats.bytecode / 1024),
stats.readTime,
stats.parseTime,
stats.compileTime
);
} }
else if (compileFormat == CompileFormat::CodegenNull) else if (compileFormat == CompileFormat::CodegenNull)
{ {
printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", printf(
int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024), "Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n",
stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime, int(stats.lines / 1000),
stats.codegenTime); int(stats.bytecode / 1024),
int(stats.codegen / 1024),
stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode),
stats.readTime,
stats.parseTime,
stats.compileTime,
stats.codegenTime
);
printf("Lowering: regalloc failed: %d, lowering failed %d; spills to stack: %d, spills to restore: %d, max spill slot %u\n", printf(
stats.lowerStats.regAllocErrors, stats.lowerStats.loweringErrors, stats.lowerStats.spillsToSlot, stats.lowerStats.spillsToRestore, "Lowering: regalloc failed: %d, lowering failed %d; spills to stack: %d, spills to restore: %d, max spill slot %u\n",
stats.lowerStats.maxSpillSlotsUsed); stats.lowerStats.regAllocErrors,
stats.lowerStats.loweringErrors,
stats.lowerStats.spillsToSlot,
stats.lowerStats.spillsToRestore,
stats.lowerStats.maxSpillSlotsUsed
);
} }
if (recordStats != RecordStats::None) if (recordStats != RecordStats::None)

View File

@ -442,12 +442,16 @@ std::vector<std::string> getSourceFiles(int argc, char** argv)
if (isDirectory(argv[i])) if (isDirectory(argv[i]))
{ {
traverseDirectory(argv[i], [&](const std::string& name) { traverseDirectory(
argv[i],
[&](const std::string& name)
{
std::string ext = getExtension(name); std::string ext = getExtension(name);
if (ext == ".lua" || ext == ".luau") if (ext == ".lua" || ext == ".luau")
files.push_back(name); files.push_back(name);
}); }
);
} }
else else
{ {

View File

@ -54,8 +54,9 @@ void setLuauFlags(const char* list)
else if (value == "false" || value == "False") else if (value == "false" || value == "False")
setLuauFlag(key, false); setLuauFlag(key, false);
else else
fprintf(stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()), fprintf(
key.data()); stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()), key.data()
);
} }
else else
{ {

View File

@ -131,8 +131,13 @@ void profilerDump(const char* path)
fclose(f); fclose(f);
printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", path, double(total) / 1e6, printf(
static_cast<long long>(gProfiler.samples.load()), static_cast<long long>(gProfiler.data.size())); "Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n",
path,
double(total) / 1e6,
static_cast<long long>(gProfiler.samples.load()),
static_cast<long long>(gProfiler.data.size())
);
uint64_t totalgc = 0; uint64_t totalgc = 0;
for (uint64_t p : gProfiler.gc) for (uint64_t p : gProfiler.gc)

View File

@ -184,7 +184,8 @@ struct Reducer
{ {
std::vector<AstStat*> result; std::vector<AstStat*> result;
auto append = [&](AstStatBlock* block) { auto append = [&](AstStatBlock* block)
{
if (block) if (block)
result.insert(result.end(), block->body.data, block->body.data + block->body.size); result.insert(result.end(), block->body.data, block->body.data + block->body.size);
}; };
@ -250,7 +251,8 @@ struct Reducer
std::vector<std::pair<Span, Span>> result; std::vector<std::pair<Span, Span>> result;
auto append = [&result](Span a, Span b) { auto append = [&result](Span a, Span b)
{
if (a.first == a.second && b.first == b.second) if (a.first == a.second && b.first == b.second)
return; return;
else else

View File

@ -388,8 +388,13 @@ static void safeGetTable(lua_State* L, int tableIndex)
// completePartialMatches finds keys that match the specified 'prefix' // completePartialMatches finds keys that match the specified 'prefix'
// Note: the table/object to be searched must be on the top of the Lua stack // Note: the table/object to be searched must be on the top of the Lua stack
static void completePartialMatches(lua_State* L, bool completeOnlyFunctions, const std::string& editBuffer, std::string_view prefix, static void completePartialMatches(
const AddCompletionCallback& addCompletionCallback) lua_State* L,
bool completeOnlyFunctions,
const std::string& editBuffer,
std::string_view prefix,
const AddCompletionCallback& addCompletionCallback
)
{ {
for (int i = 0; i < MaxTraversalLimit && lua_istable(L, -1); i++) for (int i = 0; i < MaxTraversalLimit && lua_istable(L, -1); i++)
{ {
@ -483,9 +488,14 @@ static void icGetCompletions(ic_completion_env_t* cenv, const char* editBuffer)
{ {
auto* L = reinterpret_cast<lua_State*>(ic_completion_arg(cenv)); auto* L = reinterpret_cast<lua_State*>(ic_completion_arg(cenv));
getCompletions(L, std::string(editBuffer), [cenv](const std::string& completion, const std::string& display) { getCompletions(
L,
std::string(editBuffer),
[cenv](const std::string& completion, const std::string& display)
{
ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr); ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr);
}); }
);
} }
static bool isMethodOrFunctionChar(const char* s, long len) static bool isMethodOrFunctionChar(const char* s, long len)
@ -788,9 +798,13 @@ int replMain(int argc, char** argv)
// note, there's no need to close the log explicitly as it will be closed when the process exits // note, there's no need to close the log explicitly as it will be closed when the process exits
FILE* codegenPerfLog = fopen(path, "w"); FILE* codegenPerfLog = fopen(path, "w");
Luau::CodeGen::setPerfLog(codegenPerfLog, [](void* context, uintptr_t addr, unsigned size, const char* symbol) { Luau::CodeGen::setPerfLog(
codegenPerfLog,
[](void* context, uintptr_t addr, unsigned size, const char* symbol)
{
fprintf(static_cast<FILE*>(context), "%016lx %08x %s\n", long(addr), size, symbol); fprintf(static_cast<FILE*>(context), "%016lx %08x %s\n", long(addr), size, symbol);
}); }
);
#else #else
fprintf(stderr, "--codegen-perf option is only supported on Linux\n"); fprintf(stderr, "--codegen-perf option is only supported on Linux\n");
return 1; return 1;

View File

@ -223,9 +223,15 @@ void RequireResolver::substituteAliasIfPresent(std::string& path)
std::optional<std::string> RequireResolver::getAlias(std::string alias) std::optional<std::string> RequireResolver::getAlias(std::string alias)
{ {
std::transform(alias.begin(), alias.end(), alias.begin(), [](unsigned char c) { std::transform(
alias.begin(),
alias.end(),
alias.begin(),
[](unsigned char c)
{
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
}); }
);
while (!config.aliases.count(alias) && !isConfigFullyResolved) while (!config.aliases.count(alias) && !isConfigFullyResolved)
{ {
parseNextConfig(); parseNextConfig();

View File

@ -212,8 +212,19 @@ public:
private: private:
// Instruction archetypes // Instruction archetypes
void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev, void placeBinary(
uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg); const char* name,
OperandX64 lhs,
OperandX64 rhs,
uint8_t codeimm8,
uint8_t codeimm,
uint8_t codeimmImm8,
uint8_t code8rev,
uint8_t coderev,
uint8_t code8,
uint8_t code,
uint8_t opreg
);
void placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code, uint8_t codeImm8, uint8_t opreg); void placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code, uint8_t codeImm8, uint8_t opreg);
void placeBinaryRegAndRegMem(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code); void placeBinaryRegAndRegMem(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code);
void placeBinaryRegMemAndReg(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code); void placeBinaryRegMemAndReg(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code);
@ -228,7 +239,16 @@ private:
void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, uint8_t coderev, bool setW, uint8_t mode, uint8_t prefix); void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, uint8_t coderev, bool setW, uint8_t mode, uint8_t prefix);
void placeAvx(const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); void placeAvx(const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t code, bool setW, uint8_t mode, uint8_t prefix);
void placeAvx( void placeAvx(
const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t imm8, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); const char* name,
OperandX64 dst,
OperandX64 src1,
OperandX64 src2,
uint8_t imm8,
uint8_t code,
bool setW,
uint8_t mode,
uint8_t prefix
);
// Instruction components // Instruction components
void placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs, int32_t extraCodeBytes = 0); void placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs, int32_t extraCodeBytes = 0);

View File

@ -25,7 +25,14 @@ struct CodeAllocator
// To allow allocation while previously allocated code is already running, allocation has page granularity // To allow allocation while previously allocated code is already running, allocation has page granularity
// It's important to group functions together so that page alignment won't result in a lot of wasted space // It's important to group functions together so that page alignment won't result in a lot of wasted space
bool allocate( bool allocate(
const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart); const uint8_t* data,
size_t dataSize,
const uint8_t* code,
size_t codeSize,
uint8_t*& result,
size_t& resultSize,
uint8_t*& resultCodeStart
);
// Provided to unwind info callbacks // Provided to unwind info callbacks
void* context = nullptr; void* context = nullptr;

View File

@ -77,8 +77,8 @@ struct IrOp;
using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength);
using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostVectorNamecallHandler = bool (*)( using HostVectorNamecallHandler =
IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos);
enum class HostMetamethod enum class HostMetamethod
{ {
@ -99,12 +99,21 @@ enum class HostMetamethod
using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength); using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength);
using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method); using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method);
using HostUserdataAccessHandler = bool (*)( using HostUserdataAccessHandler =
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); bool (*)(IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos);
using HostUserdataMetamethodHandler = bool (*)( using HostUserdataMetamethodHandler =
IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos); bool (*)(IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos);
using HostUserdataNamecallHandler = bool (*)( using HostUserdataNamecallHandler = bool (*)(
IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); IrBuilder& builder,
uint8_t type,
const char* member,
size_t memberLength,
int argResReg,
int sourceReg,
int params,
int results,
int pcpos
);
struct HostIrHooks struct HostIrHooks
{ {
@ -196,7 +205,11 @@ using UniqueSharedCodeGenContext = std::unique_ptr<SharedCodeGenContext, SharedC
[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext); [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext);
[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext( [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(
size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext); size_t blockSize,
size_t maxTotalSize,
AllocationCallback* allocationCallback,
void* allocationCallbackContext
);
// Destroys the provided SharedCodeGenContext. All Luau VMs using the // Destroys the provided SharedCodeGenContext. All Luau VMs using the
// SharedCodeGenContext must be destroyed before this function is called. // SharedCodeGenContext must be destroyed before this function is called.

View File

@ -135,7 +135,11 @@ struct IdfContext
// 'Iterated' comes from the definition where we recompute the IDFn+1 = DF(S) while adding IDFn to S until a fixed point is reached // 'Iterated' comes from the definition where we recompute the IDFn+1 = DF(S) while adding IDFn to S until a fixed point is reached
// Iterated dominance frontier has been shown to be equal to the set of nodes where phi instructions have to be inserted // Iterated dominance frontier has been shown to be equal to the set of nodes where phi instructions have to be inserted
void computeIteratedDominanceFrontierForDefs( void computeIteratedDominanceFrontierForDefs(
IdfContext& ctx, const IrFunction& function, const std::vector<uint32_t>& defBlocks, const std::vector<uint32_t>& liveInBlocks); IdfContext& ctx,
const IrFunction& function,
const std::vector<uint32_t>& defBlocks,
const std::vector<uint32_t>& liveInBlocks
);
// Function used to update all CFG data // Function used to update all CFG data
void computeCfgInfo(IrFunction& function); void computeCfgInfo(IrFunction& function);

View File

@ -36,9 +36,21 @@ const char* getBytecodeTypeName(uint8_t type, const char* const* userdataTypes);
void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes); void toString(std::string& result, const BytecodeTypes& bcTypes, const char* const* userdataTypes);
void toStringDetailed( void toStringDetailed(
IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, const IrInst& inst, uint32_t instIdx, IncludeUseInfo includeUseInfo); IrToStringContext& ctx,
void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t blockIdx, IncludeUseInfo includeUseInfo, IncludeCfgInfo includeCfgInfo, const IrBlock& block,
IncludeRegFlowInfo includeRegFlowInfo); uint32_t blockIdx,
const IrInst& inst,
uint32_t instIdx,
IncludeUseInfo includeUseInfo
);
void toStringDetailed(
IrToStringContext& ctx,
const IrBlock& block,
uint32_t blockIdx,
IncludeUseInfo includeUseInfo,
IncludeCfgInfo includeCfgInfo,
IncludeRegFlowInfo includeRegFlowInfo
);
std::string toString(const IrFunction& function, IncludeUseInfo includeUseInfo); std::string toString(const IrFunction& function, IncludeUseInfo includeUseInfo);

View File

@ -42,8 +42,12 @@ class SharedCodeAllocator;
class NativeModule class NativeModule
{ {
public: public:
NativeModule(SharedCodeAllocator* allocator, const std::optional<ModuleId>& moduleId, const uint8_t* moduleBaseAddress, NativeModule(
std::vector<NativeProtoExecDataPtr> nativeProtos) noexcept; SharedCodeAllocator* allocator,
const std::optional<ModuleId>& moduleId,
const uint8_t* moduleBaseAddress,
std::vector<NativeProtoExecDataPtr> nativeProtos
) noexcept;
NativeModule(const NativeModule&) = delete; NativeModule(const NativeModule&) = delete;
NativeModule(NativeModule&&) = delete; NativeModule(NativeModule&&) = delete;
@ -132,11 +136,22 @@ public:
// data and code such that it can be executed). Like std::map::insert, the // data and code such that it can be executed). Like std::map::insert, the
// bool result is true if a new module was created; false if an existing // bool result is true if a new module was created; false if an existing
// module is being returned. // module is being returned.
std::pair<NativeModuleRef, bool> getOrInsertNativeModule(const ModuleId& moduleId, std::vector<NativeProtoExecDataPtr> nativeProtos, std::pair<NativeModuleRef, bool> getOrInsertNativeModule(
const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize); const ModuleId& moduleId,
std::vector<NativeProtoExecDataPtr> nativeProtos,
const uint8_t* data,
size_t dataSize,
const uint8_t* code,
size_t codeSize
);
NativeModuleRef insertAnonymousNativeModule( NativeModuleRef insertAnonymousNativeModule(
std::vector<NativeProtoExecDataPtr> nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize); std::vector<NativeProtoExecDataPtr> nativeProtos,
const uint8_t* data,
size_t dataSize,
const uint8_t* code,
size_t codeSize
);
// If a NativeModule exists for the given ModuleId and that NativeModule // If a NativeModule exists for the given ModuleId and that NativeModule
// is no longer referenced, the NativeModule is destroyed. This should // is no longer referenced, the NativeModule is destroyed. This should

View File

@ -49,8 +49,13 @@ public:
// mov rbp, rsp // mov rbp, rsp
// push reg in the order specified in regs // push reg in the order specified in regs
// sub rsp, stackSize // sub rsp, stackSize
virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr, virtual void prologueX64(
const std::vector<X64::RegisterX64>& simd) = 0; uint32_t prologueSize,
uint32_t stackSize,
bool setupFrame,
std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd
) = 0;
virtual size_t getUnwindInfoSize(size_t blockSize) const = 0; virtual size_t getUnwindInfoSize(size_t blockSize) const = 0;

View File

@ -30,8 +30,13 @@ public:
void finishInfo() override; void finishInfo() override;
void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override; void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override;
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr, void prologueX64(
const std::vector<X64::RegisterX64>& simd) override; uint32_t prologueSize,
uint32_t stackSize,
bool setupFrame,
std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd
) override;
size_t getUnwindInfoSize(size_t blockSize = 0) const override; size_t getUnwindInfoSize(size_t blockSize = 0) const override;

View File

@ -50,8 +50,13 @@ public:
void finishInfo() override; void finishInfo() override;
void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override; void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override;
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr, void prologueX64(
const std::vector<X64::RegisterX64>& simd) override; uint32_t prologueSize,
uint32_t stackSize,
bool setupFrame,
std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd
) override;
size_t getUnwindInfoSize(size_t blockSize = 0) const override; size_t getUnwindInfoSize(size_t blockSize = 0) const override;

View File

@ -17,8 +17,8 @@ namespace A64
static const uint8_t codeForCondition[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}; static const uint8_t codeForCondition[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered"); static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
static const char* textForCondition[] = { static const char* textForCondition[] =
"b.eq", "b.ne", "b.cs", "b.cc", "b.mi", "b.pl", "b.vs", "b.vc", "b.hi", "b.ls", "b.ge", "b.lt", "b.gt", "b.le", "b.al"}; {"b.eq", "b.ne", "b.cs", "b.cc", "b.mi", "b.pl", "b.vs", "b.vc", "b.hi", "b.ls", "b.ge", "b.lt", "b.gt", "b.le", "b.al"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered"); static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
const unsigned kMaxAlign = 32; const unsigned kMaxAlign = 32;
@ -968,8 +968,10 @@ void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | ((shift < 0 ? -shift : shift) << 10) | (src2.index << 16) | (N << 21) | (int(shift < 0) << 22) | place(
(op << 24) | sf); dst.index | (src1.index << 5) | ((shift < 0 ? -shift : shift) << 10) | (src2.index << 16) | (N << 21) | (int(shift < 0) << 22) | (op << 24) |
sf
);
commit(); commit();
} }
@ -1173,7 +1175,15 @@ void AssemblyBuilderA64::placeP(const char* name, RegisterA64 src1, RegisterA64
} }
void AssemblyBuilderA64::placeCS( void AssemblyBuilderA64::placeCS(
const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond, uint8_t op, uint8_t opc, int invert) const char* name,
RegisterA64 dst,
RegisterA64 src1,
RegisterA64 src2,
ConditionA64 cond,
uint8_t op,
uint8_t opc,
int invert
)
{ {
if (logText) if (logText)
log(name, dst, src1, src2, cond); log(name, dst, src1, src2, cond);

View File

@ -15,21 +15,22 @@ namespace X64
// TODO: more assertions on operand sizes // TODO: more assertions on operand sizes
static const uint8_t codeForCondition[] = { static const uint8_t codeForCondition[] = {0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd,
0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd, 0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb}; 0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered"); static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
static const char* jccTextForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna", static const char* jccTextForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge",
"jnae", "jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"}; "jnb", "jnbe", "jna", "jnae", "jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"};
static_assert(sizeof(jccTextForCondition) / sizeof(jccTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered"); static_assert(sizeof(jccTextForCondition) / sizeof(jccTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
static const char* setccTextForCondition[] = {"seto", "setno", "setc", "setnc", "setb", "setbe", "seta", "setae", "sete", "setl", "setle", "setg", static const char* setccTextForCondition[] = {"seto", "setno", "setc", "setnc", "setb", "setbe", "seta", "setae", "sete",
"setge", "setnb", "setnbe", "setna", "setnae", "setne", "setnl", "setnle", "setng", "setnge", "setz", "setnz", "setp", "setnp"}; "setl", "setle", "setg", "setge", "setnb", "setnbe", "setna", "setnae", "setne",
"setnl", "setnle", "setng", "setnge", "setz", "setnz", "setp", "setnp"};
static_assert(sizeof(setccTextForCondition) / sizeof(setccTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered"); static_assert(sizeof(setccTextForCondition) / sizeof(setccTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
static const char* cmovTextForCondition[] = {"cmovo", "cmovno", "cmovc", "cmovnc", "cmovb", "cmovbe", "cmova", "cmovae", "cmove", "cmovl", "cmovle", static const char* cmovTextForCondition[] = {"cmovo", "cmovno", "cmovc", "cmovnc", "cmovb", "cmovbe", "cmova", "cmovae", "cmove",
"cmovg", "cmovge", "cmovnb", "cmovnbe", "cmovna", "cmovnae", "cmovne", "cmovnl", "cmovnle", "cmovng", "cmovnge", "cmovz", "cmovnz", "cmovp", "cmovl", "cmovle", "cmovg", "cmovge", "cmovnb", "cmovnbe", "cmovna", "cmovnae", "cmovne",
"cmovnp"}; "cmovnl", "cmovnle", "cmovng", "cmovnge", "cmovz", "cmovnz", "cmovp", "cmovnp"};
static_assert(sizeof(cmovTextForCondition) / sizeof(cmovTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered"); static_assert(sizeof(cmovTextForCondition) / sizeof(cmovTextForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
#define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7)) #define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7))
@ -1136,8 +1137,19 @@ unsigned AssemblyBuilderX64::getInstructionCount() const
return instructionCount; return instructionCount;
} }
void AssemblyBuilderX64::placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, void AssemblyBuilderX64::placeBinary(
uint8_t code8rev, uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg) const char* name,
OperandX64 lhs,
OperandX64 rhs,
uint8_t codeimm8,
uint8_t codeimm,
uint8_t codeimmImm8,
uint8_t code8rev,
uint8_t coderev,
uint8_t code8,
uint8_t code,
uint8_t opreg
)
{ {
if (logText) if (logText)
log(name, lhs, rhs); log(name, lhs, rhs);
@ -1292,7 +1304,15 @@ void AssemblyBuilderX64::placeAvx(const char* name, OperandX64 dst, OperandX64 s
} }
void AssemblyBuilderX64::placeAvx( void AssemblyBuilderX64::placeAvx(
const char* name, OperandX64 dst, OperandX64 src, uint8_t code, uint8_t coderev, bool setW, uint8_t mode, uint8_t prefix) const char* name,
OperandX64 dst,
OperandX64 src,
uint8_t code,
uint8_t coderev,
bool setW,
uint8_t mode,
uint8_t prefix
)
{ {
CODEGEN_ASSERT((dst.cat == CategoryX64::mem && src.cat == CategoryX64::reg) || (dst.cat == CategoryX64::reg && src.cat == CategoryX64::mem)); CODEGEN_ASSERT((dst.cat == CategoryX64::mem && src.cat == CategoryX64::reg) || (dst.cat == CategoryX64::reg && src.cat == CategoryX64::mem));
@ -1316,7 +1336,15 @@ void AssemblyBuilderX64::placeAvx(
} }
void AssemblyBuilderX64::placeAvx( void AssemblyBuilderX64::placeAvx(
const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t code, bool setW, uint8_t mode, uint8_t prefix) const char* name,
OperandX64 dst,
OperandX64 src1,
OperandX64 src2,
uint8_t code,
bool setW,
uint8_t mode,
uint8_t prefix
)
{ {
CODEGEN_ASSERT(dst.cat == CategoryX64::reg); CODEGEN_ASSERT(dst.cat == CategoryX64::reg);
CODEGEN_ASSERT(src1.cat == CategoryX64::reg); CODEGEN_ASSERT(src1.cat == CategoryX64::reg);
@ -1332,8 +1360,8 @@ void AssemblyBuilderX64::placeAvx(
commit(); commit();
} }
void AssemblyBuilderX64::placeAvx( void AssemblyBuilderX64::
const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t imm8, uint8_t code, bool setW, uint8_t mode, uint8_t prefix) placeAvx(const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t imm8, uint8_t code, bool setW, uint8_t mode, uint8_t prefix)
{ {
CODEGEN_ASSERT(dst.cat == CategoryX64::reg); CODEGEN_ASSERT(dst.cat == CategoryX64::reg);
CODEGEN_ASSERT(src1.cat == CategoryX64::reg); CODEGEN_ASSERT(src1.cat == CategoryX64::reg);
@ -1735,13 +1763,15 @@ const char* AssemblyBuilderX64::getSizeName(SizeX64 size) const
const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg) const const char* AssemblyBuilderX64::getRegisterName(RegisterX64 reg) const
{ {
static const char* names[][16] = {{"rip", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""}, static const char* names[][16] = {
{"rip", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""},
{"al", "cl", "dl", "bl", "spl", "bpl", "sil", "dil", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b"}, {"al", "cl", "dl", "bl", "spl", "bpl", "sil", "dil", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b"},
{"ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w", "r11w", "r12w", "r13w", "r14w", "r15w"}, {"ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w", "r11w", "r12w", "r13w", "r14w", "r15w"},
{"eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d"}, {"eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d"},
{"rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15"}, {"rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15"},
{"xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15"}, {"xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15"},
{"ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15"}}; {"ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15"}
};
CODEGEN_ASSERT(reg.index < 16); CODEGEN_ASSERT(reg.index < 16);
CODEGEN_ASSERT(reg.size <= SizeX64::ymmword); CODEGEN_ASSERT(reg.size <= SizeX64::ymmword);

View File

@ -116,12 +116,17 @@ void loadBytecodeTypeInfo(IrFunction& function)
static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo) static void prepareRegTypeInfoLookups(BytecodeTypeInfo& typeInfo)
{ {
// Sort by register first, then by end PC // Sort by register first, then by end PC
std::sort(typeInfo.regTypes.begin(), typeInfo.regTypes.end(), [](const BytecodeRegTypeInfo& a, const BytecodeRegTypeInfo& b) { std::sort(
typeInfo.regTypes.begin(),
typeInfo.regTypes.end(),
[](const BytecodeRegTypeInfo& a, const BytecodeRegTypeInfo& b)
{
if (a.reg != b.reg) if (a.reg != b.reg)
return a.reg < b.reg; return a.reg < b.reg;
return a.endpc < b.endpc; return a.endpc < b.endpc;
}); }
);
// Prepare data for all registers as 'regTypes' might be missing temporaries // Prepare data for all registers as 'regTypes' might be missing temporaries
typeInfo.regTypeOffsets.resize(256 + 1); typeInfo.regTypeOffsets.resize(256 + 1);
@ -805,8 +810,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (hostHooks.userdataMetamethodBytecodeType && else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
@ -837,8 +841,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
} }
else if (hostHooks.userdataMetamethodBytecodeType && else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{ {
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
} }
@ -860,8 +863,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (hostHooks.userdataMetamethodBytecodeType && else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
@ -883,8 +885,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (hostHooks.userdataMetamethodBytecodeType && else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
@ -915,8 +916,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
} }
else if (hostHooks.userdataMetamethodBytecodeType && else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{ {
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
} }
@ -938,8 +938,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (hostHooks.userdataMetamethodBytecodeType && else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
@ -960,8 +959,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
regTags[ra] = LBC_TYPE_NUMBER; regTags[ra] = LBC_TYPE_NUMBER;
else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
else if (hostHooks.userdataMetamethodBytecodeType && else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
bcType.result = regTags[ra]; bcType.result = regTags[ra];
@ -990,8 +988,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks)
if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_VECTOR; regTags[ra] = LBC_TYPE_VECTOR;
} }
else if (hostHooks.userdataMetamethodBytecodeType && else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
(isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b)))
{ {
regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op));
} }

View File

@ -143,7 +143,14 @@ CodeAllocator::~CodeAllocator()
} }
bool CodeAllocator::allocate( bool CodeAllocator::allocate(
const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart) const uint8_t* data,
size_t dataSize,
const uint8_t* code,
size_t codeSize,
uint8_t*& result,
size_t& resultSize,
uint8_t*& resultCodeStart
)
{ {
// 'Round up' to preserve code alignment // 'Round up' to preserve code alignment
size_t alignedDataSize = (dataSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); size_t alignedDataSize = (dataSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1);

Some files were not shown because too many files have changed in this diff Show More