From 4a95f2201ec1f879dfa3f1b8bc015e0a30df005a Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 7 Jul 2022 18:05:31 -0700 Subject: [PATCH] Sync to upstream/release/535 --- Analysis/include/Luau/AstQuery.h | 1 + Analysis/include/Luau/TypeInfer.h | 25 +- Analysis/include/Luau/TypePack.h | 1 + Analysis/include/Luau/TypeVar.h | 157 +- Analysis/include/Luau/VisitTypeVar.h | 8 + Analysis/src/AstQuery.cpp | 107 +- Analysis/src/Autocomplete.cpp | 201 +- Analysis/src/BuiltinDefinitions.cpp | 43 +- Analysis/src/Clone.cpp | 12 + Analysis/src/EmbeddedBuiltinDefinitions.cpp | 9 +- Analysis/src/Frontend.cpp | 2 + Analysis/src/Normalize.cpp | 38 +- Analysis/src/Substitution.cpp | 10 +- Analysis/src/ToString.cpp | 23 +- Analysis/src/Transpiler.cpp | 20 +- Analysis/src/TxnLog.cpp | 59 +- Analysis/src/TypeAttach.cpp | 8 + Analysis/src/TypeInfer.cpp | 420 +++- Analysis/src/TypePack.cpp | 51 +- Analysis/src/TypeUtils.cpp | 2 +- Analysis/src/TypeVar.cpp | 269 ++- Analysis/src/Unifier.cpp | 66 +- Ast/src/Parser.cpp | 17 +- CLI/Analyze.cpp | 18 +- CLI/Repl.cpp | 37 + CodeGen/include/Luau/AssemblyBuilderX64.h | 3 + CodeGen/src/AssemblyBuilderX64.cpp | 23 + Makefile | 4 + Sources.cmake | 8 +- VM/src/ltable.cpp | 8 +- VM/src/lvmexecute.cpp | 66 +- bench/bench.py | 37 +- bench/bench_support.lua | 10 + bench/other/LuauPolyfillMap.lua | 961 +++++++++ bench/other/regex.lua | 2089 +++++++++++++++++++ tests/AssemblyBuilderX64.test.cpp | 27 + tests/AstQuery.test.cpp | 33 + tests/Fixture.h | 1 + tests/Frontend.test.cpp | 2 + tests/Module.test.cpp | 2 - tests/Normalize.test.cpp | 42 +- tests/Parser.test.cpp | 1 - tests/ToString.test.cpp | 33 +- tests/Transpiler.test.cpp | 2 +- tests/TypeInfer.anyerror.test.cpp | 8 +- tests/TypeInfer.builtins.test.cpp | 143 +- tests/TypeInfer.functions.test.cpp | 54 +- tests/TypeInfer.generics.test.cpp | 39 +- tests/TypeInfer.loops.test.cpp | 2 +- tests/TypeInfer.modules.test.cpp | 45 +- tests/TypeInfer.operators.test.cpp | 22 + tests/TypeInfer.primitives.test.cpp | 2 +- tests/TypeInfer.provisional.test.cpp | 15 +- tests/TypeInfer.refinements.test.cpp | 45 +- tests/TypeInfer.tables.test.cpp | 14 + tests/TypeInfer.test.cpp | 33 +- tests/TypeInfer.tryUnify.test.cpp | 4 +- tests/TypeInfer.unionTypes.test.cpp | 2 +- tests/TypeInfer.unknownnever.test.cpp | 280 +++ tests/TypePack.test.cpp | 2 - tests/TypeVar.test.cpp | 2 - tests/conformance/vector.lua | 16 + tools/natvis/VM.natvis | 32 +- 63 files changed, 5097 insertions(+), 619 deletions(-) create mode 100644 bench/other/LuauPolyfillMap.lua create mode 100644 bench/other/regex.lua create mode 100644 tests/TypeInfer.unknownnever.test.cpp diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index dfe373a5..950a19da 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -63,6 +63,7 @@ private: AstLocal* local = nullptr; }; +std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos); std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos); AstNode* findNodeAtPosition(const SourceModule& source, Position pos); AstExpr* findExprAtPosition(const SourceModule& source, Position pos); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 455654d9..3fb710bb 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -153,7 +153,7 @@ struct TypeChecker const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); TypeId checkBinaryOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); - WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional expectedType = std::nullopt); WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); @@ -180,8 +180,12 @@ struct TypeChecker const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); - WithPredicate checkExprPack(const ScopePtr& scope, const AstExprCall& expr); + + WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr); + WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr); + std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); + std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); @@ -236,10 +240,11 @@ struct TypeChecker void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location); - std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); - std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); + std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); + std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); + std::optional getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); // Reduces the union to its simplest possible shape. // (A | B) | B | C yields A | B | C @@ -316,11 +321,12 @@ private: TypeIdPredicate mkTruthyPredicate(bool sense); - // Returns nullopt if the predicate filters down the TypeId to 0 options. - std::optional filterMap(TypeId type, TypeIdPredicate predicate); + // TODO: Return TypeId only. + std::optional filterMapImpl(TypeId type, TypeIdPredicate predicate); + std::pair, bool> filterMap(TypeId type, TypeIdPredicate predicate); public: - std::optional pickTypesFromSense(TypeId type, bool sense); + std::pair, bool> pickTypesFromSense(TypeId type, bool sense); private: TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true); @@ -345,6 +351,7 @@ private: TypePackId freshTypePack(TypeLevel level); TypeId resolveType(const ScopePtr& scope, const AstType& annotation); + TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, @@ -412,8 +419,12 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; + const TypeId unknownType; + const TypeId neverType; const TypePackId anyTypePack; + const TypePackId neverTypePack; + const TypePackId uninhabitableTypePack; private: int checkRecursionCount = 0; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index c1de242f..b17003b1 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -173,5 +173,6 @@ std::pair, std::optional> flatten(TypePackId tp, bool isVariadic(TypePackId tp); bool isVariadic(TypePackId tp, const TxnLog& log); +bool containsNever(TypePackId tp); } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 6ad6b927..fb6093df 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -460,10 +460,18 @@ struct LazyTypeVar std::function thunk; }; +struct UnknownTypeVar +{ +}; + +struct NeverTypeVar +{ +}; + using ErrorTypeVar = Unifiable::Error; using TypeVariant = Unifiable::Variant; + MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar>; struct TypeVar final { @@ -575,8 +583,12 @@ struct SingletonTypes const TypeId trueType; const TypeId falseType; const TypeId anyType; + const TypeId unknownType; + const TypeId neverType; const TypePackId anyTypePack; + const TypePackId neverTypePack; + const TypePackId uninhabitableTypePack; SingletonTypes(); ~SingletonTypes(); @@ -632,12 +644,30 @@ T* getMutable(TypeId tv) return get_if(&asMutable(tv)->ty); } -/* Traverses the UnionTypeVar yielding each TypeId. - * If the iterator encounters a nested UnionTypeVar, it will instead yield each TypeId within. - * - * Beware: the iterator does not currently filter for unique TypeIds. This may change in the future. +const std::vector& getTypes(const UnionTypeVar* utv); +const std::vector& getTypes(const IntersectionTypeVar* itv); +const std::vector& getTypes(const ConstrainedTypeVar* ctv); + +template +struct TypeIterator; + +using UnionTypeVarIterator = TypeIterator; +UnionTypeVarIterator begin(const UnionTypeVar* utv); +UnionTypeVarIterator end(const UnionTypeVar* utv); + +using IntersectionTypeVarIterator = TypeIterator; +IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv); +IntersectionTypeVarIterator end(const IntersectionTypeVar* itv); + +using ConstrainedTypeVarIterator = TypeIterator; +ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv); +ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv); + +/* Traverses the type T yielding each TypeId. + * If the iterator encounters a nested type T, it will instead yield each TypeId within. */ -struct UnionTypeVarIterator +template +struct TypeIterator { using value_type = Luau::TypeId; using pointer = value_type*; @@ -645,33 +675,116 @@ struct UnionTypeVarIterator using difference_type = size_t; using iterator_category = std::input_iterator_tag; - explicit UnionTypeVarIterator(const UnionTypeVar* utv); + explicit TypeIterator(const T* t) + { + LUAU_ASSERT(t); - UnionTypeVarIterator& operator++(); - UnionTypeVarIterator operator++(int); - bool operator!=(const UnionTypeVarIterator& rhs); - bool operator==(const UnionTypeVarIterator& rhs); + const std::vector& types = getTypes(t); + if (!types.empty()) + stack.push_front({t, 0}); - const TypeId& operator*(); + seen.insert(t); + } - friend UnionTypeVarIterator end(const UnionTypeVar* utv); + TypeIterator& operator++() + { + advance(); + descend(); + return *this; + } + + TypeIterator operator++(int) + { + TypeIterator copy = *this; + ++copy; + return copy; + } + + bool operator==(const TypeIterator& rhs) const + { + if (!stack.empty() && !rhs.stack.empty()) + return stack.front() == rhs.stack.front(); + + return stack.empty() && rhs.stack.empty(); + } + + bool operator!=(const TypeIterator& rhs) const + { + return !(*this == rhs); + } + + const TypeId& operator*() + { + LUAU_ASSERT(!stack.empty()); + + descend(); + + auto [t, currentIndex] = stack.front(); + LUAU_ASSERT(t); + const std::vector& types = getTypes(t); + LUAU_ASSERT(currentIndex < types.size()); + + const TypeId& ty = types[currentIndex]; + LUAU_ASSERT(!get(follow(ty))); + return ty; + } + + // Normally, we'd have `begin` and `end` be a template but there's too much trouble + // with templates portability in this area, so not worth it. Thanks MSVC. + friend UnionTypeVarIterator end(const UnionTypeVar*); + friend IntersectionTypeVarIterator end(const IntersectionTypeVar*); + friend ConstrainedTypeVarIterator end(const ConstrainedTypeVar*); private: - UnionTypeVarIterator() = default; + TypeIterator() = default; - // (UnionTypeVar* utv, size_t currentIndex) - using SavedIterInfo = std::pair; + // (T* t, size_t currentIndex) + using SavedIterInfo = std::pair; std::deque stack; - std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. + std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. - void advance(); - void descend(); + void advance() + { + while (!stack.empty()) + { + auto& [t, currentIndex] = stack.front(); + ++currentIndex; + + const std::vector& types = getTypes(t); + if (currentIndex >= types.size()) + stack.pop_front(); + else + break; + } + } + + void descend() + { + while (!stack.empty()) + { + auto [current, currentIndex] = stack.front(); + const std::vector& types = getTypes(current); + if (auto inner = get(follow(types[currentIndex]))) + { + // If we're about to descend into a cyclic type, we should skip over this. + // Ideally this should never happen, but alas it does from time to time. :( + if (seen.find(inner) != seen.end()) + advance(); + else + { + seen.insert(inner); + stack.push_front({inner, 0}); + } + + continue; + } + + break; + } + } }; -UnionTypeVarIterator begin(const UnionTypeVar* utv); -UnionTypeVarIterator end(const UnionTypeVar* utv); - using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 5fd43f0b..ab4a397d 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -129,6 +129,14 @@ struct GenericTypeVarVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const UnknownTypeVar& atv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const NeverTypeVar& atv) + { + return visit(ty); + } virtual bool visit(TypeId ty, const UnionTypeVar& utv) { return visit(ty); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 0522b1fa..1124c29e 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -17,6 +17,104 @@ namespace Luau namespace { + +struct AutocompleteNodeFinder : public AstVisitor +{ + const Position pos; + std::vector ancestry; + + explicit AutocompleteNodeFinder(Position pos, AstNode* root) + : pos(pos) + { + } + + bool visit(AstExpr* expr) override + { + if (expr->location.begin < pos && pos <= expr->location.end) + { + ancestry.push_back(expr); + return true; + } + return false; + } + + bool visit(AstStat* stat) override + { + if (stat->location.begin < pos && pos <= stat->location.end) + { + ancestry.push_back(stat); + return true; + } + return false; + } + + bool visit(AstType* type) override + { + if (type->location.begin < pos && pos <= type->location.end) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(AstTypeError* type) override + { + // For a missing type, match the whole range including the start position + if (type->isMissing && type->location.containsClosed(pos)) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(class AstTypePack* typePack) override + { + return true; + } + + bool visit(AstStatBlock* block) override + { + // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. + if (ancestry.empty()) + { + ancestry.push_back(block); + return true; + } + + // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. + // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // Type annotation error might intersect the block statement when the function header is being written, + // annotation takes priority + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, + // the expression or type wins out. + // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to + // be within the block. + if (block->location.begin == pos && !ancestry.empty()) + { + if (ancestry.back()->asExpr() && !ancestry.back()->is()) + return false; + + if (ancestry.back()->asType()) + return false; + } + + if (block->location.begin <= pos && pos <= block->location.end) + { + ancestry.push_back(block); + return true; + } + return false; + } +}; + struct FindNode : public AstVisitor { const Position pos; @@ -102,6 +200,13 @@ struct FindFullAncestry final : public AstVisitor } // namespace +std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos) +{ + AutocompleteNodeFinder finder{pos, source.root}; + source.root->visit(&finder); + return finder.ancestry; +} + std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) { const Position end = source.root->location.end; @@ -110,7 +215,7 @@ std::vector findAstAncestryOfPosition(const SourceModule& source, Posi FindFullAncestry finder(pos, end); source.root->visit(&finder); - return std::move(finder.nodes); + return finder.nodes; } AstNode* findNodeAtPosition(const SourceModule& source, Position pos) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 8a63901f..cc54d499 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -21,102 +21,6 @@ static const std::unordered_set kStatementStartingKeywords = { namespace Luau { -struct NodeFinder : public AstVisitor -{ - const Position pos; - std::vector ancestry; - - explicit NodeFinder(Position pos, AstNode* root) - : pos(pos) - { - } - - bool visit(AstExpr* expr) override - { - if (expr->location.begin < pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; - } - - bool visit(AstStat* stat) override - { - if (stat->location.begin < pos && pos <= stat->location.end) - { - ancestry.push_back(stat); - return true; - } - return false; - } - - bool visit(AstType* type) override - { - if (type->location.begin < pos && pos <= type->location.end) - { - ancestry.push_back(type); - return true; - } - return false; - } - - bool visit(AstTypeError* type) override - { - // For a missing type, match the whole range including the start position - if (type->isMissing && type->location.containsClosed(pos)) - { - ancestry.push_back(type); - return true; - } - return false; - } - - bool visit(class AstTypePack* typePack) override - { - return true; - } - - bool visit(AstStatBlock* block) override - { - // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. - if (ancestry.empty()) - { - ancestry.push_back(block); - return true; - } - - // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. - // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} - if (!ancestry.empty() && ancestry.back()->is()) - return false; - - // Type annotation error might intersect the block statement when the function header is being written, - // annotation takes priority - if (!ancestry.empty() && ancestry.back()->is()) - return false; - - // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, - // the expression or type wins out. - // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to - // be within the block. - if (block->location.begin == pos && !ancestry.empty()) - { - if (ancestry.back()->asExpr() && !ancestry.back()->is()) - return false; - - if (ancestry.back()->asType()) - return false; - } - - if (block->location.begin <= pos && pos <= block->location.end) - { - ancestry.push_back(block); - return true; - } - return false; - } -}; static bool alreadyHasParens(const std::vector& nodes) { @@ -905,7 +809,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi } AstNode* parent = nullptr; - AstType* topType = nullptr; + AstType* topType = nullptr; // TODO: rename? for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) { @@ -1477,21 +1381,20 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (isWithinComment(sourceModule, position)) return {}; - NodeFinder finder{position, sourceModule.root}; - sourceModule.root->visit(&finder); - LUAU_ASSERT(!finder.ancestry.empty()); - AstNode* node = finder.ancestry.back(); + std::vector ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position); + LUAU_ASSERT(!ancestry.empty()); + AstNode* node = ancestry.back(); AstExprConstantNil dummy{Location{}}; - AstNode* parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy; + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) { - finder.ancestry.pop_back(); + ancestry.pop_back(); - node = finder.ancestry.back(); - parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy; + node = ancestry.back(); + parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; } if (auto indexName = node->as()) @@ -1504,47 +1407,47 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty)) - return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), - finder.ancestry}; + return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), + ancestry}; else - return {autocompleteProps(*module, typeArena, ty, indexType, finder.ancestry), finder.ancestry}; + return {autocompleteProps(*module, typeArena, ty, indexType, ancestry), ancestry}; } else if (auto typeReference = node->as()) { if (typeReference->prefix) - return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), finder.ancestry}; + return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), ancestry}; else - return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; + return {autocompleteTypeNames(*module, position, ancestry), ancestry}; } else if (node->is()) { - return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; + return {autocompleteTypeNames(*module, position, ancestry), ancestry}; } else if (AstStatLocal* statLocal = node->as()) { if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) - return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; else return {}; } - else if (AstStatFor* statFor = extractStat(finder.ancestry)) + else if (AstStatFor* statFor = extractStat(ancestry)) { if (!statFor->hasDo || position < statFor->doLocation.begin) { if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || (statFor->step && statFor->step->location.containsClosed(position))) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; return {}; } - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; } else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) @@ -1560,7 +1463,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {}; } - return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; } if (!statForIn->hasDo || position <= statForIn->doLocation.begin) @@ -1569,58 +1472,58 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; if (lastExpr->location.containsClosed(position)) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; if (position > lastExpr->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; return {}; // Not sure what this means } } - else if (AstStatForIn* statForIn = extractStat(finder.ancestry)) + else if (AstStatForIn* statForIn = extractStat(ancestry)) { // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. // ex "for f in f do" if (!statForIn->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; } else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) { if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; if (!statWhile->hasDo || position < statWhile->doLocation.begin) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; if (statWhile->hasDo && position > statWhile->doLocation.end) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; } - else if (AstStatWhile* statWhile = extractStat(finder.ancestry); statWhile && !statWhile->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + else if (AstStatWhile* statWhile = extractStat(ancestry); statWhile && !statWhile->hasDo) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, - finder.ancestry}; + ancestry}; } else if (AstStatIf* statIf = parent->as(); statIf && node->is()) { if (statIf->condition->is()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; } - else if (AstStatIf* statIf = extractStat(finder.ancestry); + else if (AstStatIf* statIf = extractStat(ancestry); statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; - else if (AstStatRepeat* statRepeat = extractStat(finder.ancestry); statRepeat) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; + else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is())) { for (const auto& [kind, key, value] : exprTable->items) @@ -1630,7 +1533,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry); + auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, ancestry); // Remove keys that are already completed for (const auto& item : exprTable->items) @@ -1644,9 +1547,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // If we know for sure that a key is being written, do not offer general expression suggestions if (!key) - autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position, result); + autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position, result); - return {result, finder.ancestry}; + return {result, ancestry}; } break; @@ -1654,11 +1557,11 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } } else if (isIdentifier(node) && (parent->is() || parent->is())) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; - if (std::optional ret = autocompleteStringParams(sourceModule, module, finder.ancestry, position, callback)) + if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, callback)) { - return {*ret, finder.ancestry}; + return {*ret, ancestry}; } else if (node->is()) { @@ -1667,14 +1570,14 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto it = module->astExpectedTypes.find(node->asExpr())) autocompleteStringSingleton(*it, false, result); - if (finder.ancestry.size() >= 2) + if (ancestry.size() >= 2) { - if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) + if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) { if (auto it = module->astTypes.find(idxExpr->expr)) - autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result); + autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, ancestry, result); } - else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) + else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) { if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { @@ -1684,7 +1587,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } } - return {result, finder.ancestry}; + return {result, ancestry}; } if (node->is()) @@ -1693,9 +1596,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } if (node->asExpr()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; else if (node->asStat()) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; return {}; } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 2f57e23c..aeba2c13 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -9,6 +9,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) +LUAU_FASTFLAG(LuauUnknownAndNeverType) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -222,14 +223,14 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - // setmetatable({ @metatable MT }, MT) -> { @metatable MT } // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } addGlobalBinding(typeChecker, "setmetatable", arena.addType( FunctionTypeVar{ {genericMT}, {}, - arena.addTypePack(TypePack{{tableMetaMT, genericMT}}), + arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}), arena.addTypePack(TypePack{{tableMetaMT}}) } ), "@luau" @@ -309,6 +310,12 @@ static std::optional> magicFunctionSetMetaTable( { auto [paramPack, _predicates] = withPredicate; + if (FFlag::LuauUnknownAndNeverType) + { + if (size(paramPack) < 2 && finite(paramPack)) + return std::nullopt; + } + TypeArena& arena = typechecker.currentModule->internalTypes; std::vector expectedArgs = typechecker.unTypePack(scope, paramPack, 2, expr.location); @@ -316,6 +323,12 @@ static std::optional> magicFunctionSetMetaTable( TypeId target = follow(expectedArgs[0]); TypeId mt = follow(expectedArgs[1]); + if (FFlag::LuauUnknownAndNeverType) + { + typechecker.tablify(target); + typechecker.tablify(mt); + } + if (const auto& tab = get(target)) { if (target->persistent) @@ -324,7 +337,8 @@ static std::optional> magicFunctionSetMetaTable( } else { - typechecker.tablify(mt); + if (!FFlag::LuauUnknownAndNeverType) + typechecker.tablify(mt); const TableTypeVar* mtTtv = get(mt); MetatableTypeVar mtv{target, mt}; @@ -343,7 +357,10 @@ static std::optional> magicFunctionSetMetaTable( if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) { - return WithPredicate{}; + if (FFlag::LuauUnknownAndNeverType) + return std::nullopt; + else + return WithPredicate{}; } if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) @@ -390,11 +407,21 @@ static std::optional> magicFunctionAssert( if (head.size() > 0) { - std::optional newhead = typechecker.pickTypesFromSense(head[0], true); - if (!newhead) - head = {typechecker.nilType}; + auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true); + if (FFlag::LuauUnknownAndNeverType) + { + if (get(*ty)) + head = {*ty}; + else + head[0] = *ty; + } else - head[0] = *newhead; + { + if (!ty) + head = {typechecker.nilType}; + else + head[0] = *ty; + } } return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index df4e0a6b..88c50318 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -59,6 +59,8 @@ struct TypeCloner void operator()(const UnionTypeVar& t); void operator()(const IntersectionTypeVar& t); void operator()(const LazyTypeVar& t); + void operator()(const UnknownTypeVar& t); + void operator()(const NeverTypeVar& t); }; struct TypePackCloner @@ -310,6 +312,16 @@ void TypeCloner::operator()(const LazyTypeVar& t) defaultClone(t); } +void TypeCloner::operator()(const UnknownTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const NeverTypeVar& t) +{ + defaultClone(t); +} + } // anonymous namespace TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 1b5275fd..f93f65dd 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauCheckLenMT) namespace Luau @@ -116,8 +117,6 @@ declare function typeof(value: T): string -- `assert` has a magic function attached that will give more detailed type information declare function assert(value: T, errorMessage: string?): T -declare function error(message: T, level: number?) - declare function tostring(value: T): string declare function tonumber(value: T, radix: number?): number? @@ -204,12 +203,18 @@ declare function unpack(tab: {V}, i: number?, j: number?): ...V std::string getBuiltinDefinitionSource() { + std::string result = kBuiltinDefinitionLuaSrc; // TODO: move this into kBuiltinDefinitionLuaSrc if (FFlag::LuauCheckLenMT) result += "declare function rawlen(obj: {[K]: V} | string): number\n"; + if (FFlag::LuauUnknownAndNeverType) + result += "declare function error(message: T, level: number?): never\n"; + else + result += "declare function error(message: T, level: number?)\n"; + return result; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 4cfaa112..9195363d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -496,6 +496,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalastTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); + module->astResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); module->scopes.resize(1); } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 8ce7f742..ce8f96cc 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -14,7 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); -LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau @@ -182,7 +182,6 @@ struct Normalize final : TypeVarVisitor { if (!ty->normal) asMutable(ty)->normal = true; - return false; } @@ -193,6 +192,20 @@ struct Normalize final : TypeVarVisitor return false; } + bool visit(TypeId ty, const UnknownTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool visit(TypeId ty, const NeverTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override { CHECK_ITERATION_LIMIT(false); @@ -416,7 +429,13 @@ struct Normalize final : TypeVarVisitor std::vector result; for (TypeId part : options) + { + // AnyTypeVar always win the battle no matter what we do, so we're done. + if (FFlag::LuauUnknownAndNeverType && get(follow(part))) + return {part}; + combineIntoUnion(result, part); + } return result; } @@ -427,7 +446,17 @@ struct Normalize final : TypeVarVisitor if (auto utv = get(ty)) { for (TypeId t : utv) + { + // AnyTypeVar always win the battle no matter what we do, so we're done. + if (FFlag::LuauUnknownAndNeverType && get(t)) + { + result = {t}; + return; + } + combineIntoUnion(result, t); + } + return; } @@ -571,8 +600,7 @@ struct Normalize final : TypeVarVisitor */ TypeId combine(Replacer& replacer, TypeId a, TypeId b) { - if (FFlag::LuauNormalizeCombineEqFix) - b = follow(b); + b = follow(b); if (FFlag::LuauNormalizeCombineTableFix && a == b) return a; @@ -592,7 +620,7 @@ struct Normalize final : TypeVarVisitor } else if (auto ttv = getMutable(a)) { - if (FFlag::LuauNormalizeCombineTableFix && !get(FFlag::LuauNormalizeCombineEqFix ? b : follow(b))) + if (FFlag::LuauNormalizeCombineTableFix && !get(b)) return arena.addType(IntersectionTypeVar{{a, b}}); combineIntoTable(replacer, ttv, b); return a; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 9c4ce829..7245403c 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,8 +8,10 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauAnyificationMustClone, false) LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) +LUAU_FASTFLAG(LuauUnknownAndNeverType) namespace Luau { @@ -154,7 +156,7 @@ TarjanResult Tarjan::loop() if (currEdge == -1) { ++childCount; - if (childLimit > 0 && childLimit < childCount) + if (childLimit > 0 && (FFlag::LuauUnknownAndNeverType ? childLimit <= childCount : childLimit < childCount)) return TarjanResult::TooManyChildren; stack.push_back(index); @@ -439,6 +441,9 @@ void Substitution::replaceChildren(TypeId ty) if (ignoreChildren(ty)) return; + if (FFlag::LuauAnyificationMustClone && ty->owningArena != arena) + return; + if (FunctionTypeVar* ftv = getMutable(ty)) { ftv->argTypes = replace(ftv->argTypes); @@ -490,6 +495,9 @@ void Substitution::replaceChildren(TypePackId tp) if (ignoreChildren(tp)) return; + if (FFlag::LuauAnyificationMustClone && tp->owningArena != arena) + return; + if (TypePack* tpp = getMutable(tp)) { for (TypeId& tv : tpp->head) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index fe940d5e..c67d6397 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauLowerBoundsCalculation) +LUAU_FASTFLAG(LuauUnknownAndNeverType) /* * Prefix generic typenames with gen- @@ -699,6 +700,12 @@ struct TypeVarStringifier void operator()(TypeId, const MetatableTypeVar& mtv) { state.result.invalid = true; + if (!state.exhaustive && mtv.syntheticName) + { + state.emit(*mtv.syntheticName); + return; + } + state.emit("{ @metatable "); stringify(mtv.metatable); state.emit(","); @@ -834,7 +841,7 @@ struct TypeVarStringifier void operator()(TypeId, const ErrorTypeVar& tv) { state.result.error = true; - state.emit("*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); } void operator()(TypeId, const LazyTypeVar& ltv) @@ -843,7 +850,17 @@ struct TypeVarStringifier state.emit("lazy?"); } -}; // namespace + void operator()(TypeId, const UnknownTypeVar& ttv) + { + state.emit("unknown"); + } + + void operator()(TypeId, const NeverTypeVar& ttv) + { + state.emit("never"); + } + +}; struct TypePackStringifier { @@ -947,7 +964,7 @@ struct TypePackStringifier void operator()(TypePackId, const Unifiable::Error& error) { state.result.error = true; - state.emit("*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); } void operator()(TypePackId, const VariadicTypePack& pack) diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 1577bd63..9feff1c0 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -205,20 +205,6 @@ struct Printer } } - void visualizeWithSelf(AstExpr& expr, bool self) - { - if (!self) - return visualize(expr); - - AstExprIndexName* func = expr.as(); - LUAU_ASSERT(func); - - visualize(*func->expr); - writer.symbol(":"); - advance(func->indexLocation.begin); - writer.identifier(func->index.value); - } - void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) { advance(annotation.location.begin); @@ -366,7 +352,7 @@ struct Printer } else if (const auto& a = expr.as()) { - visualizeWithSelf(*a->func, a->self); + visualize(*a->func); writer.symbol("("); bool first = true; @@ -385,7 +371,7 @@ struct Printer else if (const auto& a = expr.as()) { visualize(*a->expr); - writer.symbol("."); + writer.symbol(std::string(1, a->op)); writer.write(a->index.value); } else if (const auto& a = expr.as()) @@ -766,7 +752,7 @@ struct Printer else if (const auto& a = program.as()) { writer.keyword("function"); - visualizeWithSelf(*a->name, a->func->self != nullptr); + visualize(*a->name); visualizeFunctionBody(*a->func); } else if (const auto& a = program.as()) diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 4c6d54e0..b3f60d30 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,7 +7,7 @@ #include #include -LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) +LUAU_FASTFLAG(LuauUnknownAndNeverType) namespace Luau { @@ -81,34 +81,10 @@ void TxnLog::concat(TxnLog rhs) void TxnLog::commit() { for (auto& [ty, rep] : typeVarChanges) - { - if (FFlag::LuauNonCopyableTypeVarFields) - { - asMutable(ty)->reassign(rep.get()->pending); - } - else - { - TypeArena* owningArena = ty->owningArena; - TypeVar* mtv = asMutable(ty); - *mtv = rep.get()->pending; - mtv->owningArena = owningArena; - } - } + asMutable(ty)->reassign(rep.get()->pending); for (auto& [tp, rep] : typePackChanges) - { - if (FFlag::LuauNonCopyableTypeVarFields) - { - asMutable(tp)->reassign(rep.get()->pending); - } - else - { - TypeArena* owningArena = tp->owningArena; - TypePackVar* mpv = asMutable(tp); - *mpv = rep.get()->pending; - mpv->owningArena = owningArena; - } - } + asMutable(tp)->reassign(rep.get()->pending); clear(); } @@ -196,9 +172,7 @@ PendingType* TxnLog::queue(TypeId ty) if (!pending) { pending = std::make_unique(*ty); - - if (FFlag::LuauNonCopyableTypeVarFields) - pending->pending.owningArena = nullptr; + pending->pending.owningArena = nullptr; } return pending.get(); @@ -214,9 +188,7 @@ PendingTypePack* TxnLog::queue(TypePackId tp) if (!pending) { pending = std::make_unique(*tp); - - if (FFlag::LuauNonCopyableTypeVarFields) - pending->pending.owningArena = nullptr; + pending->pending.owningArena = nullptr; } return pending.get(); @@ -255,24 +227,14 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) { PendingType* newTy = queue(ty); - - if (FFlag::LuauNonCopyableTypeVarFields) - newTy->pending.reassign(replacement); - else - newTy->pending = replacement; - + newTy->pending.reassign(replacement); return newTy; } PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) { PendingTypePack* newTp = queue(tp); - - if (FFlag::LuauNonCopyableTypeVarFields) - newTp->pending.reassign(replacement); - else - newTp->pending = replacement; - + newTp->pending.reassign(replacement); return newTp; } @@ -289,7 +251,7 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); if (FreeTypeVar* ftv = Luau::getMutable(newTy)) @@ -305,6 +267,11 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { ftv->level = newLevel; } + else if (ConstrainedTypeVar* ctv = Luau::getMutable(newTy)) + { + if (FFlag::LuauUnknownAndNeverType) + ctv->level = newLevel; + } return newTy; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 6cca7127..2bc89cf6 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -335,6 +335,14 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName("")); } + AstType* operator()(const UnknownTypeVar& ttv) + { + return allocator->alloc(Location(), std::nullopt, AstName{"unknown"}); + } + AstType* operator()(const NeverTypeVar& ttv) + { + return allocator->alloc(Location(), std::nullopt, AstName{"never"}); + } private: Allocator* allocator; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index d9486a4f..01939fda 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -31,6 +31,7 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTFLAGVARIABLE(LuauIndexSilenceErrors, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) @@ -41,10 +42,12 @@ LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) +LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) -LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) LUAU_FASTFLAGVARIABLE(LuauCheckLenMT, false) +LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) +LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) namespace Luau { @@ -258,7 +261,11 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(getSingletonTypes().booleanType) , threadType(getSingletonTypes().threadType) , anyType(getSingletonTypes().anyType) + , unknownType(getSingletonTypes().unknownType) + , neverType(getSingletonTypes().neverType) , anyTypePack(getSingletonTypes().anyTypePack) + , neverTypePack(getSingletonTypes().neverTypePack) + , uninhabitableTypePack(getSingletonTypes().uninhabitableTypePack) , duplicateTypeAliases{{false, {}}} { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -269,6 +276,11 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan globalScope->exportedTypeBindings["string"] = TypeFun{{}, stringType}; globalScope->exportedTypeBindings["boolean"] = TypeFun{{}, booleanType}; globalScope->exportedTypeBindings["thread"] = TypeFun{{}, threadType}; + if (FFlag::LuauUnknownAndNeverType) + { + globalScope->exportedTypeBindings["unknown"] = TypeFun{{}, unknownType}; + globalScope->exportedTypeBindings["never"] = TypeFun{{}, neverType}; + } } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -456,6 +468,59 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } } +struct InplaceDemoter : TypeVarOnceVisitor +{ + TypeLevel newLevel; + TypeArena* arena; + + InplaceDemoter(TypeLevel level, TypeArena* arena) + : newLevel(level) + , arena(arena) + { + } + + bool demote(TypeId ty) + { + if (auto level = getMutableLevel(ty)) + { + if (level->subsumesStrict(newLevel)) + { + *level = newLevel; + return true; + } + } + + return false; + } + + bool visit(TypeId ty, const BoundTypeVar& btyRef) override + { + return true; + } + + bool visit(TypeId ty) override + { + if (ty->owningArena != arena) + return false; + return demote(ty); + } + + bool visit(TypePackId tp, const FreeTypePack& ftpRef) override + { + if (tp->owningArena != arena) + return false; + + FreeTypePack* ftp = &const_cast(ftpRef); + if (ftp->level.subsumesStrict(newLevel)) + { + ftp->level = newLevel; + return true; + } + + return false; + } +}; + void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) { int subLevel = 0; @@ -559,7 +624,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A tablify(baseTy); if (!fun->func->self) - expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, false); + expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, /* addErrors= */ false); else if (auto ttv = getMutableTableType(baseTy)) { if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy) @@ -579,7 +644,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A if (auto name = fun->name->as()) { TypeId exprTy = checkExpr(scope, *name->expr).type; - expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false); } } } @@ -634,15 +699,8 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - if (FFlag::LuauNonCopyableTypeVarFields) - { - TypeVar* mty = asMutable(follow(type)); - mty->reassign(*errorRecoveryType(anyType)); - } - else - { - *asMutable(type) = *errorRecoveryType(anyType); - } + TypeVar* mty = asMutable(follow(type)); + mty->reassign(*errorRecoveryType(anyType)); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } @@ -1206,7 +1264,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } - if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location, /* addErrors= */ true)) { // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments @@ -1253,7 +1311,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (TypeId var : varTypes) unify(varTy, var, forin.location); - if (!get(iterTy) && !get(iterTy) && !get(iterTy)) + if (!get(iterTy) && !get(iterTy) && !get(iterTy) && !get(iterTy)) reportError(firstValue->location, CannotCallNonFunction{iterTy}); return check(loopScope, *forin.body); @@ -1350,7 +1408,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); - if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false)) + if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false)) { if (ttv || isTableIntersection(exprTy)) reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); @@ -1376,6 +1434,12 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); + if (FFlag::LuauUnknownAndNeverType) + { + InplaceDemoter demoter{funScope->level, ¤tModule->internalTypes}; + demoter.traverse(ty); + } + if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } @@ -1729,7 +1793,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - result = checkExpr(scope, *a); + result = checkExpr(scope, *a, FFlag::LuauBinaryNeedsExpectedTypesToo ? expectedType : std::nullopt); else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1851,41 +1915,56 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp lhsType = stripFromNilAndReport(lhsType, expr.expr->location); - if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) + if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, /* addErrors= */ true)) return {*ty}; return {errorRecoveryType(scope)}; } -std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) +std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors) { ErrorVec errors; auto result = Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); - reportErrors(errors); + if (!FFlag::LuauIndexSilenceErrors || addErrors) + reportErrors(errors); return result; } -std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location) +std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors) { ErrorVec errors; auto result = Luau::findMetatableEntry(errors, type, entry, location); - reportErrors(errors); + if (!FFlag::LuauIndexSilenceErrors || addErrors) + reportErrors(errors); return result; } std::optional TypeChecker::getIndexTypeFromType( - const ScopePtr& scope, TypeId type, const std::string& name, const Location& location, bool addErrors) + const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) +{ + size_t errorCount = currentModule->errors.size(); + + std::optional result = getIndexTypeFromTypeImpl(scope, type, name, location, addErrors); + + if (FFlag::LuauIndexSilenceErrors && !addErrors) + LUAU_ASSERT(errorCount == currentModule->errors.size()); + + return result; +} + +std::optional TypeChecker::getIndexTypeFromTypeImpl( + const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) { type = follow(type); - if (get(type) || get(type)) + if (get(type) || get(type) || get(type)) return type; tablify(type); if (isString(type)) { - std::optional mtIndex = findMetatableEntry(stringType, "__index", location); + std::optional mtIndex = findMetatableEntry(stringType, "__index", location, addErrors); LUAU_ASSERT(mtIndex); type = *mtIndex; } @@ -1919,7 +1998,7 @@ std::optional TypeChecker::getIndexTypeFromType( return result; } - if (auto found = findTablePropertyRespectingMeta(type, name, location)) + if (auto found = findTablePropertyRespectingMeta(type, name, location, addErrors)) return *found; } else if (const ClassTypeVar* cls = get(type)) @@ -1941,7 +2020,7 @@ std::optional TypeChecker::getIndexTypeFromType( if (get(follow(t))) return t; - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) goodOptions.push_back(*ty); else badOptions.push_back(t); @@ -1972,6 +2051,8 @@ std::optional TypeChecker::getIndexTypeFromType( else { std::vector result = reduceUnion(goodOptions); + if (FFlag::LuauUnknownAndNeverType && result.empty()) + return neverType; if (result.size() == 1) return result[0]; @@ -1987,7 +2068,7 @@ std::optional TypeChecker::getIndexTypeFromType( { RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) parts.push_back(*ty); } @@ -2017,6 +2098,9 @@ std::vector TypeChecker::reduceUnion(const std::vector& types) for (TypeId t : types) { t = follow(t); + if (get(t)) + continue; + if (get(t) || get(t)) return {t}; @@ -2028,6 +2112,8 @@ std::vector TypeChecker::reduceUnion(const std::vector& types) { if (FFlag::LuauNormalizeFlagIsConservative) ty = follow(ty); + if (get(ty)) + continue; if (get(ty) || get(ty)) return {ty}; @@ -2041,6 +2127,8 @@ std::vector TypeChecker::reduceUnion(const std::vector& types) for (TypeId ty : r) { ty = follow(ty); + if (get(ty)) + continue; if (get(ty) || get(ty)) return {ty}; @@ -2314,14 +2402,14 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {booleanType, {NotPredicate{std::move(result.predicates)}}}; case AstExprUnary::Minus: { - const bool operandIsAny = get(operandType) || get(operandType); + const bool operandIsAny = get(operandType) || get(operandType) || get(operandType); if (operandIsAny) return {operandType}; if (typeCouldHaveMetatable(operandType)) { - if (auto fnt = findMetatableEntry(operandType, "__unm", expr.location)) + if (auto fnt = findMetatableEntry(operandType, "__unm", expr.location, /* addErrors= */ true)) { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); @@ -2355,14 +2443,14 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp operandType = stripFromNilAndReport(operandType, expr.location); - if (get(operandType)) - return {errorRecoveryType(scope)}; + if (get(operandType) || get(operandType)) + return {!FFlag::LuauUnknownAndNeverType ? errorRecoveryType(scope) : operandType}; DenseHashSet seen{nullptr}; if (FFlag::LuauCheckLenMT && typeCouldHaveMetatable(operandType)) { - if (auto fnt = findMetatableEntry(operandType, "__len", expr.location)) + if (auto fnt = findMetatableEntry(operandType, "__len", expr.location, /* addErrors= */ true)) { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); @@ -2433,6 +2521,9 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b return a; std::vector types = reduceUnion({a, b}); + if (FFlag::LuauUnknownAndNeverType && types.empty()) + return neverType; + if (types.size() == 1) return types[0]; @@ -2485,7 +2576,7 @@ TypeId TypeChecker::checkRelationalOperation( // If we know nothing at all about the lhs type, we can usually say nothing about the result. // The notable exception to this is the equality and inequality operators, which always produce a boolean. - const bool lhsIsAny = get(lhsType) || get(lhsType); + const bool lhsIsAny = get(lhsType) || get(lhsType) || get(lhsType); // Peephole check for `cond and a or b -> type(a)|type(b)` // TODO: Kill this when singleton types arrive. :( @@ -2508,7 +2599,7 @@ TypeId TypeChecker::checkRelationalOperation( if (isNonstrictMode() && (isNil(lhsType) || isNil(rhsType))) return booleanType; - const bool rhsIsAny = get(rhsType) || get(rhsType); + const bool rhsIsAny = get(rhsType) || get(rhsType) || get(rhsType); if (lhsIsAny || rhsIsAny) return booleanType; @@ -2596,7 +2687,7 @@ TypeId TypeChecker::checkRelationalOperation( if (leftMetatable) { - std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); + std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true); if (metamethod) { if (const FunctionTypeVar* ftv = get(*metamethod)) @@ -2757,9 +2848,9 @@ TypeId TypeChecker::checkBinaryOperation( }; std::string op = opToMetaTableEntry(expr.op); - if (auto fnt = findMetatableEntry(lhsType, op, expr.location)) + if (auto fnt = findMetatableEntry(lhsType, op, expr.location, /* addErrors= */ true)) return checkMetatableCall(*fnt, lhsType, rhsType); - if (auto fnt = findMetatableEntry(rhsType, op, expr.location)) + if (auto fnt = findMetatableEntry(rhsType, op, expr.location, /* addErrors= */ true)) { // Note the intentionally reversed arguments here. return checkMetatableCall(*fnt, rhsType, lhsType); @@ -2793,27 +2884,27 @@ TypeId TypeChecker::checkBinaryOperation( } } -WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional expectedType) { if (expr.op == AstExprBinary::And) { - auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left, expectedType); ScopePtr innerScope = childScope(scope, expr.location); resolve(lhsPredicates, innerScope, true); - auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right, expectedType); return {checkBinaryOperation(scope, expr, lhsTy, rhsTy), {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::Or) { - auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left, expectedType); ScopePtr innerScope = childScope(scope, expr.location); resolve(lhsPredicates, innerScope, false); - auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right, expectedType); // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. TypeId result = checkBinaryOperation(scope, expr, lhsTy, rhsTy, lhsPredicates); @@ -2824,6 +2915,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; + // For these, passing expectedType is worse than simply forcing them, because their implementation + // may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first. WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); @@ -2842,6 +2935,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } else { + // Expected types are not useful for other binary operators. WithPredicate lhs = checkExpr(scope, *expr.left); WithPredicate rhs = checkExpr(scope, *expr.right); @@ -2896,6 +2990,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {trueType.type}; std::vector types = reduceUnion({trueType.type, falseType.type}); + if (FFlag::LuauUnknownAndNeverType && types.empty()) + return {neverType}; return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; } @@ -2927,7 +3023,10 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& exp TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) { if (std::optional ty = scope->lookup(expr.local)) - return *ty; + { + ty = follow(*ty); + return get(*ty) ? unknownType : *ty; + } reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); return errorRecoveryType(scope); @@ -2941,7 +3040,10 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGloba const auto it = moduleScope->bindings.find(expr.name); if (it != moduleScope->bindings.end()) - return it->second.typeId; + { + TypeId ty = follow(it->second.typeId); + return get(ty) ? unknownType : ty; + } TypeId result = freshType(scope); Binding& binding = moduleScope->bindings[expr.name]; @@ -2962,6 +3064,9 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex if (get(lhs) || get(lhs)) return lhs; + if (get(lhs)) + return unknownType; + tablify(lhs); Name name = expr.index.value; @@ -3023,7 +3128,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (get(lhs)) { - if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, false)) + if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, /* addErrors= */ false)) return *ty; // If intersection has a table part, report that it cannot be extended just as a sealed table @@ -3050,6 +3155,9 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex if (get(exprType) || get(exprType)) return exprType; + if (get(exprType)) + return unknownType; + AstExprConstantString* value = expr.index->as(); if (value) @@ -3156,7 +3264,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T if (!ttv || ttv->state == TableState::Sealed) { - if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, false)) + if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, /* addErrors= */ false)) return *ty; return errorRecoveryType(scope); @@ -3228,9 +3336,12 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& } } - // We do not infer type binders, so if a generic function is required we do not propagate - if (expectedFunctionType && !(expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty())) - expectedFunctionType = nullptr; + if (!FFlag::LuauCheckGenericHOFTypes) + { + // We do not infer type binders, so if a generic function is required we do not propagate + if (expectedFunctionType && !(expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty())) + expectedFunctionType = nullptr; + } } auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); @@ -3240,7 +3351,8 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& retPack = resolveTypePack(funScope, *expr.returnAnnotation); else if (FFlag::LuauReturnTypeInferenceInNonstrict ? (!FFlag::LuauLowerBoundsCalculation && isNonstrictMode()) : isNonstrictMode()) retPack = anyTypePack; - else if (expectedFunctionType) + else if (expectedFunctionType && + (!FFlag::LuauCheckGenericHOFTypes || (expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty()))) { auto [head, tail] = flatten(expectedFunctionType->retTypes); @@ -3371,16 +3483,50 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); std::vector genericTys; - genericTys.reserve(generics.size()); - std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { - return el.ty; - }); + // if we have a generic expected function type and no generics, we should use the expected ones. + if (FFlag::LuauCheckGenericHOFTypes) + { + if (expectedFunctionType && generics.empty()) + { + genericTys = expectedFunctionType->generics; + } + else + { + genericTys.reserve(generics.size()); + for (const GenericTypeDefinition& generic : generics) + genericTys.push_back(generic.ty); + } + } + else + { + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + } std::vector genericTps; - genericTps.reserve(genericPacks.size()); - std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { - return el.tp; - }); + // if we have a generic expected function type and no generic typepacks, we should use the expected ones. + if (FFlag::LuauCheckGenericHOFTypes) + { + if (expectedFunctionType && genericPacks.empty()) + { + genericTps = expectedFunctionType->genericPacks; + } + else + { + genericTps.reserve(genericPacks.size()); + for (const GenericTypePackDefinition& generic : genericPacks) + genericTps.push_back(generic.tp); + } + } + else + { + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + } TypeId funTy = addType(FunctionTypeVar(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self))); @@ -3474,9 +3620,22 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE } WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) +{ + if (FFlag::LuauUnknownAndNeverType) + { + WithPredicate result = checkExprPackHelper(scope, expr); + if (containsNever(result.type)) + return {uninhabitableTypePack}; + return result; + } + else + return checkExprPackHelper(scope, expr); +} + +WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) - return checkExprPack(scope, *a); + return checkExprPackHelper(scope, *a); else if (expr.is()) { if (!scope->varargPack) @@ -3739,7 +3898,7 @@ void TypeChecker::checkArgumentList( } } -WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr) { // evaluate type of function // decompose an intersection into its component overloads @@ -3763,7 +3922,7 @@ WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, cons selfType = checkExpr(scope, *indexExpr->expr).type; selfType = stripFromNilAndReport(selfType, expr.func->location); - if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) + if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, /* addErrors= */ true)) { functionType = *propTy; actualFunctionType = instantiate(scope, functionType, expr.func->location); @@ -3813,11 +3972,25 @@ WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, cons if (get(argPack)) return {errorRecoveryTypePack(scope)}; - TypePack* args = getMutable(argPack); - LUAU_ASSERT(args != nullptr); + TypePack* args = nullptr; + if (FFlag::LuauUnknownAndNeverType) + { + if (expr.self) + { + argPack = addTypePack(TypePack{{selfType}, argPack}); + argListResult.type = argPack; + } + args = getMutable(argPack); + LUAU_ASSERT(args); + } + else + { + args = getMutable(argPack); + LUAU_ASSERT(args != nullptr); - if (expr.self) - args->head.insert(args->head.begin(), selfType); + if (expr.self) + args->head.insert(args->head.begin(), selfType); + } std::vector argLocations; argLocations.reserve(expr.args.size + 1); @@ -3876,7 +4049,10 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st else { std::vector result = reduceUnion({*el, ty}); - el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); + if (FFlag::LuauUnknownAndNeverType && result.empty()) + el = neverType; + else + el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); } } }; @@ -3930,6 +4106,9 @@ std::optional> TypeChecker::checkCallOverload(const Sc return {{errorRecoveryTypePack(scope)}}; } + if (get(fn)) + return {{uninhabitableTypePack}}; + if (auto ftv = get(fn)) { // fn is one of the overloads of actualFunctionType, which @@ -3975,7 +4154,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc // Might be a callable table if (const MetatableTypeVar* mttv = get(fn)) { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) + if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false)) { // Construct arguments with 'self' added in front TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); @@ -4202,6 +4381,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) { + bool uninhabitable = false; TypePackId pack = addTypePack(TypePack{}); PredicateVec predicates; // At the moment we will be pushing all predicate sets into this. Do we need some way to split them up? @@ -4232,7 +4412,13 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); - if (std::optional firstTy = first(typePack)) + if (FFlag::LuauUnknownAndNeverType && containsNever(typePack)) + { + // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, ...never) + uninhabitable = true; + continue; + } + else if (std::optional firstTy = first(typePack)) { if (!currentModule->astTypes.find(expr)) currentModule->astTypes[expr] = follow(*firstTy); @@ -4248,6 +4434,13 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons auto [type, exprPredicates] = checkExpr(scope, *expr, expectedType); insert(exprPredicates); + if (FFlag::LuauUnknownAndNeverType && get(type)) + { + // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, ...never) + uninhabitable = true; + continue; + } + TypeId actualType = substituteFreeForNil && expr->is() ? freshType(scope) : type; if (instantiateGenerics.size() > i && instantiateGenerics[i]) @@ -4272,6 +4465,8 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons for (TxnLog& log : inverseLogs) log.commit(); + if (FFlag::LuauUnknownAndNeverType && uninhabitable) + return {uninhabitableTypePack}; return {pack, predicates}; } @@ -4830,7 +5025,7 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) }; } -std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) +std::optional TypeChecker::filterMapImpl(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); if (!types.empty()) @@ -4838,7 +5033,21 @@ std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predic return std::nullopt; } -std::optional TypeChecker::pickTypesFromSense(TypeId type, bool sense) +std::pair, bool> TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) +{ + if (FFlag::LuauUnknownAndNeverType) + { + TypeId ty = filterMapImpl(type, predicate).value_or(neverType); + return {ty, !bool(get(ty))}; + } + else + { + std::optional ty = filterMapImpl(type, predicate); + return {ty, bool(ty)}; + } +} + +std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense) { return filterMap(type, mkTruthyPredicate(sense)); } @@ -4884,6 +5093,13 @@ TypePackId TypeChecker::freshTypePack(TypeLevel level) } TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation) +{ + TypeId ty = resolveTypeWorker(scope, annotation); + currentModule->astResolvedTypes[&annotation] = ty; + return ty; +} + +TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& annotation) { if (const auto& lit = annotation.as()) { @@ -5200,9 +5416,10 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypeList TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation) { + TypePackId result; if (const AstTypePackVariadic* variadic = annotation.as()) { - return addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); + result = addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); } else if (const AstTypePackGeneric* generic = annotation.as()) { @@ -5216,10 +5433,12 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack else reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); - return errorRecoveryTypePack(scope); + result = errorRecoveryTypePack(scope); + } + else + { + result = *genericTy; } - - return *genericTy; } else if (const AstTypePackExplicit* explicitTp = annotation.as()) { @@ -5229,14 +5448,17 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack types.push_back(resolveType(scope, *type)); if (auto tailType = explicitTp->typeList.tailType) - return addTypePack(types, resolveTypePack(scope, *tailType)); - - return addTypePack(types); + result = addTypePack(types, resolveTypePack(scope, *tailType)); + else + result = addTypePack(types); } else { ice("Unknown AstTypePack kind"); } + + currentModule->astResolvedTypePacks[&annotation] = result; + return result; } bool ApplyTypeFunction::isDirty(TypeId ty) @@ -5452,10 +5674,18 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const // If we do not have a key, it means we're not trying to discriminate anything, so it's a simple matter of just filtering for a subset. if (!key) { - if (std::optional result = filterMap(*ty, predicate)) + auto [result, ok] = filterMap(*ty, predicate); + if (FFlag::LuauUnknownAndNeverType) + { addRefinement(refis, *target, *result); + } else - addRefinement(refis, *target, errorRecoveryType(scope)); + { + if (ok) + addRefinement(refis, *target, *result); + else + addRefinement(refis, *target, errorRecoveryType(scope)); + } return; } @@ -5471,17 +5701,29 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const { std::optional discriminantTy; if (auto field = Luau::get(*key)) // need to fully qualify Luau::get because of ADL. - discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), false); + discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), /* addErrors= */ false); else LUAU_ASSERT(!"Unhandled LValue alternative?"); if (!discriminantTy) return; // Do nothing. An error was already reported, as per usual. - if (std::optional result = filterMap(*discriminantTy, predicate)) + auto [result, ok] = filterMap(*discriminantTy, predicate); + if (FFlag::LuauUnknownAndNeverType) { - viableTargetOptions.insert(option); - viableChildOptions.insert(*result); + if (!get(*result)) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } + } + else + { + if (ok) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } } } @@ -5560,7 +5802,7 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV continue; else if (auto field = get(key)) { - found = getIndexTypeFromType(scope, *found, field->key, Location(), false); + found = getIndexTypeFromType(scope, *found, field->key, Location(), /* addErrors= */ false); if (!found) return std::nullopt; // Turns out this type doesn't have the property at all. We're done. } @@ -5740,6 +5982,9 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r auto mkFilter = [](ConditionFunc f, std::optional other = std::nullopt) -> SenseToTypeIdPredicate { return [f, other](bool sense) -> TypeIdPredicate { return [f, other, sense](TypeId ty) -> std::optional { + if (FFlag::LuauUnknownAndNeverType && sense && get(ty)) + return other.value_or(ty); + if (f(ty) == sense) return ty; @@ -5847,8 +6092,15 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp for (size_t i = 0; i < expectedLength; ++i) expectedPack->head.push_back(freshType(scope)); + size_t oldErrorsSize = currentModule->errors.size(); + unify(tp, expectedTypePack, location); + // HACK: tryUnify would undo the changes to the expectedTypePack if the length mismatches, but + // we want to tie up free types to be error types, so we do this instead. + if (FFlag::LuauUnknownAndNeverType) + currentModule->errors.resize(oldErrorsSize); + for (TypeId& tp : expectedPack->head) tp = follow(tp); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 82451bd1..d4544483 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) - namespace Luau { @@ -40,19 +38,10 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) TypePackVar& TypePackVar::operator=(const TypePackVar& rhs) { - if (FFlag::LuauNonCopyableTypeVarFields) - { - LUAU_ASSERT(owningArena == rhs.owningArena); - LUAU_ASSERT(!rhs.persistent); + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); - reassign(rhs); - } - else - { - ty = rhs.ty; - persistent = rhs.persistent; - owningArena = rhs.owningArena; - } + reassign(rhs); return *this; } @@ -294,6 +283,16 @@ std::optional first(TypePackId tp, bool ignoreHiddenVariadics) return std::nullopt; } +TypePackVar* asMutable(TypePackId tp) +{ + return const_cast(tp); +} + +TypePack* asMutable(const TypePack* tp) +{ + return const_cast(tp); +} + bool isEmpty(TypePackId tp) { tp = follow(tp); @@ -360,13 +359,25 @@ bool isVariadic(TypePackId tp, const TxnLog& log) return false; } -TypePackVar* asMutable(TypePackId tp) +bool containsNever(TypePackId tp) { - return const_cast(tp); + auto it = begin(tp); + auto endIt = end(tp); + + while (it != endIt) + { + if (get(follow(*it))) + return true; + ++it; + } + + if (auto tail = it.tail()) + { + if (auto vtp = get(*tail); vtp && get(follow(vtp->ty))) + return true; + } + + return false; } -TypePack* asMutable(const TypePack* tp) -{ - return const_cast(tp); -} } // namespace Luau diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 3d97e6eb..66b38cf3 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -24,7 +24,7 @@ std::optional findMetatableEntry(ErrorVec& errors, TypeId type, std::str const TableTypeVar* mtt = getTableType(unwrapped); if (!mtt) { - errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}}); + errors.push_back(TypeError{location, GenericError{"Metatable was not a table"}}); return std::nullopt; } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index ade70d72..f884ad77 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,7 +23,9 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauDeduceGmatchReturnTypes, false) +LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) namespace Luau { @@ -31,6 +33,9 @@ namespace Luau std::optional> magicFunctionFormat( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeId follow(TypeId t) { return follow(t, [](TypeId t) { @@ -173,8 +178,8 @@ bool maybeString(TypeId ty) { ty = follow(ty); - if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) - return true; + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) + return true; if (auto utv = get(ty)) return std::any_of(begin(utv), end(utv), maybeString); @@ -194,7 +199,7 @@ bool isOptional(TypeId ty) ty = follow(ty); - if (get(ty)) + if (get(ty) || (FFlag::LuauUnknownAndNeverType && get(ty))) return true; auto utv = get(ty); @@ -334,6 +339,28 @@ bool isGeneric(TypeId ty) bool maybeGeneric(TypeId ty) { + if (FFlag::LuauMaybeGenericIntersectionTypes) + { + ty = follow(ty); + + if (get(ty)) + return true; + + if (auto ttv = get(ty)) + { + // TODO: recurse on table types CLI-39914 + (void)ttv; + return true; + } + + if (auto itv = get(ty)) + { + return std::any_of(begin(itv), end(itv), maybeGeneric); + } + + return isGeneric(ty); + } + ty = follow(ty); if (get(ty)) return true; @@ -646,20 +673,10 @@ TypeVar& TypeVar::operator=(TypeVariant&& rhs) TypeVar& TypeVar::operator=(const TypeVar& rhs) { - if (FFlag::LuauNonCopyableTypeVarFields) - { - LUAU_ASSERT(owningArena == rhs.owningArena); - LUAU_ASSERT(!rhs.persistent); + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); - reassign(rhs); - } - else - { - ty = rhs.ty; - persistent = rhs.persistent; - normal = rhs.normal; - owningArena = rhs.owningArena; - } + reassign(rhs); return *this; } @@ -676,10 +693,14 @@ static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persist static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}; static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; +static TypeVar unknownType_{UnknownTypeVar{}, /*persistent*/ true}; +static TypeVar neverType_{NeverTypeVar{}, /*persistent*/ true}; static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; -static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; -static TypePackVar errorTypePack_{Unifiable::Error{}}; +static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, /*persistent*/ true}; +static TypePackVar errorTypePack_{Unifiable::Error{}, /*persistent*/ true}; +static TypePackVar neverTypePack_{VariadicTypePack{&neverType_}, /*persistent*/ true}; +static TypePackVar uninhabitableTypePack_{TypePack{{&neverType_}, &neverTypePack_}, /*persistent*/ true}; SingletonTypes::SingletonTypes() : nilType(&nilType_) @@ -690,7 +711,11 @@ SingletonTypes::SingletonTypes() , trueType(&trueType_) , falseType(&falseType_) , anyType(&anyType_) + , unknownType(&unknownType_) + , neverType(&neverType_) , anyTypePack(&anyTypePack_) + , neverTypePack(&neverTypePack_) + , uninhabitableTypePack(&uninhabitableTypePack_) , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); @@ -738,6 +763,7 @@ TypeId SingletonTypes::makeStringMetatable() const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); const TypeId gmatchFunc = makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, @@ -911,6 +937,8 @@ const TypeLevel* getLevel(TypeId ty) return &ttv->level; else if (auto ftv = get(ty)) return &ftv->level; + else if (auto ctv = get(ty)) + return &ctv->level; else return nullptr; } @@ -965,94 +993,19 @@ bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) return false; } -UnionTypeVarIterator::UnionTypeVarIterator(const UnionTypeVar* utv) +const std::vector& getTypes(const UnionTypeVar* utv) { - LUAU_ASSERT(utv); - - if (!utv->options.empty()) - stack.push_front({utv, 0}); - - seen.insert(utv); + return utv->options; } -UnionTypeVarIterator& UnionTypeVarIterator::operator++() +const std::vector& getTypes(const IntersectionTypeVar* itv) { - advance(); - descend(); - return *this; + return itv->parts; } -UnionTypeVarIterator UnionTypeVarIterator::operator++(int) +const std::vector& getTypes(const ConstrainedTypeVar* ctv) { - UnionTypeVarIterator copy = *this; - ++copy; - return copy; -} - -bool UnionTypeVarIterator::operator!=(const UnionTypeVarIterator& rhs) -{ - return !(*this == rhs); -} - -bool UnionTypeVarIterator::operator==(const UnionTypeVarIterator& rhs) -{ - if (!stack.empty() && !rhs.stack.empty()) - return stack.front() == rhs.stack.front(); - - return stack.empty() && rhs.stack.empty(); -} - -const TypeId& UnionTypeVarIterator::operator*() -{ - LUAU_ASSERT(!stack.empty()); - - descend(); - - auto [utv, currentIndex] = stack.front(); - LUAU_ASSERT(utv); - LUAU_ASSERT(currentIndex < utv->options.size()); - - const TypeId& ty = utv->options[currentIndex]; - LUAU_ASSERT(!get(follow(ty))); - return ty; -} - -void UnionTypeVarIterator::advance() -{ - while (!stack.empty()) - { - auto& [utv, currentIndex] = stack.front(); - ++currentIndex; - - if (currentIndex >= utv->options.size()) - stack.pop_front(); - else - break; - } -} - -void UnionTypeVarIterator::descend() -{ - while (!stack.empty()) - { - auto [utv, currentIndex] = stack.front(); - if (auto innerUnion = get(follow(utv->options[currentIndex]))) - { - // If we're about to descend into a cyclic UnionTypeVar, we should skip over this. - // Ideally this should never happen, but alas it does from time to time. :( - if (seen.find(innerUnion) != seen.end()) - advance(); - else - { - seen.insert(innerUnion); - stack.push_front({innerUnion, 0}); - } - - continue; - } - - break; - } + return ctv->parts; } UnionTypeVarIterator begin(const UnionTypeVar* utv) @@ -1065,6 +1018,27 @@ UnionTypeVarIterator end(const UnionTypeVar* utv) return UnionTypeVarIterator{}; } +IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv) +{ + return IntersectionTypeVarIterator{itv}; +} + +IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) +{ + return IntersectionTypeVarIterator{}; +} + +ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv) +{ + return ConstrainedTypeVarIterator{ctv}; +} + +ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv) +{ + return ConstrainedTypeVarIterator{}; +} + + static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs"; @@ -1144,6 +1118,101 @@ std::optional> magicFunctionFormat( return WithPredicate{arena.addTypePack({typechecker.stringType})}; } +static std::vector parsePatternString(TypeChecker& typechecker, const char* data, size_t size) +{ + std::vector result; + int depth = 0; + bool parsingSet = false; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + ++i; + if (!parsingSet && i < size && data[i] == 'b') + i += 2; + } + else if (!parsingSet && data[i] == '[') + { + parsingSet = true; + if (i + 1 < size && data[i + 1] == ']') + i += 1; + } + else if (parsingSet && data[i] == ']') + { + parsingSet = false; + } + else if (data[i] == '(') + { + if (parsingSet) + continue; + + if (i + 1 < size && data[i + 1] == ')') + { + i++; + result.push_back(typechecker.numberType); + continue; + } + + ++depth; + result.push_back(typechecker.stringType); + } + else if (data[i] == ')') + { + if (parsingSet) + continue; + + --depth; + + if (depth < 0) + break; + } + } + + if (depth != 0 || parsingSet) + return std::vector(); + + if (result.empty()) + result.push_back(typechecker.stringType); + + return result; +} + +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + if (!FFlag::LuauDeduceGmatchReturnTypes) + return std::nullopt; + + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() != 2) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t index = expr.self ? 0 : 1; + if (expr.args.size > index) + pattern = expr.args.data[index]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); + + const TypePackId emptyPack = arena.addTypePack({}); + const TypePackId returnList = arena.addTypePack(returnTypes); + const TypeId iteratorType = arena.addType(FunctionTypeVar{emptyPack, returnList}); + return WithPredicate{arena.addTypePack({iteratorType})}; +} + std::vector filterMap(TypeId type, TypeIdPredicate predicate) { type = follow(type); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 0792a350..44a3b854 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau @@ -47,33 +48,6 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor } } - // TODO cycle and operator() need to be clipped when FFlagLuauUseVisitRecursionLimit is clipped - template - void cycle(TID) - { - } - template - bool operator()(TID ty, const T&) - { - return visit(ty); - } - bool operator()(TypeId ty, const FreeTypeVar& ftv) - { - return visit(ty, ftv); - } - bool operator()(TypeId ty, const FunctionTypeVar& ftv) - { - return visit(ty, ftv); - } - bool operator()(TypeId ty, const TableTypeVar& ttv) - { - return visit(ty, ttv); - } - bool operator()(TypePackId tp, const FreeTypePack& ftp) - { - return visit(tp, ftp); - } - bool visit(TypeId ty) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -103,6 +77,15 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor return true; } + bool visit(TypeId ty, const ConstrainedTypeVar&) override + { + if (!FFlag::LuauUnknownAndNeverType) + return visit(ty); + + promote(ty, log.getMutable(ty)); + return true; + } + bool visit(TypeId ty, const FunctionTypeVar&) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -445,6 +428,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } else if (subFree) { + if (FFlag::LuauUnknownAndNeverType) + { + // Normally, if the subtype is free, it should not be bound to any, unknown, or error types. + // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. + if (get(superTy)) + return; + } + TypeLevel subLevel = subFree->level; occursCheck(subTy, superTy); @@ -468,7 +459,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return; } - if (get(superTy) || get(superTy)) + if (get(superTy) || get(superTy) || get(superTy)) return tryUnifyWithAny(subTy, superTy); if (get(subTy)) @@ -485,6 +476,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (get(subTy)) return tryUnifyWithAny(superTy, subTy); + if (get(subTy)) + return tryUnifyWithAny(superTy, subTy); + auto& cache = sharedState.cachedUnify; // What if the types are immutable and we proved their relation before @@ -1862,6 +1856,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas if (state.log.getMutable(ty)) { + // TODO: Only bind if the anyType isn't any, unknown, or error (?) state.log.replace(ty, BoundTypeVar{anyType}); } else if (auto fun = state.log.getMutable(ty)) @@ -1901,22 +1896,27 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) { - LUAU_ASSERT(get(anyTy) || get(anyTy)); + LUAU_ASSERT(get(anyTy) || get(anyTy) || get(anyTy) || get(anyTy)); // These types are not visited in general loop below if (get(subTy) || get(subTy) || get(subTy)) return; - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); - - const TypePackId anyTP = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + TypePackId anyTp; + if (FFlag::LuauUnknownAndNeverType) + anyTp = types->addTypePack(TypePackVar{VariadicTypePack{anyTy}}); + else + { + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); + anyTp = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + } std::vector queue = {subTy}; sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, getSingletonTypes().anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, FFlag::LuauUnknownAndNeverType ? anyTy : getSingletonTypes().anyType, anyTp); } void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 70c92555..b7fa7889 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -15,7 +15,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) -LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false) LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) @@ -1134,10 +1133,9 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() { - if (options.allowTypeAnnotations && - (lexer.current().type == ':' || (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow))) + if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) { - if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow) + if (lexer.current().type == Lexeme::SkinnyArrow) report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); nextLexeme(); @@ -1373,12 +1371,10 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) if (FFlag::LuauFixNamedFunctionParse && !names.empty()) forceFunctionType = true; - bool returnTypeIntroducer = - FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false; + bool returnTypeIntroducer = lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':'; // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element - if (params.size() == 1 && !varargAnnotation && !forceFunctionType && - (FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow)) + if (params.size() == 1 && !varargAnnotation && !forceFunctionType && !returnTypeIntroducer) { if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) lua_telemetry_parsed_named_non_function_type = true; @@ -1389,8 +1385,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {params[0], {}}; } - if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && !forceFunctionType && - allowPack) + if (!forceFunctionType && !returnTypeIntroducer && allowPack) { if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) lua_telemetry_parsed_named_non_function_type = true; @@ -1409,7 +1404,7 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray' instead of ':'"); lexer.next(); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 4bc8cab5..479eb167 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -8,6 +8,10 @@ #include "FileUtils.h" +#ifdef CALLGRIND +#include +#endif + LUAU_FASTFLAG(DebugLuauTimeTracing) LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution) @@ -112,6 +116,7 @@ static void displayHelp(const char* argv0) printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); + printf(" --mode=strict: default to strict mode when typechecking\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); } @@ -178,9 +183,9 @@ struct CliConfigResolver : Luau::ConfigResolver mutable std::unordered_map configCache; mutable std::vector> configErrors; - CliConfigResolver() + CliConfigResolver(Luau::Mode mode) { - defaultConfig.mode = Luau::Mode::Nonstrict; + defaultConfig.mode = mode; } const Luau::Config& getConfig(const Luau::ModuleName& name) const override @@ -229,6 +234,7 @@ int main(int argc, char** argv) } ReportFormat format = ReportFormat::Default; + Luau::Mode mode = Luau::Mode::Nonstrict; bool annotate = false; for (int i = 1; i < argc; ++i) @@ -240,6 +246,8 @@ int main(int argc, char** argv) format = ReportFormat::Luacheck; else if (strcmp(argv[i], "--formatter=gnu") == 0) format = ReportFormat::Gnu; + else if (strcmp(argv[i], "--mode=strict") == 0) + mode = Luau::Mode::Strict; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; else if (strcmp(argv[i], "--timetrace") == 0) @@ -258,12 +266,16 @@ int main(int argc, char** argv) frontendOptions.retainFullTypeGraphs = annotate; CliFileResolver fileResolver; - CliConfigResolver configResolver; + CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); Luau::registerBuiltinTypes(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); +#ifdef CALLGRIND + CALLGRIND_ZERO_STATS; +#endif + std::vector files = getSourceFiles(argc, argv); int failed = 0; diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 83060f5b..5fe12bec 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -21,6 +21,10 @@ #include #endif +#ifdef CALLGRIND +#include +#endif + #include LUAU_FASTFLAG(DebugLuauTimeTracing) @@ -166,6 +170,36 @@ static int lua_collectgarbage(lua_State* L) luaL_error(L, "collectgarbage must be called with 'count' or 'collect'"); } +#ifdef CALLGRIND +static int lua_callgrind(lua_State* L) +{ + const char* option = luaL_checkstring(L, 1); + + if (strcmp(option, "running") == 0) + { + int r = RUNNING_ON_VALGRIND; + lua_pushboolean(L, r); + return 1; + } + + if (strcmp(option, "zero") == 0) + { + CALLGRIND_ZERO_STATS; + return 0; + } + + if (strcmp(option, "dump") == 0) + { + const char* name = luaL_checkstring(L, 2); + + CALLGRIND_DUMP_STATS_AT(name); + return 0; + } + + luaL_error(L, "callgrind must be called with one of 'running', 'zero', 'dump'"); +} +#endif + void setupState(lua_State* L) { luaL_openlibs(L); @@ -174,6 +208,9 @@ void setupState(lua_State* L) {"loadstring", lua_loadstring}, {"require", lua_require}, {"collectgarbage", lua_collectgarbage}, +#ifdef CALLGRIND + {"callgrind", lua_callgrind}, +#endif {NULL, NULL}, }; diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index c5979d3c..65883b49 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -58,6 +58,9 @@ public: void jmp(Label& label); void jmp(OperandX64 op); + void call(Label& label); + void call(OperandX64 op); + // AVX void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 27e01781..26347225 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -286,11 +286,34 @@ void AssemblyBuilderX64::jmp(OperandX64 op) if (logText) log("jmp", op); + placeRex(op); place(0xff); placeModRegMem(op, 4); commit(); } +void AssemblyBuilderX64::call(Label& label) +{ + place(0xe8); + placeLabel(label); + + if (logText) + log("call", label); + + commit(); +} + +void AssemblyBuilderX64::call(OperandX64 op) +{ + if (logText) + log("call", op); + + placeRex(op); + place(0xff); + placeModRegMem(op, 2); + commit(); +} + void AssemblyBuilderX64::vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2) { placeAvx("vaddpd", dst, src1, src2, 0x58, false, AVX_0F, AVX_66); diff --git a/Makefile b/Makefile index d4ef7c79..a0a52208 100644 --- a/Makefile +++ b/Makefile @@ -93,6 +93,10 @@ ifeq ($(config),fuzz) LDFLAGS+=-fsanitize=address,fuzzer endif +ifneq ($(CALLGRIND),) + CXXFLAGS+=-DCALLGRIND=$(CALLGRIND) +endif + # target-specific flags $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include diff --git a/Sources.cmake b/Sources.cmake index f261cba6..44bed8f7 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -247,12 +247,15 @@ if(TARGET Luau.UnitTest) tests/IostreamOptional.h tests/ScopedFlags.h tests/Fixture.cpp + tests/AssemblyBuilderX64.test.cpp tests/AstQuery.test.cpp tests/AstVisitor.test.cpp tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp tests/Config.test.cpp + tests/ConstraintGraphBuilder.test.cpp + tests/ConstraintSolver.test.cpp tests/CostModel.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp @@ -262,8 +265,7 @@ if(TARGET Luau.UnitTest) tests/Module.test.cpp tests/NonstrictMode.test.cpp tests/Normalize.test.cpp - tests/ConstraintGraphBuilder.test.cpp - tests/ConstraintSolver.test.cpp + tests/NotNull.test.cpp tests/Parser.test.cpp tests/RequireTracer.test.cpp tests/RuntimeLimits.test.cpp @@ -295,11 +297,11 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.tryUnify.test.cpp tests/TypeInfer.typePacks.cpp tests/TypeInfer.unionTypes.test.cpp + tests/TypeInfer.unknownnever.test.cpp tests/TypePack.test.cpp tests/TypeVar.test.cpp tests/Variant.test.cpp tests/VisitTypeVar.test.cpp - tests/AssemblyBuilderX64.test.cpp tests/main.cpp) endif() diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 2316cc3d..79e65919 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -108,9 +108,9 @@ static LuaNode* hashvec(const Table* t, const float* v) memcpy(i, v, sizeof(i)); // convert -0 to 0 to make sure they hash to the same value - i[0] = (i[0] == 0x8000000) ? 0 : i[0]; - i[1] = (i[1] == 0x8000000) ? 0 : i[1]; - i[2] = (i[2] == 0x8000000) ? 0 : i[2]; + i[0] = (i[0] == 0x80000000) ? 0 : i[0]; + i[1] = (i[1] == 0x80000000) ? 0 : i[1]; + i[2] = (i[2] == 0x80000000) ? 0 : i[2]; // scramble bits to make sure that integer coordinates have entropy in lower bits i[0] ^= i[0] >> 17; @@ -121,7 +121,7 @@ static LuaNode* hashvec(const Table* t, const float* v) unsigned int h = (i[0] * 73856093) ^ (i[1] * 19349663) ^ (i[2] * 83492791); #if LUA_VECTOR_SIZE == 4 - i[3] = (i[3] == 0x8000000) ? 0 : i[3]; + i[3] = (i[3] == 0x80000000) ? 0 : i[3]; i[3] ^= i[3] >> 17; h ^= i[3] * 39916801; #endif diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 85829ca1..02b39313 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -640,20 +640,16 @@ static void luau_execute(lua_State* L) VM_PATCH_C(pc - 2, L->cachedslot); VM_NEXT(); } - else - { - // slow-path, may invoke Lua calls via __index metamethod - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - VM_NEXT(); - } - } - else - { - // slow-path, may invoke Lua calls via __index metamethod - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - VM_NEXT(); + + // fall through to slow path } + + // fall through to slow path } + + // slow-path, may invoke Lua calls via __index metamethod + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + VM_NEXT(); } VM_CASE(LOP_SETTABLEKS) @@ -753,19 +749,13 @@ static void luau_execute(lua_State* L) setobj2s(L, ra, &h->array[unsigned(index - 1)]); VM_NEXT(); } - else - { - // slow-path: handles out of bounds array lookups and non-integer numeric keys - VM_PROTECT(luaV_gettable(L, rb, rc, ra)); - VM_NEXT(); - } - } - else - { - // slow-path: handles non-array table lookup as well as __index MT calls - VM_PROTECT(luaV_gettable(L, rb, rc, ra)); - VM_NEXT(); + + // fall through to slow path } + + // slow-path: handles out of bounds array lookups, non-integer numeric keys, non-array table lookup, __index MT calls + VM_PROTECT(luaV_gettable(L, rb, rc, ra)); + VM_NEXT(); } VM_CASE(LOP_SETTABLE) @@ -790,19 +780,13 @@ static void luau_execute(lua_State* L) luaC_barriert(L, h, ra); VM_NEXT(); } - else - { - // slow-path: handles out of bounds array assignments and non-integer numeric keys - VM_PROTECT(luaV_settable(L, rb, rc, ra)); - VM_NEXT(); - } - } - else - { - // slow-path: handles non-array table access as well as __newindex MT calls - VM_PROTECT(luaV_settable(L, rb, rc, ra)); - VM_NEXT(); + + // fall through to slow path } + + // slow-path: handles out of bounds array assignments, non-integer numeric keys, non-array table access, __newindex MT calls + VM_PROTECT(luaV_settable(L, rb, rc, ra)); + VM_NEXT(); } VM_CASE(LOP_GETTABLEN) @@ -822,6 +806,8 @@ static void luau_execute(lua_State* L) setobj2s(L, ra, &h->array[c]); VM_NEXT(); } + + // fall through to slow path } // slow-path: handles out of bounds array lookups @@ -849,6 +835,8 @@ static void luau_execute(lua_State* L) luaC_barriert(L, h, ra); VM_NEXT(); } + + // fall through to slow path } // slow-path: handles out of bounds array lookups @@ -2176,8 +2164,10 @@ static void luau_execute(lua_State* L) if (!ttisnumber(ra + 0) || !ttisnumber(ra + 1) || !ttisnumber(ra + 2)) { // slow-path: can convert arguments to numbers and trigger Lua errors - // Note: this doesn't reallocate stack so we don't need to recompute ra - VM_PROTECT(luau_prepareFORN(L, ra + 0, ra + 1, ra + 2)); + // Note: this doesn't reallocate stack so we don't need to recompute ra/base + VM_PROTECT_PC(); + + luau_prepareFORN(L, ra + 0, ra + 1, ra + 2); } double limit = nvalue(ra + 0); diff --git a/bench/bench.py b/bench/bench.py index 67fc8cf7..bb3ea5f7 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -40,6 +40,7 @@ argumentParser.add_argument('--results', dest='results',type=str,nargs='*',help= argumentParser.add_argument('--run-test', action='store', default=None, help='Regex test filter') argumentParser.add_argument('--extra-loops', action='store',type=int,default=0, help='Amount of times to loop over one test (one test already performs multiple runs)') argumentParser.add_argument('--filename', action='store',type=str,default='bench', help='File name for graph and results file') +argumentParser.add_argument('--callgrind', dest='callgrind',action='store_const',const=1,default=0,help='Use callgrind to run benchmarks') if matplotlib != None: argumentParser.add_argument('--absolute', dest='absolute',action='store_const',const=1,default=0,help='Display absolute values instead of relative (enabled by default when benchmarking a single VM)') @@ -55,6 +56,9 @@ argumentParser.add_argument('--no-print-influx-debugging', action='store_false', argumentParser.add_argument('--no-print-final-summary', action='store_false', dest='print_final_summary', help="Don't print a table summarizing the results after all tests are run") +# Assume 2.5 IPC on a 4 GHz CPU; this is obviously incorrect but it allows us to display simulated instruction counts using regular time units +CALLGRIND_INSN_PER_SEC = 2.5 * 4e9 + def arrayRange(count): result = [] @@ -71,6 +75,21 @@ def arrayRangeOffset(count, offset): return result +def getCallgrindOutput(lines): + result = [] + name = None + + for l in lines: + if l.startswith("desc: Trigger: Client Request: "): + name = l[31:].strip() + elif l.startswith("summary: ") and name != None: + insn = int(l[9:]) + # Note: we only run each bench once under callgrind so we only report a single time per run; callgrind instruction count variance is ~0.01% so it might as well be zero + result += "|><|" + name + "|><|" + str(insn / CALLGRIND_INSN_PER_SEC * 1000.0) + "||_||" + name = None + + return "".join(result) + def getVmOutput(cmd): if os.name == "nt": try: @@ -79,6 +98,16 @@ def getVmOutput(cmd): exit(1) except: return "" + elif arguments.callgrind: + try: + subprocess.check_call("valgrind --tool=callgrind --callgrind-out-file=callgrind.out --combine-dumps=yes --dump-line=no " + cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, cwd=scriptdir) + path = os.path.join(scriptdir, "callgrind.out") + with open(path, "r") as file: + lines = file.readlines() + os.unlink(path) + return getCallgrindOutput(lines) + except: + return "" else: with subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd=scriptdir) as p: # Try to lock to a single processor @@ -375,12 +404,12 @@ def analyzeResult(subdir, main, comparisons): continue - pooledStdDev = math.sqrt((main.unbiasedEst + compare.unbiasedEst) / 2) + if main.count > 1 and stats: + pooledStdDev = math.sqrt((main.unbiasedEst + compare.unbiasedEst) / 2) - tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) - degreesOfFreedom = 2 * main.count - 2 + tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) + degreesOfFreedom = 2 * main.count - 2 - if stats: # Two-tailed distribution with 95% conf. tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) diff --git a/bench/bench_support.lua b/bench/bench_support.lua index 171b8da7..a9608ecc 100644 --- a/bench/bench_support.lua +++ b/bench/bench_support.lua @@ -5,6 +5,16 @@ bench.runs = 20 bench.extraRuns = 4 function bench.runCode(f, description) + -- Under Callgrind, run the test only once and measure just the execution cost + if callgrind and callgrind("running") then + if collectgarbage then collectgarbage() end + + callgrind("zero") + f() -- unfortunately we can't easily separate setup cost from runtime cost in f unless it calls callgrind() + callgrind("dump", description) + return + end + local timeTable = {} for i = 1,bench.runs + bench.extraRuns do diff --git a/bench/other/LuauPolyfillMap.lua b/bench/other/LuauPolyfillMap.lua new file mode 100644 index 00000000..1f957d47 --- /dev/null +++ b/bench/other/LuauPolyfillMap.lua @@ -0,0 +1,961 @@ +-- This file is part of the Roblox luau-polyfill repository and is licensed under MIT License; see LICENSE.txt for details +-- #region Array +-- Array related +local Array = {} +local Object = {} +local Map = {} + +type Array = { [number]: T } +type callbackFn = (element: V, key: K, map: Map) -> () +type callbackFnWithThisArg = (thisArg: Object, value: V, key: K, map: Map) -> () +type Map = { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + forEach: (self: Map, callback: callbackFn | callbackFnWithThisArg, thisArg: Object?) -> (), + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + [K]: V, + _map: { [K]: V }, + _array: { [number]: K }, +} +type mapFn = (element: T, index: number) -> U +type mapFnWithThisArg = (thisArg: any, element: T, index: number) -> U +type Object = { [string]: any } +type Table = { [T]: V } +type Tuple = Array + +local Set = {} + +-- #region Array +function Array.isArray(value: any): boolean + if typeof(value) ~= "table" then + return false + end + if next(value) == nil then + -- an empty table is an empty array + return true + end + + local length = #value + + if length == 0 then + return false + end + + local count = 0 + local sum = 0 + for key in pairs(value) do + if typeof(key) ~= "number" then + return false + end + if key % 1 ~= 0 or key < 1 then + return false + end + count += 1 + sum += key + end + + return sum == (count * (count + 1) / 2) +end + +function Array.from( + value: string | Array | Object, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? +): Array + if value == nil then + error("cannot create array from a nil value") + end + local valueType = typeof(value) + + local array = {} + + if valueType == "table" and Array.isArray(value) then + if mapFn then + for i = 1, #(value :: Array) do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, (value :: Array)[i], i) + else + array[i] = (mapFn :: mapFn)((value :: Array)[i], i) + end + end + else + for i = 1, #(value :: Array) do + array[i] = (value :: Array)[i] + end + end + elseif instanceOf(value, Set) then + if mapFn then + for i, v in (value :: any):ipairs() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, v, i) + else + array[i] = (mapFn :: mapFn)(v, i) + end + end + else + for i, v in (value :: any):ipairs() do + array[i] = v + end + end + elseif instanceOf(value, Map) then + if mapFn then + for i, v in (value :: any):ipairs() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, v, i) + else + array[i] = (mapFn :: mapFn)(v, i) + end + end + else + for i, v in (value :: any):ipairs() do + array[i] = v + end + end + elseif valueType == "string" then + if mapFn then + for i = 1, (value :: string):len() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, (value :: any):sub(i, i), i) + else + array[i] = (mapFn :: mapFn)((value :: any):sub(i, i), i) + end + end + else + for i = 1, (value :: string):len() do + array[i] = (value :: any):sub(i, i) + end + end + end + + return array +end + +type callbackFnArrayMap = (element: T, index: number, array: Array) -> U +type callbackFnWithThisArgArrayMap = (thisArg: V, element: T, index: number, array: Array) -> U + +-- Implements Javascript's `Array.prototype.map` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/map +function Array.map( + t: Array, + callback: callbackFnArrayMap | callbackFnWithThisArgArrayMap, + thisArg: V? +): Array + if typeof(t) ~= "table" then + error(string.format("Array.map called on %s", typeof(t))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local len = #t + local A = {} + local k = 1 + + while k <= len do + local kValue = t[k] + + if kValue ~= nil then + local mappedValue + + if thisArg ~= nil then + mappedValue = (callback :: callbackFnWithThisArgArrayMap)(thisArg, kValue, k, t) + else + mappedValue = (callback :: callbackFnArrayMap)(kValue, k, t) + end + + A[k] = mappedValue + end + k += 1 + end + + return A +end + +type Function = (any, any, number, any) -> any + +-- Implements Javascript's `Array.prototype.reduce` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/reduce +function Array.reduce(array: Array, callback: Function, initialValue: any?): any + if typeof(array) ~= "table" then + error(string.format("Array.reduce called on %s", typeof(array))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local length = #array + + local value + local initial = 1 + + if initialValue ~= nil then + value = initialValue + else + initial = 2 + if length == 0 then + error("reduce of empty array with no initial value") + end + value = array[1] + end + + for i = initial, length do + value = callback(value, array[i], i, array) + end + + return value +end + +type callbackFnArrayForEach = (element: T, index: number, array: Array) -> () +type callbackFnWithThisArgArrayForEach = (thisArg: U, element: T, index: number, array: Array) -> () + +-- Implements Javascript's `Array.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/forEach +function Array.forEach( + t: Array, + callback: callbackFnArrayForEach | callbackFnWithThisArgArrayForEach, + thisArg: U? +): () + if typeof(t) ~= "table" then + error(string.format("Array.forEach called on %s", typeof(t))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local len = #t + local k = 1 + + while k <= len do + local kValue = t[k] + + if thisArg ~= nil then + (callback :: callbackFnWithThisArgArrayForEach)(thisArg, kValue, k, t) + else + (callback :: callbackFnArrayForEach)(kValue, k, t) + end + + if #t < len then + -- don't iterate on removed items, don't iterate more than original length + len = #t + end + k += 1 + end +end +-- #endregion + +-- #region Set +Set.__index = Set + +type callbackFnSet = (value: T, key: T, set: Set) -> () +type callbackFnWithThisArgSet = (thisArg: Object, value: T, key: T, set: Set) -> () + +export type Set = { + size: number, + -- method definitions + add: (self: Set, T) -> Set, + clear: (self: Set) -> (), + delete: (self: Set, T) -> boolean, + forEach: (self: Set, callback: callbackFnSet | callbackFnWithThisArgSet, thisArg: Object?) -> (), + has: (self: Set, T) -> boolean, + ipairs: (self: Set) -> any, +} + +type Iterable = { ipairs: (any) -> any } + +function Set.new(iterable: Array | Set | Iterable | string | nil): Set + local array = {} + local map = {} + if iterable ~= nil then + local arrayIterable: Array + -- ROBLOX TODO: remove type casting from (iterable :: any).ipairs in next release + if typeof(iterable) == "table" then + if Array.isArray(iterable) then + arrayIterable = Array.from(iterable :: Array) + elseif typeof((iterable :: Iterable).ipairs) == "function" then + -- handle in loop below + elseif _G.__DEV__ then + error("cannot create array from an object-like table") + end + elseif typeof(iterable) == "string" then + arrayIterable = Array.from(iterable :: string) + else + error(("cannot create array from value of type `%s`"):format(typeof(iterable))) + end + + if arrayIterable then + for _, element in ipairs(arrayIterable) do + if not map[element] then + map[element] = true + table.insert(array, element) + end + end + elseif typeof(iterable) == "table" and typeof((iterable :: Iterable).ipairs) == "function" then + for _, element in (iterable :: Iterable):ipairs() do + if not map[element] then + map[element] = true + table.insert(array, element) + end + end + end + end + + return (setmetatable({ + size = #array, + _map = map, + _array = array, + }, Set) :: any) :: Set +end + +function Set:add(value) + if not self._map[value] then + -- Luau FIXME: analyze should know self is Set which includes size as a number + self.size = self.size :: number + 1 + self._map[value] = true + table.insert(self._array, value) + end + return self +end + +function Set:clear() + self.size = 0 + table.clear(self._map) + table.clear(self._array) +end + +function Set:delete(value): boolean + if not self._map[value] then + return false + end + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number - 1 + self._map[value] = nil + local index = table.find(self._array, value) + if index then + table.remove(self._array, index) + end + return true +end + +-- Implements Javascript's `Map.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Set/forEach +function Set:forEach(callback: callbackFnSet | callbackFnWithThisArgSet, thisArg: Object?): () + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + return Array.forEach(self._array, function(value: T) + if thisArg ~= nil then + (callback :: callbackFnWithThisArgSet)(thisArg, value, value, self) + else + (callback :: callbackFnSet)(value, value, self) + end + end) +end + +function Set:has(value): boolean + return self._map[value] ~= nil +end + +function Set:ipairs() + return ipairs(self._array) +end + +-- #endregion Set + +-- #region Object +function Object.entries(value: string | Object | Array): Array + assert(value :: any ~= nil, "cannot get entries from a nil value") + local valueType = typeof(value) + + local entries: Array> = {} + if valueType == "table" then + for key, keyValue in pairs(value :: Object) do + -- Luau FIXME: Luau should see entries as Array, given object is [string]: any, but it sees it as Array> despite all the manual annotation + table.insert(entries, { key :: string, keyValue :: any }) + end + elseif valueType == "string" then + for i = 1, string.len(value :: string) do + entries[i] = { tostring(i), string.sub(value :: string, i, i) } + end + end + + return entries +end + +-- #endregion + +-- #region instanceOf + +-- ROBLOX note: Typed tbl as any to work with strict type analyze +-- polyfill for https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/instanceof +function instanceOf(tbl: any, class) + assert(typeof(class) == "table", "Received a non-table as the second argument for instanceof") + + if typeof(tbl) ~= "table" then + return false + end + + local ok, hasNew = pcall(function() + return class.new ~= nil and tbl.new == class.new + end) + if ok and hasNew then + return true + end + + local seen = { tbl = true } + + while tbl and typeof(tbl) == "table" do + tbl = getmetatable(tbl) + if typeof(tbl) == "table" then + tbl = tbl.__index + + if tbl == class then + return true + end + end + + -- if we still have a valid table then check against seen + if typeof(tbl) == "table" then + if seen[tbl] then + return false + end + seen[tbl] = true + end + end + + return false +end +-- #endregion + +function Map.new(iterable: Array>?): Map + local array = {} + local map = {} + if iterable ~= nil then + local arrayFromIterable + local iterableType = typeof(iterable) + if iterableType == "table" then + if #iterable > 0 and typeof(iterable[1]) ~= "table" then + error("cannot create Map from {K, V} form, it must be { {K, V}... }") + end + + arrayFromIterable = Array.from(iterable) + else + error(("cannot create array from value of type `%s`"):format(iterableType)) + end + + for _, entry in ipairs(arrayFromIterable) do + local key = entry[1] + if _G.__DEV__ then + if key == nil then + error("cannot create Map from a table that isn't an array.") + end + end + local val = entry[2] + -- only add to array if new + if map[key] == nil then + table.insert(array, key) + end + -- always assign + map[key] = val + end + end + + return (setmetatable({ + size = #array, + _map = map, + _array = array, + }, Map) :: any) :: Map +end + +function Map:set(key: K, value: V): Map + -- preserve initial insertion order + if self._map[key] == nil then + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number + 1 + table.insert(self._array, key) + end + -- always update value + self._map[key] = value + return self +end + +function Map:get(key) + return self._map[key] +end + +function Map:clear() + local table_: any = table + self.size = 0 + table_.clear(self._map) + table_.clear(self._array) +end + +function Map:delete(key): boolean + if self._map[key] == nil then + return false + end + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number - 1 + self._map[key] = nil + local index = table.find(self._array, key) + if index then + table.remove(self._array, index) + end + return true +end + +-- Implements Javascript's `Map.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Map/forEach +function Map:forEach(callback: callbackFn | callbackFnWithThisArg, thisArg: Object?): () + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + return Array.forEach(self._array, function(key: K) + local value: V = self._map[key] :: V + + if thisArg ~= nil then + (callback :: callbackFnWithThisArg)(thisArg, value, key, self) + else + (callback :: callbackFn)(value, key, self) + end + end) +end + +function Map:has(key): boolean + return self._map[key] ~= nil +end + +function Map:keys() + return self._array +end + +function Map:values() + return Array.map(self._array, function(key) + return self._map[key] + end) +end + +function Map:entries() + return Array.map(self._array, function(key) + return { key, self._map[key] } + end) +end + +function Map:ipairs() + return ipairs(self:entries()) +end + +function Map.__index(self, key) + local mapProp = rawget(Map, key) + if mapProp ~= nil then + return mapProp + end + + return Map.get(self, key) +end + +function Map.__newindex(table_, key, value) + table_:set(key, value) +end + +local function coerceToMap(mapLike: Map | Table): Map + return instanceOf(mapLike, Map) and mapLike :: Map -- ROBLOX: order is preservered + or Map.new(Object.entries(mapLike)) -- ROBLOX: order is not preserved +end + +-- local function coerceToTable(mapLike: Map | Table): Table +-- if not instanceOf(mapLike, Map) then +-- return mapLike +-- end + +-- -- create table from map +-- return Array.reduce(mapLike:entries(), function(tbl, entry) +-- tbl[entry[1]] = entry[2] +-- return tbl +-- end, {}) +-- end + +-- #region Tests to verify it works as expected +local function it(description: string, fn: () -> ()) + local ok, result = pcall(fn) + + if not ok then + error("Failed test: " .. description .. "\n" .. result) + end +end + +local AN_ITEM = "bar" +local ANOTHER_ITEM = "baz" + +-- #region [Describe] "Map" +-- #region [Child Describe] "constructors" +it("creates an empty array", function() + local foo = Map.new() + assert(foo.size == 0) +end) + +it("creates a Map from an array", function() + local foo = Map.new({ + { AN_ITEM, "foo" }, + { ANOTHER_ITEM, "val" }, + }) + assert(foo.size == 2) + assert(foo:has(AN_ITEM) == true) + assert(foo:has(ANOTHER_ITEM) == true) +end) + +it("creates a Map from an array with duplicate keys", function() + local foo = Map.new({ + { AN_ITEM, "foo1" }, + { AN_ITEM, "foo2" }, + }) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == "foo2") + + assert(#foo:keys() == 1 and foo:keys()[1] == AN_ITEM) + assert(#foo:values() == 1 and foo:values()[1] == "foo2") + assert(#foo:entries() == 1) + assert(#foo:entries()[1] == 2) + + assert(foo:entries()[1][1] == AN_ITEM) + assert(foo:entries()[1][2] == "foo2") +end) + +it("preserves the order of keys first assignment", function() + local foo = Map.new({ + { AN_ITEM, "foo1" }, + { ANOTHER_ITEM, "bar" }, + { AN_ITEM, "foo2" }, + }) + assert(foo.size == 2) + assert(foo:get(AN_ITEM) == "foo2") + assert(foo:get(ANOTHER_ITEM) == "bar") + + assert(foo:keys()[1] == AN_ITEM) + assert(foo:keys()[2] == ANOTHER_ITEM) + assert(foo:values()[1] == "foo2") + assert(foo:values()[2] == "bar") + assert(foo:entries()[1][1] == AN_ITEM) + assert(foo:entries()[1][2] == "foo2") + assert(foo:entries()[2][1] == ANOTHER_ITEM) + assert(foo:entries()[2][2] == "bar") +end) +-- #endregion + +-- #region [Child Describe] "type" +it("instanceOf return true for an actual Map object", function() + local foo = Map.new() + assert(instanceOf(foo, Map) == true) +end) + +it("instanceOf return false for an regular plain object", function() + local foo = {} + assert(instanceOf(foo, Map) == false) +end) +-- #endregion + +-- #region [Child Describe] "set" +it("returns the Map object", function() + local foo = Map.new() + assert(foo:set(1, "baz") == foo) +end) + +it("increments the size if the element is added for the first time", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo.size == 1) +end) + +it("does not increment the size the second time an element is added", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(AN_ITEM, "val") + assert(foo.size == 1) +end) + +it("sets values correctly to true/false", function() + -- Luau FIXME: Luau insists that arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + foo:set(AN_ITEM, false) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == false) + + foo:set(AN_ITEM, true) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == true) + + foo:set(AN_ITEM, false) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == false) +end) + +-- #endregion + +-- #region [Child Describe] "get" +it("returns value of item from provided key", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:get(AN_ITEM) == "foo") +end) + +it("returns nil if the item is not in the Map", function() + local foo = Map.new() + assert(foo:get(AN_ITEM) == nil) +end) +-- #endregion + +-- #region [Child Describe] "clear" +it("sets the size to zero", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:clear() + assert(foo.size == 0) +end) + +it("removes the items from the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:clear() + assert(foo:has(AN_ITEM) == false) +end) +-- #endregion + +-- #region [Child Describe] "delete" +it("removes the items from the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(AN_ITEM) + assert(foo:has(AN_ITEM) == false) +end) + +it("returns true if the item was in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:delete(AN_ITEM) == true) +end) + +it("returns false if the item was not in the Map", function() + local foo = Map.new() + assert(foo:delete(AN_ITEM) == false) +end) + +it("decrements the size if the item was in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(AN_ITEM) + assert(foo.size == 0) +end) + +it("does not decrement the size if the item was not in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(ANOTHER_ITEM) + assert(foo.size == 1) +end) + +it("deletes value set to false", function() + -- Luau FIXME: Luau insists arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + + foo:delete(AN_ITEM) + + assert(foo.size == 0) + assert(foo:get(AN_ITEM) == nil) +end) +-- #endregion + +-- #region [Child Describe] "has" +it("returns true if the item is in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:has(AN_ITEM) == true) +end) + +it("returns false if the item is not in the Map", function() + local foo = Map.new() + assert(foo:has(AN_ITEM) == false) +end) + +it("returns correctly with value set to false", function() + -- Luau FIXME: Luau insists arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + + assert(foo:has(AN_ITEM) == true) +end) +-- #endregion + +-- #region [Child Describe] "keys / values / entries" +it("returns array of elements", function() + local myMap = Map.new() + myMap:set(AN_ITEM, "foo") + myMap:set(ANOTHER_ITEM, "val") + + assert(myMap:keys()[1] == AN_ITEM) + assert(myMap:keys()[2] == ANOTHER_ITEM) + + assert(myMap:values()[1] == "foo") + assert(myMap:values()[2] == "val") + + assert(myMap:entries()[1][1] == AN_ITEM) + assert(myMap:entries()[1][2] == "foo") + assert(myMap:entries()[2][1] == ANOTHER_ITEM) + assert(myMap:entries()[2][2] == "val") +end) +-- #endregion + +-- #region [Child Describe] "__index" +it("can access fields directly without using get", function() + local typeName = "size" + + local foo = Map.new({ + { AN_ITEM, "foo" }, + { ANOTHER_ITEM, "val" }, + { typeName, "buzz" }, + }) + + assert(foo.size == 3) + assert(foo[AN_ITEM] == "foo") + assert(foo[ANOTHER_ITEM] == "val") + assert(foo:get(typeName) == "buzz") +end) +-- #endregion + +-- #region [Child Describe] "__newindex" +it("can set fields directly without using set", function() + local foo = Map.new() + + assert(foo.size == 0) + + foo[AN_ITEM] = "foo" + foo[ANOTHER_ITEM] = "val" + foo.fizz = "buzz" + + assert(foo.size == 3) + assert(foo:get(AN_ITEM) == "foo") + assert(foo:get(ANOTHER_ITEM) == "val") + assert(foo:get("fizz") == "buzz") +end) +-- #endregion + +-- #region [Child Describe] "ipairs" +local function makeArray(...) + local array = {} + for _, item in ... do + table.insert(array, item) + end + return array +end + +it("iterates on the elements by their insertion order", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + assert(makeArray(foo:ipairs())[1][1] == AN_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "foo") + assert(makeArray(foo:ipairs())[2][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[2][2] == "val") +end) + +it("does not iterate on removed elements", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + foo:delete(AN_ITEM) + assert(makeArray(foo:ipairs())[1][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "val") +end) + +it("iterates on elements if the added back to the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + foo:delete(AN_ITEM) + foo:set(AN_ITEM, "food") + assert(makeArray(foo:ipairs())[1][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "val") + assert(makeArray(foo:ipairs())[2][1] == AN_ITEM) + assert(makeArray(foo:ipairs())[2][2] == "food") +end) +-- #endregion + +-- #region [Child Describe] "Integration Tests" +-- it("MDN Examples", function() +-- local myMap = Map.new() :: Map + +-- local keyString = "a string" +-- local keyObj = {} +-- local keyFunc = function() end + +-- -- setting the values +-- myMap:set(keyString, "value associated with 'a string'") +-- myMap:set(keyObj, "value associated with keyObj") +-- myMap:set(keyFunc, "value associated with keyFunc") + +-- assert(myMap.size == 3) + +-- -- getting the values +-- assert(myMap:get(keyString) == "value associated with 'a string'") +-- assert(myMap:get(keyObj) == "value associated with keyObj") +-- assert(myMap:get(keyFunc) == "value associated with keyFunc") + +-- assert(myMap:get("a string") == "value associated with 'a string'") + +-- assert(myMap:get({}) == nil) -- nil, because keyObj !== {} +-- assert(myMap:get(function() -- nil because keyFunc !== function () {} +-- end) == nil) +-- end) + +it("handles non-traditional keys", function() + local myMap = Map.new() :: Map + + local falseKey = false + local trueKey = true + local negativeKey = -1 + local emptyKey = "" + + myMap:set(falseKey, "apple") + myMap:set(trueKey, "bear") + myMap:set(negativeKey, "corgi") + myMap:set(emptyKey, "doge") + + assert(myMap.size == 4) + + assert(myMap:get(falseKey) == "apple") + assert(myMap:get(trueKey) == "bear") + assert(myMap:get(negativeKey) == "corgi") + assert(myMap:get(emptyKey) == "doge") + + myMap:delete(falseKey) + myMap:delete(trueKey) + myMap:delete(negativeKey) + myMap:delete(emptyKey) + + assert(myMap.size == 0) +end) +-- #endregion + +-- #endregion [Describe] "Map" + +-- #region [Describe] "coerceToMap" +it("returns the same object if instance of Map", function() + local map = Map.new() + assert(coerceToMap(map) == map) + + map = Map.new({}) + assert(coerceToMap(map) == map) + + map = Map.new({ { AN_ITEM, "foo" } }) + assert(coerceToMap(map) == map) +end) +-- #endregion [Describe] "coerceToMap" + +-- #endregion Tests to verify it works as expected diff --git a/bench/other/regex.lua b/bench/other/regex.lua new file mode 100644 index 00000000..eb659a5e --- /dev/null +++ b/bench/other/regex.lua @@ -0,0 +1,2089 @@ +--[[ + PCRE2-based RegEx implemention for Luau + Version 1.0.0a2 + BSD 2-Clause Licence + Copyright © 2020 - Blockzez (devforum /u/Blockzez and github.com/Blockzez) + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +]] +--[[ Settings ]]-- +-- You can change them here +local options = { + -- The maximum cache size for regex so the patterns are cached so it doesn't recompile the pattern + -- The only accepted value are number values >= 0, strings that can be automatically coered to numbers that are >= 0, false and nil + -- Do note that empty regex patterns (comment-only patterns included) are never cached regardless + -- The default is 256 + cacheSize = 256, + + -- A boolean that determines whether this use unicode data + -- If this value evalulates to false, you can remove _unicodechar_category, _scripts and _xuc safely and it'll now error if: + -- - You try to compile a RegEx with unicode flag + -- - You try to use the \p pattern + -- The default is true + unicodeData = false, +}; + +-- +local u_categories = options.unicodeData and require(script:WaitForChild("_unicodechar_category")); +local chr_scripts = options.unicodeData and require(script:WaitForChild("_scripts")); +local xuc_chr = options.unicodeData and require(script:WaitForChild("_xuc")); +local proxy = setmetatable({ }, { __mode = 'k' }); +local re, re_m, match_m = { }, { }, { }; +local lockmsg; + +--[[ Functions ]]-- +local function to_str_arr(self, init) + if init then + self = string.sub(self, utf8.offset(self, init)); + end; + local len = utf8.len(self); + if len <= 1999 then + return { n = len, s = self, utf8.codepoint(self, 1, #self) }; + end; + local clen = math.ceil(len / 1999); + local ret = table.create(len); + local p = 1; + for i = 1, clen do + local c = table.pack(utf8.codepoint(self, utf8.offset(self, i * 1999 - 1998), utf8.offset(self, i * 1999 - (i == clen and 1998 - ((len - 1) % 1999 + 1) or - 1)) - 1)); + table.move(c, 1, c.n, p, ret); + p += c.n; + end; + ret.s, ret.n = self, len; + return ret; +end; + +local function from_str_arr(self) + local len = self.n or #self; + if len <= 7997 then + return utf8.char(table.unpack(self)); + end; + local clen = math.ceil(len / 7997); + local r = table.create(clen); + for i = 1, clen do + r[i] = utf8.char(table.unpack(self, i * 7997 - 7996, i * 7997 - (i == clen and 7997 - ((len - 1) % 7997 + 1) or 0))); + end; + return table.concat(r); +end; + +local function utf8_sub(self, i, j) + j = utf8.offset(self, j); + return string.sub(self, utf8.offset(self, i), j and j - 1); +end; + +-- +local flag_map = { + a = 'anchored', i = 'caseless', m = 'multiline', s = 'dotall', u = 'unicode', U = 'ungreedy', x ='extended', +}; + +local posix_class_names = { + alnum = true, alpha = true, ascii = true, blank = true, cntrl = true, digit = true, graph = true, lower = true, print = true, punct = true, space = true, upper = true, word = true, xdigit = true, +}; + +local escape_chars = { + -- grouped + -- digit, spaces and words + [0x44] = { "class", "digit", true }, [0x53] = { "class", "space", true }, [0x57] = { "class", "word", true }, + [0x64] = { "class", "digit", false }, [0x73] = { "class", "space", false }, [0x77] = { "class", "word", false }, + -- horizontal/vertical whitespace and newline + [0x48] = { "class", "blank", true }, [0x56] = { "class", "vertical_tab", true }, + [0x68] = { "class", "blank", false }, [0x76] = { "class", "vertical_tab", false }, + [0x4E] = { 0x4E }, [0x52] = { 0x52 }, + + -- not grouped + [0x42] = 0x08, + [0x6E] = 0x0A, [0x72] = 0x0D, [0x74] = 0x09, +}; + +local b_escape_chars = { + -- word boundary and not word boundary + [0x62] = { 0x62, { "class", "word", false } }, [0x42] = { 0x42, { "class", "word", false } }, + + -- keep match out + [0x4B] = { 0x4B }, + + -- start & end of string + [0x47] = { 0x47 }, [0x4A] = { 0x4A }, [0x5A] = { 0x5A }, [0x7A] = { 0x7A }, +}; + +local valid_categories = { + C = true, Cc = true, Cf = true, Cn = true, Co = true, Cs = true, + L = true, Ll = true, Lm = true, Lo = true, Lt = true, Lu = true, + M = true, Mc = true, Me = true, Mn = true, + N = true, Nd = true, Nl = true, No = true, + P = true, Pc = true, Pd = true, Pe = true, Pf = true, Pi = true, Po = true, Ps = true, + S = true, Sc = true, Sk = true, Sm = true, So = true, + Z = true, Zl = true, Zp = true, Zs = true, + + Xan = true, Xps = true, Xsp = true, Xuc = true, Xwd = true, +}; + +local class_ascii_punct = { + [0x21] = true, [0x22] = true, [0x23] = true, [0x24] = true, [0x25] = true, [0x26] = true, [0x27] = true, [0x28] = true, [0x29] = true, [0x2A] = true, [0x2B] = true, [0x2C] = true, [0x2D] = true, [0x2E] = true, [0x2F] = true, + [0x3A] = true, [0x3B] = true, [0x3C] = true, [0x3D] = true, [0x3E] = true, [0x3F] = true, [0x40] = true, [0x5B] = true, [0x5C] = true, [0x5D] = true, [0x5E] = true, [0x5F] = true, [0x60] = true, [0x7B] = true, [0x7C] = true, + [0x7D] = true, [0x7E] = true, +}; + +local end_str = { 0x24 }; +local dot = { 0x2E }; +local beginning_str = { 0x5E }; +local alternation = { 0x7C }; + +local function check_re(re_type, name, func) + if re_type == "Match" then + return function(...) + local arg_n = select('#', ...); + if arg_n < 1 then + error("missing argument #1 (Match expected)", 2); + end; + local arg0, arg1 = ...; + if not (proxy[arg0] and proxy[arg0].name == "Match") then + error(string.format("invalid argument #1 to %q (Match expected, got %s)", name, typeof(arg0)), 2); + else + arg0 = proxy[arg0]; + end; + if name == "group" or name == "span" then + if arg1 == nil then + arg1 = 0; + end; + end; + return func(arg0, arg1); + end; + end; + return function(...) + local arg_n = select('#', ...); + if arg_n < 1 then + error("missing argument #1 (RegEx expected)", 2); + elseif arg_n < 2 then + error("missing argument #2 (string expected)", 2); + end; + local arg0, arg1, arg2, arg3, arg4, arg5 = ...; + if not (proxy[arg0] and proxy[arg0].name == "RegEx") then + if type(arg0) ~= "string" and type(arg0) ~= "number" then + error(string.format("invalid argument #1 to %q (RegEx expected, got %s)", name, typeof(arg0)), 2); + end; + arg0 = re.fromstring(arg0); + elseif name == "sub" then + if type(arg2) == "number" then + arg2 ..= ''; + elseif type(arg2) ~= "string" then + error(string.format("invalid argument #3 to 'sub' (string expected, got %s)", typeof(arg2)), 2); + end; + elseif type(arg1) == "number" then + arg1 ..= ''; + elseif type(arg1) ~= "string" then + error(string.format("invalid argument #2 to %q (string expected, got %s)", name, typeof(arg1)), 2); + end; + if name ~= "sub" and name ~= "split" then + local init_type = typeof(arg2); + if init_type ~= 'nil' then + arg2 = tonumber(arg2); + if not arg2 then + error(string.format("invalid argument #3 to %q (number expected, got %s)", name, init_type), 2); + elseif arg2 < 0 then + arg2 = #arg1 + math.floor(arg2 + 0.5) + 1; + else + arg2 = math.max(math.floor(arg2 + 0.5), 1); + end; + end; + end; + arg0 = proxy[arg0]; + if name == "match" or name == "matchiter" then + arg3 = ...; + elseif name == "sub" then + arg5 = ...; + end; + return func(arg0, arg1, arg2, arg3, arg4, arg5); + end; +end; + +--[[ Matches ]]-- +local function match_tostr(self) + local spans = proxy[self].spans; + local s_start, s_end = spans[0][1], spans[0][2]; + if s_end <= s_start then + return string.format("Match (%d..%d, empty)", s_start, s_end - 1); + end; + return string.format("Match (%d..%d): %s", s_start, s_end - 1, utf8_sub(spans.input, s_start, s_end)); +end; + +local function new_match(span_arr, group_id, re, str) + span_arr.source, span_arr.input = re, str; + local object = newproxy(true); + local object_mt = getmetatable(object); + object_mt.__metatable = lockmsg; + object_mt.__index = setmetatable(span_arr, match_m); + object_mt.__tostring = match_tostr; + + proxy[object] = { name = "Match", spans = span_arr, group_id = group_id }; + return object; +end; + +match_m.group = check_re('Match', 'group', function(self, group_id) + local span = self.spans[type(group_id) == "number" and group_id or self.group_id[group_id]]; + if not span then + return nil; + end; + return utf8_sub(self.spans.input, span[1], span[2]); +end); + +match_m.span = check_re('Match', 'span', function(self, group_id) + local span = self.spans[type(group_id) == "number" and group_id or self.group_id[group_id]]; + if not span then + return nil; + end; + return span[1], span[2] - 1; +end); + +match_m.groups = check_re('Match', 'groups', function(self) + local spans = self.spans; + if spans.n > 0 then + local ret = table.create(spans.n); + for i = 0, spans.n do + local v = spans[i]; + if v then + ret[i] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + return table.unpack(ret, 1, spans.n); + end; + return utf8_sub(spans.input, spans[0][1], spans[0][2]); +end); + +match_m.groupdict = check_re('Match', 'groupdict', function(self) + local spans = self.spans; + local ret = { }; + for k, v in pairs(self.group_id) do + v = spans[v]; + if v then + ret[k] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + return ret; +end); + +match_m.grouparr = check_re('Match', 'groupdict', function(self) + local spans = self.spans; + local ret = table.create(spans.n); + for i = 0, spans.n do + local v = spans[i]; + if v then + ret[i] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + ret.n = spans.n; + return ret; +end); + +-- +local line_verbs = { + CR = 0, LF = 1, CRLF = 2, ANYRLF = 3, ANY = 4, NUL = 5, +}; +local function is_newline(str_arr, i, verb_flags) + local line_verb_n = verb_flags.newline; + local chr = str_arr[i]; + if line_verb_n == 0 then + -- carriage return + return chr == 0x0D; + elseif line_verb_n == 2 then + -- carriage return followed by line feed + return chr == 0x0A and str_arr[i - 1] == 0x20; + elseif line_verb_n == 3 then + -- any of the above + return chr == 0x0A or chr == 0x0D; + elseif line_verb_n == 4 then + -- any of Unicode newlines + return chr == 0x0A or chr == 0x0B or chr == 0x0C or chr == 0x0D or chr == 0x85 or chr == 0x2028 or chr == 0x2029; + elseif line_verb_n == 5 then + -- null + return chr == 0; + end; + -- linefeed + return chr == 0x0A; +end; + + +local function tkn_char_match(tkn_part, str_arr, i, flags, verb_flags) + local chr = str_arr[i]; + if not chr then + return false; + elseif flags.ignoreCase and chr >= 0x61 and chr <= 0x7A then + chr -= 0x20; + end; + if type(tkn_part) == "number" then + return tkn_part == chr; + elseif tkn_part[1] == "charset" then + for _, v in ipairs(tkn_part[3]) do + if tkn_char_match(v, str_arr, i, flags, verb_flags) then + return not tkn_part[2]; + end; + end; + return tkn_part[2]; + elseif tkn_part[1] == "range" then + return chr >= tkn_part[2] and chr <= tkn_part[3] or flags.ignoreCase and chr >= 0x41 and chr <= 0x5A and (chr + 0x20) >= tkn_part[2] and (chr + 0x20) <= tkn_part[3]; + elseif tkn_part[1] == "class" then + local char_class = tkn_part[2]; + local negate = tkn_part[3]; + local match = false; + -- if and elseifs :( + -- Might make these into tables in the future + if char_class == "xdigit" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x46 or chr >= 0x61 and chr <= 0x66; + elseif char_class == "ascii" then + match = chr <= 0x7F; + -- cannot be accessed through POSIX classes + elseif char_class == "vertical_tab" then + match = chr >= 0x0A and chr <= 0x0D or chr == 0x2028 or chr == 0x2029; + -- + elseif flags.unicode then + local current_category = u_categories[chr] or 'Cn'; + local first_category = current_category:sub(1, 1); + if char_class == "alnum" then + match = first_category == 'L' or current_category == 'Nl' or current_category == 'Nd'; + elseif char_class == "alpha" then + match = first_category == 'L' or current_category == 'Nl'; + elseif char_class == "blank" then + match = current_category == 'Zs' or chr == 0x09; + elseif char_class == "cntrl" then + match = current_category == 'Cc'; + elseif char_class == "digit" then + match = current_category == 'Nd'; + elseif char_class == "graph" then + match = first_category ~= 'P' and first_category ~= 'C'; + elseif char_class == "lower" then + match = current_category == 'Ll'; + elseif char_class == "print" then + match = first_category ~= 'C'; + elseif char_class == "punct" then + match = first_category == 'P'; + elseif char_class == "space" then + match = first_category == 'Z' or chr >= 0x09 and chr <= 0x0D; + elseif char_class == "upper" then + match = current_category == 'Lu'; + elseif char_class == "word" then + match = first_category == 'L' or current_category == 'Nl' or current_category == 'Nd' or current_category == 'Pc'; + end; + elseif char_class == "alnum" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A; + elseif char_class == "alpha" then + match = chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A; + elseif char_class == "blank" then + match = chr == 0x09 or chr == 0x20; + elseif char_class == "cntrl" then + match = chr <= 0x1F or chr == 0x7F; + elseif char_class == "digit" then + match = chr >= 0x30 and chr <= 0x39; + elseif char_class == "graph" then + match = chr >= 0x21 and chr <= 0x7E; + elseif char_class == "lower" then + match = chr >= 0x61 and chr <= 0x7A; + elseif char_class == "print" then + match = chr >= 0x20 and chr <= 0x7E; + elseif char_class == "punct" then + match = class_ascii_punct[chr]; + elseif char_class == "space" then + match = chr >= 0x09 and chr <= 0x0D or chr == 0x20; + elseif char_class == "upper" then + match = chr >= 0x41 and chr <= 0x5A; + elseif char_class == "word" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A or chr == 0x5F; + end; + if negate then + return not match; + end; + return match; + elseif tkn_part[1] == "category" then + local chr_category = u_categories[chr] or 'Cn'; + local category_v = tkn_part[3]; + local category_len = #category_v; + if category_len == 3 then + local match = false; + if category_v == "Xan" or category_v == "Xwd" then + match = chr_category:find("^[LN]") or category_v == "Xwd" and chr == 0x5F; + elseif category_v == "Xps" or category_v == "Xsp" then + match = chr_category:sub(1, 1) == 'Z' or chr >= 0x09 and chr <= 0x0D; + elseif category_v == "Xuc" then + match = tkn_char_match(xuc_chr, str_arr, i, flags, verb_flags); + end; + if tkn_part[2] then + return not match; + end + return match; + elseif chr_category:sub(1, category_len) == category_v then + return not tkn_part[2]; + end; + return tkn_part[2]; + elseif tkn_part[1] == 0x2E then + return flags.dotAll or not is_newline(str_arr, i, verb_flags); + elseif tkn_part[1] == 0x4E then + return not is_newline(str_arr, i, verb_flags); + elseif tkn_part[1] == 0x52 then + if verb_flags.newline_seq == 0 then + -- CR, LF or CRLF + return chr == 0x0A or chr == 0x0D; + end; + -- any unicode newline + return chr == 0x0A or chr == 0x0B or chr == 0x0C or chr == 0x0D or chr == 0x85 or chr == 0x2028 or chr == 0x2029; + end; + return false; +end; + +local function find_alternation(token, i, count) + while true do + local v = token[i]; + local is_table = type(v) == "table"; + if v == alternation then + return i, count; + elseif is_table and v[1] == 0x28 then + if count then + count += v.count; + end; + i = v[3]; + elseif is_table and v[1] == "quantifier" and type(v[5]) == "table" and v[5][1] == 0x28 then + if count then + count += v[5].count; + end; + i = v[5][3]; + elseif not v or is_table and v[1] == 0x29 then + return nil, count; + elseif count then + if is_table and v[1] == "quantifier" then + count += v[3]; + else + count += 1; + end; + end; + i += 1; + end; +end; + +local function re_rawfind(token, str_arr, init, flags, verb_flags, as_bool) + local tkn_i, str_i, start_i = 0, init, init; + local states = { }; + while tkn_i do + if tkn_i == 0 then + tkn_i += 1; + local next_alt = find_alternation(token, tkn_i); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + continue; + end; + local ctkn = token[tkn_i]; + local tkn_type = type(ctkn) == "table" and ctkn[1]; + if not ctkn then + break; + elseif ctkn == "ACCEPT" then + local not_lookaround = true; + local close_i = tkn_i; + repeat + close_i += 1; + local is_table = type(token[close_i]) == "table"; + local close_i_tkn = token[close_i]; + if is_table and (close_i_tkn[1] == 0x28 or close_i_tkn[1] == "quantifier" and type(close_i_tkn[5]) == "table" and close_i_tkn[5][1] == 0x28) then + close_i = close_i_tkn[1] == "quantifier" and close_i_tkn[5][3] or close_i_tkn[3]; + elseif is_table and close_i_tkn[1] == 0x29 and (close_i_tkn[4] == 0x21 or close_i_tkn[4] == 0x3D) then + not_lookaround = false; + tkn_i = close_i; + break; + end; + until not close_i_tkn; + if not_lookaround then + break; + end; + elseif ctkn == "PRUNE" or ctkn == "SKIP" then + table.insert(states, 1, { ctkn, str_i }); + tkn_i += 1; + elseif tkn_type == 0x28 then + table.insert(states, 1, { "group", tkn_i, str_i, nil, ctkn[2], ctkn[3], ctkn[4] }); + tkn_i += 1; + local next_alt, count = find_alternation(token, tkn_i, (ctkn[4] == 0x21 or ctkn[4] == 0x3D) and ctkn[5] and 0); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + if count then + str_i -= count; + end; + elseif tkn_type == 0x29 and ctkn[4] ~= 0x21 then + if ctkn[4] == 0x21 or ctkn[4] == 0x3D then + while true do + local selected_match_start; + local selected_state = table.remove(states, 1); + if selected_state[1] == "group" and selected_state[2] == ctkn[3] then + if (ctkn[4] == 0x21 or ctkn[4] == 0x3D) and not ctkn[5] then + str_i = selected_state[3]; + end; + if selected_match_start then + table.insert(states, 1, selected_match_start); + end; + break; + elseif selected_state[1] == "matchStart" and not selected_match_start and ctkn[4] == 0x3D then + selected_match_start = selected_state; + end; + end; + elseif ctkn[4] == 0x3E then + repeat + local selected_state = table.remove(states, 1); + until not selected_state or selected_state[1] == "group" and selected_state[2] == ctkn[3]; + else + for i, v in ipairs(states) do + if v[1] == "group" and v[2] == ctkn[3] then + if v.jmp then + -- recursive match + tkn_i = v.jmp; + end; + v[4] = str_i; + if v[7] == "quantifier" and v[10] + 1 < v[9] then + if token[ctkn[3]][4] ~= "lazy" or v[10] + 1 < v[8] then + tkn_i = ctkn[3]; + end; + local ctkn1 = token[ctkn[3]]; + local new_group = { "group", v[2], str_i, nil, ctkn1[5][2], ctkn1[5][3], "quantifier", ctkn1[2], ctkn1[3], v[10] + 1, v[11], ctkn1[4] }; + table.insert(states, 1, new_group); + if v[11] then + table.insert(states, 1, { "alternation", v[11], str_i }); + end; + end; + break; + end; + end; + end; + tkn_i += 1; + elseif tkn_type == 0x4B then + table.insert(states, 1, { "matchStart", str_i }); + tkn_i += 1; + elseif tkn_type == 0x7C then + local close_i = tkn_i; + repeat + close_i += 1; + local is_table = type(token[close_i]) == "table"; + local close_i_tkn = token[close_i]; + if is_table and (close_i_tkn[1] == 0x28 or close_i_tkn[1] == "quantifier" and type(close_i_tkn[5]) == "table" and close_i_tkn[5][1] == 0x28) then + close_i = close_i_tkn[1] == "quantifier" and close_i_tkn[5][3] or close_i_tkn[3]; + end; + until is_table and close_i_tkn[1] == 0x29 or not close_i_tkn; + if token[close_i] then + for _, v in ipairs(states) do + if v[1] == "group" and v[6] == close_i then + tkn_i = v[6]; + break; + end; + end; + else + tkn_i = close_i; + end; + elseif tkn_type == "recurmatch" then + table.insert(states, 1, { "group", ctkn[3], str_i, nil, nil, token[ctkn[3]][3], nil, jmp = tkn_i }); + tkn_i = ctkn[3] + 1; + local next_alt, count = find_alternation(token, tkn_i); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + else + local match; + if ctkn == "FAIL" then + match = false; + elseif tkn_type == 0x29 then + repeat + local selected_state = table.remove(states, 1); + until selected_state[1] == "group" and selected_state[2] == ctkn[3]; + elseif tkn_type == "quantifier" then + if type(ctkn[5]) == "table" and ctkn[5][1] == 0x28 then + local next_alt = find_alternation(token, tkn_i + 1); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + table.insert(states, next_alt and 2 or 1, { "group", tkn_i, str_i, nil, ctkn[5][2], ctkn[5][3], "quantifier", ctkn[2], ctkn[3], 0, next_alt, ctkn[4] }); + if ctkn[4] == "lazy" and ctkn[2] == 0 then + tkn_i = ctkn[5][3]; + end; + match = true; + else + local start_i, end_i; + local pattern_count = 1; + local is_backref = type(ctkn[5]) == "table" and ctkn[5][1] == "backref"; + if is_backref then + pattern_count = 0; + local group_n = ctkn[5][2]; + for _, v in ipairs(states) do + if v[1] == "group" and v[5] == group_n then + start_i, end_i = v[3], v[4]; + pattern_count = end_i - start_i; + break; + end; + end; + end; + local min_max_i = str_i + ctkn[2] * pattern_count; + local mcount = 0; + while mcount < ctkn[3] do + if is_backref then + if start_i and end_i then + local org_i = str_i; + if utf8_sub(str_arr.s, start_i, end_i) ~= utf8_sub(str_arr.s, org_i, str_i + pattern_count) then + break; + end; + else + break; + end; + elseif not tkn_char_match(ctkn[5], str_arr, str_i, flags, verb_flags) then + break; + end; + str_i += pattern_count; + mcount += 1; + end; + match = mcount >= ctkn[2]; + if match and ctkn[4] ~= "possessive" then + if ctkn[4] == "lazy" then + min_max_i, str_i = str_i, min_max_i; + end; + table.insert(states, 1, { "quantifier", tkn_i, str_i, math.min(min_max_i, str_arr.n + 1), (ctkn[4] == "lazy" and 1 or -1) * pattern_count }); + end; + end; + elseif tkn_type == "backref" then + local start_i, end_i; + local group_n = ctkn[2]; + for _, v in ipairs(states) do + if v[1] == "group" and v[5] == group_n then + start_i, end_i = v[3], v[4]; + break; + end; + end; + if start_i and end_i then + local org_i = str_i; + str_i += end_i - start_i; + match = utf8_sub(str_arr.s, start_i, end_i) == utf8_sub(str_arr.s, org_i, str_i); + end; + else + local chr = str_arr[str_i]; + if tkn_type == 0x24 or tkn_type == 0x5A or tkn_type == 0x7A then + match = str_i == str_arr.n + 1 or tkn_type == 0x24 and flags.multiline and is_newline(str_arr, str_i + 1, verb_flags) or tkn_type == 0x5A and str_i == str_arr.n and is_newline(str_arr, str_i, verb_flags); + elseif tkn_type == 0x5E or tkn_type == 0x41 or tkn_type == 0x47 then + match = str_i == 1 or tkn_type == 0x5E and flags.multiline and is_newline(str_arr, str_i - 1, verb_flags) or tkn_type == 0x47 and str_i == init; + elseif tkn_type == 0x42 or tkn_type == 0x62 then + local start_m = str_i == 1 or flags.multiline and is_newline(str_arr, str_i - 1, verb_flags); + local end_m = str_i == str_arr.n + 1 or flags.multiline and is_newline(str_arr, str_i, verb_flags); + local w_m = tkn_char_match(ctkn[2], str_arr[str_i - 1], flags) and 0 or tkn_char_match(ctkn[2], chr, flags) and 1; + if w_m == 0 then + match = end_m or not tkn_char_match(ctkn[2], chr, flags); + elseif w_m then + match = start_m or not tkn_char_match(ctkn[2], str_arr[str_i - 1], flags); + end; + if tkn_type == 0x42 then + match = not match; + end; + else + match = tkn_char_match(ctkn, str_arr, str_i, flags, verb_flags); + str_i += 1; + end; + end; + if not match then + while true do + local prev_type, prev_state = states[1] and states[1][1], states[1]; + if not prev_type or prev_type == "PRUNE" or prev_type == "SKIP" then + if prev_type then + table.clear(states); + end; + if start_i > str_arr.n then + if as_bool then + return false; + end; + return nil; + end; + start_i = prev_type == "SKIP" and prev_state[2] or start_i + 1; + tkn_i, str_i = 0, start_i; + break; + elseif prev_type == "alternation" then + tkn_i, str_i = prev_state[2], prev_state[3]; + local next_alt, count = find_alternation(token, tkn_i + 1); + if next_alt then + prev_state[2] = next_alt; + else + table.remove(states, 1); + end; + if count then + str_i -= count; + end; + break; + elseif prev_type == "group" then + if prev_state[7] == "quantifier" then + if prev_state[12] == "greedy" and prev_state[10] >= prev_state[8] + or prev_state[12] == "lazy" and prev_state[10] < prev_state[9] and not prev_state[13] then + tkn_i, str_i = prev_state[12] == "greedy" and prev_state[6] or prev_state[2], prev_state[3]; + if prev_state[12] == "greedy" then + table.remove(states, 1); + break; + elseif prev_state[10] >= prev_state[8] then + prev_state[13] = true; + break; + end; + end; + elseif prev_state[7] == 0x21 then + table.remove(states, 1); + tkn_i, str_i = prev_state[6], prev_state[3]; + break; + end; + elseif prev_type == "quantifier" then + if math.sign(prev_state[4] - prev_state[3]) == math.sign(prev_state[5]) then + prev_state[3] += prev_state[5]; + tkn_i, str_i = prev_state[2], prev_state[3]; + break; + end; + end; + -- keep match out state and recursive state, can be safely removed + -- prevents infinite loop + table.remove(states, 1); + end; + end; + tkn_i += 1; + end; + end; + if as_bool then + return true; + end; + local match_start_ran = false; + local span = table.create(token.group_n); + span[0], span.n = { start_i, str_i }, token.group_n; + for _, v in ipairs(states) do + if v[1] == "matchStart" and not match_start_ran then + span[0][1], match_start_ran = v[2], true; + elseif v[1] == "group" and v[5] and not span[v[5]] then + span[v[5]] = { v[3], v[4] }; + end; + end; + return span; +end; + +--[[ Methods ]]-- +re_m.test = check_re('RegEx', 'test', function(self, str, init) + return re_rawfind(self.token, to_str_arr(str, init), 1, self.flags, self.verb_flags, true); +end); + +re_m.match = check_re('RegEx', 'match', function(self, str, init, source) + local span = re_rawfind(self.token, to_str_arr(str, init), 1, self.flags, self.verb_flags, false); + if not span then + return nil; + end; + return new_match(span, self.group_id, source, str); +end); + +re_m.matchall = check_re('RegEx', 'matchall', function(self, str, init, source) + str = to_str_arr(str, init); + local i = 1; + return function() + local span = i <= str.n + 1 and re_rawfind(self.token, str, i, self.flags, self.verb_flags, false); + if not span then + return nil; + end; + i = span[0][2] + (span[0][1] >= span[0][2] and 1 or 0); + return new_match(span, self.group_id, source, str.s); + end; +end); + +local function insert_tokenized_sub(repl_r, str, span, tkn) + for _, v in ipairs(tkn) do + if type(v) == "table" then + if v[1] == "condition" then + if span[v[2]] then + if v[3] then + insert_tokenized_sub(repl_r, str, span, v[3]); + else + table.move(str, span[v[2]][1], span[v[2]][2] - 1, #repl_r + 1, repl_r); + end; + elseif v[4] then + insert_tokenized_sub(repl_r, str, span, v[4]); + end; + else + table.move(v, 1, #v, #repl_r + 1, repl_r); + end; + elseif span[v] then + table.move(str, span[v][1], span[v][2] - 1, #repl_r + 1, repl_r); + end; + end; + repl_r.n = #repl_r; + return repl_r; +end; + +re_m.sub = check_re('RegEx', 'sub', function(self, repl, str, n, repl_flag_str, source) + if repl_flag_str ~= nil and type(repl_flag_str) ~= "number" and type(repl_flag_str) ~= "string" then + error(string.format("invalid argument #5 to 'sub' (string expected, got %s)", typeof(repl_flag_str)), 3); + end + local repl_flags = { + l = false, o = false, u = false, + }; + for f in string.gmatch(repl_flag_str or '', utf8.charpattern) do + if repl_flags[f] ~= false then + error("invalid regular expression substitution flag " .. f, 3); + end; + repl_flags[f] = true; + end; + local repl_type = type(repl); + if repl_type == "number" then + repl ..= ''; + elseif repl_type ~= "string" and repl_type ~= "function" and (not repl_flags.o or repl_type ~= "table") then + error(string.format("invalid argument #2 to 'sub' (string/function%s expected, got %s)", repl_flags.o and "/table" or '', typeof(repl)), 3); + end; + if tonumber(n) then + n = tonumber(n); + if n <= -1 or n ~= n then + n = math.huge; + end; + elseif n ~= nil then + error(string.format("invalid argument #4 to 'sub' (number expected, got %s)", typeof(n)), 3); + else + n = math.huge; + end; + if n < 1 then + return str, 0; + end; + local min_repl_n = 0; + if repl_type == "string" then + repl = to_str_arr(repl); + if not repl_flags.l then + local i1 = 0; + local repl_r = table.create(3); + local group_n = self.token.group_n; + local conditional_c = { }; + while i1 < repl.n do + local i2 = i1; + repeat + i2 += 1; + until not repl[i2] or repl[i2] == 0x24 or repl[i2] == 0x5C or (repl[i2] == 0x3A or repl[i2] == 0x7D) and conditional_c[1]; + min_repl_n += i2 - i1 - 1; + if i2 - i1 > 1 then + table.insert(repl_r, table.move(repl, i1 + 1, i2 - 1, 1, table.create(i2 - i1 - 1))); + end; + if repl[i2] == 0x3A then + local current_conditional_c = conditional_c[1]; + if current_conditional_c[2] then + error("malformed substitution pattern", 3); + end; + current_conditional_c[2] = table.move(repl_r, current_conditional_c[3], #repl_r, 1, table.create(#repl_r + 1 - current_conditional_c[3])); + for i3 = #repl_r, current_conditional_c[3], -1 do + repl_r[i3] = nil; + end; + elseif repl[i2] == 0x7D then + local current_conditional_c = table.remove(conditional_c, 1); + local second_c = table.move(repl_r, current_conditional_c[3], #repl_r, 1, table.create(#repl_r + 1 - current_conditional_c[3])); + for i3 = #repl_r, current_conditional_c[3], -1 do + repl_r[i3] = nil; + end; + table.insert(repl_r, { "condition", current_conditional_c[1], current_conditional_c[2] ~= true and (current_conditional_c[2] or second_c), current_conditional_c[2] and second_c }); + elseif repl[i2] then + i2 += 1; + local subst_c = repl[i2]; + if not subst_c then + if repl[i2 - 1] == 0x5C then + error("replacement string must not end with a trailing backslash", 3); + end; + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, repl[i2 - 1]); + else + table.insert(repl_r, { repl[i2 - 1] }); + end; + elseif subst_c == 0x5C and repl[i2 - 1] == 0x24 then + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, 0x24); + else + table.insert(repl_r, { 0x24 }); + end; + i2 -= 1; + min_repl_n += 1; + elseif subst_c == 0x30 then + table.insert(repl_r, 0); + elseif subst_c > 0x30 and subst_c <= 0x39 then + local start_i2 = i2; + local group_i = subst_c - 0x30; + while repl[i2 + 1] and repl[i2 + 1] >= 0x30 and repl[i2 + 1] <= 0x39 do + group_i ..= repl[i2 + 1] - 0x30; + i2 += 1; + end; + group_i = tonumber(group_i); + if not repl_flags.u and group_i > group_n then + error("reference to non-existent subpattern", 3); + end; + table.insert(repl_r, group_i); + elseif subst_c == 0x7B and repl[i2 - 1] == 0x24 then + i2 += 1; + local start_i2 = i2; + while repl[i2] and + (repl[i2] >= 0x30 and repl[i2] <= 0x39 + or repl[i2] >= 0x41 and repl[i2] <= 0x5A + or repl[i2] >= 0x61 and repl[i2] <= 0x7A + or repl[i2] == 0x5F) do + i2 += 1; + end; + if (repl[i2] == 0x7D or repl[i2] == 0x3A and (repl[i2 + 1] == 0x2B or repl[i2 + 1] == 0x2D)) and i2 ~= start_i2 then + local group_k = utf8_sub(repl.s, start_i2, i2); + if repl[start_i2] >= 0x30 and repl[start_i2] <= 0x39 then + group_k = tonumber(group_k); + if not repl_flags.u and group_k > group_n then + error("reference to non-existent subpattern", 3); + end; + else + group_k = self.group_id[group_k]; + if not repl_flags.u and (not group_k or group_k > group_n) then + error("reference to non-existent subpattern", 3); + end; + end; + if repl[i2] == 0x3A then + i2 += 1; + table.insert(conditional_c, { group_k, repl[i2] == 0x2D, #repl_r + 1 }); + else + table.insert(repl_r, group_k); + end; + else + error("malformed substitution pattern", 3); + end; + else + local c_escape_char; + if repl[i2 - 1] == 0x24 then + if subst_c ~= 0x24 then + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, 0x24); + else + table.insert(repl_r, { 0x24 }); + end; + end; + else + c_escape_char = escape_chars[repl[i2]]; + if type(c_escape_char) ~= "number" then + c_escape_char = nil; + end; + end; + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, c_escape_char or repl[i2]); + else + table.insert(repl_r, { c_escape_char or repl[i2] }); + end; + min_repl_n += 1; + end; + end; + i1 = i2; + end; + if conditional_c[1] then + error("malformed substitution pattern", 3); + end; + if not repl_r[2] and type(repl_r[1]) == "table" and repl_r[1][1] ~= "condition" then + repl, repl.n = repl_r[1], #repl_r[1]; + else + repl, repl_type = repl_r, "subst_string"; + end; + end; + end; + str = to_str_arr(str); + local incr, i0, count = 0, 1, 0; + while i0 <= str.n + incr + 1 do + local span = re_rawfind(self.token, str, i0, self.flags, self.verb_flags, false); + if not span then + break; + end; + local repl_r; + if repl_type == "string" then + repl_r = repl; + elseif repl_type == "subst_string" then + repl_r = insert_tokenized_sub(table.create(min_repl_n), str, span, repl); + else + local re_match; + local repl_c; + if repl_type == "table" then + re_match = utf8_sub(str.s, span[0][1], span[0][2]); + repl_c = repl[re_match]; + else + re_match = new_match(span, self.group_id, source, str.s); + repl_c = repl(re_match); + end; + if repl_c == re_match or repl_flags.o and not repl_c then + local repl_n = span[0][2] - span[0][1]; + repl_r = table.move(str, span[0][1], span[0][2] - 1, 1, table.create(repl_n)); + repl_r.n = repl_n; + elseif type(repl_c) == "string" then + repl_r = to_str_arr(repl_c); + elseif type(repl_c) == "number" then + repl_r = to_str_arr(repl_c .. ''); + elseif repl_flags.o then + error(string.format("invalid replacement value (a %s)", type(repl_c)), 3); + else + repl_r = { n = 0 }; + end; + end; + local match_len = span[0][2] - span[0][1]; + local repl_len = math.min(repl_r.n, match_len); + for i1 = 0, repl_len - 1 do + str[span[0][1] + i1] = repl_r[i1 + 1]; + end; + local i1 = span[0][1] + repl_len; + i0 = span[0][2]; + if match_len > repl_r.n then + for i2 = 1, match_len - repl_r.n do + table.remove(str, i1); + incr -= 1; + i0 -= 1; + end; + elseif repl_r.n > match_len then + for i2 = 1, repl_r.n - match_len do + table.insert(str, i1 + i2 - 1, repl_r[repl_len + i2]); + incr += 1; + i0 += 1; + end; + end; + if match_len <= 0 then + i0 += 1; + end; + count += 1; + if n < count + 1 then + break; + end; + end; + return from_str_arr(str), count; +end); + +re_m.split = check_re('RegEx', 'split', function(self, str, n) + if tonumber(n) then + n = tonumber(n); + if n <= -1 or n ~= n then + n = math.huge; + end; + elseif n ~= nil then + error(string.format("invalid argument #3 to 'split' (number expected, got %s)", typeof(n)), 3); + else + n = math.huge; + end; + str = to_str_arr(str); + local i, count = 1, 0; + local ret = { }; + local prev_empty = 0; + while i <= str.n + 1 do + count += 1; + local span = n >= count and re_rawfind(self.token, str, i, self.flags, self.verb_flags, false); + if not span then + break; + end; + table.insert(ret, utf8_sub(str.s, i - prev_empty, span[0][1])); + prev_empty = span[0][1] >= span[0][2] and 1 or 0; + i = span[0][2] + prev_empty; + end; + table.insert(ret, string.sub(str.s, utf8.offset(str.s, i - prev_empty))); + return ret; +end); + +-- +local function re_index(self, index) + return re_m[index] or proxy[self].flags[index]; +end; + +local function re_tostr(self) + return proxy[self].pattern_repr .. proxy[self].flag_repr; +end; +-- + +local other_valid_group_char = { + -- non-capturing group + [0x3A] = true, + -- lookarounds + [0x21] = true, [0x3D] = true, + -- atomic + [0x3E] = true, + -- branch reset + [0x7C] = true, +}; + +local function tokenize_ptn(codes, flags) + if flags.unicode and not options.unicodeData then + return "options.unicodeData cannot be turned off while having unicode flag"; + end; + local i, len = 1, codes.n; + local group_n = 0; + local outln, group_id, verb_flags = { }, { }, { + newline = 1, newline_seq = 1, not_empty = 0, + }; + while i <= len do + local c = codes[i]; + if c == 0x28 then + -- Match + local ret; + if codes[i + 1] == 0x2A then + i += 2; + local start_i = i; + while codes[i] + and (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F or codes[i] == 0x3A) do + i += 1; + end; + if codes[i] ~= 0x29 and codes[i - 1] ~= 0x3A then + -- fallback as normal and ( can't be repeated + return "quantifier doesn't follow a repeatable pattern"; + end; + local selected_verb = utf8_sub(codes.s, start_i, i); + if selected_verb == "positive_lookahead:" or selected_verb == "negative_lookhead:" + or selected_verb == "positive_lookbehind:" or selected_verb == "negative_lookbehind:" + or selected_verb:find("^[pn]l[ab]:$") then + ret = { 0x28, nil, nil, selected_verb:find('^n') and 0x21 or 0x3D, selected_verb:find('b', 3, true) and 1 }; + elseif selected_verb == "atomic:" then + ret = { 0x28, nil, nil, 0x3E, nil }; + elseif selected_verb == "ACCEPT" or selected_verb == "FAIL" or selected_verb == 'F' or selected_verb == "PRUNE" or selected_verb == "SKIP" then + ret = selected_verb == 'F' and "FAIL" or selected_verb; + else + if line_verbs[selected_verb] then + verb_flags.newline = selected_verb; + elseif selected_verb == "BSR_ANYCRLF" or selected_verb == "BSR_UNICODE" then + verb_flags.newline_seq = selected_verb == "BSR_UNICODE" and 1 or 0; + elseif selected_verb == "NOTEMPTY" or selected_verb == "NOTEMPTY_ATSTART" then + verb_flags.not_empty = selected_verb == "NOTEMPTY" and 1 or 2; + else + return "unknown or malformed verb"; + end; + if outln[1] then + return "this verb must be placed at the beginning of the regex"; + end; + end; + elseif codes[i + 1] == 0x3F then + -- ? syntax + i += 2; + if codes[i] == 0x23 then + -- comments + i = table.find(codes, 0x29, i); + if not i then + return "unterminated parenthetical"; + end; + i += 1; + continue; + elseif not codes[i] then + return "unterminated parenthetical"; + end; + ret = { 0x28, nil, nil, codes[i], nil }; + if codes[i] == 0x30 and codes[i + 1] == 0x29 then + -- recursive match entire pattern + ret[1], ret[2], ret[3], ret[5] = "recurmatch", 0, 0, nil; + elseif codes[i] > 0x30 and codes[i] <= 0x39 then + -- recursive match + local org_i = i; + i += 1; + while codes[i] >= 0x30 and codes[i] <= 0x30 do + i += 1; + end; + if codes[i] ~= 0x29 then + return "invalid group structure"; + end; + ret[1], ret[2], ret[4] = "recurmatch", tonumber(utf8_sub(codes.s, org_i, i)), nil; + elseif codes[i] == 0x3C and codes[i + 1] == 0x21 or codes[i + 1] == 0x3D then + -- lookbehinds + i += 1; + ret[4], ret[5] = codes[i], 1; + elseif codes[i] == 0x7C then + -- branch reset + ret[5] = group_n; + elseif codes[i] == 0x50 or codes[i] == 0x3C or codes[i] == 0x27 then + if codes[i] == 0x50 then + i += 1; + end; + if codes[i] == 0x3D then + -- backref + local start_i = i + 1; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if not codes[i] then + return "unterminated parenthetical"; + elseif codes[i] ~= 0x29 or i == start_i then + return "invalid group structure"; + end; + ret = { "backref", utf8_sub(codes.s, start_i, i) }; + elseif codes[i] == 0x3C or codes[i - 1] ~= 0x50 and codes[i] == 0x27 then + -- named capture + local delimiter = codes[i] == 0x27 and 0x27 or 0x3E; + local start_i = i + 1; + i += 1; + if codes[i] == 0x29 then + return "missing character in subpattern"; + elseif codes[i] >= 0x30 and codes[i] <= 0x39 then + return "subpattern name must not begin with a digit"; + elseif not (codes[i] >= 0x41 and codes[i] <= 0x5A or codes[i] >= 0x61 and codes[i] <= 0x7A or codes[i] == 0x5F) then + return "invalid character in subpattern"; + end; + i += 1; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if not codes[i] then + return "unterminated parenthetical"; + elseif codes[i] ~= delimiter then + return "invalid character in subpattern"; + end; + local name = utf8_sub(codes.s, start_i, i); + group_n += 1; + if (group_id[name] or group_n) ~= group_n then + return "subpattern name already exists"; + end; + for name1, group_n1 in pairs(group_id) do + if name ~= name1 and group_n == group_n1 then + return "different names for subpatterns of the same number aren't permitted"; + end; + end; + group_id[name] = group_n; + ret[2], ret[4] = group_n, nil; + else + return "invalid group structure"; + end; + elseif not other_valid_group_char[codes[i]] then + return "invalid group structure"; + end; + else + group_n += 1; + ret = { 0x28, group_n, nil, nil }; + end; + if ret then + table.insert(outln, ret); + end; + elseif c == 0x29 then + -- Close parenthesis + local i1 = #outln + 1; + local lookbehind_c = -1; + local current_lookbehind_c = 0; + local max_c, group_c = 0, 0; + repeat + i1 -= 1; + local v, is_table = outln[i1], type(outln[i1]) == "table"; + if is_table and v[1] == 0x28 then + group_c += 1; + if current_lookbehind_c and v.count then + current_lookbehind_c += v.count; + end; + if not v[3] then + if v[4] == 0x7C then + group_n = v[5] + math.max(max_c, group_c); + end; + if current_lookbehind_c ~= lookbehind_c and lookbehind_c ~= -1 then + lookbehind_c = nil; + else + lookbehind_c = current_lookbehind_c; + end; + break; + end; + elseif v == alternation then + if current_lookbehind_c ~= lookbehind_c and lookbehind_c ~= -1 then + lookbehind_c, current_lookbehind_c = nil, nil; + else + lookbehind_c, current_lookbehind_c = current_lookbehind_c, 0; + end; + max_c, group_c = math.max(max_c, group_c), 0; + elseif current_lookbehind_c then + if is_table and v[1] == "quantifier" then + if v[2] == v[3] then + current_lookbehind_c += v[2]; + else + current_lookbehind_c = nil; + end; + else + current_lookbehind_c += 1; + end; + end; + until i1 < 1; + if i1 < 1 then + return "unmatched ) in regular expression"; + end; + local v = outln[i1]; + local outln_len_p_1 = #outln + 1; + local ret = { 0x29, v[2], i1, v[4], v[5], count = lookbehind_c }; + if (v[4] == 0x21 or v[4] == 0x3D) and v[5] and not lookbehind_c then + return "lookbehind assertion is not fixed width"; + end; + v[3] = outln_len_p_1; + table.insert(outln, ret); + elseif c == 0x2E then + table.insert(outln, dot); + elseif c == 0x5B then + -- Character set + local negate, char_class = false, nil; + i += 1; + local start_i = i; + if codes[i] == 0x5E then + negate = true; + i += 1; + elseif codes[i] == 0x2E or codes[i] == 0x3A or codes[i] == 0x3D then + -- POSIX character classes + char_class = codes[i]; + end; + local ret; + if codes[i] == 0x5B or codes[i] == 0x5C then + ret = { }; + else + ret = { codes[i] }; + i += 1; + end; + while codes[i] ~= 0x5D do + if not codes[i] then + return "unterminated character class"; + elseif codes[i] == 0x2D and ret[1] and type(ret[1]) == "number" then + if codes[i + 1] == 0x5D then + table.insert(ret, 1, 0x2D); + else + i += 1; + local ret_c = codes[i]; + if ret_c == 0x5B then + if codes[i + 1] == 0x2E or codes[i + 1] == 0x3A or codes[i + 1] == 0x3D then + -- Check for POSIX character class, name does not matter + local i1 = i + 2; + repeat + i1 = table.find(codes, 0x5D, i1); + until not i1 or codes[i1 - 1] ~= 0x5C; + if not i1 then + return "unterminated character class"; + elseif codes[i1 - 1] == codes[i + 1] and i1 - 1 ~= i + 1 then + return "invalid range in character class"; + end; + end; + if ret[1] > 0x5B then + return "invalid range in character class"; + end; + elseif ret_c == 0x5C then + i += 1; + if codes[i] == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + ret_c = radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0; + elseif codes[i] >= 0x30 and codes[i] <= 0x37 then + local radix0, radix1, radix2 = codes[i] - 0x30, nil, nil; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + ret_c = radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0; + else + ret_c = escape_chars[codes[i]] or codes[i]; + if type(ret_c) ~= "number" then + return "invalid range in character class"; + end; + end; + elseif ret[1] > ret_c then + return "invalid range in character class"; + end; + ret[1] = { "range", ret[1], ret_c }; + end; + elseif codes[i] == 0x5B then + if codes[i + 1] == 0x2E or codes[i + 1] == 0x3A or codes[i + 1] == 0x3D then + local i1 = i + 2; + repeat + i1 = table.find(codes, 0x5D, i1); + until not i1 or codes[i1 - 1] ~= 0x5C; + if not i1 then + return "unterminated character class"; + elseif codes[i1 - 1] ~= codes[i + 1] or i1 - 1 == i + 1 then + table.insert(ret, 1, 0x5B); + elseif codes[i1 - 1] == 0x2E or codes[i1 - 1] == 0x3D then + return "POSIX collating elements aren't supported"; + elseif codes[i1 - 1] == 0x3A then + -- I have no plans to support escape codes (\) in character class names + local negate = codes[i + 3] == 0x5E; + local class_name = utf8_sub(codes.s, i + (negate and 3 or 2), i1 - 1); + -- If not valid then throw an error + if not posix_class_names[class_name] then + return "unknown POSIX class name"; + end; + table.insert(ret, 1, { "class", class_name, negate }); + i = i1; + end; + else + table.insert(ret, 1, 0x5B); + end; + elseif codes[i] == 0x5C then + i += 1; + if codes[i] == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] == 0x7B then + i += 1; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed hexadecimal character"; + elseif i - org_i > 4 then + return "character offset too large"; + end; + table.insert(ret, 1, tonumber(utf8_sub(codes.s, org_i, i), 16)); + else + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(ret, 1, radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0); + end; + elseif codes[i] >= 0x30 and codes[i] <= 0x37 then + local radix0, radix1, radix2 = codes[i] - 0x30, nil, nil; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(ret, 1, radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0); + elseif codes[i] == 0x45 then + -- intentionally left blank, \E that's not preceded \Q is ignored + elseif codes[i] == 0x51 then + local start_i = i + 1; + repeat + i = table.find(codes, 0x5C, i + 1); + until not i or codes[i + 1] == 0x45; + table.move(codes, start_i, i and i - 1 or #codes, #outln + 1, outln); + if not i then + break; + end; + i += 1; + elseif codes[i] == 0x4E then + if codes[i + 1] == 0x7B and codes[i + 2] == 0x55 and codes[i + 3] == 0x2B and flags.unicode then + i += 4; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == start_i then + return "malformed Unicode code point"; + end; + local code_point = tonumber(utf8_sub(codes.s, start_i, i)); + table.insert(ret, 1, code_point); + else + return "invalid escape sequence"; + end; + elseif codes[i] == 0x50 or codes[i] == 0x70 then + if not options.unicodeData then + return "options.unicodeData cannot be turned off when using \\p"; + end; + i += 1; + if codes[i] ~= 0x7B then + local c_name = utf8.char(codes[i] or 0); + if not valid_categories[c_name] then + return "unknown or malformed script name"; + end; + table.insert(ret, 1, { "category", false, c_name }); + else + local negate = codes[i] == 0x50; + i += 1; + if codes[i] == 0x5E then + i += 1; + negate = not negate; + end; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if codes[i] ~= 0x7D then + return "unknown or malformed script name"; + end; + local c_name = utf8_sub(codes.s, start_i, i); + local script_set = chr_scripts[c_name]; + if script_set then + table.insert(ret, 1, { "charset", negate, script_set }); + elseif not valid_categories[c_name] then + return "unknown or malformed script name"; + else + table.insert(ret, 1, { "category", negate, c_name }); + end; + end; + elseif codes[i] == 0x6F then + i += 1; + if codes[i] ~= 0x7B then + return "malformed octal code"; + end; + i += 1; + local org_i = i; + while codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed octal code"; + end; + local ret_chr = tonumber(utf8_sub(codes.s, org_i, i), 8); + if ret_chr > 0xFFFF then + return "character offset too large"; + end; + table.insert(ret, 1, ret_chr); + else + local esc_char = escape_chars[codes[i]]; + table.insert(ret, 1, type(esc_char) == "string" and { "class", esc_char, false } or esc_char or codes[i]); + end; + elseif flags.ignoreCase and codes[i] >= 0x61 and codes[i] <= 0x7A then + table.insert(ret, 1, codes[i] - 0x20); + else + table.insert(ret, 1, codes[i]); + end; + i += 1; + end; + if codes[i - 1] == char_class and i - 1 ~= start_i then + return char_class == 0x3A and "POSIX named classes are only support within a character set" or "POSIX collating elements aren't supported"; + end; + if not ret[2] and not negate then + table.insert(outln, ret[1]); + else + table.insert(outln, { "charset", negate, ret }); + end; + elseif c == 0x5C then + -- Escape char + i += 1; + local escape_c = codes[i]; + if not escape_c then + return "pattern may not end with a trailing backslash"; + elseif escape_c >= 0x30 and escape_c <= 0x39 then + local org_i = i; + while codes[i + 1] and codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39 do + i += 1; + end; + local escape_d = tonumber(utf8_sub(codes.s, org_i, i + 1)); + if escape_d > group_n and i ~= org_i then + i = org_i; + local radix0, radix1, radix2; + if codes[i] <= 0x37 then + radix0 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + end; + table.insert(outln, radix0 and (radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0) or codes[org_i]); + else + table.insert(outln, { "backref", escape_d }); + end; + elseif escape_c == 0x45 then + -- intentionally left blank, \E that's not preceded \Q is ignored + elseif escape_c == 0x51 then + local start_i = i + 1; + repeat + i = table.find(codes, 0x5C, i + 1); + until not i or codes[i + 1] == 0x45; + table.move(codes, start_i, i and i - 1 or #codes, #outln + 1, outln); + if not i then + break; + end; + i += 1; + elseif escape_c == 0x4E then + if codes[i + 1] == 0x7B and codes[i + 2] == 0x55 and codes[i + 3] == 0x2B and flags.unicode then + i += 4; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == start_i then + return "malformed Unicode code point"; + end; + local code_point = tonumber(utf8_sub(codes.s, start_i, i)); + table.insert(outln, code_point); + else + table.insert(outln, escape_chars[0x4E]); + end; + elseif escape_c == 0x50 or escape_c == 0x70 then + if not options.unicodeData then + return "options.unicodeData cannot be turned off when using \\p"; + end; + i += 1; + if codes[i] ~= 0x7B then + local c_name = utf8.char(codes[i] or 0); + if not valid_categories[c_name] then + return "unknown or malformed script name"; + end; + table.insert(outln, { "category", false, c_name }); + else + local negate = escape_c == 0x50; + i += 1; + if codes[i] == 0x5E then + i += 1; + negate = not negate; + end; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if codes[i] ~= 0x7D then + return "unknown or malformed script name"; + end; + local c_name = utf8_sub(codes.s, start_i, i); + local script_set = chr_scripts[c_name]; + if script_set then + table.insert(outln, { "charset", negate, script_set }); + elseif not valid_categories[c_name] then + return "unknown or malformed script name"; + else + table.insert(outln, { "category", negate, c_name }); + end; + end; + elseif escape_c == 0x67 and (codes[i + 1] == 0x7B or codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39) then + local is_grouped = false; + i += 1; + if codes[i] == 0x7B then + i += 1; + is_grouped = true; + elseif codes[i] < 0x30 or codes[i] > 0x39 then + return "malformed reference code"; + end; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if is_grouped and codes[i] ~= 0x7D then + return "malformed reference code"; + end; + local ref_name = tonumber(utf8_sub(codes.s, org_i, i + (is_grouped and 0 or 1))); + table.insert(outln, { "backref", ref_name }); + if not is_grouped then + i -= 1; + end; + elseif escape_c == 0x6F then + i += 1; + if codes[i + 1] ~= 0x7B then + return "malformed octal code"; + end + i += 1; + local org_i = i; + while codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed octal code"; + end; + local ret_chr = tonumber(utf8_sub(codes.s, org_i, i), 8); + if ret_chr > 0xFFFF then + return "character offset too large"; + end; + table.insert(outln, ret_chr); + elseif escape_c == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] == 0x7B then + i += 1; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed hexadecimal code"; + elseif i - org_i > 4 then + return "character offset too large"; + end; + table.insert(outln, tonumber(utf8_sub(codes.s, org_i, i), 16)); + else + if codes[i] and (codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66) then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and (codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66) then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(outln, radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0); + end; + else + local esc_char = b_escape_chars[escape_c] or escape_chars[escape_c]; + table.insert(outln, esc_char or escape_c); + end; + elseif c == 0x2A or c == 0x2B or c == 0x3F or c == 0x7B then + -- Quantifier + local start_q, end_q; + if c == 0x7B then + local org_i = i + 1; + local start_i; + while codes[i + 1] and (codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39 or codes[i + 1] == 0x2C and not start_i and i + 1 ~= org_i) do + i += 1; + if codes[i] == 0x2C then + start_i = i; + end; + end; + if codes[i + 1] == 0x7D then + i += 1; + if not start_i then + start_q = tonumber(utf8_sub(codes.s, org_i, i)); + end_q = start_q; + else + start_q, end_q = tonumber(utf8_sub(codes.s, org_i, start_i)), start_i + 1 == i and math.huge or tonumber(utf8_sub(codes.s, start_i + 1, i)); + if end_q < start_q then + return "numbers out of order in {} quantifier"; + end; + end; + else + table.move(codes, org_i - 1, i, #outln + 1, outln); + end; + else + start_q, end_q = c == 0x2B and 1 or 0, c == 0x3F and 1 or math.huge; + end; + if start_q then + local quantifier_type = flags.ungreedy and "lazy" or "greedy"; + if codes[i + 1] == 0x2B or codes[i + 1] == 0x3F then + i += 1; + quantifier_type = codes[i] == 0x2B and "possessive" or flags.ungreedy and "greedy" or "lazy"; + end; + local outln_len = #outln; + local last_outln_value = outln[outln_len]; + if not last_outln_value or type(last_outln_value) == "table" and (last_outln_value[1] == "quantifier" or last_outln_value[1] == 0x28 or b_escape_chars[last_outln_value[1]]) + or last_outln_value == alternation or type(last_outln_value) == "string" then + return "quantifier doesn't follow a repeatable pattern"; + end; + if end_q == 0 then + table.remove(outln); + elseif start_q ~= 1 or end_q ~= 1 then + if type(last_outln_value) == "table" and last_outln_value[1] == 0x29 then + outln_len = last_outln_value[3]; + end; + outln[outln_len] = { "quantifier", start_q, end_q, quantifier_type, outln[outln_len] }; + end; + end; + elseif c == 0x7C then + -- Alternation + table.insert(outln, alternation); + local i1 = #outln; + repeat + i1 -= 1; + local v1, is_table = outln[i1], type(outln[i1]) == "table"; + if is_table and v1[1] == 0x29 then + i1 = outln[i1][3]; + elseif is_table and v1[1] == 0x28 then + if v1[4] == 0x7C then + group_n = v1[5]; + end; + break; + end; + until not v1; + elseif c == 0x24 or c == 0x5E then + table.insert(outln, c == 0x5E and beginning_str or end_str); + elseif flags.ignoreCase and c >= 0x61 and c <= 0x7A then + table.insert(outln, c - 0x20); + elseif flags.extended and (c >= 0x09 and c <= 0x0D or c == 0x20 or c == 0x23) then + if c == 0x23 then + repeat + i += 1; + until not codes[i] or codes[i] == 0x0A or codes[i] == 0x0D; + end; + else + table.insert(outln, c); + end; + i += 1; + end; + local max_group_n = 0; + for i, v in ipairs(outln) do + if type(v) == "table" and (v[1] == 0x28 or v[1] == "quantifier" and type(v[5]) == "table" and v[5][1] == 0x28) then + if v[1] == "quantifier" then + v = v[5]; + end; + if not v[3] then + return "unterminated parenthetical"; + elseif v[2] then + max_group_n = math.max(max_group_n, v[2]); + end; + elseif type(v) == "table" and (v[1] == "backref" or v[1] == "recurmatch") then + if not group_id[v[2]] and (type(v[2]) ~= "number" or v[2] > group_n) then + return "reference to a non-existent or invalid subpattern"; + elseif v[1] == "recurmatch" and v[2] ~= 0 then + for i1, v1 in ipairs(outln) do + if type(v1) == "table" and v1[1] == 0x28 and v1[2] == v[2] then + v[3] = i1; + break; + end; + end; + elseif type(v[2]) == "string" then + v[2] = group_id[v[2]]; + end; + end; + end; + outln.group_n = max_group_n; + return outln, group_id, verb_flags; +end; + +if not tonumber(options.cacheSize) then + error(string.format("expected number for options.cacheSize, got %s", typeof(options.cacheSize)), 2); +end; +local cacheSize = math.floor(options.cacheSize or 0) ~= 0 and tonumber(options.cacheSize); +local cache_pattern, cache_pattern_names; +if not cacheSize then +elseif cacheSize < 0 or cacheSize ~= cacheSize then + error("cache size cannot be a negative number or a NaN", 2); +elseif cacheSize == math.huge then + cache_pattern, cache_pattern_names = { nil }, { nil }; +elseif cacheSize >= 2 ^ 32 then + error("cache size too large", 2); +else + cache_pattern, cache_pattern_names = table.create(options.cacheSize), table.create(options.cacheSize); +end; +if cacheSize then + function re.pruge() + table.clear(cache_pattern_names); + table.clear(cache_pattern); + end; +end; + +local function new_re(str_arr, flags, flag_repr, pattern_repr) + local tokenized_ptn, group_id, verb_flags; + local cache_format = cacheSize and string.format("%s|%s", str_arr.s, flag_repr); + local cached_token = cacheSize and cache_pattern[table.find(cache_pattern_names, cache_format)]; + if cached_token then + tokenized_ptn, group_id, verb_flags = table.unpack(cached_token, 1, 3); + else + tokenized_ptn, group_id, verb_flags = tokenize_ptn(str_arr, flags); + if type(tokenized_ptn) == "string" then + error(tokenized_ptn, 2); + end; + if cacheSize and tokenized_ptn[1] then + table.insert(cache_pattern_names, 1, cache_format); + table.insert(cache_pattern, 1, { tokenized_ptn, group_id, verb_flags }); + if cacheSize ~= math.huge then + table.remove(cache_pattern_names, cacheSize + 1); + table.remove(cache_pattern, cacheSize + 1); + end; + end; + end; + + local object = newproxy(true); + proxy[object] = { name = "RegEx", flags = flags, flag_repr = flag_repr, pattern_repr = pattern_repr, token = tokenized_ptn, group_id = group_id, verb_flags = verb_flags }; + local object_mt = getmetatable(object); + object_mt.__index = setmetatable(flags, re_m); + object_mt.__tostring = re_tostr; + object_mt.__metatable = lockmsg; + + return object; +end; + +local function escape_fslash(pre) + return (#pre % 2 == 0 and '\\' or '') .. pre .. '.'; +end; + +local function sort_flag_chr(a, b) + return a:lower() < b:lower(); +end; + +function re.new(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local ptn, flags_str = ...; + if type(ptn) == "number" then + ptn ..= ''; + elseif type(ptn) ~= "string" then + error(string.format("invalid argument #1 (string expected, got %s)", typeof(ptn)), 2); + end; + if type(flags_str) ~= "string" and type(flags_str) ~= "number" and flags_str ~= nil then + error(string.format("invalid argument #2 (string expected, got %s)", typeof(flags_str)), 2); + end; + + local flags = { + anchored = false, caseless = false, multiline = false, dotall = false, unicode = false, ungreedy = false, extended = false, + }; + local flag_repr = { }; + for f in string.gmatch(flags_str or '', utf8.charpattern) do + if flags[flag_map[f]] ~= false then + error("invalid regular expression flag " .. f, 3); + end; + flags[flag_map[f]] = true; + table.insert(flag_repr, f); + end; + table.sort(flag_repr, sort_flag_chr); + flag_repr = table.concat(flag_repr); + return new_re(to_str_arr(ptn), flags, flag_repr, string.format("/%s/", ptn:gsub("(\\*)/", escape_fslash))); +end; + +function re.fromstring(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local ptn = ...; + if type(ptn) == "number" then + ptn ..= ''; + elseif type(ptn) ~= "string" then + error(string.format("invalid argument #1 (string expected, got %s)", typeof(ptn), 2)); + end; + local str_arr = to_str_arr(ptn); + local delimiter = str_arr[1]; + if not delimiter then + error("empty regex", 2); + elseif delimiter == 0x5C or (delimiter >= 0x30 and delimiter <= 0x39) or (delimiter >= 0x41 and delimiter <= 0x5A) or (delimiter >= 0x61 and delimiter <= 0x7A) then + error("delimiter must not be alphanumeric or a backslash", 2); + end; + + local i0 = 1; + repeat + i0 = table.find(str_arr, delimiter, i0 + 1); + if not i0 then + error(string.format("no ending delimiter ('%s') found", utf8.char(delimiter)), 2); + end; + local escape_count = 1; + while str_arr[i0 - escape_count] == 0x5C do + escape_count += 1; + end; + until escape_count % 2 == 1; + + local flags = { + anchored = false, caseless = false, multiline = false, dotall = false, unicode = false, ungreedy = false, extended = false, + }; + local flag_repr = { }; + while str_arr.n > i0 do + local f = utf8.char(table.remove(str_arr)); + str_arr.n -= 1; + if flags[flag_map[f]] ~= false then + error("invalid regular expression flag " .. f, 3); + end; + flags[flag_map[f]] = true; + table.insert(flag_repr, f); + end; + table.sort(flag_repr, sort_flag_chr); + flag_repr = table.concat(flag_repr); + table.remove(str_arr, 1); + table.remove(str_arr); + str_arr.n -= 2; + str_arr.s = string.sub(str_arr.s, 2, 1 + str_arr.n); + return new_re(str_arr, flags, flag_repr, string.sub(ptn, 1, 2 + str_arr.n)); +end; + +local re_escape_line_chrs = { + ['\0'] = '\\x00', ['\n'] = '\\n', ['\t'] = '\\t', ['\r'] = '\\r', ['\f'] = '\\f', +}; + +function re.escape(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local str, extended, delimiter = ...; + if type(str) == "number" then + str ..= ''; + elseif type(str) ~= "string" then + error(string.format("invalid argument #1 to 'escape' (string expected, got %s)", typeof(str)), 2); + end; + if delimiter == nil then + delimiter = ''; + elseif type(delimiter) == "number" then + delimiter ..= ''; + elseif type(delimiter) ~= "string" then + error(string.format("invalid argument #3 to 'escape' (string expected, got %s)", typeof(delimiter)), 2); + end; + if utf8.len(delimiter) > 1 or delimiter:match("^[%a\\]$") then + error("delimiter have not be alphanumeric", 2); + end; + return (string.gsub(str, "[\0\f\n\r\t]", re_escape_line_chrs):gsub(string.format("[\\%s#()%%%%*+.?[%%]^{|%s]", extended and '%s' or '', (delimiter:find'^[%%%]]$' and '%' or '') .. delimiter), "\\%1")); +end; + +function re.type(...) + if select('#', ...) == 0 then + error("missing argument #1", 2); + end; + return proxy[...] and proxy[...].name; +end; + +for k, f in pairs(re_m) do + re[k] = f; +end; + +re_m = { __index = re_m }; + +lockmsg = re.fromstring([[/The\s*metatable\s*is\s*(?:locked|inaccessible)(?#Nice try :])/i]]); +getmetatable(lockmsg).__metatable = lockmsg; + +local function readonly_table() + error("Attempt to modify a readonly table", 2); +end; + +match_m = { + __index = match_m, + __metatable = lockmsg, + __newindex = readonly_table, +}; + +re.Match = setmetatable({ }, match_m); + +return setmetatable({ }, { + __index = re, + __metatable = lockmsg, + __newindex = readonly_table, +}); diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 7f863c6f..15813ae9 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -213,6 +213,16 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") SINGLE_COMPARE(lea(rax, qword[r13 + r12 * 4 + 4]), 0x4b, 0x8d, 0x44, 0xa5, 0x04); } +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps") +{ + SINGLE_COMPARE(jmp(rax), 0x48, 0xff, 0xe0); + SINGLE_COMPARE(jmp(r14), 0x49, 0xff, 0xe6); + SINGLE_COMPARE(jmp(qword[r14 + rdx * 4]), 0x49, 0xff, 0x24, 0x96); + SINGLE_COMPARE(call(rax), 0x48, 0xff, 0xd0); + SINGLE_COMPARE(call(r14), 0x49, 0xff, 0xd6); + SINGLE_COMPARE(call(qword[r14 + rdx * 4]), 0x49, 0xff, 0x14, 0x96); +} + TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") { // Jump back @@ -260,6 +270,23 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") {0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e}); } +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall") +{ + check( + [](AssemblyBuilderX64& build) { + Label fnB; + + build.and_(rcx, 0x3e); + build.call(fnB); + build.ret(); + + build.setLabel(fnB); + build.lea(rax, qword[rcx + 0x1f]); + build.ret(); + }, + {0x48, 0x83, 0xe1, 0x3e, 0xe8, 0x01, 0x00, 0x00, 0x00, 0xc3, 0x48, 0x8d, 0x41, 0x1f, 0xc3}); +} + TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") { SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xa9, 0x58, 0xc6); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index f0017509..6ec1426c 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -105,4 +105,37 @@ if true then REQUIRE(parentStat->is()); } +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_at_number_const") +{ + check(R"( +print(3.) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 8)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_in_workspace_dot") +{ + check(R"( +print(workspace.) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 16)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_in_workspace_colon") +{ + check(R"( +print(workspace:) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 16)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.h b/tests/Fixture.h index 1bc573da..4bd6f1ea 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -128,6 +128,7 @@ struct Fixture std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; + ScopedFastFlag sff_UnknownNever{"LuauUnknownAndNeverType", true}; TestFileResolver fileResolver; TestConfigResolver configResolver; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index b9c24704..4f331399 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -791,6 +791,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") CHECK_EQ(0, module->internalTypes.typeVars.size()); CHECK_EQ(0, module->internalTypes.typePacks.size()); CHECK_EQ(0, module->astTypes.size()); + CHECK_EQ(0, module->astResolvedTypes.size()); + CHECK_EQ(0, module->astResolvedTypePacks.size()); } TEST_CASE_FIXTURE(FrontendFixture, "it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded") diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 7c2f4d1c..dd94e9d7 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -301,8 +301,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") { - ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; - fileResolver.source["Module/A"] = R"( export type A = B type B = A diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index a474b6e7..fb0a899e 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -1055,8 +1055,6 @@ export type t1 = { a: typeof(string.byte) } TEST_CASE_FIXTURE(Fixture, "intersection_combine_on_bound_self") { - ScopedFastFlag luauNormalizeCombineEqFix{"LuauNormalizeCombineEqFix", true}; - CheckResult result = check(R"( export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,})) )"); @@ -1064,6 +1062,46 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "normalize_unions_containing_never") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type Foo = string | never + local foo: Foo + )"); + + CHECK_EQ("string", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "normalize_unions_containing_unknown") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type Foo = string | unknown + local foo: Foo + )"); + + CHECK_EQ("unknown", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "any_wins_the_battle_over_unknown_in_unions") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type Foo = unknown | any + local foo: Foo + + type Bar = any | unknown + local bar: Bar + )"); + + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_does_not_convert_ever") { ScopedFastFlag sff[]{ diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index c3c75998..c517853f 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2648,7 +2648,6 @@ type Z = { a: string | T..., b: number } TEST_CASE_FIXTURE(Fixture, "recover_function_return_type_annotations") { - ScopedFastFlag sff{"LuauReturnTypeTokenConfusion", true}; ParseResult result = tryParse(R"( type Custom = { x: A, y: B, z: C } type Packed = { x: (A...) -> () } diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 387e07cd..52a29bcc 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -96,6 +96,37 @@ TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") //clang-format on } +TEST_CASE_FIXTURE(Fixture, "metatable") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TypeVar metatable{TypeVariant(TableTypeVar())}; + TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable})}; + CHECK_EQ("{ @metatable { }, { } }", toString(&mtv)); +} + +TEST_CASE_FIXTURE(Fixture, "named_metatable") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TypeVar metatable{TypeVariant(TableTypeVar())}; + TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable, "NamedMetatable"})}; + CHECK_EQ("NamedMetatable", toString(&mtv)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "named_metatable_toStringNamedFunction") +{ + CheckResult result = check(R"( + local function createTbl(): NamedMetatable + return setmetatable({}, {}) + end + type NamedMetatable = typeof(createTbl()) + )"); + + TypeId ty = requireType("createTbl"); + const FunctionTypeVar* ftv = get(follow(ty)); + REQUIRE(ftv); + CHECK_EQ("createTbl(): NamedMetatable", toStringNamedFunction("createTbl", *ftv)); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") { CheckResult result = check(R"( @@ -468,7 +499,7 @@ local function target(callback: nil) return callback(4, "hello") end )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("(nil) -> (*unknown*)", toString(requireType("target"))); + CHECK_EQ("(nil) -> ()", toString(requireType("target"))); } TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index b02a52b2..d2ed9aef 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -583,7 +583,7 @@ TEST_CASE_FIXTURE(Fixture, "transpile_error_expr") auto names = AstNameTable{allocator}; ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); - CHECK_EQ("local a = (error-expr: f.%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); + CHECK_EQ("local a = (error-expr: f:%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); } TEST_CASE_FIXTURE(Fixture, "transpile_error_stat") diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index bc55940e..4c5309e0 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -94,7 +94,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("*unknown*", toString(requireType("a"))); + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -110,7 +110,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("*unknown*", toString(requireType("a"))); + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") @@ -225,7 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK_EQ("*unknown*", toString(requireType("a"))); + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -234,7 +234,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - CHECK_EQ("*unknown*", toString(requireType("a"))); + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 2f0266ec..7d1bd6ba 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -925,7 +925,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_returns_false_and_string_iff_it_knows )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(nil) -> nil", toString(requireType("f"))); + CHECK_EQ("(nil) -> (never, ...never)", toString(requireType("f"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") @@ -952,7 +952,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("boolean", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("", toString(requireType("d"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") @@ -965,8 +965,8 @@ a:b() a:b({}) )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 0}, {2, 5}}, CountMismatch{2, 0}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{3, 0}, {3, 5}}, CountMismatch{2, 1}})); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function expects 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function expects 2 arguments, but only 1 is specified"); } TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") @@ -1008,4 +1008,139 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = string.gmatch("This is a string", "(.()(%a+))")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types2") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = ("This is a string"):gmatch("(.()(%a+))")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c, d = string.gmatch("T(his)() is a string", ".")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 4); + + CHECK_EQ(toString(requireType("a")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c, d = string.gmatch("T(his) is a string", "((.)%b()())")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 3); + CHECK_EQ(acm->actual, 4); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "string"); + CHECK_EQ(toString(requireType("c")), "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_ignored") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = string.gmatch("T(his)() is a string", "(T[()])()")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 3); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b = string.gmatch("[[[", "()([[])")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "number"); + CHECK_EQ(toString(requireType("b")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_leading_end_bracket_is_part_of_set") +{ + CheckResult result = check(R"END( + -- An immediate right-bracket following a left-bracket is included within the set; + -- thus, '[]]'' is the set containing ']', and '[]' is an invalid set missing an enclosing + -- right-bracket. We detect an invalid set in this case and fall back to to default gmatch + -- typing. + local foo = string.gmatch("T[hi%]s]]]() is a string", "([]s)") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_invalid_pattern_fallback_to_builtin") +{ + CheckResult result = check(R"END( + local foo = string.gmatch("T(his)() is a string", ")") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_invalid_pattern_fallback_to_builtin2") +{ + CheckResult result = check(R"END( + local foo = string.gmatch("T(his)() is a string", "[") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 401a6c64..6e6549d3 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -916,13 +916,13 @@ TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") REQUIRE(tm1); CHECK_EQ("(string) -> number", toString(tm1->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm1->givenType)); + CHECK_EQ("(string, ) -> number", toString(tm1->givenType)); auto tm2 = get(result.errors[1]); REQUIRE(tm2); CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm2->givenType)); + CHECK_EQ("(string, ) -> number", toString(tm2->givenType)); } TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") @@ -1535,7 +1535,7 @@ function t:b() return 2 end -- not OK )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type '(*unknown*) -> number' could not be converted into '() -> number' + CHECK_EQ(R"(Type '() -> number' could not be converted into '() -> number' caused by: Argument count mismatch. Function expects 1 argument, but none are specified)", toString(result.errors[0])); @@ -1692,4 +1692,52 @@ TEST_CASE_FIXTURE(Fixture, "call_o_with_another_argument_after_foo_was_quantifie // TODO: check the normalized type of f } +TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_unknown") +{ + CheckResult result = check(R"( + local function foo(f: (unknown) -> (), x) + f(x) + end + )"); + + CHECK_EQ("((unknown) -> (), a) -> ()", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_infer_parameter_types_for_functions_from_their_call_site") +{ + CheckResult result = check(R"( + local t = {} + + function t.f(x) + return x + end + + t.__index = t + + function g(s) + local q = s.p and s.p.q or nil + return q and t.f(q) or nil + end + + local f = t.f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(a) -> a", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_mutate_the_underlying_head_of_typepack_when_calling_with_self") +{ + CheckResult result = check(R"( + local t = {} + function t:m(x) end + function f(): never return 5 :: never end + t:m(f()) + t:m(f()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index e9e94cfb..46258072 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,6 +9,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauCheckGenericHOFTypes) + using namespace Luau; TEST_SUITE_BEGIN("GenericsTests"); @@ -1001,7 +1003,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK_EQ("*unknown*", toString(t0->type)); + CHECK_EQ("", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -1095,10 +1097,18 @@ local b = sumrec(sum) -- ok local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred )"); - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " - "parameters", - toString(result.errors[0])); + if (FFlag::LuauCheckGenericHOFTypes) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ( + "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") @@ -1185,4 +1195,23 @@ TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_gen CHECK("((X) -> (a...), X) -> (a...)" == toString(requireType("foo"))); } +TEST_CASE_FIXTURE(Fixture, "do_not_always_instantiate_generic_intersection_types") +{ + ScopedFastFlag sff[] = { + {"LuauMaybeGenericIntersectionTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type Array = { [number]: T } + + type Array_Statics = { + new: () -> Array, + } + + local _Arr : Array & Array_Statics = {} :: Array_Statics + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 1c6fe1d8..56b807f8 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -142,7 +142,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") CHECK_EQ(2, result.errors.size()); TypeId p = requireType("p"); - CHECK_EQ("*unknown*", toString(p)); + CHECK_EQ("", toString(p)); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index a0f670f1..2343a7fa 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -143,7 +143,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export") auto hootyType = requireType(bModule, "Hooty"); - CHECK_EQ("*unknown*", toString(hootyType)); + CHECK_EQ("", toString(hootyType)); } TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript") @@ -244,7 +244,7 @@ local ModuleA = require(game.A) LUAU_REQUIRE_NO_ERRORS(result); std::optional oty = requireType("ModuleA"); - CHECK_EQ("*unknown*", toString(*oty)); + CHECK_EQ("", toString(*oty)); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types") @@ -302,6 +302,30 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_4") +{ + fileResolver.source["game/A"] = R"( +export type Array = {T} +local arrayops = {} +function arrayops.foo(x: Array) end +return arrayops + )"; + + CheckResult result = check(R"( +local arrayops = require(game.A) + +local tbl = {} +tbl.a = 2 +function tbl:foo(b: number, c: number) + -- introduce BoundTypeVar to imported type + arrayops.foo(self._regions) +end +type Table = typeof(tbl) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict") { fileResolver.source["game/A"] = R"( @@ -363,4 +387,21 @@ caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "constrained_anyification_clone_immutable_types") +{ + ScopedFastFlag luauAnyificationMustClone{"LuauAnyificationMustClone", true}; + + fileResolver.source["game/A"] = R"( +return function(...) end + )"; + + fileResolver.source["game/B"] = R"( +local l0 = require(game.A) +return l0 + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index e6174df2..c90f0a4d 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -871,4 +871,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_bra CHECK(toString(result2.errors[0]) == "Types Foo and Bar cannot be compared with == because they do not have the same metatable"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_and") +{ + ScopedFastFlag sff{"LuauBinaryNeedsExpectedTypesToo", true}; + + CheckResult result = check(R"( + local x: "a" | "b" | boolean = math.random() > 0.5 and "a" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or") +{ + ScopedFastFlag sff{"LuauBinaryNeedsExpectedTypesToo", true}; + + CheckResult result = check(R"( + local x: "a" | "b" | boolean = math.random() > 0.5 or "b" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index e1684df7..9e8e2503 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -47,7 +47,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index") REQUIRE(nat); CHECK_EQ("string", toString(nat->ty)); - CHECK_EQ("*unknown*", toString(requireType("t"))); + CHECK_EQ("", toString(requireType("t"))); } TEST_CASE_FIXTURE(Fixture, "string_method") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 059aed2e..dc686890 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -225,7 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") CHECK_EQ("{| x: nil, y: nil |} | {| x: string, y: number |}", toString(requireTypeAtPosition({7, 28}))); } -TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) +TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1}; @@ -499,6 +499,17 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_any") +{ + CheckResult result = check(R"( + local function foo(f: (any) -> (), x) + f(x) + end + )"); + + CHECK_EQ("((any) -> (), any) -> ()", toString(requireType("foo"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_function_with_no_returns") { ScopedFastFlag sff{"DebugLuauSharedSelf", true}; @@ -518,7 +529,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_f )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Not all codepaths in this function return '{ @metatable T, {| |} }, a...'.", toString(result.errors[0])); + CHECK_EQ("Not all codepaths in this function return 'self, a...'.", toString(result.errors[0])); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 3f5dad3d..cc8cdee3 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -272,8 +272,8 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); + CHECK_EQ("never", toString(requireTypeAtPosition({8, 44}))); + CHECK_EQ("never", toString(requireTypeAtPosition({9, 38}))); } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -526,7 +526,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") @@ -651,7 +651,7 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } -TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_narrowed_into_nothingness") { CheckResult result = check(R"( local function f(t: {x: number}) @@ -666,7 +666,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_warns_on_no_overlapping_types_onl LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") @@ -1074,7 +1074,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" + CHECK_EQ("never", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" } @@ -1206,6 +1206,24 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") +{ + CheckResult result = check(R"( + local function f(x: unknown) + if type(x) == "string" then + local foo = x + else + local bar = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("unknown", toString(requireTypeAtPosition({5, 28}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") { ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; @@ -1227,4 +1245,19 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_ni CHECK_EQ("number", toString(requireTypeAtPosition({6, 28}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "what_nonsensical_condition") +{ + CheckResult result = check(R"( + local function f(x) + if type(x) == "string" and type(x) == "number" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index eead5b30..3e830f2a 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -3070,4 +3070,18 @@ TEST_CASE_FIXTURE(Fixture, "quantify_even_that_table_was_never_exported_at_all") CHECK_EQ("{| m: ({+ x: a, y: b +}) -> a, n: ({+ x: a, y: b +}) -> b |}", toString(requireType("T"), opts)); } +TEST_CASE_FIXTURE(BuiltinsFixture, "leaking_bad_metatable_errors") +{ + ScopedFastFlag luauIndexSilenceErrors{"LuauIndexSilenceErrors", true}; + + CheckResult result = check(R"( +local a = setmetatable({}, 1) +local b = a.x + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Metatable was not a table", toString(result.errors[0])); + CHECK_EQ("Type 'a' does not have key 'x'", toString(result.errors[1])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index efdfe0b1..7d1fb56b 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -238,10 +238,10 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") // TODO: Should we assert anything about these tests when DCR is being used? if (!FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ("*unknown*", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); - CHECK_EQ("*unknown*", toString(requireType("e"))); - CHECK_EQ("*unknown*", toString(requireType("f"))); + CHECK_EQ("", toString(requireType("c"))); + CHECK_EQ("", toString(requireType("d"))); + CHECK_EQ("", toString(requireType("e"))); + CHECK_EQ("", toString(requireType("f"))); } } @@ -622,7 +622,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK_EQ("*unknown*", toString(t0->type)); + CHECK_EQ("", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -1003,4 +1003,27 @@ TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_ )"); } +TEST_CASE_FIXTURE(Fixture, "types stored in astResolvedTypes") +{ + CheckResult result = check(R"( +type alias = typeof("hello") +local function foo(param: alias) +end + )"); + + auto node = findNodeAtPosition(*getMainSourceModule(), {2, 16}); + auto ty = lookupType("alias"); + REQUIRE(node); + REQUIRE(node->is()); + REQUIRE(ty); + + auto func = node->as(); + REQUIRE(func->args.size == 1); + + auto arg = *func->args.begin(); + auto annotation = arg->annotation; + + CHECK_EQ(*getMainModule()->astResolvedTypes.find(annotation), *ty); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 49deae71..d51a38f8 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -121,7 +121,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("a", toString(requireType("a"))); - CHECK_EQ("*unknown*", toString(requireType("b"))); + CHECK_EQ("", toString(requireType("b"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") @@ -136,7 +136,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_con LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("a", toString(requireType("a"))); - CHECK_EQ("*unknown*", toString(requireType("b"))); + CHECK_EQ("", toString(requireType("b"))); CHECK_EQ("number", toString(requireType("c"))); } diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 2b48133d..94918692 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -199,7 +199,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - CHECK_EQ("*unknown*", toString(requireType("r"))); + CHECK_EQ("", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp new file mode 100644 index 00000000..bc742b03 --- /dev/null +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -0,0 +1,280 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferUnknownNever"); + +TEST_CASE_FIXTURE(Fixture, "string_subtype_and_unknown_supertype") +{ + CheckResult result = check(R"( + local function f(x: string) + local foo: unknown = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_subtype_and_string_supertype") +{ + CheckResult result = check(R"( + local function f(x: unknown) + local foo: string = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_is_reflexive") +{ + CheckResult result = check(R"( + local function f(x: unknown) + local foo: unknown = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_subtype_and_never_supertype") +{ + CheckResult result = check(R"( + local function f(x: string) + local foo: never = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "never_subtype_and_string_supertype") +{ + CheckResult result = check(R"( + local function f(x: never) + local foo: string = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "never_is_reflexive") +{ + CheckResult result = check(R"( + local function f(x: never) + local foo: never = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_is_optional_because_it_too_encompasses_nil") +{ + CheckResult result = check(R"( + local t: {x: unknown} = {} + )"); +} + +TEST_CASE_FIXTURE(Fixture, "table_with_prop_of_type_never_is_uninhabitable") +{ + CheckResult result = check(R"( + local t: {x: never} = {} + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "table_with_prop_of_type_never_is_also_reflexive") +{ + CheckResult result = check(R"( + local t: {x: never} = {x = 5 :: never} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "array_like_table_of_never_is_inhabitable") +{ + CheckResult result = check(R"( + local t: {never} = {} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable") +{ + CheckResult result = check(R"( + local function f() return "foo", 5 :: never end + + local x, y, z = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); + CHECK_EQ("never", toString(requireType("y"))); + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable2") +{ + CheckResult result = check(R"( + local function f(): (string, never) return "", 5 :: never end + local function g(): (never, string) return 5 :: never, "" end + + local x1, x2 = f() + local y1, y2 = g() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x1"))); + CHECK_EQ("never", toString(requireType("x2"))); + CHECK_EQ("never", toString(requireType("y1"))); + CHECK_EQ("never", toString(requireType("y2"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_never") +{ + CheckResult result = check(R"( + local x: never = 5 :: never + local z = x.y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_never") +{ + CheckResult result = check(R"( + local f: never = 5 :: never + local x, y, z = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); + CHECK_EQ("never", toString(requireType("y"))); + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_local_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t = 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_global_which_is_never") +{ + CheckResult result = check(R"( + --!nonstrict + t = 5 :: never + t = "" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_prop_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t.x = 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t[5] = 7 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") +{ + CheckResult result = check(R"( + for i, v in (5 :: never) do + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "pick_never_from_variadic_type_pack") +{ + CheckResult result = check(R"( + local function f(...: never) + local x, y = (...) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_never") +{ + CheckResult result = check(R"( + type Disjoint = {foo: never, bar: unknown, tag: "ok"} | {foo: never, baz: unknown, tag: "err"} + local disjoint: Disjoint = {foo = 5 :: never, bar = true, tag = "ok"} + local foo = disjoint.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_sorta_never") +{ + CheckResult result = check(R"( + type Disjoint = {foo: string, bar: unknown, tag: "ok"} | {foo: never, baz: unknown, tag: "err"} + local disjoint: Disjoint = {foo = 5 :: never, bar = true, tag = "ok"} + local foo = disjoint.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "unary_minus_of_never") +{ + CheckResult result = check(R"( + local x = -(5 :: never) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "length_of_never") +{ + CheckResult result = check(R"( + local x = #({} :: never) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); +} + +TEST_SUITE_END(); diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 8a5a65fe..35852c05 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -199,8 +199,6 @@ TEST_CASE_FIXTURE(TypePackFixture, "std_distance") TEST_CASE("content_reassignment") { - ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; - TypePackVar myError{Unifiable::Error{}, /*presistent*/ true}; TypeArena arena; diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 4f8fc502..f4670048 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -418,8 +418,6 @@ TEST_CASE("proof_that_isBoolean_uses_all_of") TEST_CASE("content_reassignment") { - ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; - TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; myAny.normal = true; myAny.documentationSymbol = "@global/any"; diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 22d6adfc..6164e929 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -101,4 +101,20 @@ if vector_size == 4 then assert(vector(1, 2, 3, 4).W == 4) end +-- negative zero should hash the same as zero +-- note: our earlier test only really checks the low hash bit, so in absence of perfect avalanche it's insufficient +do + local larget = {} + for i = 1, 2^14 do + larget[vector(0, 0, i)] = true + end + + larget[vector(0, 0, 0)] = 42 + + assert(larget[vector(0, 0, 0)] == 42) + assert(larget[vector(0, 0, -0)] == 42) + assert(larget[vector(0, -0, 0)] == 42) + assert(larget[vector(-0, 0, 0)] == 42) +end + return 'OK' diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis index ccc7e390..cb2f355f 100644 --- a/tools/natvis/VM.natvis +++ b/tools/natvis/VM.natvis @@ -137,16 +137,28 @@ {data,s} + + + + + + + + + empty + none + + + {proto()->source->data,sb}:{line()} function {proto()->debugname->data,sb}() + {proto()->source->data,sb}:{line()} function() + + + =[C] function {cl().c.debugname,sb}() {cl().c.f,na} + =[C] {cl().c.f,na} + + - - {ci->func->value.gc->cl.c.f,na} - - - {ci->func->value.gc->cl.l.p->source->data,sb}:{ci->func->value.gc->cl.l.p->linedefined,d} {ci->func->value.gc->cl.l.p->debugname->data,sb} - - - {ci->func->value.gc->cl.l.p->source->data,sb}:{ci->func->value.gc->cl.l.p->linedefined,d} - + {ci,na} thread @@ -156,7 +168,7 @@ ci-base_ci - base_ci[ci-base_ci - $i].func->value.gc->cl,view(short) + base_ci[ci-base_ci - $i]