From 99c0db3b0845b6e9450753d5ed45a2cf6f3e6e68 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 28 Oct 2022 01:22:49 +0300 Subject: [PATCH] Sync to upstream/release/551 --- .../include/Luau/ConstraintGraphBuilder.h | 55 +- Analysis/include/Luau/Error.h | 24 +- Analysis/include/Luau/Normalize.h | 75 +- Analysis/include/Luau/RecursionCounter.h | 22 +- Analysis/include/Luau/ToString.h | 2 + Analysis/include/Luau/TypeInfer.h | 13 +- Analysis/include/Luau/Unifier.h | 2 + Analysis/src/AstQuery.cpp | 85 ++- Analysis/src/ConstraintGraphBuilder.cpp | 437 +++++------ Analysis/src/ConstraintSolver.cpp | 44 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 58 +- Analysis/src/Error.cpp | 43 ++ Analysis/src/Frontend.cpp | 153 +++- Analysis/src/IostreamHelpers.cpp | 2 + Analysis/src/Normalize.cpp | 406 +++++++++- Analysis/src/ToString.cpp | 40 +- Analysis/src/TopoSortStatements.cpp | 5 +- Analysis/src/TypeAttach.cpp | 2 +- Analysis/src/TypeChecker2.cpp | 61 +- Analysis/src/TypeInfer.cpp | 81 +- Analysis/src/TypePack.cpp | 3 +- Analysis/src/TypeVar.cpp | 8 +- Analysis/src/Unifier.cpp | 61 +- Ast/include/Luau/ParseResult.h | 2 + Ast/include/Luau/Parser.h | 4 +- Ast/src/Parser.cpp | 64 +- CLI/Repl.cpp | 19 +- CMakeLists.txt | 5 + CodeGen/include/Luau/CodeGen.h | 2 +- CodeGen/src/CodeGen.cpp | 691 ++++++++++-------- CodeGen/src/CodeGenUtils.cpp | 76 ++ CodeGen/src/CodeGenUtils.h | 17 + CodeGen/src/EmitCommonX64.h | 15 +- CodeGen/src/EmitInstructionX64.cpp | 383 +++++++++- CodeGen/src/EmitInstructionX64.h | 16 +- CodeGen/src/NativeState.cpp | 23 +- CodeGen/src/NativeState.h | 8 + Compiler/include/Luau/BytecodeBuilder.h | 2 +- Compiler/src/BytecodeBuilder.cpp | 18 +- Makefile | 4 +- Sources.cmake | 3 + VM/include/luaconf.h | 8 + VM/src/lbuiltins.cpp | 26 +- VM/src/lbuiltins.h | 2 +- VM/src/lnumutils.h | 1 + VM/src/lvmexecute.cpp | 52 +- bench/tests/voxelgen.lua | 456 ++++++++++++ tests/AstQuery.test.cpp | 71 ++ tests/Fixture.cpp | 9 + tests/Fixture.h | 2 + tests/Frontend.test.cpp | 27 + tests/Module.test.cpp | 10 +- tests/Normalize.test.cpp | 141 +++- tests/Parser.test.cpp | 2 - tests/ToString.test.cpp | 2 + tests/TypeInfer.aliases.test.cpp | 10 + tests/TypeInfer.annotations.test.cpp | 10 +- tests/TypeInfer.anyerror.test.cpp | 2 - tests/TypeInfer.definitions.test.cpp | 17 + tests/TypeInfer.functions.test.cpp | 53 +- tests/TypeInfer.modules.test.cpp | 18 +- tests/TypeInfer.negations.test.cpp | 52 ++ tests/TypeInfer.operators.test.cpp | 42 +- tests/TypeInfer.provisional.test.cpp | 13 +- tests/TypeInfer.tables.test.cpp | 61 +- tests/VisitTypeVar.test.cpp | 9 +- tools/faillist.txt | 23 +- 67 files changed, 3222 insertions(+), 931 deletions(-) create mode 100644 CodeGen/src/CodeGenUtils.cpp create mode 100644 CodeGen/src/CodeGenUtils.h create mode 100644 bench/tests/voxelgen.lua create mode 100644 tests/TypeInfer.negations.test.cpp diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index dc5d4598..6106717c 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -23,6 +23,30 @@ using ScopePtr = std::shared_ptr; struct DcrLogger; +struct Inference +{ + TypeId ty = nullptr; + + Inference() = default; + + explicit Inference(TypeId ty) + : ty(ty) + { + } +}; + +struct InferencePack +{ + TypePackId tp = nullptr; + + InferencePack() = default; + + explicit InferencePack(TypePackId tp) + : tp(tp) + { + } +}; + struct ConstraintGraphBuilder { // A list of all the scopes in the module. This vector holds ownership of the @@ -130,8 +154,10 @@ struct ConstraintGraphBuilder void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); void visit(const ScopePtr& scope, AstStatError* error); - TypePackId checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes = {}); - TypePackId checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes = {}); + InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes = {}); + InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes = {}); + + InferencePack checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes); /** * Checks an expression that is expected to evaluate to one type. @@ -141,18 +167,19 @@ struct ConstraintGraphBuilder * surrounding context. Used to implement bidirectional type checking. * @return the type of the expression. */ - TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); + Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); - TypeId check(const ScopePtr& scope, AstExprLocal* local); - TypeId check(const ScopePtr& scope, AstExprGlobal* global); - TypeId check(const ScopePtr& scope, AstExprIndexName* indexName); - TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); - TypeId check(const ScopePtr& scope, AstExprUnary* unary); - TypeId check_(const ScopePtr& scope, AstExprUnary* unary); - TypeId check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); - TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); - TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); - TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprLocal* local); + Inference check(const ScopePtr& scope, AstExprGlobal* global); + Inference check(const ScopePtr& scope, AstExprIndexName* indexName); + Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); + Inference check(const ScopePtr& scope, AstExprUnary* unary); + Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); + Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); @@ -202,7 +229,7 @@ struct ConstraintGraphBuilder std::vector> createGenerics(const ScopePtr& scope, AstArray generics); std::vector> createGenericPacks(const ScopePtr& scope, AstArray packs); - TypeId flattenPack(const ScopePtr& scope, Location location, TypePackId tp); + Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); void reportError(Location location, TypeErrorData err); void reportCodeTooComplex(Location location); diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 7338627c..f7bd9d50 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -7,6 +7,8 @@ #include "Luau/Variant.h" #include "Luau/TypeArena.h" +LUAU_FASTFLAG(LuauIceExceptionInheritanceChange) + namespace Luau { struct TypeError; @@ -302,12 +304,20 @@ struct NormalizationTooComplex } }; +struct TypePackMismatch +{ + TypePackId wantedTp; + TypePackId givenTp; + + bool operator==(const TypePackMismatch& rhs) const; +}; + using TypeErrorData = Variant; + TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch>; struct TypeError { @@ -374,6 +384,10 @@ struct InternalErrorReporter class InternalCompilerError : public std::exception { public: + explicit InternalCompilerError(const std::string& message) + : message(message) + { + } explicit InternalCompilerError(const std::string& message, const std::string& moduleName) : message(message) , moduleName(moduleName) @@ -388,8 +402,14 @@ public: virtual const char* what() const throw(); const std::string message; - const std::string moduleName; + const std::optional moduleName; const std::optional location; }; +// These two function overloads only exist to facilitate fast flagging a change to InternalCompilerError +// Both functions can be removed when FFlagLuauIceExceptionInheritanceChange is removed and calling code +// can directly throw InternalCompilerError. +[[noreturn]] void throwRuntimeError(const std::string& message); +[[noreturn]] void throwRuntimeError(const std::string& message, const std::string& moduleName); + } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index a23d0fda..f98442dd 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -106,9 +106,68 @@ struct std::equal_to namespace Luau { -// A normalized string type is either `string` (represented by `nullopt`) -// or a union of string singletons. -using NormalizedStringType = std::optional>; +/** A normalized string type is either `string` (represented by `nullopt`) or a + * union of string singletons. + * + * When FFlagLuauNegatedStringSingletons is unset, the representation is as + * follows: + * + * * The `string` data type is represented by the option `singletons` having the + * value `std::nullopt`. + * * The type `never` is represented by `singletons` being populated with an + * empty map. + * * A union of string singletons is represented by a map populated by the names + * and TypeIds of the singletons contained therein. + * + * When FFlagLuauNegatedStringSingletons is set, the representation is as + * follows: + * + * * A union of string singletons is finite and includes the singletons named by + * the `singletons` field. + * * An intersection of negated string singletons is cofinite and includes the + * singletons excluded by the `singletons` field. It is implied that cofinite + * values are exclusions from `string` itself. + * * The `string` data type is a cofinite set minus zero elements. + * * The `never` data type is a finite set plus zero elements. + */ +struct NormalizedStringType +{ + // When false, this type represents a union of singleton string types. + // eg "a" | "b" | "c" + // + // When true, this type represents string intersected with negated string + // singleton types. + // eg string & ~"a" & ~"b" & ... + bool isCofinite = false; + + // TODO: This field cannot be nullopt when FFlagLuauNegatedStringSingletons + // is set. When clipping that flag, we can remove the wrapping optional. + std::optional> singletons; + + void resetToString(); + void resetToNever(); + + bool isNever() const; + bool isString() const; + + /// Returns true if the string has finite domain. + /// + /// Important subtlety: This method returns true for `never`. The empty set + /// is indeed an empty set. + bool isUnion() const; + + /// Returns true if the string has infinite domain. + bool isIntersection() const; + + bool includes(const std::string& str) const; + + static const NormalizedStringType never; + + NormalizedStringType() = default; + NormalizedStringType(bool isCofinite, std::optional> singletons); +}; + +bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr); // A normalized function type is either `never` (represented by `nullopt`) // or an intersection of function types. @@ -157,7 +216,7 @@ struct NormalizedType // The string part of the type. // This may be the `string` type, or a union of singletons. - NormalizedStringType strings = std::map{}; + NormalizedStringType strings; // The thread part of the type. // This type is either never or thread. @@ -231,8 +290,14 @@ public: bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); + // ------- Negations + NormalizedType negateNormal(const NormalizedType& here); + TypeIds negateAll(const TypeIds& theres); + TypeId negate(TypeId there); + void subtractPrimitive(NormalizedType& here, TypeId ty); + void subtractSingleton(NormalizedType& here, TypeId ty); + // ------- Normalizing intersections - void intersectTysWithTy(TypeIds& here, TypeId there); TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there); void intersectClasses(TypeIds& heres, const TypeIds& theres); diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index f964dbfe..632afd19 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Common.h" +#include "Luau/Error.h" #include #include @@ -9,10 +10,20 @@ namespace Luau { -struct RecursionLimitException : public std::exception +struct RecursionLimitException : public InternalCompilerError +{ + RecursionLimitException() + : InternalCompilerError("Internal recursion counter limit exceeded") + { + LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange); + } +}; + +struct RecursionLimitException_DEPRECATED : public std::exception { const char* what() const noexcept { + LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange); return "Internal recursion counter limit exceeded"; } }; @@ -42,7 +53,14 @@ struct RecursionLimiter : RecursionCounter { if (limit > 0 && *count > limit) { - throw RecursionLimitException(); + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw RecursionLimitException(); + } + else + { + throw RecursionLimitException_DEPRECATED(); + } } } }; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index dd2d709b..ff2561e6 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -117,6 +117,8 @@ inline std::string toStringNamedFunction(const std::string& funcName, const Func return toStringNamedFunction(funcName, ftv, opts); } +std::optional getFunctionNameAsString(const AstExpr& expr); + // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression std::string dump(TypeId ty); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 384637bb..c5d7501d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -48,7 +48,17 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; -class TimeLimitError : public std::exception +class TimeLimitError : public InternalCompilerError +{ +public: + explicit TimeLimitError(const std::string& moduleName) + : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) + { + LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange); + } +}; + +class TimeLimitError_DEPRECATED : public std::exception { public: virtual const char* what() const throw(); @@ -236,6 +246,7 @@ public: [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + [[noreturn]] void throwTimeLimitError(); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); ScopePtr childScope(const ScopePtr& parent, const Location& location); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index c15cae31..7bf4d50b 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -96,6 +96,8 @@ private: void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy); + void tryUnifyNegationWithType(TypeId subTy, TypeId superTy); TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 50299704..b93c2cc2 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -11,6 +11,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauCheckOverloadedDocSymbol, false) + namespace Luau { @@ -427,6 +429,38 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) return findVisitor.result; } +static std::optional checkOverloadedDocumentationSymbol( + const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) +{ + LUAU_ASSERT(FFlag::LuauCheckOverloadedDocSymbol); + + if (!documentationSymbol) + return std::nullopt; + + // This might be an overloaded function. + if (get(follow(ty))) + { + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) + { + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; + } + } + + return documentationSymbol; +} + std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) { std::vector ancestry = findAstAncestryOfPosition(source, position); @@ -436,31 +470,38 @@ std::optional getDocumentationSymbolAtPosition(const Source if (std::optional binding = findBindingAtPosition(module, source, position)) { - if (binding->documentationSymbol) + if (FFlag::LuauCheckOverloadedDocSymbol) { - // This might be an overloaded function binding. - if (get(follow(binding->typeId))) + return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol); + } + else + { + if (binding->documentationSymbol) { - TypeId matchingOverload = nullptr; - if (parentExpr && parentExpr->is()) + // This might be an overloaded function binding. + if (get(follow(binding->typeId))) { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) { - matchingOverload = *it; + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; } } - - if (matchingOverload) - { - std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; - // Default toString options are fine for this purpose. - overloadSymbol += toString(matchingOverload); - return overloadSymbol; - } } - } - return binding->documentationSymbol; + return binding->documentationSymbol; + } } if (targetExpr) @@ -474,14 +515,20 @@ std::optional getDocumentationSymbolAtPosition(const Source { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) { - return propIt->second.documentationSymbol; + if (FFlag::LuauCheckOverloadedDocSymbol) + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + else + return propIt->second.documentationSymbol; } } else if (const ClassTypeVar* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) { - return propIt->second.documentationSymbol; + if (FFlag::LuauCheckOverloadedDocSymbol) + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + else + return propIt->second.documentationSymbol; } } } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index de2b0a4e..455fc221 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -263,7 +263,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (hasAnnotation) expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes)); - TypePackId exprPack = checkPack(scope, value, expectedTypes); + TypePackId exprPack = checkPack(scope, value, expectedTypes).tp; if (i < local->vars.size) { @@ -292,7 +292,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (hasAnnotation) expectedType = varTypes.at(i); - TypeId exprType = check(scope, value, expectedType); + TypeId exprType = check(scope, value, expectedType).ty; if (i < varTypes.size()) { if (varTypes[i]) @@ -350,7 +350,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) if (!expr) return; - TypeId t = check(scope, expr); + TypeId t = check(scope, expr).ty; addConstraint(scope, expr->location, SubtypeConstraint{t, singletonTypes->numberType}); }; @@ -368,7 +368,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) { ScopePtr loopScope = childScope(forIn, scope); - TypePackId iterator = checkPack(scope, forIn->values); + TypePackId iterator = checkPack(scope, forIn->values).tp; std::vector variableTypes; variableTypes.reserve(forIn->vars.size); @@ -489,7 +489,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct } else if (AstExprIndexName* indexName = function->name->as()) { - TypeId containingTableType = check(scope, indexName->expr); + TypeId containingTableType = check(scope, indexName->expr).ty; functionType = arena->addType(BlockedTypeVar{}); @@ -531,7 +531,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) for (TypeId ty : scope->returnType) expectedTypes.push_back(ty); - TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes); + TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); } @@ -545,7 +545,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { TypePackId varPackId = checkLValues(scope, assign->vars); - TypePackId valuePack = checkPack(scope, assign->values); + TypePackId valuePack = checkPack(scope, assign->values).tp; addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); } @@ -732,7 +732,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) { - std::vector> generics = createGenerics(scope, global->generics); std::vector> genericPacks = createGenericPacks(scope, global->genericPacks); @@ -779,7 +778,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) check(scope, expr); } -TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes) { std::vector head; std::optional tail; @@ -792,201 +791,180 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray expectedType; if (i < expectedTypes.size()) expectedType = expectedTypes[i]; - head.push_back(check(scope, expr)); + head.push_back(check(scope, expr).ty); } else { std::vector expectedTailTypes; if (i < expectedTypes.size()) expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); - tail = checkPack(scope, expr, expectedTailTypes); + tail = checkPack(scope, expr, expectedTailTypes).tp; } } if (head.empty() && tail) - return *tail; + return InferencePack{*tail}; else - return arena->addTypePack(TypePack{std::move(head), tail}); + return InferencePack{arena->addTypePack(TypePack{std::move(head), tail})}; } -TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes) { RecursionCounter counter{&recursionCount}; if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(expr->location); - return singletonTypes->errorRecoveryTypePack(); + return InferencePack{singletonTypes->errorRecoveryTypePack()}; } - TypePackId result = nullptr; + InferencePack result; if (AstExprCall* call = expr->as()) - { - TypeId fnType = check(scope, call->func); - const size_t constraintIndex = scope->constraints.size(); - const size_t scopeIndex = scopes.size(); - - std::vector args; - - for (AstExpr* arg : call->args) - { - args.push_back(check(scope, arg)); - } - - // TODO self - - if (matchSetmetatable(*call)) - { - LUAU_ASSERT(args.size() == 2); - TypeId target = args[0]; - TypeId mt = args[1]; - - MetatableTypeVar mtv{target, mt}; - TypeId resultTy = arena->addType(mtv); - result = arena->addTypePack({resultTy}); - } - else - { - const size_t constraintEndIndex = scope->constraints.size(); - const size_t scopeEndIndex = scopes.size(); - - astOriginalCallTypes[call->func] = fnType; - - TypeId instantiatedType = arena->addType(BlockedTypeVar{}); - // TODO: How do expectedTypes play into this? Do they? - TypePackId rets = arena->addTypePack(BlockedTypePack{}); - TypePackId argPack = arena->addTypePack(TypePack{args, {}}); - FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); - TypeId inferredFnType = arena->addType(ftv); - - scope->unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); - NotNull ic(scope->unqueuedConstraints.back().get()); - - scope->unqueuedConstraints.push_back( - std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); - NotNull sc(scope->unqueuedConstraints.back().get()); - - // We force constraints produced by checking function arguments to wait - // until after we have resolved the constraint on the function itself. - // This ensures, for instance, that we start inferring the contents of - // lambdas under the assumption that their arguments and return types - // will be compatible with the enclosing function call. - for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) - scope->constraints[ci]->dependencies.push_back(sc); - - for (size_t si = scopeIndex; si < scopeEndIndex; ++si) - { - for (auto& c : scopes[si].second->constraints) - { - c->dependencies.push_back(sc); - } - } - - addConstraint(scope, call->func->location, - FunctionCallConstraint{ - {ic, sc}, - fnType, - argPack, - rets, - call, - }); - - result = rets; - } - } + result = {checkPack(scope, call, expectedTypes)}; else if (AstExprVarargs* varargs = expr->as()) { if (scope->varargPack) - result = *scope->varargPack; + result = InferencePack{*scope->varargPack}; else - result = singletonTypes->errorRecoveryTypePack(); + result = InferencePack{singletonTypes->errorRecoveryTypePack()}; } else { std::optional expectedType; if (!expectedTypes.empty()) expectedType = expectedTypes[0]; - TypeId t = check(scope, expr, expectedType); - result = arena->addTypePack({t}); + TypeId t = check(scope, expr, expectedType).ty; + result = InferencePack{arena->addTypePack({t})}; } - LUAU_ASSERT(result); - astTypePacks[expr] = result; + LUAU_ASSERT(result.tp); + astTypePacks[expr] = result.tp; return result; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType) +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) +{ + TypeId fnType = check(scope, call->func).ty; + const size_t constraintIndex = scope->constraints.size(); + const size_t scopeIndex = scopes.size(); + + std::vector args; + + for (AstExpr* arg : call->args) + { + args.push_back(check(scope, arg).ty); + } + + // TODO self + + if (matchSetmetatable(*call)) + { + LUAU_ASSERT(args.size() == 2); + TypeId target = args[0]; + TypeId mt = args[1]; + + AstExpr* targetExpr = call->args.data[0]; + + MetatableTypeVar mtv{target, mt}; + TypeId resultTy = arena->addType(mtv); + + if (AstExprLocal* targetLocal = targetExpr->as()) + scope->bindings[targetLocal->local].typeId = resultTy; + + return InferencePack{arena->addTypePack({resultTy})}; + } + else + { + const size_t constraintEndIndex = scope->constraints.size(); + const size_t scopeEndIndex = scopes.size(); + + astOriginalCallTypes[call->func] = fnType; + + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); + // TODO: How do expectedTypes play into this? Do they? + TypePackId rets = arena->addTypePack(BlockedTypePack{}); + TypePackId argPack = arena->addTypePack(TypePack{args, {}}); + FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); + TypeId inferredFnType = arena->addType(ftv); + + scope->unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); + NotNull ic(scope->unqueuedConstraints.back().get()); + + scope->unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); + NotNull sc(scope->unqueuedConstraints.back().get()); + + // We force constraints produced by checking function arguments to wait + // until after we have resolved the constraint on the function itself. + // This ensures, for instance, that we start inferring the contents of + // lambdas under the assumption that their arguments and return types + // will be compatible with the enclosing function call. + for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) + scope->constraints[ci]->dependencies.push_back(sc); + + for (size_t si = scopeIndex; si < scopeEndIndex; ++si) + { + for (auto& c : scopes[si].second->constraints) + { + c->dependencies.push_back(sc); + } + } + + addConstraint(scope, call->func->location, + FunctionCallConstraint{ + {ic, sc}, + fnType, + argPack, + rets, + call, + }); + + return InferencePack{rets}; + } +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType) { RecursionCounter counter{&recursionCount}; if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(expr->location); - return singletonTypes->errorRecoveryType(); + return Inference{singletonTypes->errorRecoveryType()}; } - TypeId result = nullptr; + Inference result; if (auto group = expr->as()) result = check(scope, group->expr, expectedType); else if (auto stringExpr = expr->as()) - { - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - if (get(expectedTy) || get(expectedTy)) - { - result = arena->addType(BlockedTypeVar{}); - TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringExpr->value.data, stringExpr->value.size)})); - addConstraint(scope, expr->location, PrimitiveTypeConstraint{result, expectedTy, singletonType, singletonTypes->stringType}); - } - else if (maybeSingleton(expectedTy)) - result = arena->addType(SingletonTypeVar{StringSingleton{std::string{stringExpr->value.data, stringExpr->value.size}}}); - else - result = singletonTypes->stringType; - } - else - result = singletonTypes->stringType; - } + result = check(scope, stringExpr, expectedType); else if (expr->is()) - result = singletonTypes->numberType; + result = Inference{singletonTypes->numberType}; else if (auto boolExpr = expr->as()) - { - if (expectedType) - { - const TypeId expectedTy = follow(*expectedType); - const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; - - if (get(expectedTy) || get(expectedTy)) - { - result = arena->addType(BlockedTypeVar{}); - addConstraint(scope, expr->location, PrimitiveTypeConstraint{result, expectedTy, singletonType, singletonTypes->booleanType}); - } - else if (maybeSingleton(expectedTy)) - result = singletonType; - else - result = singletonTypes->booleanType; - } - else - result = singletonTypes->booleanType; - } + result = check(scope, boolExpr, expectedType); else if (expr->is()) - result = singletonTypes->nilType; + result = Inference{singletonTypes->nilType}; else if (auto local = expr->as()) result = check(scope, local); else if (auto global = expr->as()) result = check(scope, global); else if (expr->is()) result = flattenPack(scope, expr->location, checkPack(scope, expr)); - else if (expr->is()) - result = flattenPack(scope, expr->location, checkPack(scope, expr)); // TODO: needs predicates too + else if (auto call = expr->as()) + { + std::vector expectedTypes; + if (expectedType) + expectedTypes.push_back(*expectedType); + result = flattenPack(scope, expr->location, checkPack(scope, call, expectedTypes)); // TODO: needs predicates too + } else if (auto a = expr->as()) { FunctionSignature sig = checkFunctionSignature(scope, a); checkFunctionBody(sig.bodyScope, a); - return sig.signature; + return Inference{sig.signature}; } else if (auto indexName = expr->as()) result = check(scope, indexName); @@ -1008,20 +986,63 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std:: for (AstExpr* subExpr : err->expressions) check(scope, subExpr); - result = singletonTypes->errorRecoveryType(); + result = Inference{singletonTypes->errorRecoveryType()}; } else { LUAU_ASSERT(0); - result = freshType(scope); + result = Inference{freshType(scope)}; } - LUAU_ASSERT(result); - astTypes[expr] = result; + LUAU_ASSERT(result.ty); + astTypes[expr] = result.ty; return result; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType) +{ + if (expectedType) + { + const TypeId expectedTy = follow(*expectedType); + if (get(expectedTy) || get(expectedTy)) + { + TypeId ty = arena->addType(BlockedTypeVar{}); + TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(string->value.data, string->value.size)})); + addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->stringType}); + return Inference{ty}; + } + else if (maybeSingleton(expectedTy)) + return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + + return Inference{singletonTypes->stringType}; + } + + return Inference{singletonTypes->stringType}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType) +{ + if (expectedType) + { + const TypeId expectedTy = follow(*expectedType); + const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; + + if (get(expectedTy) || get(expectedTy)) + { + TypeId ty = arena->addType(BlockedTypeVar{}); + addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->booleanType}); + return Inference{ty}; + } + else if (maybeSingleton(expectedTy)) + return Inference{singletonType}; + + return Inference{singletonTypes->booleanType}; + } + + return Inference{singletonTypes->booleanType}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) { std::optional resultTy; @@ -1035,26 +1056,26 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) } if (!resultTy) - return singletonTypes->errorRecoveryType(); // TODO: replace with ice, locals should never exist before its definition. + return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. - return *resultTy; + return Inference{*resultTy}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) { if (std::optional ty = scope->lookup(global->name)) - return *ty; + return Inference{*ty}; /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any * global that is not already in-scope is definitely an unknown symbol. */ reportError(global->location, UnknownSymbol{global->name.value}); - return singletonTypes->errorRecoveryType(); + return Inference{singletonTypes->errorRecoveryType()}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { - TypeId obj = check(scope, indexName->expr); + TypeId obj = check(scope, indexName->expr).ty; TypeId result = freshType(scope); TableTypeVar::Props props{{indexName->index.value, Property{result}}}; @@ -1065,13 +1086,13 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* in addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); - return result; + return Inference{result}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) { - TypeId obj = check(scope, indexExpr->expr); - TypeId indexType = check(scope, indexExpr->index); + TypeId obj = check(scope, indexExpr->expr).ty; + TypeId indexType = check(scope, indexExpr->index).ty; TypeId result = freshType(scope); @@ -1081,61 +1102,49 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* in addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); - return result; + return Inference{result}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { - TypeId operandType = check_(scope, unary); + TypeId operandType = check(scope, unary->expr).ty; TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); - return resultType; + return Inference{resultType}; } -TypeId ConstraintGraphBuilder::check_(const ScopePtr& scope, AstExprUnary* unary) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) { - if (unary->op == AstExprUnary::Not) - { - TypeId ty = check(scope, unary->expr, std::nullopt); - - return ty; - } - - return check(scope, unary->expr); -} - -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) -{ - TypeId leftType = check(scope, binary->left, expectedType); - TypeId rightType = check(scope, binary->right, expectedType); + TypeId leftType = check(scope, binary->left, expectedType).ty; + TypeId rightType = check(scope, binary->right, expectedType).ty; TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); - return resultType; + return Inference{resultType}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) { check(scope, ifElse->condition); - TypeId thenType = check(scope, ifElse->trueExpr, expectedType); - TypeId elseType = check(scope, ifElse->falseExpr, expectedType); + TypeId thenType = check(scope, ifElse->trueExpr, expectedType).ty; + TypeId elseType = check(scope, ifElse->falseExpr, expectedType).ty; if (ifElse->hasElse) { TypeId resultType = expectedType ? *expectedType : freshType(scope); addConstraint(scope, ifElse->trueExpr->location, SubtypeConstraint{thenType, resultType}); addConstraint(scope, ifElse->falseExpr->location, SubtypeConstraint{elseType, resultType}); - return resultType; + return Inference{resultType}; } - return thenType; + return Inference{thenType}; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) { check(scope, typeAssert->expr, std::nullopt); - return resolveType(scope, typeAssert->annotation); + return Inference{resolveType(scope, typeAssert->annotation)}; } TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) @@ -1286,22 +1295,22 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) auto dottedPath = extractDottedName(expr); if (!dottedPath) - return check(scope, expr); + return check(scope, expr).ty; const auto [sym, segments] = std::move(*dottedPath); if (!sym.local) - return check(scope, expr); + return check(scope, expr).ty; auto lookupResult = scope->lookupEx(sym); if (!lookupResult) - return check(scope, expr); + return check(scope, expr).ty; const auto [ty, symbolScope] = std::move(*lookupResult); TypeId replaceTy = arena->freshType(scope.get()); std::optional updatedType = updateTheTableType(arena, ty, segments, replaceTy); if (!updatedType) - return check(scope, expr); + return check(scope, expr).ty; std::optional def = dfg->getDef(sym); LUAU_ASSERT(def); @@ -1310,7 +1319,7 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) return replaceTy; } -TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) { TypeId ty = arena->addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(ty); @@ -1344,16 +1353,14 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, } } - TypeId itemTy = check(scope, item.value, expectedValueType); - if (get(follow(itemTy))) - return ty; + TypeId itemTy = check(scope, item.value, expectedValueType).ty; if (item.key) { // Even though we don't need to use the type of the item's key if // it's a string constant, we still want to check it to populate // astTypes. - TypeId keyTy = check(scope, item.key); + TypeId keyTy = check(scope, item.key).ty; if (AstExprConstantString* key = item.key->as()) { @@ -1373,7 +1380,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, } } - return ty; + return Inference{ty}; } ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn) @@ -1541,9 +1548,18 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b } } - std::optional alias = scope->lookupType(ref->name.value); + std::optional alias; - if (alias.has_value() || ref->prefix.has_value()) + if (ref->prefix.has_value()) + { + alias = scope->lookupImportedType(ref->prefix->value, ref->name.value); + } + else + { + alias = scope->lookupType(ref->name.value); + } + + if (alias.has_value()) { // If the alias is not generic, we don't need to set up a blocked // type and an instantiation constraint. @@ -1586,7 +1602,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b } else { - reportError(ty->location, UnknownSymbol{ref->name.value, UnknownSymbol::Context::Type}); + std::string typeName; + if (ref->prefix) + typeName = std::string(ref->prefix->value) + "."; + typeName += ref->name.value; + result = singletonTypes->errorRecoveryType(); } } @@ -1685,7 +1705,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b else if (auto tof = ty->as()) { // TODO: Recursion limit. - TypeId exprType = check(scope, tof->expr); + TypeId exprType = check(scope, tof->expr).ty; result = exprType; } else if (auto unionAnnotation = ty->as()) @@ -1694,7 +1714,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (AstType* part : unionAnnotation->types) { // TODO: Recursion limit. - parts.push_back(resolveType(scope, part)); + parts.push_back(resolveType(scope, part, topLevel)); } result = arena->addType(UnionTypeVar{parts}); @@ -1705,7 +1725,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (AstType* part : intersectionAnnotation->types) { // TODO: Recursion limit. - parts.push_back(resolveType(scope, part)); + parts.push_back(resolveType(scope, part, topLevel)); } result = arena->addType(IntersectionTypeVar{parts}); @@ -1795,10 +1815,7 @@ std::vector> ConstraintGraphBuilder::crea if (generic.defaultValue) defaultTy = resolveType(scope, generic.defaultValue); - result.push_back({generic.name.value, GenericTypeDefinition{ - genericTy, - defaultTy, - }}); + result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); } return result; @@ -1816,19 +1833,17 @@ std::vector> ConstraintGraphBuilder:: if (generic.defaultValue) defaultTy = resolveTypePack(scope, generic.defaultValue); - result.push_back({generic.name.value, GenericTypePackDefinition{ - genericTy, - defaultTy, - }}); + result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); } return result; } -TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, TypePackId tp) +Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) { + auto [tp] = pack; if (auto f = first(tp)) - return *f; + return Inference{*f}; TypeId typeResult = freshType(scope); TypePack onePack{{typeResult}, freshTypePack(scope)}; @@ -1836,7 +1851,7 @@ TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location locat addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); - return typeResult; + return Inference{typeResult}; } void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 60f4666a..5e43be0f 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -544,6 +544,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNullty.emplace(singletonTypes->numberType); return true; } @@ -552,13 +553,46 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull(operandType) || get(operandType)) { asMutable(c.resultType)->ty.emplace(c.operandType); - return true; } - break; + else if (std::optional mm = findMetatableEntry(singletonTypes, errors, operandType, "__unm", constraint->location)) + { + const FunctionTypeVar* ftv = get(follow(*mm)); + + if (!ftv) + { + if (std::optional callMm = findMetatableEntry(singletonTypes, errors, follow(*mm), "__call", constraint->location)) + { + ftv = get(follow(*callMm)); + } + } + + if (!ftv) + { + asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + return true; + } + + TypePackId argsPack = arena->addTypePack({operandType}); + unify(ftv->argTypes, argsPack, constraint->scope); + + TypeId result = singletonTypes->errorRecoveryType(); + if (ftv) + { + result = first(ftv->retTypes).value_or(singletonTypes->errorRecoveryType()); + } + + asMutable(c.resultType)->ty.emplace(result); + } + else + { + asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + } + + return true; } } - LUAU_ASSERT(false); // TODO metatable handling + LUAU_ASSERT(false); return false; } @@ -862,6 +896,10 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNullname = c.name; else if (MetatableTypeVar* mtv = getMutable(target)) mtv->syntheticName = c.name; + else if (get(target) || get(target)) + { + // nothing (yet) + } else return block(c.namedType, constraint); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 0f04ace0..67abbff1 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -26,34 +26,34 @@ declare bit32: { } declare math: { - frexp: (number) -> (number, number), - ldexp: (number, number) -> number, - fmod: (number, number) -> number, - modf: (number) -> (number, number), - pow: (number, number) -> number, - exp: (number) -> number, + frexp: (n: number) -> (number, number), + ldexp: (s: number, e: number) -> number, + fmod: (x: number, y: number) -> number, + modf: (n: number) -> (number, number), + pow: (x: number, y: number) -> number, + exp: (n: number) -> number, - ceil: (number) -> number, - floor: (number) -> number, - abs: (number) -> number, - sqrt: (number) -> number, + ceil: (n: number) -> number, + floor: (n: number) -> number, + abs: (n: number) -> number, + sqrt: (n: number) -> number, - log: (number, number?) -> number, - log10: (number) -> number, + log: (n: number, base: number?) -> number, + log10: (n: number) -> number, - rad: (number) -> number, - deg: (number) -> number, + rad: (n: number) -> number, + deg: (n: number) -> number, - sin: (number) -> number, - cos: (number) -> number, - tan: (number) -> number, - sinh: (number) -> number, - cosh: (number) -> number, - tanh: (number) -> number, - atan: (number) -> number, - acos: (number) -> number, - asin: (number) -> number, - atan2: (number, number) -> number, + sin: (n: number) -> number, + cos: (n: number) -> number, + tan: (n: number) -> number, + sinh: (n: number) -> number, + cosh: (n: number) -> number, + tanh: (n: number) -> number, + atan: (n: number) -> number, + acos: (n: number) -> number, + asin: (n: number) -> number, + atan2: (y: number, x: number) -> number, min: (number, ...number) -> number, max: (number, ...number) -> number, @@ -61,13 +61,13 @@ declare math: { pi: number, huge: number, - randomseed: (number) -> (), + randomseed: (seed: number) -> (), random: (number?, number?) -> number, - sign: (number) -> number, - clamp: (number, number, number) -> number, - noise: (number, number?, number?) -> number, - round: (number) -> number, + sign: (n: number) -> number, + clamp: (n: number, min: number, max: number) -> number, + noise: (x: number, y: number?, z: number?) -> number, + round: (n: number) -> number, } type DateTypeArg = { diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index e5553003..ed1a49cd 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,6 +7,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauIceExceptionInheritanceChange, false) + static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { @@ -460,6 +462,11 @@ struct ErrorConverter { return "Code is too complex to typecheck! Consider simplifying the code around this area"; } + + std::string operator()(const TypePackMismatch& e) const + { + return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'"; + } }; struct InvalidNameChecker @@ -718,6 +725,11 @@ bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const return left == rhs.left && right == rhs.right; } +bool TypePackMismatch::operator==(const TypePackMismatch& rhs) const +{ + return *wantedTp == *rhs.wantedTp && *givenTp == *rhs.givenTp; +} + std::string toString(const TypeError& error) { return toString(error, TypeErrorToStringOptions{}); @@ -869,6 +881,11 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) else if constexpr (std::is_same_v) { } + else if constexpr (std::is_same_v) + { + e.wantedTp = clone(e.wantedTp); + e.givenTp = clone(e.givenTp); + } else static_assert(always_false_v, "Non-exhaustive type switch"); } @@ -913,4 +930,30 @@ const char* InternalCompilerError::what() const throw() return this->message.data(); } +// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. +void throwRuntimeError(const std::string& message) +{ + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw InternalCompilerError(message); + } + else + { + throw std::runtime_error(message); + } +} + +// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. +void throwRuntimeError(const std::string& message, const std::string& moduleName) +{ + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw InternalCompilerError(message, moduleName); + } + else + { + throw std::runtime_error(message); + } +} + } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 8f2a3ebd..39e6428d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -30,6 +30,8 @@ LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAGVARIABLE(LuauFixMarkDirtyReverseDeps, false) +LUAU_FASTFLAGVARIABLE(LuauPersistTypesAfterGeneratingDocSyms, false) namespace Luau { @@ -110,24 +112,57 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c CloneState cloneState; - for (const auto& [name, ty] : checkedModule->declaredGlobals) + if (FFlag::LuauPersistTypesAfterGeneratingDocSyms) { - TypeId globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); - persist(globalTy); + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + typesToPersist.push_back(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + globalScope->exportedTypeBindings[name] = globalTy; + + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); + } } - - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + else { - TypeFun globalTy = clone(ty, globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - globalScope->exportedTypeBindings[name] = globalTy; + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - persist(globalTy.type); + persist(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + globalScope->exportedTypeBindings[name] = globalTy; + + persist(globalTy.type); + } } return LoadDefinitionFileResult{true, parseResult, checkedModule}; @@ -159,24 +194,57 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t CloneState cloneState; - for (const auto& [name, ty] : checkedModule->declaredGlobals) + if (FFlag::LuauPersistTypesAfterGeneratingDocSyms) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); - persist(globalTy); + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + typesToPersist.push_back(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; + + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); + } } - - for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + else { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - targetScope->exportedTypeBindings[name] = globalTy; + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - persist(globalTy.type); + persist(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; + + persist(globalTy.type); + } } return LoadDefinitionFileResult{true, parseResult, checkedModule}; @@ -425,13 +493,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); + throwRuntimeError("Frontend::modules does not have data for " + name, name); } else { auto it2 = moduleResolver.modules.find(name); if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); + throwRuntimeError("Frontend::modules does not have data for " + name, name); } return CheckResult{ @@ -538,7 +606,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional* marked sourceNode.dirtyModule = true; sourceNode.dirtyModuleForAutocomplete = true; - if (0 == reverseDeps.count(name)) - continue; + if (FFlag::LuauFixMarkDirtyReverseDeps) + { + if (0 == reverseDeps.count(next)) + continue; - sourceModules.erase(name); + sourceModules.erase(next); - const std::vector& dependents = reverseDeps[name]; - queue.insert(queue.end(), dependents.begin(), dependents.end()); + const std::vector& dependents = reverseDeps[next]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); + } + else + { + if (0 == reverseDeps.count(name)) + continue; + + sourceModules.erase(name); + + const std::vector& dependents = reverseDeps[name]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); + } } } @@ -993,11 +1074,11 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const double timestamp = getTimestamp(); - auto parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); + Luau::ParseResult parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); stats.timeParse += getTimestamp() - timestamp; stats.files++; - stats.lines += std::count(src.begin(), src.end(), '\n') + (src.size() && src.back() != '\n'); + stats.lines += parseResult.lines; if (!parseResult.errors.empty()) sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end()); diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index e4fac455..b47270a0 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -188,6 +188,8 @@ static void errorToString(std::ostream& stream, const T& err) stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; else if constexpr (std::is_same_v) stream << "NormalizationTooComplex { }"; + else if constexpr (std::is_same_v) + stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }"; else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index cea159c3..5ef4b7e7 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -7,6 +7,7 @@ #include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/TypeVar.h" #include "Luau/Unifier.h" #include "Luau/VisitTypeVar.h" @@ -18,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); +LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); @@ -107,12 +109,110 @@ bool TypeIds::operator==(const TypeIds& there) const return hash == there.hash && types == there.types; } +NormalizedStringType::NormalizedStringType(bool isCofinite, std::optional> singletons) + : isCofinite(isCofinite) + , singletons(std::move(singletons)) +{ + if (!FFlag::LuauNegatedStringSingletons) + LUAU_ASSERT(!isCofinite); +} + +void NormalizedStringType::resetToString() +{ + if (FFlag::LuauNegatedStringSingletons) + { + isCofinite = true; + singletons->clear(); + } + else + singletons.reset(); +} + +void NormalizedStringType::resetToNever() +{ + if (FFlag::LuauNegatedStringSingletons) + { + isCofinite = false; + singletons.emplace(); + } + else + { + if (singletons) + singletons->clear(); + else + singletons.emplace(); + } +} + +bool NormalizedStringType::isNever() const +{ + if (FFlag::LuauNegatedStringSingletons) + return !isCofinite && singletons->empty(); + else + return singletons && singletons->empty(); +} + +bool NormalizedStringType::isString() const +{ + if (FFlag::LuauNegatedStringSingletons) + return isCofinite && singletons->empty(); + else + return !singletons; +} + +bool NormalizedStringType::isUnion() const +{ + if (FFlag::LuauNegatedStringSingletons) + return !isCofinite; + else + return singletons.has_value(); +} + +bool NormalizedStringType::isIntersection() const +{ + if (FFlag::LuauNegatedStringSingletons) + return isCofinite; + else + return false; +} + +bool NormalizedStringType::includes(const std::string& str) const +{ + if (isString()) + return true; + else if (isUnion() && singletons->count(str)) + return true; + else if (isIntersection() && !singletons->count(str)) + return true; + else + return false; +} + +const NormalizedStringType NormalizedStringType::never{false, {{}}}; + +bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr) +{ + if (subStr.isUnion() && superStr.isUnion()) + { + for (auto [name, ty] : *subStr.singletons) + { + if (!superStr.singletons->count(name)) + return false; + } + } + else if (subStr.isString() && superStr.isUnion()) + return false; + + return true; +} + NormalizedType::NormalizedType(NotNull singletonTypes) : tops(singletonTypes->neverType) , booleans(singletonTypes->neverType) , errors(singletonTypes->neverType) , nils(singletonTypes->neverType) , numbers(singletonTypes->neverType) + , strings{NormalizedStringType::never} , threads(singletonTypes->neverType) { } @@ -120,7 +220,7 @@ NormalizedType::NormalizedType(NotNull singletonTypes) static bool isInhabited(const NormalizedType& norm) { return !get(norm.tops) || !get(norm.booleans) || !norm.classes.empty() || !get(norm.errors) || - !get(norm.nils) || !get(norm.numbers) || !norm.strings || !norm.strings->empty() || + !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty(); } @@ -183,10 +283,10 @@ static bool isNormalizedNumber(TypeId ty) static bool isNormalizedString(const NormalizedStringType& ty) { - if (!ty) + if (ty.isString()) return true; - for (auto& [str, ty] : *ty) + for (auto& [str, ty] : *ty.singletons) { if (const SingletonTypeVar* stv = get(ty)) { @@ -317,10 +417,7 @@ void Normalizer::clearNormal(NormalizedType& norm) norm.errors = singletonTypes->neverType; norm.nils = singletonTypes->neverType; norm.numbers = singletonTypes->neverType; - if (norm.strings) - norm.strings->clear(); - else - norm.strings.emplace(); + norm.strings.resetToNever(); norm.threads = singletonTypes->neverType; norm.tables.clear(); norm.functions = std::nullopt; @@ -495,10 +592,56 @@ void Normalizer::unionClasses(TypeIds& heres, const TypeIds& theres) void Normalizer::unionStrings(NormalizedStringType& here, const NormalizedStringType& there) { - if (!there) - here.reset(); - else if (here) - here->insert(there->begin(), there->end()); + if (FFlag::LuauNegatedStringSingletons) + { + if (there.isString()) + here.resetToString(); + else if (here.isUnion() && there.isUnion()) + here.singletons->insert(there.singletons->begin(), there.singletons->end()); + else if (here.isUnion() && there.isIntersection()) + { + here.isCofinite = true; + for (const auto& pair : *there.singletons) + { + auto it = here.singletons->find(pair.first); + if (it != end(*here.singletons)) + here.singletons->erase(it); + else + here.singletons->insert(pair); + } + } + else if (here.isIntersection() && there.isUnion()) + { + for (const auto& [name, ty] : *there.singletons) + here.singletons->erase(name); + } + else if (here.isIntersection() && there.isIntersection()) + { + auto iter = begin(*here.singletons); + auto endIter = end(*here.singletons); + + while (iter != endIter) + { + if (!there.singletons->count(iter->first)) + { + auto eraseIt = iter; + ++iter; + here.singletons->erase(eraseIt); + } + else + ++iter; + } + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + { + if (there.isString()) + here.resetToString(); + else if (here.isUnion()) + here.singletons->insert(there.singletons->begin(), there.singletons->end()); + } } std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePackId there) @@ -858,7 +1001,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else if (ptv->type == PrimitiveTypeVar::Number) here.numbers = there; else if (ptv->type == PrimitiveTypeVar::String) - here.strings = std::nullopt; + here.strings.resetToString(); else if (ptv->type == PrimitiveTypeVar::Thread) here.threads = there; else @@ -870,12 +1013,33 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.booleans = unionOfBools(here.booleans, there); else if (const StringSingleton* sstv = get(stv)) { - if (here.strings) - here.strings->insert({sstv->value, there}); + if (FFlag::LuauNegatedStringSingletons) + { + if (here.strings.isCofinite) + { + auto it = here.strings.singletons->find(sstv->value); + if (it != here.strings.singletons->end()) + here.strings.singletons->erase(it); + } + else + here.strings.singletons->insert({sstv->value, there}); + } + else + { + if (here.strings.isUnion()) + here.strings.singletons->insert({sstv->value, there}); + } } else LUAU_ASSERT(!"Unreachable"); } + else if (const NegationTypeVar* ntv = get(there)) + { + const NormalizedType* thereNormal = normalize(ntv->ty); + NormalizedType tn = negateNormal(*thereNormal); + if (!unionNormals(here, tn)) + return false; + } else LUAU_ASSERT(!"Unreachable"); @@ -887,6 +1051,159 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor return true; } +// ------- Negations + +NormalizedType Normalizer::negateNormal(const NormalizedType& here) +{ + NormalizedType result{singletonTypes}; + if (!get(here.tops)) + { + // The negation of unknown or any is never. Easy. + return result; + } + + if (!get(here.errors)) + { + // Negating an error yields the same error. + result.errors = here.errors; + return result; + } + + if (get(here.booleans)) + result.booleans = singletonTypes->booleanType; + else if (get(here.booleans)) + result.booleans = singletonTypes->neverType; + else if (auto stv = get(here.booleans)) + { + auto boolean = get(stv); + LUAU_ASSERT(boolean != nullptr); + if (boolean->value) + result.booleans = singletonTypes->falseType; + else + result.booleans = singletonTypes->trueType; + } + + result.classes = negateAll(here.classes); + result.nils = get(here.nils) ? singletonTypes->nilType : singletonTypes->neverType; + result.numbers = get(here.numbers) ? singletonTypes->numberType : singletonTypes->neverType; + + result.strings = here.strings; + result.strings.isCofinite = !result.strings.isCofinite; + + result.threads = get(here.threads) ? singletonTypes->threadType : singletonTypes->neverType; + + // TODO: negating tables + // TODO: negating functions + // TODO: negating tyvars? + + return result; +} + +TypeIds Normalizer::negateAll(const TypeIds& theres) +{ + TypeIds tys; + for (TypeId there : theres) + tys.insert(negate(there)); + return tys; +} + +TypeId Normalizer::negate(TypeId there) +{ + there = follow(there); + if (get(there)) + return there; + else if (get(there)) + return singletonTypes->neverType; + else if (get(there)) + return singletonTypes->unknownType; + else if (auto ntv = get(there)) + return ntv->ty; // TODO: do we want to normalize this? + else if (auto utv = get(there)) + { + std::vector parts; + for (TypeId option : utv) + parts.push_back(negate(option)); + return arena->addType(IntersectionTypeVar{std::move(parts)}); + } + else if (auto itv = get(there)) + { + std::vector options; + for (TypeId part : itv) + options.push_back(negate(part)); + return arena->addType(UnionTypeVar{std::move(options)}); + } + else + return there; +} + +void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) +{ + const PrimitiveTypeVar* ptv = get(follow(ty)); + LUAU_ASSERT(ptv); + switch (ptv->type) + { + case PrimitiveTypeVar::NilType: + here.nils = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Boolean: + here.booleans = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Number: + here.numbers = singletonTypes->neverType; + break; + case PrimitiveTypeVar::String: + here.strings.resetToNever(); + break; + case PrimitiveTypeVar::Thread: + here.threads = singletonTypes->neverType; + break; + } +} + +void Normalizer::subtractSingleton(NormalizedType& here, TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauNegatedStringSingletons); + + const SingletonTypeVar* stv = get(ty); + LUAU_ASSERT(stv); + + if (const StringSingleton* ss = get(stv)) + { + if (here.strings.isCofinite) + here.strings.singletons->insert({ss->value, ty}); + else + { + auto it = here.strings.singletons->find(ss->value); + if (it != here.strings.singletons->end()) + here.strings.singletons->erase(it); + } + } + else if (const BooleanSingleton* bs = get(stv)) + { + if (get(here.booleans)) + { + // Nothing + } + else if (get(here.booleans)) + here.booleans = bs->value ? singletonTypes->falseType : singletonTypes->trueType; + else if (auto hereSingleton = get(here.booleans)) + { + const BooleanSingleton* hereBooleanSingleton = get(hereSingleton); + LUAU_ASSERT(hereBooleanSingleton); + + // Crucial subtlety: ty (and thus bs) are the value that is being + // negated out. We therefore reduce to never when the values match, + // rather than when they differ. + if (bs->value == hereBooleanSingleton->value) + here.booleans = singletonTypes->neverType; + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + LUAU_ASSERT(!"Unreachable"); +} + // ------- Normalizing intersections TypeId Normalizer::intersectionOfTops(TypeId here, TypeId there) { @@ -971,17 +1288,17 @@ void Normalizer::intersectClassesWithClass(TypeIds& heres, TypeId there) void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) { - if (!there) + if (there.isString()) return; - if (!here) - here.emplace(); + if (here.isString()) + here.resetToNever(); - for (auto it = here->begin(); it != here->end();) + for (auto it = here.singletons->begin(); it != here.singletons->end();) { - if (there->count(it->first)) + if (there.singletons->count(it->first)) it++; else - it = here->erase(it); + it = here.singletons->erase(it); } } @@ -1646,12 +1963,35 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) here.booleans = intersectionOfBools(booleans, there); else if (const StringSingleton* sstv = get(stv)) { - if (!strings || strings->count(sstv->value)) - here.strings->insert({sstv->value, there}); + if (strings.includes(sstv->value)) + here.strings.singletons->insert({sstv->value, there}); } else LUAU_ASSERT(!"Unreachable"); } + else if (const NegationTypeVar* ntv = get(there); FFlag::LuauNegatedStringSingletons && ntv) + { + TypeId t = follow(ntv->ty); + if (const PrimitiveTypeVar* ptv = get(t)) + subtractPrimitive(here, ntv->ty); + else if (const SingletonTypeVar* stv = get(t)) + subtractSingleton(here, follow(ntv->ty)); + else if (const UnionTypeVar* itv = get(t)) + { + for (TypeId part : itv->options) + { + const NormalizedType* normalPart = normalize(part); + NormalizedType negated = negateNormal(*normalPart); + intersectNormals(here, negated); + } + } + else + { + // TODO negated unions, intersections, table, and function. + // Report a TypeError for other types. + LUAU_ASSERT(!"Unimplemented"); + } + } else LUAU_ASSERT(!"Unreachable"); @@ -1691,11 +2031,25 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) result.push_back(norm.nils); if (!get(norm.numbers)) result.push_back(norm.numbers); - if (norm.strings) - for (auto& [_, ty] : *norm.strings) - result.push_back(ty); - else + if (norm.strings.isString()) result.push_back(singletonTypes->stringType); + else if (norm.strings.isUnion()) + { + for (auto& [_, ty] : *norm.strings.singletons) + result.push_back(ty); + } + else if (FFlag::LuauNegatedStringSingletons && norm.strings.isIntersection()) + { + std::vector parts; + parts.push_back(singletonTypes->stringType); + for (const auto& [name, ty] : *norm.strings.singletons) + parts.push_back(arena->addType(NegationTypeVar{ty})); + + result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); + } + if (!get(norm.threads)) + result.push_back(singletonTypes->threadType); + result.insert(result.end(), norm.tables.begin(), norm.tables.end()); for (auto& [tyvar, intersect] : norm.tyvars) { diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5897ca21..44000647 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(LuauLvaluelessPath) LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) @@ -434,7 +435,7 @@ struct TypeVarStringifier return; default: LUAU_ASSERT(!"Unknown primitive type"); - throw std::runtime_error("Unknown primitive type " + std::to_string(ptv.type)); + throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type)); } } @@ -451,7 +452,7 @@ struct TypeVarStringifier else { LUAU_ASSERT(!"Unknown singleton type"); - throw std::runtime_error("Unknown singleton type"); + throwRuntimeError("Unknown singleton type"); } } @@ -1538,6 +1539,8 @@ std::string dump(const Constraint& c) std::string toString(const LValue& lvalue) { + LUAU_ASSERT(!FFlag::LuauLvaluelessPath); + std::string s; for (const LValue* current = &lvalue; current; current = baseof(*current)) { @@ -1552,4 +1555,37 @@ std::string toString(const LValue& lvalue) return s; } +std::optional getFunctionNameAsString(const AstExpr& expr) +{ + LUAU_ASSERT(FFlag::LuauLvaluelessPath); + + const AstExpr* curr = &expr; + std::string s; + + for (;;) + { + if (auto local = curr->as()) + return local->local->name.value + s; + + if (auto global = curr->as()) + return global->name.value + s; + + if (auto indexname = curr->as()) + { + curr = indexname->expr; + + s = "." + std::string(indexname->index.value) + s; + } + else if (auto group = curr->as()) + { + curr = group->expr; + } + else + { + return std::nullopt; + } + } + + return s; +} } // namespace Luau diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 1ea2e27d..052c10de 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TopoSortStatements.h" +#include "Luau/Error.h" /* Decide the order in which we typecheck Lua statements in a block. * * Algorithm: @@ -149,7 +150,7 @@ Identifier mkName(const AstStatFunction& function) auto name = mkName(*function.name); LUAU_ASSERT(bool(name)); if (!name) - throw std::runtime_error("Internal error: Function declaration has a bad name"); + throwRuntimeError("Internal error: Function declaration has a bad name"); return *name; } @@ -255,7 +256,7 @@ struct ArcCollector : public AstVisitor { auto name = mkName(*node->name); if (!name) - throw std::runtime_error("Internal error: AstStatFunction has a bad name"); + throwRuntimeError("Internal error: AstStatFunction has a bad name"); add(*name); return true; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f2613cae..179846d7 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -347,7 +347,7 @@ public: AstType* operator()(const NegationTypeVar& ntv) { // FIXME: do the same thing we do with ErrorTypeVar - throw std::runtime_error("Cannot convert NegationTypeVar into AstNode"); + throwRuntimeError("Cannot convert NegationTypeVar into AstNode"); } private: diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index bd220e9c..a2673158 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -934,8 +934,62 @@ struct TypeChecker2 void visit(AstExprUnary* expr) { - // TODO! visit(expr->expr); + + NotNull scope = stack.back(); + TypeId operandType = lookupType(expr->expr); + + if (get(operandType) || get(operandType) || get(operandType)) + return; + + if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end()) + { + std::optional mm = findMetatableEntry(singletonTypes, module->errors, operandType, it->second, expr->location); + if (mm) + { + if (const FunctionTypeVar* ftv = get(follow(*mm))) + { + TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); + reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); + + if (std::optional ret = first(ftv->retTypes)) + { + if (expr->op == AstExprUnary::Op::Len) + { + reportErrors(tryUnify(scope, expr->location, follow(*ret), singletonTypes->numberType)); + } + } + else + { + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + } + } + + return; + } + } + + if (expr->op == AstExprUnary::Op::Len) + { + DenseHashSet seen{nullptr}; + int recursionCount = 0; + + if (!hasLength(operandType, seen, &recursionCount)) + { + reportError(NotATable{operandType}, expr->location); + } + } + else if (expr->op == AstExprUnary::Op::Minus) + { + reportErrors(tryUnify(scope, expr->location, operandType, singletonTypes->numberType)); + } + else if (expr->op == AstExprUnary::Op::Not) + { + } + else + { + LUAU_ASSERT(!"Unhandled unary operator"); + } } void visit(AstExprBinary* expr) @@ -1240,9 +1294,8 @@ struct TypeChecker2 Scope* scope = findInnermostScope(ty->location); LUAU_ASSERT(scope); - // TODO: Imported types - - std::optional alias = scope->lookupType(ty->name.value); + std::optional alias = + (ty->prefix) ? scope->lookupImportedType(ty->prefix->value, ty->name.value) : scope->lookupType(ty->name.value); if (alias.has_value()) { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index d5c6b2c4..ccb1490a 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) +LUAU_FASTFLAGVARIABLE(LuauLvaluelessPath, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) @@ -43,15 +44,15 @@ LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) -LUAU_FASTFLAGVARIABLE(LuauUnionOfTypesFollow, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) +LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) namespace Luau { - -const char* TimeLimitError::what() const throw() +const char* TimeLimitError_DEPRECATED::what() const throw() { + LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange); return "Typeinfer failed to complete in allotted time"; } @@ -264,6 +265,11 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona reportErrorCodeTooComplex(module.root->location); return std::move(currentModule); } + catch (const RecursionLimitException_DEPRECATED&) + { + reportErrorCodeTooComplex(module.root->location); + return std::move(currentModule); + } } ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -308,6 +314,10 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo { currentModule->timeout = true; } + catch (const TimeLimitError_DEPRECATED&) + { + currentModule->timeout = true; + } if (FFlag::DebugLuauSharedSelf) { @@ -415,7 +425,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) ice("Unknown AstStat"); if (finishTime && TimeTrace::getClock() > *finishTime) - throw TimeLimitError(); + throwTimeLimitError(); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. @@ -442,6 +452,11 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } + catch (const RecursionLimitException_DEPRECATED&) + { + reportErrorCodeTooComplex(block.location); + return; + } } struct InplaceDemoter : TypeVarOnceVisitor @@ -2456,11 +2471,8 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op) TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes) { - if (FFlag::LuauUnionOfTypesFollow) - { - a = follow(a); - b = follow(b); - } + a = follow(a); + b = follow(b); if (unifyFreeTypes && (get(a) || get(b))) { @@ -3596,8 +3608,17 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam location = {state.location.begin, argLocations.back().end}; std::string namePath; - if (std::optional lValue = tryGetLValue(funName)) - namePath = toString(*lValue); + + if (FFlag::LuauLvaluelessPath) + { + if (std::optional path = getFunctionNameAsString(funName)) + namePath = *path; + } + else + { + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); + } auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); state.reportError(TypeError{location, @@ -3706,11 +3727,28 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam bool isVariadic = tail && Luau::isVariadic(*tail); std::string namePath; - if (std::optional lValue = tryGetLValue(funName)) - namePath = toString(*lValue); - state.reportError(TypeError{ - state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + if (FFlag::LuauLvaluelessPath) + { + if (std::optional path = getFunctionNameAsString(funName)) + namePath = *path; + } + else + { + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); + } + + if (FFlag::LuauArgMismatchReportFunctionLocation) + { + state.reportError(TypeError{ + funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + } + else + { + state.reportError(TypeError{ + state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + } return; } ++paramIter; @@ -4647,6 +4685,19 @@ void TypeChecker::ice(const std::string& message) iceHandler->ice(message); } +// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted. +void TypeChecker::throwTimeLimitError() +{ + if (FFlag::LuauIceExceptionInheritanceChange) + { + throw TimeLimitError(iceHandler->moduleName); + } + else + { + throw TimeLimitError_DEPRECATED(); + } +} + void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) { // Remove errors with names that were generated by recovery from a parse error diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 0fa4df60..0852f053 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypePack.h" +#include "Luau/Error.h" #include "Luau/TxnLog.h" #include @@ -234,7 +235,7 @@ TypePackId follow(TypePackId tp, std::function mapper) cycleTester = nullptr; if (tp == cycleTester) - throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); + throwRuntimeError("Luau::follow detected a TypeVar cycle!!"); } } } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 19d3d266..94d633c7 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -61,7 +61,7 @@ TypeId follow(TypeId t, std::function mapper) { std::optional ty = utv->scope->lookup(utv->def); if (!ty) - throw std::runtime_error("UseTypeVar must map to another TypeId"); + throwRuntimeError("UseTypeVar must map to another TypeId"); return *ty; } else @@ -73,7 +73,7 @@ TypeId follow(TypeId t, std::function mapper) { TypeId res = ltv->thunk(); if (get(res)) - throw std::runtime_error("Lazy TypeVar cannot resolve to another Lazy TypeVar"); + throwRuntimeError("Lazy TypeVar cannot resolve to another Lazy TypeVar"); *asMutable(ty) = BoundTypeVar(res); } @@ -111,7 +111,7 @@ TypeId follow(TypeId t, std::function mapper) cycleTester = nullptr; if (t == cycleTester) - throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); + throwRuntimeError("Luau::follow detected a TypeVar cycle!!"); } } } @@ -946,7 +946,7 @@ void persist(TypeId ty) queue.push_back(mtv->table); queue.push_back(mtv->metatable); } - else if (get(t) || get(t) || get(t) || get(t) || get(t)) + else if (get(t) || get(t) || get(t) || get(t) || get(t) || get(t)) { } else diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index e23e6161..b5eba980 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -16,6 +16,7 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauReportTypeMismatchForTypePackUnificationFailure, false) LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) @@ -273,7 +274,7 @@ TypeId Widen::clean(TypeId ty) TypePackId Widen::clean(TypePackId) { - throw std::runtime_error("Widen attempted to clean a dirty type pack?"); + throwRuntimeError("Widen attempted to clean a dirty type pack?"); } bool Widen::ignoreChildren(TypeId ty) @@ -551,6 +552,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(subTy)) tryUnifyWithClass(subTy, superTy, /*reversed*/ true); + else if (log.get(superTy)) + tryUnifyTypeWithNegation(subTy, superTy); + + else if (log.get(subTy)) + tryUnifyNegationWithType(subTy, superTy); + else reportError(TypeError{location, TypeMismatch{superTy, subTy}}); @@ -866,13 +873,7 @@ void Unifier::tryUnifyNormalizedTypes( if (!get(superNorm.numbers)) return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); - if (subNorm.strings && superNorm.strings) - { - for (auto [name, ty] : *subNorm.strings) - if (!superNorm.strings->count(name)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); - } - else if (!subNorm.strings && superNorm.strings) + if (!isSubtype(subNorm.strings, superNorm.strings)) return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); if (get(subNorm.threads)) @@ -1392,7 +1393,10 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else { - reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); + if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure) + reportError(TypeError{location, TypePackMismatch{subTp, superTp}}); + else + reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); } } @@ -1441,7 +1445,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0); - if (FFlag::LuauInstantiateInSubtyping && variance == Covariant && shouldInstantiate) + // TODO: This is unsound when the context is invariant, but the annotation burden without allowing it and without + // read-only properties is too high for lua-apps. Read-only properties _should_ resolve their issue by allowing + // generic methods in tables to be marked read-only. + if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate) { Instantiation instantiation{&log, types, scope->level, scope}; @@ -1576,6 +1583,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { TableTypeVar* superTable = log.getMutable(superTy); TableTypeVar* subTable = log.getMutable(subTy); + TableTypeVar* instantiatedSubTable = subTable; if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1593,6 +1601,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (instantiated.has_value()) { subTable = log.getMutable(*instantiated); + instantiatedSubTable = subTable; if (!subTable) ice("instantiation made a table type into a non-table type in tryUnifyTables"); @@ -1696,7 +1705,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // txn log. TableTypeVar* newSuperTable = log.getMutable(superTy); TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) return tryUnifyTables(subTy, superTy, isIntersection); @@ -1758,7 +1767,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // txn log. TableTypeVar* newSuperTable = log.getMutable(superTy); TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) return tryUnifyTables(subTy, superTy, isIntersection); @@ -2098,6 +2107,34 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) return fail(); } +void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy) +{ + const NegationTypeVar* ntv = get(superTy); + if (!ntv) + ice("tryUnifyTypeWithNegation superTy must be a negation type"); + + const NormalizedType* subNorm = normalizer->normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + return reportError(TypeError{location, UnificationTooComplex{}}); + + // T (subTy); + if (!ntv) + ice("tryUnifyNegationWithType subTy must be a negation type"); + + // TODO: ~T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { while (true) diff --git a/Ast/include/Luau/ParseResult.h b/Ast/include/Luau/ParseResult.h index 17ce2e3b..9c0a9527 100644 --- a/Ast/include/Luau/ParseResult.h +++ b/Ast/include/Luau/ParseResult.h @@ -58,6 +58,8 @@ struct Comment struct ParseResult { AstStatBlock* root; + size_t lines = 0; + std::vector hotcomments; std::vector errors; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 848d7117..8b7eb73c 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -302,8 +302,8 @@ private: AstStatError* reportStatError(const Location& location, const AstArray& expressions, const AstArray& statements, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); - AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) - LUAU_PRINTF_ATTR(5, 6); + AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) + LUAU_PRINTF_ATTR(4, 5); // `parseErrorLocation` is associated with the parser error // `astErrorLocation` is associated with the AstTypeError created // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 7150b18f..85c5f5c6 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -23,7 +23,6 @@ LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) -LUAU_FASTFLAGVARIABLE(LuauTypeAnnotationLocationChange, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) @@ -164,15 +163,16 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n try { AstStatBlock* root = p.parseChunk(); + size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n'); - return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; + return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; } catch (ParseError& err) { // when catching a fatal error, append it to the list of non-fatal errors and return p.parseErrors.push_back(err); - return ParseResult{nullptr, {}, p.parseErrors}; + return ParseResult{nullptr, 0, {}, p.parseErrors}; } } @@ -811,9 +811,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { - return AstDeclaredClassProp{fnName.name, - reportTypeAnnotationError(Location(start, end), {}, /*isMissing*/ false, "'self' must be present as the unannotated first parameter"), - true}; + return AstDeclaredClassProp{ + fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; } // Skip the first index. @@ -824,8 +823,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args[i].annotation) vars.push_back(args[i].annotation); else - vars.push_back(reportTypeAnnotationError( - Location(start, end), {}, /*isMissing*/ false, "All declaration parameters aside from 'self' must be annotated")); + vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); } if (vararg && !varargAnnotation) @@ -1537,7 +1535,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (isUnion && isIntersection) { - return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), /*isMissing*/ false, + return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } @@ -1623,18 +1621,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) return {allocator.alloc(start, svalue)}; } else - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; + return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")}; } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { parseInterpString(); - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Interpolated string literals cannot be used as types")}; + return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")}; } else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Malformed string")}; + return {reportTypeAnnotationError(start, {}, "Malformed string")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1693,33 +1691,20 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, + return {reportTypeAnnotationError(start, {}, "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " "...any'"), {}}; } else { - if (FFlag::LuauTypeAnnotationLocationChange) - { - // For a missing type annotation, capture 'space' between last token and the next one - Location astErrorlocation(lexer.previousLocation().end, start.begin); - // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). - // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. - Location parseErrorLocation(lexer.previousLocation().end, start.end); - return { - reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), - {}}; - } - else - { - Location location = lexer.current().location; - - // For a missing type annotation, capture 'space' between last token and the next one - location = Location(lexer.previousLocation().end, lexer.current().location.begin); - - return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; - } + // For a missing type annotation, capture 'space' between last token and the next one + Location astErrorlocation(lexer.previousLocation().end, start.begin); + // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). + // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. + Location parseErrorLocation(lexer.previousLocation().end, start.end); + return { + reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } @@ -3033,27 +3018,18 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray(location, expressions, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) +AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) { - if (FFlag::LuauTypeAnnotationLocationChange) - { - // Missing type annotations should be using `reportMissingTypeAnnotationError` when LuauTypeAnnotationLocationChange is enabled - // Note: `isMissing` can be removed once FFlag::LuauTypeAnnotationLocationChange is removed since it will always be true. - LUAU_ASSERT(!isMissing); - } - va_list args; va_start(args, format); report(location, format, args); va_end(args); - return allocator.alloc(location, types, isMissing, unsigned(parseErrors.size() - 1)); + return allocator.alloc(location, types, false, unsigned(parseErrors.size() - 1)); } AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) { - LUAU_ASSERT(FFlag::LuauTypeAnnotationLocationChange); - va_list args; va_start(args, format); report(parseErrorLocation, format, args); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index a9dd8970..87e19db8 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -16,7 +16,6 @@ #include "isocline.h" -#include #include #ifdef _WIN32 @@ -688,11 +687,11 @@ static std::string getCodegenAssembly(const char* name, const std::string& bytec return ""; } -static void annotateInstruction(void* context, std::string& text, int fid, int instid) +static void annotateInstruction(void* context, std::string& text, int fid, int instpos) { Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; - bcb.annotateInstruction(text, fid, instid); + bcb.annotateInstruction(text, fid, instpos); } struct CompileStats @@ -711,7 +710,8 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st return false; } - stats.lines += std::count(source->begin(), source->end(), '\n'); + // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) + // This function is much more complicated because it supports many output human-readable formats through internal interfaces try { @@ -736,7 +736,16 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st bcb.setDumpSource(*source); } - Luau::compileOrThrow(bcb, *source, copts()); + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + stats.lines += result.lines; + + Luau::compileOrThrow(bcb, result, names, copts()); stats.bytecode += bcb.getBytecode().size(); switch (format) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0016160a..05d701ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,6 +143,11 @@ if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) endif() +if (NOT MSVC) + # disable support for math_errno which allows compilers to lower sqrt() into a single CPU instruction + target_compile_options(Luau.VM PRIVATE -fno-math-errno) +endif() + if(MSVC AND LUAU_BUILD_CLI) # the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index e8b30195..cef9ec7c 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -17,7 +17,7 @@ void create(lua_State* L); // Builds target function and all inner functions void compile(lua_State* L, int idx); -using annotatorFn = void (*)(void* context, std::string& result, int fid, int instid); +using annotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); struct AssemblyOptions { diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index f78ead59..78645766 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -34,7 +34,350 @@ namespace CodeGen constexpr uint32_t kFunctionAlignment = 32; -static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, Proto* proto, AssemblyOptions options) +struct InstructionOutline +{ + int pcpos; + int length; +}; + +static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) +{ + if (build.logText) + build.logAppend("; exitContinueVm\n"); + helpers.exitContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ true); + + if (build.logText) + build.logAppend("; exitNoContinueVm\n"); + helpers.exitNoContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ false); +} + +static int emitInst( + AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i, Label* labelarr, Label& fallback) +{ + int skip = 0; + + switch (op) + { + case LOP_NOP: + break; + case LOP_LOADNIL: + emitInstLoadNil(build, pc); + break; + case LOP_LOADB: + emitInstLoadB(build, pc, i, labelarr); + break; + case LOP_LOADN: + emitInstLoadN(build, pc); + break; + case LOP_LOADK: + emitInstLoadK(build, pc); + break; + case LOP_LOADKX: + emitInstLoadKX(build, pc); + break; + case LOP_MOVE: + emitInstMove(build, pc); + break; + case LOP_GETGLOBAL: + emitInstGetGlobal(build, pc, i, fallback); + break; + case LOP_SETGLOBAL: + emitInstSetGlobal(build, pc, i, labelarr, fallback); + break; + case LOP_RETURN: + emitInstReturn(build, helpers, pc, i, labelarr); + break; + case LOP_GETTABLE: + emitInstGetTable(build, pc, i, fallback); + break; + case LOP_SETTABLE: + emitInstSetTable(build, pc, i, labelarr, fallback); + break; + case LOP_GETTABLEKS: + emitInstGetTableKS(build, pc, i, fallback); + break; + case LOP_SETTABLEKS: + emitInstSetTableKS(build, pc, i, labelarr, fallback); + break; + case LOP_GETTABLEN: + emitInstGetTableN(build, pc, i, fallback); + break; + case LOP_SETTABLEN: + emitInstSetTableN(build, pc, i, labelarr, fallback); + break; + case LOP_JUMP: + emitInstJump(build, pc, i, labelarr); + break; + case LOP_JUMPBACK: + emitInstJumpBack(build, pc, i, labelarr); + break; + case LOP_JUMPIF: + emitInstJumpIf(build, pc, i, labelarr, /* not_ */ false); + break; + case LOP_JUMPIFNOT: + emitInstJumpIf(build, pc, i, labelarr, /* not_ */ true); + break; + case LOP_JUMPIFEQ: + emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback); + break; + case LOP_JUMPIFLE: + emitInstJumpIfCond(build, pc, i, labelarr, Condition::LessEqual, fallback); + break; + case LOP_JUMPIFLT: + emitInstJumpIfCond(build, pc, i, labelarr, Condition::Less, fallback); + break; + case LOP_JUMPIFNOTEQ: + emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback); + break; + case LOP_JUMPIFNOTLE: + emitInstJumpIfCond(build, pc, i, labelarr, Condition::NotLessEqual, fallback); + break; + case LOP_JUMPIFNOTLT: + emitInstJumpIfCond(build, pc, i, labelarr, Condition::NotLess, fallback); + break; + case LOP_JUMPX: + emitInstJumpX(build, pc, i, labelarr); + break; + case LOP_JUMPXEQKNIL: + emitInstJumpxEqNil(build, pc, i, labelarr); + break; + case LOP_JUMPXEQKB: + emitInstJumpxEqB(build, pc, i, labelarr); + break; + case LOP_JUMPXEQKN: + emitInstJumpxEqN(build, pc, proto->k, i, labelarr); + break; + case LOP_JUMPXEQKS: + emitInstJumpxEqS(build, pc, i, labelarr); + break; + case LOP_ADD: + emitInstBinary(build, pc, i, TM_ADD, fallback); + break; + case LOP_SUB: + emitInstBinary(build, pc, i, TM_SUB, fallback); + break; + case LOP_MUL: + emitInstBinary(build, pc, i, TM_MUL, fallback); + break; + case LOP_DIV: + emitInstBinary(build, pc, i, TM_DIV, fallback); + break; + case LOP_MOD: + emitInstBinary(build, pc, i, TM_MOD, fallback); + break; + case LOP_POW: + emitInstBinary(build, pc, i, TM_POW, fallback); + break; + case LOP_ADDK: + emitInstBinaryK(build, pc, i, TM_ADD, fallback); + break; + case LOP_SUBK: + emitInstBinaryK(build, pc, i, TM_SUB, fallback); + break; + case LOP_MULK: + emitInstBinaryK(build, pc, i, TM_MUL, fallback); + break; + case LOP_DIVK: + emitInstBinaryK(build, pc, i, TM_DIV, fallback); + break; + case LOP_MODK: + emitInstBinaryK(build, pc, i, TM_MOD, fallback); + break; + case LOP_POWK: + emitInstPowK(build, pc, proto->k, i, fallback); + break; + case LOP_NOT: + emitInstNot(build, pc); + break; + case LOP_MINUS: + emitInstMinus(build, pc, i, fallback); + break; + case LOP_LENGTH: + emitInstLength(build, pc, i, fallback); + break; + case LOP_NEWTABLE: + emitInstNewTable(build, pc, i, labelarr); + break; + case LOP_DUPTABLE: + emitInstDupTable(build, pc, i, labelarr); + break; + case LOP_SETLIST: + emitInstSetList(build, pc, i, labelarr); + break; + case LOP_GETUPVAL: + emitInstGetUpval(build, pc, i); + break; + case LOP_SETUPVAL: + emitInstSetUpval(build, pc, i, labelarr); + break; + case LOP_CLOSEUPVALS: + emitInstCloseUpvals(build, pc, i, labelarr); + break; + case LOP_FASTCALL: + skip = emitInstFastCall(build, pc, i, labelarr); + break; + case LOP_FASTCALL1: + skip = emitInstFastCall1(build, pc, i, labelarr); + break; + case LOP_FASTCALL2: + skip = emitInstFastCall2(build, pc, i, labelarr); + break; + case LOP_FASTCALL2K: + skip = emitInstFastCall2K(build, pc, i, labelarr); + break; + case LOP_FORNPREP: + emitInstForNPrep(build, pc, i, labelarr); + break; + case LOP_FORNLOOP: + emitInstForNLoop(build, pc, i, labelarr); + break; + case LOP_FORGLOOP: + emitinstForGLoop(build, pc, i, labelarr, fallback); + break; + case LOP_FORGPREP_NEXT: + emitInstForGPrepNext(build, pc, i, labelarr, fallback); + break; + case LOP_FORGPREP_INEXT: + emitInstForGPrepInext(build, pc, i, labelarr, fallback); + break; + case LOP_AND: + emitInstAnd(build, pc); + break; + case LOP_ANDK: + emitInstAndK(build, pc); + break; + case LOP_OR: + emitInstOr(build, pc); + break; + case LOP_ORK: + emitInstOrK(build, pc); + break; + case LOP_GETIMPORT: + emitInstGetImport(build, pc, fallback); + break; + case LOP_CONCAT: + emitInstConcat(build, pc, i, labelarr); + break; + default: + emitFallback(build, data, op, i); + break; + } + + return skip; +} + +static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauOpcode op, const Instruction* pc, int i, Label* labelarr) +{ + switch (op) + { + case LOP_GETIMPORT: + emitInstGetImportFallback(build, pc, i); + break; + case LOP_GETTABLE: + emitInstGetTableFallback(build, pc, i); + break; + case LOP_SETTABLE: + emitInstSetTableFallback(build, pc, i); + break; + case LOP_GETTABLEN: + emitInstGetTableNFallback(build, pc, i); + break; + case LOP_SETTABLEN: + emitInstSetTableNFallback(build, pc, i); + break; + case LOP_JUMPIFEQ: + emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false); + break; + case LOP_JUMPIFLE: + emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::LessEqual); + break; + case LOP_JUMPIFLT: + emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::Less); + break; + case LOP_JUMPIFNOTEQ: + emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true); + break; + case LOP_JUMPIFNOTLE: + emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::NotLessEqual); + break; + case LOP_JUMPIFNOTLT: + emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::NotLess); + break; + case LOP_ADD: + emitInstBinaryFallback(build, pc, i, TM_ADD); + break; + case LOP_SUB: + emitInstBinaryFallback(build, pc, i, TM_SUB); + break; + case LOP_MUL: + emitInstBinaryFallback(build, pc, i, TM_MUL); + break; + case LOP_DIV: + emitInstBinaryFallback(build, pc, i, TM_DIV); + break; + case LOP_MOD: + emitInstBinaryFallback(build, pc, i, TM_MOD); + break; + case LOP_POW: + emitInstBinaryFallback(build, pc, i, TM_POW); + break; + case LOP_ADDK: + emitInstBinaryKFallback(build, pc, i, TM_ADD); + break; + case LOP_SUBK: + emitInstBinaryKFallback(build, pc, i, TM_SUB); + break; + case LOP_MULK: + emitInstBinaryKFallback(build, pc, i, TM_MUL); + break; + case LOP_DIVK: + emitInstBinaryKFallback(build, pc, i, TM_DIV); + break; + case LOP_MODK: + emitInstBinaryKFallback(build, pc, i, TM_MOD); + break; + case LOP_POWK: + emitInstBinaryKFallback(build, pc, i, TM_POW); + break; + case LOP_MINUS: + emitInstMinusFallback(build, pc, i); + break; + case LOP_LENGTH: + emitInstLengthFallback(build, pc, i); + break; + case LOP_FORGLOOP: + emitinstForGLoopFallback(build, pc, i, labelarr); + break; + case LOP_FORGPREP_NEXT: + case LOP_FORGPREP_INEXT: + emitInstForGPrepXnextFallback(build, pc, i, labelarr); + break; + case LOP_GETGLOBAL: + // TODO: luaV_gettable + cachedslot update instead of full fallback + emitFallback(build, data, op, i); + break; + case LOP_SETGLOBAL: + // TODO: luaV_settable + cachedslot update instead of full fallback + emitFallback(build, data, op, i); + break; + case LOP_GETTABLEKS: + // Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access + // It is also required to perform cached slot update + // TODO: extra fast-paths could be lowered before the full fallback + emitFallback(build, data, op, i); + break; + case LOP_SETTABLEKS: + // TODO: luaV_settable + cachedslot update instead of full fallback + emitFallback(build, data, op, i); + break; + default: + LUAU_ASSERT(!"Expected fallback for instruction"); + } +} + +static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { NativeProto* result = new NativeProto(); @@ -59,222 +402,65 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat std::vector(a) -> ()' could not be converted into 't1 where t1 = ({- Clone: t1 -}) -> (a...)'; different number of generic type parameters)", - toString(result.errors[0])); - } - else - { - LUAU_REQUIRE_NO_ERRORS(result); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp new file mode 100644 index 00000000..1035eda4 --- /dev/null +++ b/tests/TypeInfer.negations.test.cpp @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/Common.h" +#include "ScopedFlags.h" + +using namespace Luau; + +namespace +{ +struct NegationFixture : Fixture +{ + TypeArena arena; + ScopedFastFlag sff[2] { + {"LuauNegatedStringSingletons", true}, + {"LuauSubtypeNormalizer", true}, + }; + + NegationFixture() + { + registerNotType(*this, arena); + } +}; +} + +TEST_SUITE_BEGIN("Negations"); + +TEST_CASE_FIXTURE(NegationFixture, "negated_string_is_a_subtype_of_string") +{ + CheckResult result = check(R"( + function foo(arg: string) end + local a: string & Not<"Hello"> + foo(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NegationFixture, "string_is_not_a_subtype_of_negated_string") +{ + CheckResult result = check(R"( + function foo(arg: string & Not<"hello">) end + local a: string + foo(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index e572c87a..b2516f6d 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -434,16 +434,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") { CheckResult result = check(R"( --!strict - local foo = { - value = 10 - } + local foo local mt = {} - setmetatable(foo, mt) mt.__unm = function(val: typeof(foo)): string - return val.value .. "test" + return tostring(val.value) .. "test" end + foo = setmetatable({ + value = 10 + }, mt) + local a = -foo local b = 1+-1 @@ -459,25 +460,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") CHECK_EQ("string", toString(requireType("a"))); CHECK_EQ("number", toString(requireType("b"))); - GenericError* gen = get(result.errors[0]); - REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK(toString(result.errors[0]) == "Type '{ value: number }' could not be converted into 'number'"); + } + else + { + GenericError* gen = get(result.errors[0]); + REQUIRE(gen); + REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") { CheckResult result = check(R"( --!strict - local foo = { - value = 10 - } - local mt = {} - setmetatable(foo, mt) mt.__unm = function(val: boolean): string return "test" end + local foo = setmetatable({ + value = 10 + }, mt) + local a = -foo )"); @@ -494,16 +502,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") { CheckResult result = check(R"( --!strict - local foo = { - value = 10 - } local mt = {} - setmetatable(foo, mt) - mt.__len = function(val: any): string + mt.__len = function(val): string return "test" end + local foo = setmetatable({ + value = 10, + }, mt) + local a = #foo )"); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index ccc4d775..8e04c165 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -624,15 +624,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") CHECK_EQ("{string | string}", toString(requireType("t"))); } -struct NormalizeFixture : Fixture +namespace +{ +struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); } }; +} -TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arities") +TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection_of_functions_of_different_arities") { check(R"( type A = (any) -> () @@ -653,7 +656,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arit CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t"))); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity") { check(R"( local a: (number) -> () @@ -676,7 +679,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") CHECK(!isSubtype(b, c)); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_optional_parameters") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity_but_optional_parameters") { /* * (T0..TN) <: (T0..TN, A?) @@ -736,7 +739,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_option // CHECK(!isSubtype(b, c)); } -TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_any_is_an_optional_param") +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity_but_any_is_an_optional_param") { check(R"( local a: (number?) -> () diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2a208cce..7de412ff 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1,5 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/Frontend.h" +#include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -14,6 +17,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauSpecialTypesAsterisked) TEST_SUITE_BEGIN("TableTests"); @@ -1957,7 +1961,11 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table local c : string = t.m("hi") )"); - LUAU_REQUIRE_ERRORS(result); + // TODO: test behavior is wrong with LuauInstantiateInSubtyping until we can re-enable the covariant requirement for instantiation in subtyping + if (FFlag::LuauInstantiateInSubtyping) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") @@ -3262,11 +3270,13 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table local c : string = t.m("hi") )"); - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' -caused by: - Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)"); - // this error message is not great since the underlying issue is that the context is invariant, + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: test behavior is wrong until we can re-enable the covariant requirement for instantiation in subtyping +// LUAU_REQUIRE_ERRORS(result); +// CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' +// caused by: +// Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)"); +// // this error message is not great since the underlying issue is that the context is invariant, // and `(number) -> number` cannot be a subtype of `(a) -> a`. } @@ -3292,4 +3302,43 @@ local g : ({ p : number, q : string }) -> ({ p : number, r : boolean }) = f CHECK_EQ("r", error->properties[0]); } +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_has_a_side_effect") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local mt = { + __add = function(x, y) + return 123 + end, + } + + local foo = {} + setmetatable(foo, mt) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("foo")) == "{ @metatable { __add: (a, b) -> number }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated") +{ + CheckResult result = check(R"( + local t = { + x = 5 :: NonexistingTypeWhichEndsUpReturningAnErrorType, + y = 5 + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + ToStringOptions opts; + opts.exhaustive = true; + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts)); + else + CHECK_EQ("{ x: , y: number }", toString(requireType("t"), opts)); +} + TEST_SUITE_END(); diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitTypeVar.test.cpp index 4fba694a..589c3bad 100644 --- a/tests/VisitTypeVar.test.cpp +++ b/tests/VisitTypeVar.test.cpp @@ -22,7 +22,14 @@ TEST_CASE_FIXTURE(Fixture, "throw_when_limit_is_exceeded") TypeId tType = requireType("t"); - CHECK_THROWS_AS(toString(tType), RecursionLimitException); + if (FFlag::LuauIceExceptionInheritanceChange) + { + CHECK_THROWS_AS(toString(tType), RecursionLimitException); + } + else + { + CHECK_THROWS_AS(toString(tType), RecursionLimitException_DEPRECATED); + } } TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") diff --git a/tools/faillist.txt b/tools/faillist.txt index c869e0c4..a4c05b7b 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,14 +1,14 @@ -AnnotationTests.builtin_types_are_not_exported AnnotationTests.corecursive_types_error_on_tight_loop AnnotationTests.duplicate_type_param_name AnnotationTests.for_loop_counter_annotation_is_checked AnnotationTests.generic_aliases_are_cloned_properly AnnotationTests.instantiation_clone_has_to_follow +AnnotationTests.luau_print_is_not_special_without_the_flag AnnotationTests.occurs_check_on_cyclic_intersection_typevar AnnotationTests.occurs_check_on_cyclic_union_typevar AnnotationTests.too_many_type_params AnnotationTests.two_type_params -AnnotationTests.use_type_required_from_another_file +AnnotationTests.unknown_type_reference_generates_error AstQuery.last_argument_function_call_type AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AutocompleteTest.autocomplete_first_function_arg_expected_type @@ -86,7 +86,6 @@ BuiltinTests.table_pack BuiltinTests.table_pack_reduce BuiltinTests.table_pack_variadic BuiltinTests.tonumber_returns_optional_number_type -BuiltinTests.tonumber_returns_optional_number_type2 DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props DefinitionTests.declaring_generic_functions @@ -96,7 +95,6 @@ FrontendTest.imported_table_modification_2 FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded FrontendTest.nocheck_cycle_used_by_checked FrontendTest.reexport_cyclic_type -FrontendTest.reexport_type_alias FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics2 @@ -105,7 +103,6 @@ GenericsTests.calling_self_generic_methods GenericsTests.check_generic_typepack_function GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions -GenericsTests.do_not_always_instantiate_generic_intersection_types GenericsTests.do_not_infer_generic_functions GenericsTests.duplicate_generic_type_packs GenericsTests.duplicate_generic_types @@ -143,7 +140,6 @@ IntersectionTypes.table_write_sealed_indirect ModuleTests.any_persistance_does_not_leak ModuleTests.clone_self_property ModuleTests.deepClone_cyclic_table -ModuleTests.do_not_clone_reexports NonstrictModeTests.for_in_iterator_variables_are_any NonstrictModeTests.function_parameters_are_any NonstrictModeTests.inconsistent_module_return_types_are_ok @@ -158,7 +154,6 @@ NonstrictModeTests.parameters_having_type_any_are_optional NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_props_are_any Normalize.cyclic_table_normalizes_sensibly -Normalize.intersection_combine_on_bound_self ParseErrorRecovery.generic_type_list_recovery ParseErrorRecovery.recovery_of_parenthesized_expressions ParserTests.parse_nesting_based_end_detection_failsafe_earlier @@ -249,7 +244,6 @@ TableTests.defining_a_self_method_for_a_builtin_sealed_table_must_fail TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index -TableTests.dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back TableTests.dont_leak_free_table_props TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys @@ -279,7 +273,6 @@ TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors -TableTests.length_operator_union_errors TableTests.less_exponential_blowup_please TableTests.meta_add TableTests.meta_add_both_ways @@ -347,9 +340,9 @@ TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails TryUnifyTests.variadics_should_use_reversed_properly +TypeAliases.cannot_create_cyclic_type_with_unknown_module TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any TypeAliases.generic_param_remap -TypeAliases.mismatched_generic_pack_type_param TypeAliases.mismatched_generic_type_param TypeAliases.mutually_recursive_types_restriction_not_ok_1 TypeAliases.mutually_recursive_types_restriction_not_ok_2 @@ -363,7 +356,7 @@ TypeAliases.type_alias_fwd_declaration_is_precise TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_of_an_imported_recursive_generic_type -TypeAliases.type_alias_of_an_imported_recursive_type +TypeInfer.check_type_infer_recursion_count TypeInfer.checking_should_not_ice TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError @@ -394,6 +387,7 @@ TypeInferClasses.warn_when_prop_almost_matches TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument +TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict @@ -439,12 +433,9 @@ TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferModules.bound_free_table_export_is_ok TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types -TypeInferModules.do_not_modify_imported_types_2 -TypeInferModules.do_not_modify_imported_types_3 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated TypeInferModules.require_a_variadic_function -TypeInferModules.require_types TypeInferModules.type_error_of_unknown_qualified_type TypeInferOOP.CheckMethodsOfSealed TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works @@ -468,9 +459,6 @@ TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with TypeInferOperators.refine_and_or TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs -TypeInferOperators.typecheck_unary_len_error -TypeInferOperators.typecheck_unary_minus -TypeInferOperators.typecheck_unary_minus_error TypeInferOperators.UnknownGlobalCompoundAssign TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.singleton_types @@ -489,6 +477,7 @@ TypeInferUnknownNever.math_operators_and_never TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.unary_minus_of_never +TypePackTests.detect_cyclic_typepacks2 TypePackTests.higher_order_function TypePackTests.pack_tail_unification_check TypePackTests.parenthesized_varargs_returns_any