diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h index 9ebdfc16..8864ef81 100644 --- a/Analysis/include/Luau/Subtyping.h +++ b/Analysis/include/Luau/Subtyping.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Type.h" +#include "Luau/UnifierSharedState.h" #include #include @@ -11,11 +12,12 @@ namespace Luau template struct TryPair; +struct InternalErrorReporter; class Normalizer; struct NormalizedType; -struct SubtypingGraph +struct SubtypingResult { // Did the test succeed? bool isSubtype = false; @@ -25,39 +27,78 @@ struct SubtypingGraph // If so, what constraints are implied by this relation? // If not, what happened? - SubtypingGraph and_(const SubtypingGraph& other); - SubtypingGraph or_(const SubtypingGraph& other); + void andAlso(const SubtypingResult& other); + void orElse(const SubtypingResult& other); - static SubtypingGraph and_(const std::vector& results); - static SubtypingGraph or_(const std::vector& results); + static SubtypingResult all(const std::vector& results); + static SubtypingResult any(const std::vector& results); }; struct Subtyping { NotNull builtinTypes; + NotNull arena; NotNull normalizer; + NotNull iceReporter; + + enum class Variance + { + Covariant, + Contravariant + }; + + Variance variance = Variance::Covariant; + + struct GenericBounds + { + DenseHashSet lowerBound{nullptr}; + DenseHashSet upperBound{nullptr}; + }; + + /* + * When we encounter a generic over the course of a subtyping test, we need + * to tentatively map that generic onto a type on the other side. + */ + DenseHashMap mappedGenerics{nullptr}; + DenseHashMap mappedGenericPacks{nullptr}; + + using SeenSet = std::unordered_set, TypeIdPairHash>; + + SeenSet seenTypes; // TODO cache // TODO cyclic types // TODO recursion limits - SubtypingGraph isSubtype(TypeId subTy, TypeId superTy); - SubtypingGraph isSubtype(TypePackId subTy, TypePackId superTy); + SubtypingResult isSubtype(TypeId subTy, TypeId superTy); + SubtypingResult isSubtype(TypePackId subTy, TypePackId superTy); private: + SubtypingResult isSubtype_(TypeId subTy, TypeId superTy); + SubtypingResult isSubtype_(TypePackId subTy, TypePackId superTy); + template - SubtypingGraph isSubtype(const TryPair& pair); + SubtypingResult isSubtype_(const TryPair& pair); - SubtypingGraph isSubtype(TypeId subTy, const UnionType* superUnion); - SubtypingGraph isSubtype(const UnionType* subUnion, TypeId superTy); - SubtypingGraph isSubtype(TypeId subTy, const IntersectionType* superIntersection); - SubtypingGraph isSubtype(const IntersectionType* subIntersection, TypeId superTy); - SubtypingGraph isSubtype(const PrimitiveType* subPrim, const PrimitiveType* superPrim); - SubtypingGraph isSubtype(const SingletonType* subSingleton, const PrimitiveType* superPrim); - SubtypingGraph isSubtype(const SingletonType* subSingleton, const SingletonType* superSingleton); - SubtypingGraph isSubtype(const FunctionType* subFunction, const FunctionType* superFunction); + SubtypingResult isSubtype_(TypeId subTy, const UnionType* superUnion); + SubtypingResult isSubtype_(const UnionType* subUnion, TypeId superTy); + SubtypingResult isSubtype_(TypeId subTy, const IntersectionType* superIntersection); + SubtypingResult isSubtype_(const IntersectionType* subIntersection, TypeId superTy); + SubtypingResult isSubtype_(const PrimitiveType* subPrim, const PrimitiveType* superPrim); + SubtypingResult isSubtype_(const SingletonType* subSingleton, const PrimitiveType* superPrim); + SubtypingResult isSubtype_(const SingletonType* subSingleton, const SingletonType* superSingleton); + SubtypingResult isSubtype_(const TableType* subTable, const TableType* superTable); + SubtypingResult isSubtype_(const FunctionType* subFunction, const FunctionType* superFunction); + SubtypingResult isSubtype_(const NormalizedType* subNorm, const NormalizedType* superNorm); - SubtypingGraph isSubtype(const NormalizedType* subNorm, const NormalizedType* superNorm); + bool bindGeneric(TypeId subTp, TypeId superTp); + bool bindGeneric(TypePackId subTp, TypePackId superTp); + + template + TypeId makeAggregateType(const Container& container, TypeId orElse); + + [[noreturn]] + void unexpected(TypePackId tp); }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index c360e4bc..9699c4ae 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -108,7 +108,7 @@ struct TryPair { A first; B second; - operator bool() const + explicit operator bool() const { return bool(first) && bool(second); } diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 6a6f10e8..a71dd592 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -12,6 +12,7 @@ #include LUAU_FASTFLAG(DebugLuauReadWriteProperties) +LUAU_FASTFLAGVARIABLE(FixFindBindingAtFunctionName, false); namespace Luau { @@ -148,6 +149,23 @@ struct FindNode : public AstVisitor return false; } + bool visit(AstStatFunction* node) override + { + if (FFlag::FixFindBindingAtFunctionName) + { + visit(static_cast(node)); + if (node->name->location.contains(pos)) + node->name->visit(this); + else if (node->func->location.contains(pos)) + node->func->visit(this); + return false; + } + else + { + return AstVisitor::visit(node); + } + } + bool visit(AstStatBlock* block) override { visit(static_cast(block)); @@ -188,6 +206,23 @@ struct FindFullAncestry final : public AstVisitor return false; } + bool visit(AstStatFunction* node) override + { + if (FFlag::FixFindBindingAtFunctionName) + { + visit(static_cast(node)); + if (node->name->location.contains(pos)) + node->name->visit(this); + else if (node->func->location.contains(pos)) + node->func->visit(this); + return false; + } + else + { + return AstVisitor::visit(node); + } + } + bool visit(AstNode* node) override { if (node->location.contains(pos)) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 9b6f4db7..471fd006 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -15,7 +15,7 @@ LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTFLAGVARIABLE(LuauAnonymousAutofilled1, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteLastTypecheck, false) -LUAU_FASTFLAGVARIABLE(LuauAutocompleteHideSelfArg, false) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringLiteralBounds, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -283,38 +283,20 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); - if (FFlag::LuauAutocompleteHideSelfArg) - { - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Property, - type, - prop.deprecated, - isWrongIndexer(type), - typeCorrect, - containingClass, - &prop, - prop.documentationSymbol, - {}, - parens, - {}, - indexType == PropIndexType::Colon - }; - } - else - { - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Property, - type, - prop.deprecated, - isWrongIndexer(type), - typeCorrect, - containingClass, - &prop, - prop.documentationSymbol, - {}, - parens - }; - } + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Property, + type, + prop.deprecated, + isWrongIndexer(type), + typeCorrect, + containingClass, + &prop, + prop.documentationSymbol, + {}, + parens, + {}, + indexType == PropIndexType::Colon + }; } } }; @@ -484,8 +466,19 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi return result; } -static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result) +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result) { + if (FFlag::LuauAutocompleteStringLiteralBounds) + { + if (position == node->location.begin || position == node->location.end) + { + if (auto str = node->as(); str && str->quoteStyle == AstExprConstantString::Quoted) + return; + else if (node->is()) + return; + } + } + auto formatKey = [addQuotes](const std::string& key) { if (addQuotes) return "\"" + escape(key) + "\""; @@ -1238,7 +1231,7 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; if (auto ty = findExpectedTypeAt(module, node, position)) - autocompleteStringSingleton(*ty, true, result); + autocompleteStringSingleton(*ty, true, node, position, result); } return AutocompleteContext::Expression; @@ -1719,7 +1712,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*nodeIt, !node->is(), result); + autocompleteStringSingleton(*nodeIt, !node->is(), node, position, result); if (!key) { @@ -1731,7 +1724,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // suggest those too. if (auto ttv = get(follow(*it)); ttv && ttv->indexer) { - autocompleteStringSingleton(ttv->indexer->indexType, false, result); + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); } } @@ -1768,7 +1761,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteEntryMap result; if (auto it = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*it, false, result); + autocompleteStringSingleton(*it, false, node, position, result); if (ancestry.size() >= 2) { @@ -1782,7 +1775,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) - autocompleteStringSingleton(*it, false, result); + autocompleteStringSingleton(*it, false, node, position, result); } } } diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index 596890f9..b55686b9 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -3,8 +3,12 @@ #include "Luau/Subtyping.h" #include "Luau/Common.h" +#include "Luau/Error.h" #include "Luau/Normalize.h" +#include "Luau/StringUtils.h" +#include "Luau/ToString.h" #include "Luau/Type.h" +#include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" @@ -13,42 +17,109 @@ namespace Luau { -SubtypingGraph SubtypingGraph::and_(const SubtypingGraph& other) +struct VarianceFlipper { - return SubtypingGraph{ - isSubtype && other.isSubtype, - // `||` is intentional here, we want to preserve error-suppressing flag. - isErrorSuppressing || other.isErrorSuppressing, - normalizationTooComplex || other.normalizationTooComplex, - }; + Subtyping::Variance* variance; + Subtyping::Variance oldValue; + + VarianceFlipper(Subtyping::Variance* v) + : variance(v) + , oldValue(*v) + { + switch (oldValue) + { + case Subtyping::Variance::Covariant: + *variance = Subtyping::Variance::Contravariant; + break; + case Subtyping::Variance::Contravariant: + *variance = Subtyping::Variance::Covariant; + break; + } + } + + ~VarianceFlipper() + { + *variance = oldValue; + } +}; + +void SubtypingResult::andAlso(const SubtypingResult& other) +{ + isSubtype &= other.isSubtype; + // `|=` is intentional here, we want to preserve error related flags. + isErrorSuppressing |= other.isErrorSuppressing; + normalizationTooComplex |= other.normalizationTooComplex; } -SubtypingGraph SubtypingGraph::or_(const SubtypingGraph& other) +void SubtypingResult::orElse(const SubtypingResult& other) { - return SubtypingGraph{ - isSubtype || other.isSubtype, - isErrorSuppressing || other.isErrorSuppressing, - normalizationTooComplex || other.normalizationTooComplex, - }; + isSubtype |= other.isSubtype; + isErrorSuppressing |= other.isErrorSuppressing; + normalizationTooComplex |= other.normalizationTooComplex; } -SubtypingGraph SubtypingGraph::and_(const std::vector& results) +SubtypingResult SubtypingResult::all(const std::vector& results) { - SubtypingGraph acc{true, false}; - for (const SubtypingGraph& current : results) - acc = acc.and_(current); + SubtypingResult acc{true, false}; + for (const SubtypingResult& current : results) + acc.andAlso(current); return acc; } -SubtypingGraph SubtypingGraph::or_(const std::vector& results) +SubtypingResult SubtypingResult::any(const std::vector& results) { - SubtypingGraph acc{false, false}; - for (const SubtypingGraph& current : results) - acc = acc.or_(current); + SubtypingResult acc{false, false}; + for (const SubtypingResult& current : results) + acc.orElse(current); return acc; } -SubtypingGraph Subtyping::isSubtype(TypeId subTy, TypeId superTy) +SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy) +{ + mappedGenerics.clear(); + mappedGenericPacks.clear(); + + SubtypingResult result = isSubtype_(subTy, superTy); + + for (const auto& [subTy, bounds]: mappedGenerics) + { + const auto& lb = bounds.lowerBound; + const auto& ub = bounds.upperBound; + + TypeId lowerBound = makeAggregateType(lb, builtinTypes->neverType); + TypeId upperBound = makeAggregateType(ub, builtinTypes->unknownType); + + result.andAlso(isSubtype_(lowerBound, upperBound)); + } + + return result; +} + +SubtypingResult Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) +{ + return isSubtype_(subTp, superTp); +} + +namespace +{ +struct SeenSetPopper +{ + Subtyping::SeenSet* seenTypes; + std::pair pair; + + SeenSetPopper(Subtyping::SeenSet* seenTypes, std::pair pair) + : seenTypes(seenTypes) + , pair(pair) + {} + + ~SeenSetPopper() + { + seenTypes->erase(pair); + } +}; +} + +SubtypingResult Subtyping::isSubtype_(TypeId subTy, TypeId superTy) { subTy = follow(subTy); superTy = follow(superTy); @@ -60,20 +131,25 @@ SubtypingGraph Subtyping::isSubtype(TypeId subTy, TypeId superTy) if (subTy == superTy) return {true}; + std::pair typePair{subTy, superTy}; + if (!seenTypes.insert(typePair).second) + return {true}; + + SeenSetPopper ssp{&seenTypes, typePair}; if (auto superUnion = get(superTy)) - return isSubtype(subTy, superUnion); + return isSubtype_(subTy, superUnion); else if (auto subUnion = get(subTy)) - return isSubtype(subUnion, superTy); + return isSubtype_(subUnion, superTy); else if (auto superIntersection = get(superTy)) - return isSubtype(subTy, superIntersection); + return isSubtype_(subTy, superIntersection); else if (auto subIntersection = get(subTy)) { - SubtypingGraph result = isSubtype(subIntersection, superTy); + SubtypingResult result = isSubtype_(subIntersection, superTy); if (result.isSubtype || result.isErrorSuppressing || result.normalizationTooComplex) return result; else - return isSubtype(normalizer->normalize(subTy), normalizer->normalize(superTy)); + return isSubtype_(normalizer->normalize(subTy), normalizer->normalize(superTy)); } else if (get(superTy)) return {true}; // This is always true. @@ -81,9 +157,11 @@ SubtypingGraph Subtyping::isSubtype(TypeId subTy, TypeId superTy) { // any = unknown | error, so we rewrite this to match. // As per TAPL: A | B <: T iff A <: T && B <: T - return isSubtype(builtinTypes->unknownType, superTy).and_(isSubtype(builtinTypes->errorType, superTy)); + SubtypingResult result = isSubtype_(builtinTypes->unknownType, superTy); + result.andAlso(isSubtype_(builtinTypes->errorType, superTy)); + return result; } - else if (auto superUnknown = get(superTy)) + else if (get(superTy)) { LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. @@ -98,19 +176,31 @@ SubtypingGraph Subtyping::isSubtype(TypeId subTy, TypeId superTy) return {false, true}; else if (get(subTy)) return {false, true}; + else if (auto subGeneric = get(subTy); subGeneric && variance == Variance::Covariant) + { + bool ok = bindGeneric(subTy, superTy); + return {ok}; + } + else if (auto superGeneric = get(superTy); superGeneric && variance == Variance::Contravariant) + { + bool ok = bindGeneric(subTy, superTy); + return {ok}; + } else if (auto p = get2(subTy, superTy)) - return isSubtype(p); + return isSubtype_(p); else if (auto p = get2(subTy, superTy)) - return isSubtype(p); + return isSubtype_(p); else if (auto p = get2(subTy, superTy)) - return isSubtype(p); + return isSubtype_(p); else if (auto p = get2(subTy, superTy)) - return isSubtype(p); + return isSubtype_(p); + else if (auto p = get2(subTy, superTy)) + return isSubtype_(p); return {false}; } -SubtypingGraph Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) +SubtypingResult Subtyping::isSubtype_(TypePackId subTp, TypePackId superTp) { subTp = follow(subTp); superTp = follow(superTp); @@ -120,14 +210,17 @@ SubtypingGraph Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) const size_t headSize = std::min(subHead.size(), superHead.size()); - std::vector results; + std::vector results; results.reserve(std::max(subHead.size(), superHead.size()) + 1); + if (subTp == superTp) + return {true}; + // Match head types pairwise for (size_t i = 0; i < headSize; ++i) { - results.push_back(isSubtype(subHead[i], superHead[i])); + results.push_back(isSubtype_(subHead[i], superHead[i])); if (!results.back().isSubtype) return {false}; } @@ -141,12 +234,40 @@ SubtypingGraph Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) if (auto vt = get(*subTail)) { for (size_t i = headSize; i < superHead.size(); ++i) + results.push_back(isSubtype_(vt->ty, superHead[i])); + } + else if (auto gt = get(*subTail)) + { + if (variance == Variance::Covariant) { - results.push_back(isSubtype(vt->ty, superHead[i])); + // For any non-generic type T: + // + // (X) -> () <: (T) -> () + + // Possible optimization: If headSize == 0 then we can just use subTp as-is. + std::vector headSlice(begin(superHead), end(superHead) + headSize); + TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail); + + if (TypePackId* other = mappedGenericPacks.find(*subTail)) + results.push_back(isSubtype_(*other, superTailPack)); + else + mappedGenericPacks.try_insert(*subTail, superTailPack); + + // FIXME? Not a fan of the early return here. It makes the + // control flow harder to reason about. + return SubtypingResult::all(results); + } + else + { + // For any non-generic type T: + // + // (T) -> () (X) -> () + // + return {false}; } } else - LUAU_ASSERT(0); // TODO + unexpected(*subTail); } else return {false}; @@ -158,20 +279,43 @@ SubtypingGraph Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) if (auto vt = get(*superTail)) { for (size_t i = headSize; i < subHead.size(); ++i) + results.push_back(isSubtype_(subHead[i], vt->ty)); + } + else if (auto gt = get(*superTail)) + { + if (variance == Variance::Contravariant) { - results.push_back(isSubtype(subHead[i], vt->ty)); + // For any non-generic type T: + // + // (X...) -> () <: (T) -> () + + // Possible optimization: If headSize == 0 then we can just use subTp as-is. + std::vector headSlice(begin(subHead), end(subHead) + headSize); + TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail); + + if (TypePackId* other = mappedGenericPacks.find(*superTail)) + results.push_back(isSubtype_(*other, subTailPack)); + else + mappedGenericPacks.try_insert(*superTail, subTailPack); + + // FIXME? Not a fan of the early return here. It makes the + // control flow harder to reason about. + return SubtypingResult::all(results); + } + else + { + // For any non-generic type T: + // + // () -> T () -> X... + return {false}; } } else - LUAU_ASSERT(0); // TODO + unexpected(*superTail); } else return {false}; } - else - { - // subHead and superHead are the same size. Nothing more must be done. - } // Handle tails @@ -179,10 +323,43 @@ SubtypingGraph Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) { if (auto p = get2(*subTail, *superTail)) { - results.push_back(isSubtype(p.first->ty, p.second->ty)); + results.push_back(isSubtype_(p.first->ty, p.second->ty)); + } + else if (auto p = get2(*subTail, *superTail)) + { + bool ok = bindGeneric(*subTail, *superTail); + results.push_back({ok}); + } + else if (get2(*subTail, *superTail)) + { + if (variance == Variance::Contravariant) + { + // (A...) -> number <: (...number) -> number + bool ok = bindGeneric(*subTail, *superTail); + results.push_back({ok}); + } + else + { + // (number) -> ...number (number) -> A... + results.push_back({false}); + } + } + else if (get2(*subTail, *superTail)) + { + if (variance == Variance::Contravariant) + { + // (...number) -> number (A...) -> number + results.push_back({false}); + } + else + { + // () -> A... <: () -> ...number + bool ok = bindGeneric(*subTail, *superTail); + results.push_back({ok}); + } } else - LUAU_ASSERT(0); // TODO + iceReporter->ice(format("Subtyping::isSubtype got unexpected type packs %s and %s", toString(*subTail).c_str(), toString(*superTail).c_str())); } else if (subTail) { @@ -190,8 +367,13 @@ SubtypingGraph Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) { return {false}; } - - LUAU_ASSERT(0); // TODO + else if (get(*subTail)) + { + bool ok = bindGeneric(*subTail, builtinTypes->emptyTypePack); + return {ok}; + } + else + unexpected(*subTail); } else if (superTail) { @@ -207,17 +389,27 @@ SubtypingGraph Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) * All variadic type packs are therefore supertypes of the empty type pack. */ } + else if (get(*superTail)) + { + if (variance == Variance::Contravariant) + { + bool ok = bindGeneric(builtinTypes->emptyTypePack, *superTail); + results.push_back({ok}); + } + else + results.push_back({false}); + } else LUAU_ASSERT(0); // TODO } - return SubtypingGraph::and_(results); + return SubtypingResult::all(results); } template -SubtypingGraph Subtyping::isSubtype(const TryPair& pair) +SubtypingResult Subtyping::isSubtype_(const TryPair& pair) { - return isSubtype(pair.first, pair.second); + return isSubtype_(pair.first, pair.second); } /* @@ -251,49 +443,49 @@ SubtypingGraph Subtyping::isSubtype(const TryPair& * other just asks for boolean ~ 'b. We can dispatch this and only commit * boolean ~ 'b. This constraint does not teach us anything about 'a. */ -SubtypingGraph Subtyping::isSubtype(TypeId subTy, const UnionType* superUnion) +SubtypingResult Subtyping::isSubtype_(TypeId subTy, const UnionType* superUnion) { // As per TAPL: T <: A | B iff T <: A || T <: B - std::vector subtypings; + std::vector subtypings; for (TypeId ty : superUnion) - subtypings.push_back(isSubtype(subTy, ty)); - return SubtypingGraph::or_(subtypings); + subtypings.push_back(isSubtype_(subTy, ty)); + return SubtypingResult::any(subtypings); } -SubtypingGraph Subtyping::isSubtype(const UnionType* subUnion, TypeId superTy) +SubtypingResult Subtyping::isSubtype_(const UnionType* subUnion, TypeId superTy) { // As per TAPL: A | B <: T iff A <: T && B <: T - std::vector subtypings; + std::vector subtypings; for (TypeId ty : subUnion) - subtypings.push_back(isSubtype(ty, superTy)); - return SubtypingGraph::and_(subtypings); + subtypings.push_back(isSubtype_(ty, superTy)); + return SubtypingResult::all(subtypings); } -SubtypingGraph Subtyping::isSubtype(TypeId subTy, const IntersectionType* superIntersection) +SubtypingResult Subtyping::isSubtype_(TypeId subTy, const IntersectionType* superIntersection) { // As per TAPL: T <: A & B iff T <: A && T <: B - std::vector subtypings; + std::vector subtypings; for (TypeId ty : superIntersection) - subtypings.push_back(isSubtype(subTy, ty)); - return SubtypingGraph::and_(subtypings); + subtypings.push_back(isSubtype_(subTy, ty)); + return SubtypingResult::all(subtypings); } -SubtypingGraph Subtyping::isSubtype(const IntersectionType* subIntersection, TypeId superTy) +SubtypingResult Subtyping::isSubtype_(const IntersectionType* subIntersection, TypeId superTy) { // TODO: Semantic subtyping here. // As per TAPL: A & B <: T iff A <: T || B <: T - std::vector subtypings; + std::vector subtypings; for (TypeId ty : subIntersection) - subtypings.push_back(isSubtype(ty, superTy)); - return SubtypingGraph::or_(subtypings); + subtypings.push_back(isSubtype_(ty, superTy)); + return SubtypingResult::any(subtypings); } -SubtypingGraph Subtyping::isSubtype(const PrimitiveType* subPrim, const PrimitiveType* superPrim) +SubtypingResult Subtyping::isSubtype_(const PrimitiveType* subPrim, const PrimitiveType* superPrim) { return {subPrim->type == superPrim->type}; } -SubtypingGraph Subtyping::isSubtype(const SingletonType* subSingleton, const PrimitiveType* superPrim) +SubtypingResult Subtyping::isSubtype_(const SingletonType* subSingleton, const PrimitiveType* superPrim) { if (get(subSingleton) && superPrim->type == PrimitiveType::String) return {true}; @@ -303,42 +495,123 @@ SubtypingGraph Subtyping::isSubtype(const SingletonType* subSingleton, const Pri return {false}; } -SubtypingGraph Subtyping::isSubtype(const SingletonType* subSingleton, const SingletonType* superSingleton) +SubtypingResult Subtyping::isSubtype_(const SingletonType* subSingleton, const SingletonType* superSingleton) { return {*subSingleton == *superSingleton}; } -SubtypingGraph Subtyping::isSubtype(const FunctionType* subFunction, const FunctionType* superFunction) +SubtypingResult Subtyping::isSubtype_(const TableType* subTable, const TableType* superTable) { - SubtypingGraph argResult = isSubtype(superFunction->argTypes, subFunction->argTypes); - SubtypingGraph retResult = isSubtype(subFunction->retTypes, superFunction->retTypes); + SubtypingResult result{true}; - return argResult.and_(retResult); -} - -SubtypingGraph Subtyping::isSubtype(const NormalizedType* subNorm, const NormalizedType* superNorm) -{ - if (!subNorm || !superNorm) - return {false, true, true}; - - SubtypingGraph result{true}; - result = result.and_(isSubtype(subNorm->tops, superNorm->tops)); - result = result.and_(isSubtype(subNorm->booleans, superNorm->booleans)); - // isSubtype(subNorm->classes, superNorm->classes); - // isSubtype(subNorm->classes, superNorm->tables); - result = result.and_(isSubtype(subNorm->errors, superNorm->errors)); - result = result.and_(isSubtype(subNorm->nils, superNorm->nils)); - result = result.and_(isSubtype(subNorm->numbers, superNorm->numbers)); - result.isSubtype &= Luau::isSubtype(subNorm->strings, superNorm->strings); - // isSubtype(subNorm->strings, superNorm->tables); - result = result.and_(isSubtype(subNorm->threads, superNorm->threads)); - // isSubtype(subNorm->tables, superNorm->tables); - // isSubtype(subNorm->tables, superNorm->strings); - // isSubtype(subNorm->tables, superNorm->classes); - // isSubtype(subNorm->functions, superNorm->functions); - // isSubtype(subNorm->tyvars, superNorm->tyvars); + for (const auto& [name, prop]: superTable->props) + { + auto it = subTable->props.find(name); + if (it != subTable->props.end()) + { + // Table properties are invariant + result.andAlso(isSubtype(it->second.type(), prop.type())); + result.andAlso(isSubtype(prop.type(), it->second.type())); + } + else + return SubtypingResult{false}; + } return result; } +SubtypingResult Subtyping::isSubtype_(const FunctionType* subFunction, const FunctionType* superFunction) +{ + SubtypingResult result; + { + VarianceFlipper vf{&variance}; + result.orElse(isSubtype_(superFunction->argTypes, subFunction->argTypes)); + } + + result.andAlso(isSubtype_(subFunction->retTypes, superFunction->retTypes)); + + return result; +} + +SubtypingResult Subtyping::isSubtype_(const NormalizedType* subNorm, const NormalizedType* superNorm) +{ + if (!subNorm || !superNorm) + return {false, true, true}; + + SubtypingResult result = isSubtype_(subNorm->tops, superNorm->tops); + result.andAlso(isSubtype_(subNorm->booleans, superNorm->booleans)); + // isSubtype_(subNorm->classes, superNorm->classes); + // isSubtype_(subNorm->classes, superNorm->tables); + result.andAlso(isSubtype_(subNorm->errors, superNorm->errors)); + result.andAlso(isSubtype_(subNorm->nils, superNorm->nils)); + result.andAlso(isSubtype_(subNorm->numbers, superNorm->numbers)); + result.isSubtype &= Luau::isSubtype(subNorm->strings, superNorm->strings); + // isSubtype_(subNorm->strings, superNorm->tables); + result.andAlso(isSubtype_(subNorm->threads, superNorm->threads)); + // isSubtype_(subNorm->tables, superNorm->tables); + // isSubtype_(subNorm->tables, superNorm->strings); + // isSubtype_(subNorm->tables, superNorm->classes); + // isSubtype_(subNorm->functions, superNorm->functions); + // isSubtype_(subNorm->tyvars, superNorm->tyvars); + + return result; +} + +bool Subtyping::bindGeneric(TypeId subTy, TypeId superTy) +{ + if (variance == Variance::Covariant) + { + if (!get(subTy)) + return false; + + mappedGenerics[subTy].upperBound.insert(superTy); + } + else + { + if (!get(superTy)) + return false; + + mappedGenerics[superTy].lowerBound.insert(subTy); + } + + return true; +} + +/* + * If, when performing a subtyping test, we encounter a generic on the left + * side, it is permissible to tentatively bind that generic to the right side + * type. + */ +bool Subtyping::bindGeneric(TypePackId subTp, TypePackId superTp) +{ + if (variance == Variance::Contravariant) + std::swap(superTp, subTp); + + if (!get(subTp)) + return false; + + if (TypePackId* m = mappedGenericPacks.find(subTp)) + return *m == superTp; + + mappedGenericPacks[subTp] = superTp; + + return true; +} + +template +TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) +{ + if (container.empty()) + return orElse; + else if (container.size() == 1) + return *begin(container); + else + return arena->addType(T{std::vector(begin(container), end(container))}); +} + +void Subtyping::unexpected(TypePackId tp) +{ + iceReporter->ice(format("Unexpected type pack %s", toString(tp).c_str())); +} + } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 00cf4cd0..569f9720 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,7 +36,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) -LUAU_FASTFLAGVARIABLE(LuauFixCyclicModuleExports, false) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) @@ -1195,16 +1194,13 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) scope->importedTypeBindings[name] = module->exportedTypeBindings; scope->importedModules[name] = moduleInfo->name; - if (FFlag::LuauFixCyclicModuleExports) + // Imported types of requires that transitively refer to current module have to be replaced with 'any' + for (const auto& [location, path] : requireCycles) { - // Imported types of requires that transitively refer to current module have to be replaced with 'any' - for (const auto& [location, path] : requireCycles) + if (!path.empty() && path.front() == moduleInfo->name) { - if (!path.empty() && path.front() == moduleInfo->name) - { - for (auto& [name, tf] : scope->importedTypeBindings[name]) - tf = TypeFun{{}, {}, anyType}; - } + for (auto& [name, tf] : scope->importedTypeBindings[name]) + tf = TypeFun{{}, {}, anyType}; } } } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index db8e2008..bc8ef018 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -605,6 +605,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { // TODO: there are probably cheaper ways to check if any <: T. const NormalizedType* superNorm = normalizer->normalize(superTy); + + if (!superNorm) + return reportError(location, UnificationTooComplex{}); + if (!log.get(superNorm->tops)) failure = true; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index f9f9ab41..7478e15d 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -272,11 +272,18 @@ class AstExprConstantString : public AstExpr public: LUAU_RTTI(AstExprConstantString) - AstExprConstantString(const Location& location, const AstArray& value); + enum QuoteStyle + { + Quoted, + Unquoted + }; + + AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle = Quoted); void visit(AstVisitor* visitor) override; AstArray value; + QuoteStyle quoteStyle = Quoted; }; class AstExprLocal : public AstExpr diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index 929402b3..e1415183 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -204,7 +204,9 @@ private: Position position() const; + // consume() assumes current character is not a newline for performance; when that is not known, consumeAny() should be used instead. void consume(); + void consumeAny(); Lexeme readCommentBody(); diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 3c87e36c..9b3acb7f 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -62,9 +62,10 @@ void AstExprConstantNumber::visit(AstVisitor* visitor) visitor->visit(this); } -AstExprConstantString::AstExprConstantString(const Location& location, const AstArray& value) +AstExprConstantString::AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle) : AstExpr(ClassIndex(), location) , value(value) + , quoteStyle(quoteStyle) { } diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 75b4fe30..fe32e2a1 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,6 +6,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauLexerConsumeFast, false) + namespace Luau { @@ -373,7 +375,7 @@ const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) { // consume whitespace before the token while (isSpace(peekch())) - consume(); + consumeAny(); if (updatePrevLocation) prevLocation = lexeme.location; @@ -438,7 +440,28 @@ Position Lexer::position() const return Position(line, offset - lineOffset); } +LUAU_FORCEINLINE void Lexer::consume() +{ + if (isNewline(buffer[offset])) + { + // TODO: When the flag is removed, remove the outer condition + if (FFlag::LuauLexerConsumeFast) + { + LUAU_ASSERT(!isNewline(buffer[offset])); + } + else + { + line++; + lineOffset = offset + 1; + } + } + + offset++; +} + +LUAU_FORCEINLINE +void Lexer::consumeAny() { if (isNewline(buffer[offset])) { @@ -524,7 +547,7 @@ Lexeme Lexer::readLongString(const Position& start, int sep, Lexeme::Type ok, Le } else { - consume(); + consumeAny(); } } @@ -540,7 +563,7 @@ void Lexer::readBackslashInString() case '\r': consume(); if (peekch() == '\n') - consume(); + consumeAny(); break; case 0: @@ -549,11 +572,11 @@ void Lexer::readBackslashInString() case 'z': consume(); while (isSpace(peekch())) - consume(); + consumeAny(); break; default: - consume(); + consumeAny(); } } @@ -939,6 +962,9 @@ Lexeme Lexer::readNext() case ';': case ',': case '#': + case '?': + case '&': + case '|': { char ch = peekch(); consume(); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index cc5d7b38..20186dfc 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -15,8 +15,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseDeclareClassIndexer, false) -#define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" - namespace Luau { @@ -899,13 +897,13 @@ AstStat* Parser::parseDeclaration(const Location& start) expectAndConsume(':', "property type annotation"); AstType* type = parseType(); - // TODO: since AstName conains a char*, it can't contain null + // since AstName contains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); if (chars && !containsNull) props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false}); else - report(begin.location, "String literal contains malformed escape sequence"); + report(begin.location, "String literal contains malformed escape sequence or \\0"); } else if (lexer.current().type == '[' && FFlag::LuauParseDeclareClassIndexer) { @@ -1328,13 +1326,13 @@ AstType* Parser::parseTableType() AstType* type = parseType(); - // TODO: since AstName conains a char*, it can't contain null + // since AstName contains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); if (chars && !containsNull) props.push_back({AstName(chars->data), begin.location, type}); else - report(begin.location, "String literal contains malformed escape sequence"); + report(begin.location, "String literal contains malformed escape sequence or \\0"); } else if (lexer.current().type == '[') { @@ -1622,7 +1620,7 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack) else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return {reportTypeError(start, {}, "Malformed string")}; + return {reportTypeError(start, {}, "Malformed string; did you forget to finish it?")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1741,7 +1739,8 @@ AstTypePack* Parser::parseTypePack() return allocator.alloc(Location(name.location, end), name.name); } - // No type pack annotation exists here. + // TODO: shouldParseTypePack can be removed and parseTypePack can be called unconditionally instead + LUAU_ASSERT(!"parseTypePack can't be called if shouldParseTypePack() returned false"); return nullptr; } @@ -1826,7 +1825,7 @@ std::optional Parser::checkUnaryConfusables() if (curr.type == '!') { - report(start, "Unexpected '!', did you mean 'not'?"); + report(start, "Unexpected '!'; did you mean 'not'?"); return AstExprUnary::Not; } @@ -1848,20 +1847,20 @@ std::optional Parser::checkBinaryConfusables(const BinaryOpPr if (curr.type == '&' && next.type == '&' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::And].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '&&', did you mean 'and'?"); + report(Location(start, next.location), "Unexpected '&&'; did you mean 'and'?"); return AstExprBinary::And; } else if (curr.type == '|' && next.type == '|' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::Or].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '||', did you mean 'or'?"); + report(Location(start, next.location), "Unexpected '||'; did you mean 'or'?"); return AstExprBinary::Or; } else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::CompareNe].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '!=', did you mean '~='?"); + report(Location(start, next.location), "Unexpected '!='; did you mean '~='?"); return AstExprBinary::CompareNe; } @@ -2169,12 +2168,12 @@ AstExpr* Parser::parseSimpleExpr() else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return reportExprError(start, {}, "Malformed string"); + return reportExprError(start, {}, "Malformed string; did you forget to finish it?"); } else if (lexer.current().type == Lexeme::BrokenInterpDoubleBrace) { nextLexeme(); - return reportExprError(start, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + return reportExprError(start, {}, "Double braces are not permitted within interpolated strings; did you mean '\\{'?"); } else if (lexer.current().type == Lexeme::Dot3) { @@ -2312,7 +2311,7 @@ AstExpr* Parser::parseTableConstructor() nameString.data = const_cast(name.name.value); nameString.size = strlen(name.name.value); - AstExpr* key = allocator.alloc(name.location, nameString); + AstExpr* key = allocator.alloc(name.location, nameString, AstExprConstantString::Unquoted); AstExpr* value = parseExpr(); if (AstExprFunction* func = value->as()) @@ -2661,7 +2660,7 @@ AstExpr* Parser::parseInterpString() { errorWhileChecking = true; nextLexeme(); - expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '`'?")); + expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string; did you forget to add a '`'?")); break; } default: @@ -2681,10 +2680,10 @@ AstExpr* Parser::parseInterpString() break; case Lexeme::BrokenInterpDoubleBrace: nextLexeme(); - return reportExprError(endLocation, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + return reportExprError(endLocation, {}, "Double braces are not permitted within interpolated strings; did you mean '\\{'?"); case Lexeme::BrokenString: nextLexeme(); - return reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '}'?"); + return reportExprError(endLocation, {}, "Malformed interpolated string; did you forget to add a '}'?"); default: return reportExprError(endLocation, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index b2a523d9..19083bfc 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -757,14 +757,6 @@ int replMain(int argc, char** argv) } #endif -#if !LUA_CUSTOM_EXECUTION - if (codegen) - { - fprintf(stderr, "To run with --codegen, Luau has to be built with LUA_CUSTOM_EXECUTION enabled\n"); - return 1; - } -#endif - if (codegenPerf) { #if __linux__ @@ -784,10 +776,7 @@ int replMain(int argc, char** argv) } if (codegen && !Luau::CodeGen::isSupported()) - { - fprintf(stderr, "Cannot enable --codegen, native code generation is not supported in current configuration\n"); - return 1; - } + fprintf(stderr, "Warning: Native code generation is not supported in current configuration\n"); const std::vector files = getSourceFiles(argc, argv); diff --git a/CMakeLists.txt b/CMakeLists.txt index f0f0497d..15a1b8a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,6 @@ option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) -option(LUAU_NATIVE "Enable support for native code generation" OFF) cmake_policy(SET CMP0054 NEW) cmake_policy(SET CMP0091 NEW) @@ -146,13 +145,7 @@ if(LUAU_EXTERN_C) target_compile_definitions(Luau.VM PUBLIC LUA_USE_LONGJMP=1) target_compile_definitions(Luau.VM PUBLIC LUA_API=extern\"C\") target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\") -endif() - -if(LUAU_NATIVE) - target_compile_definitions(Luau.VM PUBLIC LUA_CUSTOM_EXECUTION=1) - if(LUAU_EXTERN_C) - target_compile_definitions(Luau.CodeGen PUBLIC LUACODEGEN_API=extern\"C\") - endif() + target_compile_definitions(Luau.CodeGen PUBLIC LUACODEGEN_API=extern\"C\") endif() if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" AND MSVC_VERSION GREATER_EQUAL 1924) diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index c11f9628..85f19d01 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -18,6 +18,16 @@ enum CodeGenFlags CodeGen_OnlyNativeModules = 1 << 0, }; +enum class CodeGenCompilationResult +{ + Success, // Successfully generated code for at least one function + NothingToCompile, // There were no new functions to compile + + CodeGenNotInitialized, // Native codegen system is not initialized + CodeGenFailed, // Native codegen failed due to an internal compiler error + AllocationFailed, // Native codegen failed due to an allocation error +}; + struct CompilationStats { size_t bytecodeSizeBytes = 0; @@ -36,7 +46,7 @@ void create(lua_State* L, AllocationCallback* allocationCallback, void* allocati void create(lua_State* L); // Builds target function and all inner functions -void compile(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); +CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 5ac5b2ac..e8e56d19 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -245,8 +245,8 @@ enum class IrCmd : uint8_t STRING_LEN, // Allocate new table - // A: int (array element count) - // B: int (node element count) + // A: unsigned int (array element count) + // B: unsigned int (node element count) NEW_TABLE, // Duplicate a table @@ -359,12 +359,6 @@ enum class IrCmd : uint8_t // C: tag/undef (tag of the value that was written) SET_UPVALUE, - // Convert TValues into numbers for a numerical for loop - // A: Rn (start) - // B: Rn (end) - // C: Rn (step) - PREPARE_FORN, - // Guards and checks (these instructions are not block terminators even though they jump to fallback) // Guard against tag mismatch @@ -463,6 +457,7 @@ enum class IrCmd : uint8_t // C: Rn (source start) // D: int (count or -1 to assign values up to stack top) // E: unsigned int (table index to start from) + // F: undef/unsigned int (target table known size) SETLIST, // Call specified function diff --git a/CodeGen/include/Luau/IrRegAllocX64.h b/CodeGen/include/Luau/IrRegAllocX64.h index 95930811..665b5229 100644 --- a/CodeGen/include/Luau/IrRegAllocX64.h +++ b/CodeGen/include/Luau/IrRegAllocX64.h @@ -77,6 +77,7 @@ struct IrRegAllocX64 std::array gprInstUsers; std::array freeXmmMap; std::array xmmInstUsers; + uint8_t usableXmmRegCount = 0; std::bitset<256> usedSpillSlots; unsigned maxUsedSlot = 0; diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h index 8a44629f..1ba377ba 100644 --- a/CodeGen/include/Luau/UnwindBuilder.h +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -5,6 +5,7 @@ #include "Luau/RegisterX64.h" #include +#include #include #include @@ -48,7 +49,8 @@ public: // mov rbp, rsp // push reg in the order specified in regs // sub rsp, stackSize - virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) = 0; + virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, + const std::vector& simd) = 0; virtual size_t getSize() const = 0; virtual size_t getFunctionCount() const = 0; diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index 66749bfc..741aaed2 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -30,7 +30,8 @@ public: void finishInfo() override; void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) override; - void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) override; + void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, + const std::vector& simd) override; size_t getSize() const override; size_t getFunctionCount() const override; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index 5afed693..3a7e1b5a 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -50,7 +50,8 @@ public: void finishInfo() override; void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) override; - void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) override; + void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, + const std::vector& simd) override; size_t getSize() const override; size_t getFunctionCount() const override; diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 10c3dc79..9d117b1d 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -162,9 +162,6 @@ unsigned int getCpuFeaturesA64() bool isSupported() { - if (!LUA_CUSTOM_EXECUTION) - return false; - if (LUA_EXTRA_SIZE != 1) return false; @@ -247,23 +244,33 @@ void create(lua_State* L) create(L, nullptr, nullptr); } -void compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) { LUAU_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); + Proto* root = clvalue(func)->l.p; + if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) + return CodeGenCompilationResult::NothingToCompile; + // If initialization has failed, do not compile any functions NativeState* data = getNativeState(L); if (!data) - return; - - Proto* root = clvalue(func)->l.p; - if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) - return; + return CodeGenCompilationResult::CodeGenNotInitialized; std::vector protos; gatherFunctions(protos, root); + // Skip protos that have been compiled during previous invocations of CodeGen::compile + protos.erase(std::remove_if(protos.begin(), protos.end(), + [](Proto* p) { + return p == nullptr || p->execdata != nullptr; + }), + protos.end()); + + if (protos.empty()) + return CodeGenCompilationResult::NothingToCompile; + #if defined(__aarch64__) static unsigned int cpuFeatures = getCpuFeaturesA64(); A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); @@ -281,11 +288,9 @@ void compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) std::vector results; results.reserve(protos.size()); - // Skip protos that have been compiled during previous invocations of CodeGen::compile for (Proto* p : protos) - if (p && p->execdata == nullptr) - if (std::optional np = createNativeFunction(build, helpers, p)) - results.push_back(*np); + if (std::optional np = createNativeFunction(build, helpers, p)) + results.push_back(*np); // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module if (!build.finalize()) @@ -293,12 +298,12 @@ void compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) for (NativeProto result : results) destroyExecData(result.execdata); - return; + return CodeGenCompilationResult::CodeGenFailed; } // If no functions were assembled, we don't need to allocate/copy executable pages for helpers if (results.empty()) - return; + return CodeGenCompilationResult::CodeGenFailed; uint8_t* nativeData = nullptr; size_t sizeNativeData = 0; @@ -309,7 +314,7 @@ void compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) for (NativeProto result : results) destroyExecData(result.execdata); - return; + return CodeGenCompilationResult::AllocationFailed; } if (gPerfLogFn && results.size() > 0) @@ -348,6 +353,8 @@ void compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) stats->nativeCodeSizeBytes += build.code.size(); stats->nativeDataSizeBytes += build.data.size(); } + + return CodeGenCompilationResult::Success; } void setPerfLog(void* context, PerfLogFn logFn) diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index ef655a24..a8cf2e73 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -16,10 +16,24 @@ * | rdx home space | (unused) * | rcx home space | (unused) * | return address | - * | ... saved non-volatile registers ... <-- rsp + kStackSize + kLocalsSize - * | unused | for 16 byte alignment of the stack + * | ... saved non-volatile registers ... <-- rsp + kStackSizeFull + * | alignment | + * | xmm9 non-vol | + * | xmm9 cont. | + * | xmm8 non-vol | + * | xmm8 cont. | + * | xmm7 non-vol | + * | xmm7 cont. | + * | xmm6 non-vol | + * | xmm6 cont. | + * | spill slot 5 | + * | spill slot 4 | + * | spill slot 3 | + * | spill slot 2 | + * | spill slot 1 | <-- rsp + kStackOffsetToSpillSlots + * | sTemporarySlot | * | sCode | - * | sClosure | <-- rsp + kStackSize + * | sClosure | <-- rsp + kStackOffsetToLocals * | argument 6 | <-- rsp + 40 * | argument 5 | <-- rsp + 32 * | r9 home space | @@ -81,24 +95,43 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde build.push(rdi); build.push(rsi); - // On Windows, rbp is available as a general-purpose non-volatile register; we currently don't use it, but we need to push an even number - // of registers for stack alignment... + // On Windows, rbp is available as a general-purpose non-volatile register and this might be freed up build.push(rbp); - - // TODO: once we start using non-volatile SIMD registers on Windows, we will save those here } - // Allocate stack space (reg home area + local data) - build.sub(rsp, kStackSize + kLocalsSize); + // Allocate stack space + uint8_t usableXmmRegCount = getXmmRegisterCount(build.abi); + unsigned xmmStorageSize = getNonVolXmmStorageSize(build.abi, usableXmmRegCount); + unsigned fullStackSize = getFullStackSize(build.abi, usableXmmRegCount); + + build.sub(rsp, fullStackSize); + + OperandX64 xmmStorageOffset = rsp + (fullStackSize - (kStackAlign + xmmStorageSize)); + + // On Windows, we have to save non-volatile xmm registers + std::vector savedXmmRegs; + + if (build.abi == ABIX64::Windows) + { + if (usableXmmRegCount > kWindowsFirstNonVolXmmReg) + savedXmmRegs.reserve(usableXmmRegCount - kWindowsFirstNonVolXmmReg); + + for (uint8_t i = kWindowsFirstNonVolXmmReg, offset = 0; i < usableXmmRegCount; i++, offset += 16) + { + RegisterX64 xmmReg = RegisterX64{SizeX64::xmmword, i}; + build.vmovaps(xmmword[xmmStorageOffset + offset], xmmReg); + savedXmmRegs.push_back(xmmReg); + } + } locations.prologueEnd = build.setLabel(); uint32_t prologueSize = build.getLabelOffset(locations.prologueEnd) - build.getLabelOffset(locations.start); if (build.abi == ABIX64::SystemV) - unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ true, {rbx, r12, r13, r14, r15}); + unwind.prologueX64(prologueSize, fullStackSize, /* setupFrame= */ true, {rbx, r12, r13, r14, r15}, {}); else if (build.abi == ABIX64::Windows) - unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ false, {rbx, r12, r13, r14, r15, rdi, rsi, rbp}); + unwind.prologueX64(prologueSize, fullStackSize, /* setupFrame= */ false, {rbx, r12, r13, r14, r15, rdi, rsi, rbp}, savedXmmRegs); // Setup native execution environment build.mov(rState, rArg1); @@ -118,8 +151,15 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde // Even though we jumped away, we will return here in the end locations.epilogueStart = build.setLabel(); - // Cleanup and exit - build.add(rsp, kStackSize + kLocalsSize); + // Epilogue and exit + if (build.abi == ABIX64::Windows) + { + // xmm registers are restored before the official epilogue that has to start with 'add rsp/lea rsp' + for (uint8_t i = kWindowsFirstNonVolXmmReg, offset = 0; i < usableXmmRegCount; i++, offset += 16) + build.vmovaps(RegisterX64{SizeX64::xmmword, i}, xmmword[xmmStorageOffset + offset]); + } + + build.add(rsp, fullStackSize); if (build.abi == ABIX64::Windows) { diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 43568035..c4da5467 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -123,16 +123,6 @@ void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, in emitUpdateBase(build); } -void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init) -{ - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, luauRegAddress(limit)); - callWrap.addArgument(SizeX64::qword, luauRegAddress(step)); - callWrap.addArgument(SizeX64::qword, luauRegAddress(init)); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]); -} - void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) { IrCallWrapperX64 callWrap(regs, build); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 782f2084..dd9b082b 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -42,16 +42,55 @@ constexpr RegisterX64 rBase = r14; // StkId base constexpr RegisterX64 rNativeContext = r13; // NativeContext* context constexpr RegisterX64 rConstants = r12; // TValue* k -// Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point -// See CodeGenX64.cpp for layout -constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments -constexpr unsigned kSpillSlots = 4; // locations for register allocator to spill data into -constexpr unsigned kLocalsSize = 24 + 8 * kSpillSlots; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary) +constexpr unsigned kExtraLocals = 3; // Number of 8 byte slots available for specialized local variables specified below +constexpr unsigned kSpillSlots = 5; // Number of 8 byte slots available for register allocator to spill data into +static_assert((kExtraLocals + kSpillSlots) * 8 % 16 == 0, "locals have to preserve 16 byte alignment"); -constexpr OperandX64 sClosure = qword[rsp + kStackSize + 0]; // Closure* cl -constexpr OperandX64 sCode = qword[rsp + kStackSize + 8]; // Instruction* code -constexpr OperandX64 sTemporarySlot = addr[rsp + kStackSize + 16]; -constexpr OperandX64 sSpillArea = addr[rsp + kStackSize + 24]; +constexpr uint8_t kWindowsFirstNonVolXmmReg = 6; + +constexpr uint8_t kWindowsUsableXmmRegs = 10; // Some xmm regs are non-volatile, we have to balance how many we want to use/preserve +constexpr uint8_t kSystemVUsableXmmRegs = 16; // All xmm regs are volatile + +inline uint8_t getXmmRegisterCount(ABIX64 abi) +{ + return abi == ABIX64::SystemV ? kSystemVUsableXmmRegs : kWindowsUsableXmmRegs; +} + +// Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point +// Stack is separated into sections for different data. See CodeGenX64.cpp for layout overview +constexpr unsigned kStackAlign = 8; // Bytes we need to align the stack for non-vol xmm register storage +constexpr unsigned kStackLocalStorage = 8 * kExtraLocals; +constexpr unsigned kStackSpillStorage = 8 * kSpillSlots; +constexpr unsigned kStackExtraArgumentStorage = 2 * 8; // Bytes for 5th and 6th function call arguments used under Windows ABI +constexpr unsigned kStackRegHomeStorage = 4 * 8; // Register 'home' locations that can be used by callees under Windows ABI + +inline unsigned getNonVolXmmStorageSize(ABIX64 abi, uint8_t xmmRegCount) +{ + if (abi == ABIX64::SystemV) + return 0; + + // First 6 are volatile + if (xmmRegCount <= kWindowsFirstNonVolXmmReg) + return 0; + + LUAU_ASSERT(xmmRegCount <= 16); + return (xmmRegCount - kWindowsFirstNonVolXmmReg) * 16; +} + +// Useful offsets to specific parts +constexpr unsigned kStackOffsetToLocals = kStackExtraArgumentStorage + kStackRegHomeStorage; +constexpr unsigned kStackOffsetToSpillSlots = kStackOffsetToLocals + kStackLocalStorage; + +inline unsigned getFullStackSize(ABIX64 abi, uint8_t xmmRegCount) +{ + return kStackOffsetToSpillSlots + kStackSpillStorage + getNonVolXmmStorageSize(abi, xmmRegCount) + kStackAlign; +} + +constexpr OperandX64 sClosure = qword[rsp + kStackOffsetToLocals + 0]; // Closure* cl +constexpr OperandX64 sCode = qword[rsp + kStackOffsetToLocals + 8]; // Instruction* code +constexpr OperandX64 sTemporarySlot = addr[rsp + kStackOffsetToLocals + 16]; + +constexpr OperandX64 sSpillArea = addr[rsp + kStackOffsetToSpillSlots]; inline OperandX64 luauReg(int ri) { @@ -161,7 +200,6 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm); void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb); -void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init); void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, int ratag, Label& skip); diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index ea511958..bccdc8f0 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -251,7 +251,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i } } -void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index) +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index, int knownSize) { // TODO: This should use IrCallWrapperX64 RegisterX64 rArg1 = (build.abi == ABIX64::Windows) ? rcx : rdi; @@ -285,25 +285,28 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int build.add(last, index - 1); } - Label skipResize; - RegisterX64 table = regs.takeReg(rax, kInvalidInstIdx); build.mov(table, luauRegValue(ra)); - // Resize if h->sizearray < last - build.cmp(dword[table + offsetof(Table, sizearray)], last); - build.jcc(ConditionX64::NotBelow, skipResize); + if (count == LUA_MULTRET || knownSize < 0 || knownSize < int(index + count - 1)) + { + Label skipResize; - // Argument setup reordered to avoid conflicts - LUAU_ASSERT(rArg3 != table); - build.mov(dwordReg(rArg3), last); - build.mov(rArg2, table); - build.mov(rArg1, rState); - build.call(qword[rNativeContext + offsetof(NativeContext, luaH_resizearray)]); - build.mov(table, luauRegValue(ra)); // Reload cloberred register value + // Resize if h->sizearray < last + build.cmp(dword[table + offsetof(Table, sizearray)], last); + build.jcc(ConditionX64::NotBelow, skipResize); - build.setLabel(skipResize); + // Argument setup reordered to avoid conflicts + LUAU_ASSERT(rArg3 != table); + build.mov(dwordReg(rArg3), last); + build.mov(rArg2, table); + build.mov(rArg1, rState); + build.call(qword[rNativeContext + offsetof(NativeContext, luaH_resizearray)]); + build.mov(table, luauRegValue(ra)); // Reload clobbered register value + + build.setLabel(skipResize); + } RegisterX64 arrayDst = rdx; RegisterX64 offset = rcx; diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index b248b7e8..59fd8e41 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -19,7 +19,7 @@ struct IrRegAllocX64; void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults, bool functionVariadic); -void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index); +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index, int knownSize); void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat); } // namespace X64 diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index b29927bb..eb4630dd 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -253,15 +253,6 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrBlock& case IrCmd::SET_UPVALUE: visitor.use(inst.b); break; - case IrCmd::PREPARE_FORN: - visitor.use(inst.a); - visitor.use(inst.b); - visitor.use(inst.c); - - visitor.def(inst.a); - visitor.def(inst.b); - visitor.def(inst.c); - break; case IrCmd::INTERRUPT: break; case IrCmd::BARRIER_OBJ: diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index d34dfb57..aebc0ba7 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -373,7 +373,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstDupTable(*this, pc, i); break; case LOP_SETLIST: - inst(IrCmd::SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); + inst(IrCmd::SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1]), undef()); break; case LOP_GETUPVAL: translateInstGetUpval(*this, pc, i); diff --git a/CodeGen/src/IrCallWrapperX64.cpp b/CodeGen/src/IrCallWrapperX64.cpp index 816e0184..15fabf09 100644 --- a/CodeGen/src/IrCallWrapperX64.cpp +++ b/CodeGen/src/IrCallWrapperX64.cpp @@ -13,7 +13,7 @@ namespace CodeGen namespace X64 { -static const std::array kWindowsGprOrder = {rcx, rdx, r8, r9, addr[rsp + 32], addr[rsp + 40]}; +static const std::array kWindowsGprOrder = {rcx, rdx, r8, r9, addr[rsp + kStackRegHomeStorage], addr[rsp + kStackRegHomeStorage + 8]}; static const std::array kSystemvGprOrder = {rdi, rsi, rdx, rcx, r8, r9}; static const std::array kXmmOrder = {xmm0, xmm1, xmm2, xmm3}; // Common order for first 4 fp arguments on Windows/SystemV diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 50d5012e..67f77b64 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -213,8 +213,6 @@ const char* getCmdName(IrCmd cmd) return "GET_UPVALUE"; case IrCmd::SET_UPVALUE: return "SET_UPVALUE"; - case IrCmd::PREPARE_FORN: - return "PREPARE_FORN"; case IrCmd::CHECK_TAG: return "CHECK_TAG"; case IrCmd::CHECK_TRUTHY: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 03006e30..d944a766 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -1080,16 +1080,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) } break; } - case IrCmd::PREPARE_FORN: - regs.spill(build, index); - build.mov(x0, rState); - build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); - build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); - build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_prepareFORN))); - build.blr(x4); - // note: no emitUpdateBase necessary because prepareFORN does not reallocate stack - break; case IrCmd::CHECK_TAG: { Label fresh; // used when guard aborts execution or jumps to a VM exit diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index ad18b849..2a436d54 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -622,7 +622,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) LUAU_ASSERT(inst.b.kind == IrOpKind::Inst || inst.b.kind == IrOpKind::Constant); OperandX64 opb = inst.b.kind == IrOpKind::Inst ? regOp(inst.b) : OperandX64(tagOp(inst.b)); - build.cmp(memRegTagOp(inst.a), opb); + if (inst.a.kind == IrOpKind::Constant) + build.cmp(opb, tagOp(inst.a)); + else + build.cmp(memRegTagOp(inst.a), opb); if (isFallthroughBlock(blockOp(inst.d), next)) { @@ -997,9 +1000,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c)); break; } - case IrCmd::PREPARE_FORN: - callPrepareForN(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); - break; case IrCmd::CHECK_TAG: build.cmp(memRegTagOp(inst.a), tagOp(inst.b)); jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next); @@ -1205,7 +1205,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) // Fallbacks to non-IR instruction implementations case IrCmd::SETLIST: regs.assertAllFree(); - emitInstSetList(regs, build, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); + emitInstSetList( + regs, build, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e), inst.f.kind == IrOpKind::Undef ? -1 : int(uintOp(inst.f))); break; case IrCmd::CALL: regs.assertAllFree(); diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index 607c975f..091def39 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -17,6 +17,7 @@ static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function) : build(build) , function(function) + , usableXmmRegCount(getXmmRegisterCount(build.abi)) { freeGprMap.fill(true); gprInstUsers.fill(kInvalidInstIdx); @@ -28,7 +29,7 @@ RegisterX64 IrRegAllocX64::allocReg(SizeX64 size, uint32_t instIdx) { if (size == SizeX64::xmmword) { - for (size_t i = 0; i < freeXmmMap.size(); ++i) + for (size_t i = 0; i < usableXmmRegCount; ++i) { if (freeXmmMap[i]) { diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 38922131..5a92132f 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -609,22 +609,19 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) IrOp loopStart = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); IrOp loopExit = build.blockAtInst(getJumpTarget(*pc, pcpos)); - IrOp fallback = build.block(IrBlockKind::Fallback); - IrOp nextStep = build.block(IrBlockKind::Internal); IrOp direct = build.block(IrBlockKind::Internal); IrOp reverse = build.block(IrBlockKind::Internal); + // When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails + // Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant + // To avoid that overhead for an extreemely rare case (that doesn't even typecheck), we exit to VM to handle it IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); - build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), fallback); + build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); - build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), fallback); + build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); - build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), fallback); - build.inst(IrCmd::JUMP, nextStep); - - // After successful conversion of arguments to number in a fallback, we return here - build.beginBlock(nextStep); + build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); IrOp zero = build.constDouble(0.0); IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); @@ -644,12 +641,6 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(reverse); build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopStart, loopExit); - // Fallback will try to convert loop variables to numbers or throw an error - build.beginBlock(fallback); - build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst(IrCmd::PREPARE_FORN, build.vmReg(ra + 0), build.vmReg(ra + 1), build.vmReg(ra + 2)); - build.inst(IrCmd::JUMP, nextStep); - // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(loopStart)) build.beginBlock(loopStart); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index e5a55f11..e51dca99 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -111,7 +111,6 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::CONCAT: case IrCmd::GET_UPVALUE: case IrCmd::SET_UPVALUE: - case IrCmd::PREPARE_FORN: case IrCmd::CHECK_TAG: case IrCmd::CHECK_TRUTHY: case IrCmd::CHECK_READONLY: diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index e781bda3..c32d718c 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -54,11 +54,6 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) case IrCmd::GET_UPVALUE: invalidateRestoreOp(inst.a); break; - case IrCmd::PREPARE_FORN: - invalidateRestoreOp(inst.a); - invalidateRestoreOp(inst.b); - invalidateRestoreOp(inst.c); - break; case IrCmd::CALL: // Even if result count is limited, all registers starting from function (ra) might be modified invalidateRestoreVmRegs(vmRegOp(inst.a), -1); diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 5a71345e..13ef33d3 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -44,7 +44,6 @@ void initFunctions(NativeState& data) data.context.luaV_equalval = luaV_equalval; data.context.luaV_doarith = luaV_doarith; data.context.luaV_dolen = luaV_dolen; - data.context.luaV_prepareFORN = luaV_prepareFORN; data.context.luaV_gettable = luaV_gettable; data.context.luaV_settable = luaV_settable; data.context.luaV_getimport = luaV_getimport; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index a9ba7cfd..85b7a3a3 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -35,7 +35,6 @@ struct NativeContext int (*luaV_equalval)(lua_State* L, const TValue* t1, const TValue* t2) = nullptr; void (*luaV_doarith)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op) = nullptr; void (*luaV_dolen)(lua_State* L, StkId ra, const TValue* rb) = nullptr; - void (*luaV_prepareFORN)(lua_State* L, StkId plimit, StkId pstep, StkId pinit) = nullptr; void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_getimport)(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil) = nullptr; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 9ef57afa..4b09d423 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -992,25 +992,20 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateRegisterRange(vmRegOp(inst.a), function.uintOp(inst.b)); state.invalidateUserCall(); // TODO: if only strings and numbers are concatenated, there will be no user calls break; - case IrCmd::PREPARE_FORN: - state.invalidateValue(inst.a); - state.saveTag(inst.a, LUA_TNUMBER); - state.invalidateValue(inst.b); - state.saveTag(inst.b, LUA_TNUMBER); - state.invalidateValue(inst.c); - state.saveTag(inst.c, LUA_TNUMBER); - break; case IrCmd::INTERRUPT: state.invalidateUserCall(); break; case IrCmd::SETLIST: + if (RegisterInfo* info = state.tryGetRegisterInfo(inst.b); info && info->knownTableArraySize >= 0) + replace(function, inst.f, build.constUint(info->knownTableArraySize)); + state.valueMap.clear(); // TODO: this can be relaxed when x64 emitInstSetList becomes aware of register allocator break; case IrCmd::CALL: state.invalidateRegistersFrom(vmRegOp(inst.a)); state.invalidateUserCall(); - // We cannot guarantee right now that all live values can be remeterialized from non-stack memory locations + // We cannot guarantee right now that all live values can be rematerialized from non-stack memory locations // To prevent earlier values from being propagated to after the call, we have to clear the map // TODO: remove only the values that don't have a guaranteed restore location state.valueMap.clear(); diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index e9df184d..08c8e831 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -225,9 +225,10 @@ void UnwindBuilderDwarf2::prologueA64(uint32_t prologueSize, uint32_t stackSize, } } -void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) +void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, + const std::vector& simd) { - LUAU_ASSERT(stackSize > 0 && stackSize <= 128 && stackSize % 8 == 0); + LUAU_ASSERT(stackSize > 0 && stackSize < 4096 && stackSize % 8 == 0); unsigned int stackOffset = 8; // Return address was pushed by calling the function unsigned int prologueOffset = 0; @@ -247,7 +248,7 @@ void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, } // push reg - for (X64::RegisterX64 reg : regs) + for (X64::RegisterX64 reg : gpr) { LUAU_ASSERT(reg.size == X64::SizeX64::qword); @@ -258,9 +259,11 @@ void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, pos = defineSavedRegisterLocation(pos, regIndexToDwRegX64[reg.index], stackOffset); } + LUAU_ASSERT(simd.empty()); + // sub rsp, stackSize stackOffset += stackSize; - prologueOffset += 4; + prologueOffset += stackSize >= 128 ? 7 : 4; pos = advanceLocation(pos, 4); pos = defineCfaExpressionOffset(pos, stackOffset); diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index f9b927c5..336a4e3f 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -82,7 +82,7 @@ void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset) if (!unwindCodes.empty()) { // Copy unwind codes in reverse order - // Some unwind codes take up two array slots, but we don't use those atm + // Some unwind codes take up two array slots, we write those in reverse order uint8_t* unwindCodePos = rawDataPos + sizeof(UnwindCodeWin) * (unwindCodes.size() - 1); LUAU_ASSERT(unwindCodePos <= rawData + kRawDataLimit); @@ -109,9 +109,10 @@ void UnwindBuilderWin::prologueA64(uint32_t prologueSize, uint32_t stackSize, st LUAU_ASSERT(!"Not implemented"); } -void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) +void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, + const std::vector& simd) { - LUAU_ASSERT(stackSize > 0 && stackSize <= 128 && stackSize % 8 == 0); + LUAU_ASSERT(stackSize > 0 && stackSize < 4096 && stackSize % 8 == 0); LUAU_ASSERT(prologueSize < 256); unsigned int stackOffset = 8; // Return address was pushed by calling the function @@ -132,7 +133,7 @@ void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bo } // push reg - for (X64::RegisterX64 reg : regs) + for (X64::RegisterX64 reg : gpr) { LUAU_ASSERT(reg.size == X64::SizeX64::qword); @@ -141,10 +142,51 @@ void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bo unwindCodes.push_back({uint8_t(prologueOffset), UWOP_PUSH_NONVOL, reg.index}); } + // If frame pointer is used, simd register storage is not implemented, it will require reworking store offsets + LUAU_ASSERT(!setupFrame || simd.size() == 0); + + unsigned int simdStorageSize = unsigned(simd.size()) * 16; + + // It's the responsibility of the caller to provide simd register storage in 'stackSize', including alignment to 16 bytes + if (!simd.empty() && stackOffset % 16 == 8) + simdStorageSize += 8; + // sub rsp, stackSize - stackOffset += stackSize; - prologueOffset += 4; - unwindCodes.push_back({uint8_t(prologueOffset), UWOP_ALLOC_SMALL, uint8_t((stackSize - 8) / 8)}); + if (stackSize <= 128) + { + stackOffset += stackSize; + prologueOffset += stackSize == 128 ? 7 : 4; + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_ALLOC_SMALL, uint8_t((stackSize - 8) / 8)}); + } + else + { + // This command can handle allocations up to 512K-8 bytes, but that potentially requires stack probing + LUAU_ASSERT(stackSize < 4096); + + stackOffset += stackSize; + prologueOffset += 7; + + uint16_t encodedOffset = stackSize / 8; + unwindCodes.push_back(UnwindCodeWin()); + memcpy(&unwindCodes.back(), &encodedOffset, sizeof(encodedOffset)); + + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_ALLOC_LARGE, 0}); + } + + // It's the responsibility of the caller to provide simd register storage in 'stackSize' + unsigned int xmmStoreOffset = stackSize - simdStorageSize; + + // vmovaps [rsp+n], xmm + for (X64::RegisterX64 reg : simd) + { + LUAU_ASSERT(reg.size == X64::SizeX64::xmmword); + LUAU_ASSERT(xmmStoreOffset % 16 == 0 && "simd stores have to be performed to aligned locations"); + + prologueOffset += xmmStoreOffset >= 128 ? 10 : 7; + unwindCodes.push_back({uint8_t(xmmStoreOffset / 16), 0, 0}); + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_SAVE_XMM128, reg.index}); + xmmStoreOffset += 16; + } LUAU_ASSERT(stackOffset % 16 == 0); LUAU_ASSERT(prologueOffset == prologueSize); diff --git a/Common/include/Luau/DenseHash.h b/Common/include/Luau/DenseHash.h index 997e090f..72aa6ec5 100644 --- a/Common/include/Luau/DenseHash.h +++ b/Common/include/Luau/DenseHash.h @@ -282,6 +282,13 @@ public: class const_iterator { public: + using value_type = Item; + using reference = Item&; + using pointer = Item*; + using iterator = pointer; + using difference_type = size_t; + using iterator_category = std::input_iterator_tag; + const_iterator() : set(0) , index(0) diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 4ec083bb..a15c8f08 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,9 +4,6 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" -LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinTonumber, false) -LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinTostring, false) - namespace Luau { namespace Compile @@ -72,9 +69,9 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op if (builtin.isGlobal("setmetatable")) return LBF_SETMETATABLE; - if (FFlag::LuauCompileBuiltinTonumber && builtin.isGlobal("tonumber")) + if (builtin.isGlobal("tonumber")) return LBF_TONUMBER; - if (FFlag::LuauCompileBuiltinTostring && builtin.isGlobal("tostring")) + if (builtin.isGlobal("tostring")) return LBF_TOSTRING; if (builtin.object == "math") diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index b673ffc2..f9a00f64 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,8 +26,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileFoldMathK, false) - namespace Luau { @@ -3871,9 +3869,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c { compiler.builtinsFold = &compiler.builtins; - if (FFlag::LuauCompileFoldMathK) - if (AstName math = names.get("math"); math.value && getGlobalState(compiler.globals, math) == Global::Default) - compiler.builtinsFoldMathK = true; + if (AstName math = names.get("math"); math.value && getGlobalState(compiler.globals, math) == Global::Default) + compiler.builtinsFoldMathK = true; } if (options.optimizationLevel >= 1) diff --git a/Makefile b/Makefile index d1c2ac90..f0f008be 100644 --- a/Makefile +++ b/Makefile @@ -121,12 +121,11 @@ ifeq ($(protobuf),download) endif ifneq ($(native),) - CXXFLAGS+=-DLUA_CUSTOM_EXECUTION=1 TESTS_ARGS+=--codegen endif ifneq ($(nativelj),) - CXXFLAGS+=-DLUA_CUSTOM_EXECUTION=1 -DLUA_USE_LONGJMP=1 + CXXFLAGS+=-DLUA_USE_LONGJMP=1 TESTS_ARGS+=--codegen endif @@ -142,7 +141,7 @@ $(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler $(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -Iextern -Iextern/isocline/include $(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IConfig/include -Iextern $(COMPILE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICodeGen/include +$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICodeGen/include -IConfig/include $(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread @@ -219,11 +218,11 @@ $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(COMPILE_CLI_TARGET): $(CXX) $^ $(LDFLAGS) -o $@ # executable targets for fuzzing -fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) +fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(CXX) $^ $(LDFLAGS) -o $@ -fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator -fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator +fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator +fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator # static library targets $(AST_TARGET): $(AST_OBJECTS) diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index dcb785b6..7a1bbb95 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -121,11 +121,6 @@ #define LUA_MAXCAPTURES 32 #endif -// enables callbacks to redirect code execution from Luau VM to a custom implementation -#ifndef LUA_CUSTOM_EXECUTION -#define LUA_CUSTOM_EXECUTION 1 -#endif - // }================================================================== /* diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 951b3028..63da1810 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,8 +11,6 @@ #include -LUAU_FASTFLAG(LuauFasterInterp) - // convert a stack index to positive #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -524,19 +522,10 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) { if (luaL_callmeta(L, idx, "__tostring")) // is there a metafield? { - if (FFlag::LuauFasterInterp) - { - const char* s = lua_tolstring(L, -1, len); - if (!s) - luaL_error(L, "'__tostring' must return a string"); - return s; - } - else - { - if (!lua_isstring(L, -1)) - luaL_error(L, "'__tostring' must return a string"); - return lua_tolstring(L, -1, len); - } + const char* s = lua_tolstring(L, -1, len); + if (!s) + luaL_error(L, "'__tostring' must return a string"); + return s; } switch (lua_type(L, idx)) diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index c893d603..a916f73a 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -23,8 +23,6 @@ #endif #endif -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFastcallGC, false) - // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -832,7 +830,7 @@ static int luauF_char(lua_State* L, StkId res, TValue* arg0, int nresults, StkId if (nparams < int(sizeof(buffer)) && nresults <= 1) { - if (DFFlag::LuauFastcallGC && luaC_needsGC(L)) + if (luaC_needsGC(L)) return -1; // we can't call luaC_checkGC so fall back to C implementation if (nparams >= 1) @@ -904,7 +902,7 @@ static int luauF_sub(lua_State* L, StkId res, TValue* arg0, int nresults, StkId int i = int(nvalue(args)); int j = int(nvalue(args + 1)); - if (DFFlag::LuauFastcallGC && luaC_needsGC(L)) + if (luaC_needsGC(L)) return -1; // we can't call luaC_checkGC so fall back to C implementation if (i >= 1 && j >= i && unsigned(j - 1) < unsigned(ts->len)) @@ -1300,7 +1298,7 @@ static int luauF_tostring(lua_State* L, StkId res, TValue* arg0, int nresults, S } case LUA_TNUMBER: { - if (DFFlag::LuauFastcallGC && luaC_needsGC(L)) + if (luaC_needsGC(L)) return -1; // we can't call luaC_checkGC so fall back to C implementation char s[LUAI_MAXNUM2STR]; diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 90b30ead..ca57786e 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,9 +8,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauFasterInterp, false) -LUAU_FASTFLAGVARIABLE(LuauFasterFormatS, false) - // macro to `unsign' a character #define uchar(c) ((unsigned char)(c)) @@ -969,7 +966,7 @@ static int str_format(lua_State* L) luaL_addchar(&b, *strfrmt++); else if (*++strfrmt == L_ESC) luaL_addchar(&b, *strfrmt++); // %% - else if (FFlag::LuauFasterInterp && *strfrmt == '*') + else if (*strfrmt == '*') { strfrmt++; if (++arg > top) @@ -1029,49 +1026,22 @@ static int str_format(lua_State* L) { size_t l; const char* s = luaL_checklstring(L, arg, &l); - if (FFlag::LuauFasterFormatS) + // no precision and string is too long to be formatted, or no format necessary to begin with + if (form[2] == '\0' || (!strchr(form, '.') && l >= 100)) { - // no precision and string is too long to be formatted, or no format necessary to begin with - if (form[2] == '\0' || (!strchr(form, '.') && l >= 100)) - { - luaL_addlstring(&b, s, l, -1); - continue; // skip the `luaL_addlstring' at the end - } - else - { - snprintf(buff, sizeof(buff), form, s); - break; - } + luaL_addlstring(&b, s, l, -1); + continue; // skip the `luaL_addlstring' at the end } else { - if (!strchr(form, '.') && l >= 100) - { - /* no precision and string is too long to be formatted; - keep original string */ - lua_pushvalue(L, arg); - luaL_addvalue(&b); - continue; // skip the `luaL_addlstring' at the end - } - else - { - snprintf(buff, sizeof(buff), form, s); - break; - } + snprintf(buff, sizeof(buff), form, s); + break; } } case '*': { - if (FFlag::LuauFasterInterp || formatItemSize != 1) - luaL_error(L, "'%%*' does not take a form"); - - size_t length; - const char* string = luaL_tolstring(L, arg, &length); - - luaL_addlstring(&b, string, length, -2); - lua_pop(L, 1); - - continue; // skip the `luaL_addlstring' at the end + // %* is parsed above, so if we got here we must have %...* + luaL_error(L, "'%%*' does not take a form"); } default: { // also treat cases `pnLlh' diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 2909d477..0d5a53df 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -132,7 +132,7 @@ #endif // Does VM support native execution via ExecutionCallbacks? We mostly assume it does but keep the define to make it easy to quantify the cost. -#define VM_HAS_NATIVE LUA_CUSTOM_EXECUTION +#define VM_HAS_NATIVE 1 LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) { @@ -2380,7 +2380,7 @@ reentry: else goto exit; #else - LUAU_ASSERT(!"Opcode is only valid when LUA_CUSTOM_EXECUTION is defined"); + LUAU_ASSERT(!"Opcode is only valid when VM_HAS_NATIVE is defined"); LUAU_UNREACHABLE(); #endif } diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index e1ecaf65..1ad7852d 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -7,6 +7,7 @@ #include "Luau/CodeGen.h" #include "Luau/Common.h" #include "Luau/Compiler.h" +#include "Luau/Config.h" #include "Luau/Frontend.h" #include "Luau/Linter.h" #include "Luau/ModuleResolver.h" diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 25521e35..723e107e 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -295,4 +295,34 @@ TEST_CASE_FIXTURE(Fixture, "include_types_ancestry") CHECK(ancestryTypes.back()->asType()); } +TEST_CASE_FIXTURE(Fixture, "find_name_ancestry") +{ + ScopedFastFlag sff{"FixFindBindingAtFunctionName", true}; + check(R"( + local tbl = {} + function tbl:abc() end + )"); + const Position pos(2, 18); + + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), pos); + + REQUIRE(!ancestry.empty()); + CHECK(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "find_expr_ancestry") +{ + ScopedFastFlag sff{"FixFindBindingAtFunctionName", true}; + check(R"( + local tbl = {} + function tbl:abc() end + )"); + const Position pos(2, 29); + + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), pos); + + REQUIRE(!ancestry.empty()); + CHECK(ancestry.back()->is()); +} + TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index b8171a75..fac23e88 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -80,7 +80,7 @@ struct ACFixtureImpl : BaseType { if (prevChar == '@') { - LUAU_ASSERT("Illegal marker character" && c >= '0' && c <= '9'); + LUAU_ASSERT("Illegal marker character" && ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z'))); LUAU_ASSERT("Duplicate marker found" && markerPosition.count(c) == 0); markerPosition.insert(std::pair{c, curPos}); } @@ -126,7 +126,6 @@ struct ACFixtureImpl : BaseType LUAU_ASSERT(i != markerPosition.end()); return i->second; } - ScopedFastFlag flag{"LuauAutocompleteHideSelfArg", true}; // Maps a marker character (0-9 inclusive) to a position in the source code. std::map markerPosition; }; @@ -3083,6 +3082,86 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") CHECK(ac.entryMap.count("\"down\"")); } +// https://github.com/Roblox/luau/issues/858 +TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement") +{ + ScopedFastFlag sff{"LuauAutocompleteStringLiteralBounds", true}; + + check(R"( + --!strict + + type Direction = "left" | "right" + + local dir: Direction = "left" + + if dir == @1"@2"@3 then end + local a: {[Direction]: boolean} = {[@4"@5"@6]} + + if dir == @7`@8`@9 then end + local a: {[Direction]: boolean} = {[@A`@B`@C]} + )"); + + auto ac = autocomplete('1'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("left")); + CHECK(ac.entryMap.count("right")); + + ac = autocomplete('3'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('4'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('5'); + + CHECK(ac.entryMap.count("left")); + CHECK(ac.entryMap.count("right")); + + ac = autocomplete('6'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('7'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('8'); + + CHECK(ac.entryMap.count("left")); + CHECK(ac.entryMap.count("right")); + + ac = autocomplete('9'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('A'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('B'); + + CHECK(ac.entryMap.count("left")); + CHECK(ac.entryMap.count("right")); + + ac = autocomplete('C'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); +} + TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") { check(R"( diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index b44ca6d5..298035c2 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -187,7 +187,7 @@ TEST_CASE("WindowsUnwindCodesX64") unwind.startInfo(UnwindBuilder::X64); unwind.startFunction(); - unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}); + unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}, {}); unwind.finishFunction(0x11223344, 0x55443322); unwind.finishInfo(); @@ -211,7 +211,7 @@ TEST_CASE("Dwarf2UnwindCodesX64") unwind.startInfo(UnwindBuilder::X64); unwind.startFunction(); - unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}); + unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}, {}); unwind.finishFunction(0, 0); unwind.finishInfo(); @@ -309,6 +309,11 @@ static void throwing(int64_t arg) throw std::runtime_error("testing"); } +static void nonthrowing(int64_t arg) +{ + CHECK(arg == 25); +} + TEST_CASE("GeneratedCodeExecutionWithThrowX64") { using namespace X64; @@ -339,7 +344,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") uint32_t prologueSize = build.setLabel().location; - unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}); + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}, {}); // Body build.mov(rNonVol1, rArg1); @@ -379,6 +384,8 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") using FunctionType = int64_t(int64_t, void (*)(int64_t)); FunctionType* f = (FunctionType*)nativeEntry; + f(10, nonthrowing); + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here try { @@ -390,6 +397,121 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") } } +static void obscureThrowCase(int64_t (*f)(int64_t, void (*)(int64_t))) +{ + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} + +TEST_CASE("GeneratedCodeExecutionWithThrowX64Simd") +{ + // This test requires AVX + if (!Luau::CodeGen::isSupported()) + return; + + using namespace X64; + + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->startInfo(UnwindBuilder::X64); + + Label functionBegin = build.setLabel(); + unwind->startFunction(); + + int stackSize = 32 + 64; + int localsSize = 16; + + // Prologue + build.push(rNonVol1); + build.push(rNonVol2); + build.push(rbp); + build.sub(rsp, stackSize + localsSize); + + if (build.abi == ABIX64::Windows) + { + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x40)], xmm6); + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x30)], xmm7); + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x20)], xmm8); + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x10)], xmm9); + } + + uint32_t prologueSize = build.setLabel().location; + + if (build.abi == ABIX64::Windows) + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rbp}, {xmm6, xmm7, xmm8, xmm9}); + else + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rbp}, {}); + + // Body + build.vxorpd(xmm0, xmm0, xmm0); + build.vmovsd(xmm6, xmm0, xmm0); + build.vmovsd(xmm7, xmm0, xmm0); + build.vmovsd(xmm8, xmm0, xmm0); + build.vmovsd(xmm9, xmm0, xmm0); + + build.mov(rNonVol1, rArg1); + build.mov(rNonVol2, rArg2); + + build.add(rNonVol1, 15); + build.mov(rArg1, rNonVol1); + build.call(rNonVol2); + + // Epilogue + if (build.abi == ABIX64::Windows) + { + build.vmovaps(xmm6, xmmword[rsp + ((stackSize + localsSize) - 0x40)]); + build.vmovaps(xmm7, xmmword[rsp + ((stackSize + localsSize) - 0x30)]); + build.vmovaps(xmm8, xmmword[rsp + ((stackSize + localsSize) - 0x20)]); + build.vmovaps(xmm9, xmmword[rsp + ((stackSize + localsSize) - 0x10)]); + } + + build.add(rsp, stackSize + localsSize); + build.pop(rbp); + build.pop(rNonVol2); + build.pop(rNonVol1); + build.ret(); + + unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); + + build.finalize(); + + unwind->finishInfo(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f = (FunctionType*)nativeEntry; + + f(10, nonthrowing); + + obscureThrowCase(f); +} + TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") { using namespace X64; @@ -425,7 +547,7 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") uint32_t prologueSize = build.setLabel().location - start1.location; - unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}); + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}, {}); // Body build.mov(rNonVol1, rArg1); @@ -464,7 +586,7 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") uint32_t prologueSize = build.setLabel().location - start2.location; - unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rNonVol3, rNonVol4}); + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rNonVol3, rNonVol4}, {}); // Body build.mov(rNonVol3, rArg1); @@ -561,7 +683,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") uint32_t prologueSize = build.setLabel().location; - unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {r10, r11, r12, r13, r14, r15}); + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {r10, r11, r12, r13, r14, r15}, {}); // Body build.mov(rax, rArg1); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 078b8af6..93290567 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -63,6 +63,35 @@ static std::string compileTypeTable(const char* source) TEST_SUITE_BEGIN("Compiler"); +TEST_CASE("BytecodeIsStable") +{ + // As noted in Bytecode.h, all enums used for bytecode storage and serialization are order-sensitive + // Adding entries in the middle will typically pass the tests but break compatibility + // This test codifies this by validating that in each enum, the last (or close-to-last) entry has a fixed encoding + + // This test will need to get occasionally revised to "move" the checked enum entries forward as we ship newer versions + // When doing so, please add *new* checks for more recent bytecode versions and keep existing checks in place. + + // Bytecode ops (serialized & in-memory) + CHECK(LOP_FASTCALL2K == 75); // bytecode v1 + CHECK(LOP_JUMPXEQKS == 80); // bytecode v3 + + // Bytecode fastcall ids (serialized & in-memory) + // Note: these aren't strictly bound to specific bytecode versions, but must monotonically increase to keep backwards compat + CHECK(LBF_VECTOR == 54); + CHECK(LBF_TOSTRING == 63); + + // Bytecode capture type (serialized & in-memory) + CHECK(LCT_UPVAL == 2); // bytecode v1 + + // Bytecode constants (serialized) + CHECK(LBC_CONSTANT_CLOSURE == 6); // bytecode v1 + + // Bytecode type encoding (serialized & in-memory) + // Note: these *can* change retroactively *if* type version is bumped, but probably shouldn't + LUAU_ASSERT(LBC_TYPE_VECTOR == 8); // type version 1 +} + TEST_CASE("CompileToBytecode") { Luau::BytecodeBuilder bcb; @@ -5085,7 +5114,7 @@ RETURN R1 1 )"); } -TEST_CASE("InlineBasicProhibited") +TEST_CASE("InlineProhibited") { // we can't inline variadic functions CHECK_EQ("\n" + compileFunction(R"( @@ -5125,6 +5154,66 @@ RETURN R1 1 )"); } +TEST_CASE("InlineProhibitedRecursion") +{ + // we can't inline recursive invocations of functions in the functions + // this is actually profitable in certain cases, but it complicates the compiler as it means a local has multiple registers/values + + // in this example, inlining is blocked because we're compiling fact() and we don't yet have the cost model / profitability data for fact() + CHECK_EQ("\n" + compileFunction(R"( +local function fact(n) + return if n <= 1 then 1 else fact(n-1)*n +end + +return fact +)", + 0, 2), + R"( +LOADN R2 1 +JUMPIFNOTLE R0 R2 L0 +LOADN R1 1 +RETURN R1 1 +L0: GETUPVAL R2 0 +SUBK R3 R0 K0 [1] +CALL R2 1 1 +MUL R1 R2 R0 +RETURN R1 1 +)"); + + // in this example, inlining of fact() succeeds, but the nested call to fact() fails since fact is already on the inline stack + CHECK_EQ("\n" + compileFunction(R"( +local function fact(n) + return if n <= 1 then 1 else fact(n-1)*n +end + +local function factsafe(n) + assert(n >= 1) + return fact(n) +end + +return factsafe +)", + 1, 2), + R"( +LOADN R3 1 +JUMPIFLE R3 R0 L0 +LOADB R2 0 +1 +L0: LOADB R2 1 +L1: FASTCALL1 1 R2 L2 +GETIMPORT R1 1 [assert] +CALL R1 1 0 +L2: LOADN R2 1 +JUMPIFNOTLE R0 R2 L3 +LOADN R1 1 +RETURN R1 1 +L3: GETUPVAL R2 0 +SUBK R3 R0 K2 [1] +CALL R2 1 1 +MUL R1 R2 R0 +RETURN R1 1 +)"); +} + TEST_CASE("InlineNestedLoops") { // functions with basic loops get inlined @@ -7252,10 +7341,31 @@ end )"); } +TEST_CASE("TypeUnionIntersection") +{ + CHECK_EQ("\n" + compileTypeTable(R"( +function myfunc(test: string | nil, foo: nil) +end + +function myfunc2(test: string & nil, foo: nil) +end + +function myfunc3(test: string | number, foo: nil) +end + +function myfunc4(test: string & number, foo: nil) +end +)"), + R"( +0: function(string?, nil) +1: function(any, nil) +2: function(any, nil) +3: function(any, nil) +)"); +} + TEST_CASE("BuiltinFoldMathK") { - ScopedFastFlag sff("LuauCompileFoldMathK", true); - // we can fold math.pi at optimization level 2 CHECK_EQ("\n" + compileFunction(R"( function test() diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index fda0a6f0..ec0c213a 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -446,8 +446,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_exports") { - ScopedFastFlag luauFixCyclicModuleExports{"LuauFixCyclicModuleExports", true}; - fileResolver.source["game/A"] = R"( local b = require(game.B) export type atype = { x: b.btype } diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 3798082b..26a157e4 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -900,7 +900,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_begin") } catch (const ParseErrors& e) { - CHECK_EQ("Double braces are not permitted within interpolated strings. Did you mean '\\{'?", e.getErrors().front().getMessage()); + CHECK_EQ("Double braces are not permitted within interpolated strings; did you mean '\\{'?", e.getErrors().front().getMessage()); } } @@ -915,7 +915,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_mid") } catch (const ParseErrors& e) { - CHECK_EQ("Double braces are not permitted within interpolated strings. Did you mean '\\{'?", e.getErrors().front().getMessage()); + CHECK_EQ("Double braces are not permitted within interpolated strings; did you mean '\\{'?", e.getErrors().front().getMessage()); } } @@ -933,7 +933,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace") CHECK_EQ(e.getErrors().size(), 1); auto error = e.getErrors().front(); - CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", error.getMessage()); + CHECK_EQ("Malformed interpolated string; did you forget to add a '}'?", error.getMessage()); return error.getLocation().begin.column; } }; @@ -956,7 +956,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace_in_table { CHECK_EQ(e.getErrors().size(), 2); - CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", e.getErrors().front().getMessage()); + CHECK_EQ("Malformed interpolated string; did you forget to add a '}'?", e.getErrors().front().getMessage()); CHECK_EQ("Expected '}' (to close '{' at line 2), got ", e.getErrors().back().getMessage()); } } @@ -974,7 +974,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_mid_without_end_brace_in_t { CHECK_EQ(e.getErrors().size(), 2); - CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", e.getErrors().front().getMessage()); + CHECK_EQ("Malformed interpolated string; did you forget to add a '}'?", e.getErrors().front().getMessage()); CHECK_EQ("Expected '}' (to close '{' at line 2), got ", e.getErrors().back().getMessage()); } } @@ -1041,6 +1041,36 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_expression") } } +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_malformed_escape") +{ + try + { + parse(R"( + local a = `???\xQQ {1}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Interpolated string literal contains malformed escape sequence", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_weird_token") +{ + try + { + parse(R"( + local a = `??? {42 !!}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Malformed interpolated string, got '!'", e.getErrors().front().getMessage()); + } +} + TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection") { try @@ -1569,9 +1599,9 @@ TEST_CASE_FIXTURE(Fixture, "string_literals_escapes_broken") TEST_CASE_FIXTURE(Fixture, "string_literals_broken") { - matchParseError("return \"", "Malformed string"); - matchParseError("return \"\\", "Malformed string"); - matchParseError("return \"\r\r", "Malformed string"); + matchParseError("return \"", "Malformed string; did you forget to finish it?"); + matchParseError("return \"\\", "Malformed string; did you forget to finish it?"); + matchParseError("return \"\r\r", "Malformed string; did you forget to finish it?"); } TEST_CASE_FIXTURE(Fixture, "number_literals") @@ -2530,12 +2560,12 @@ TEST_CASE_FIXTURE(Fixture, "incomplete_method_call_still_yields_an_AstExprIndexN TEST_CASE_FIXTURE(Fixture, "recover_confusables") { // Binary - matchParseError("local a = 4 != 10", "Unexpected '!=', did you mean '~='?"); - matchParseError("local a = true && false", "Unexpected '&&', did you mean 'and'?"); - matchParseError("local a = false || true", "Unexpected '||', did you mean 'or'?"); + matchParseError("local a = 4 != 10", "Unexpected '!='; did you mean '~='?"); + matchParseError("local a = true && false", "Unexpected '&&'; did you mean 'and'?"); + matchParseError("local a = false || true", "Unexpected '||'; did you mean 'or'?"); // Unary - matchParseError("local a = !false", "Unexpected '!', did you mean 'not'?"); + matchParseError("local a = !false", "Unexpected '!'; did you mean 'not'?"); // Check that separate tokens are not considered as a single one matchParseError("local a = 4 ! = 10", "Expected identifier when parsing expression, got '!'"); @@ -2880,4 +2910,64 @@ TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_ty CHECK_EQ("Expected type pack after '=', got type", result.errors[1].getMessage()); } +TEST_CASE_FIXTURE(Fixture, "table_type_keys_cant_contain_nul") +{ + ParseResult result = tryParse(R"( + type Foo = { ["\0"]: number } + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 21}, {1, 22}}, result.errors[0].getLocation()); + CHECK_EQ("String literal contains malformed escape sequence or \\0", result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "invalid_escape_literals_get_reported_but_parsing_continues") +{ + ParseResult result = tryParse(R"( + local foo = "\xQQ" + print(foo) + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 20}, {1, 26}}, result.errors[0].getLocation()); + CHECK_EQ("String literal contains malformed escape sequence", result.errors[0].getMessage()); + + REQUIRE(result.root); + CHECK_EQ(result.root->body.size, 2); +} + +TEST_CASE_FIXTURE(Fixture, "unfinished_string_literals_get_reported_but_parsing_continues") +{ + ParseResult result = tryParse(R"( + local foo = "hi + print(foo) + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 20}, {1, 23}}, result.errors[0].getLocation()); + CHECK_EQ("Malformed string; did you forget to finish it?", result.errors[0].getMessage()); + + REQUIRE(result.root); + CHECK_EQ(result.root->body.size, 2); +} + +TEST_CASE_FIXTURE(Fixture, "unfinished_string_literal_types_get_reported_but_parsing_continues") +{ + ParseResult result = tryParse(R"( + type Foo = "hi + print(foo) + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 19}, {1, 22}}, result.errors[0].getLocation()); + CHECK_EQ("Malformed string; did you forget to finish it?", result.errors[0].getMessage()); + + REQUIRE(result.root); + CHECK_EQ(result.root->body.size, 2); +} + TEST_SUITE_END(); diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index f7f24d0f..e1fa0e5a 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -4,16 +4,18 @@ #include "Fixture.h" #include "Luau/Subtyping.h" +#include "Luau/TypePack.h" using namespace Luau; struct SubtypeFixture : Fixture { TypeArena arena; - InternalErrorReporter ice; + InternalErrorReporter iceReporter; UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Subtyping subtyping{builtinTypes, NotNull{&normalizer}}; + + Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&iceReporter}}; TypePackId pack(std::initializer_list tys) { @@ -45,7 +47,28 @@ struct SubtypeFixture : Fixture return arena.addType(FunctionType{pack(argHead, std::move(argTail)), pack(retHead, std::move(retTail))}); } - SubtypingGraph isSubtype(TypeId subTy, TypeId superTy) + TypeId tbl(TableType::Props&& props) + { + return arena.addType(TableType{std::move(props), std::nullopt, {}, TableState::Sealed}); + } + + TypeId cyclicTable(std::function&& cb) + { + TypeId res = arena.addType(GenericType{}); + TableType tt{}; + cb(res, &tt); + emplaceType(asMutable(res), std::move(tt)); + return res; + } + + TypeId genericT = arena.addType(GenericType{"T"}); + TypeId genericU = arena.addType(GenericType{"U"}); + + TypePackId genericAs = arena.addTypePack(GenericTypePack{"A"}); + TypePackId genericBs = arena.addTypePack(GenericTypePack{"B"}); + TypePackId genericCs = arena.addTypePack(GenericTypePack{"C"}); + + SubtypingResult isSubtype(TypeId subTy, TypeId superTy) { return subtyping.isSubtype(subTy, superTy); } @@ -57,7 +80,16 @@ struct SubtypeFixture : Fixture TypeId helloOrWorldType = arena.addType(UnionType{{helloType, worldType}}); TypeId trueOrFalseType = arena.addType(UnionType{{builtinTypes->trueType, builtinTypes->falseType}}); + // "hello" | "hello" + TypeId helloOrHelloType = arena.addType(UnionType{{helloType, helloType}}); + + // () -> () + const TypeId nothingToNothingType = fn({}, {}); + + // ("hello") -> "world" TypeId helloAndWorldType = arena.addType(IntersectionType{{helloType, worldType}}); + + // (boolean) -> true TypeId booleanAndTrueType = arena.addType(IntersectionType{{builtinTypes->booleanType, builtinTypes->trueType}}); // (number) -> string @@ -72,6 +104,24 @@ struct SubtypeFixture : Fixture {builtinTypes->stringType} ); + // (number) -> () + const TypeId numberToNothingType = fn( + {builtinTypes->numberType}, + {} + ); + + // () -> number + const TypeId nothingToNumberType = fn( + {}, + {builtinTypes->numberType} + ); + + // (number) -> number + const TypeId numberToNumberType = fn( + {builtinTypes->numberType}, + {builtinTypes->numberType} + ); + // (number) -> unknown const TypeId numberToUnknownType = fn( {builtinTypes->numberType}, @@ -120,6 +170,83 @@ struct SubtypeFixture : Fixture {builtinTypes->stringType} ); + // (...number) -> number + const TypeId numbersToNumberType = arena.addType(FunctionType{ + arena.addTypePack(VariadicTypePack{builtinTypes->numberType}), + arena.addTypePack({builtinTypes->numberType}) + }); + + // (T) -> () + const TypeId genericTToNothingType = arena.addType(FunctionType{ + {genericT}, + {}, + arena.addTypePack({genericT}), + builtinTypes->emptyTypePack + }); + + // (T) -> T + const TypeId genericTToTType = arena.addType(FunctionType{ + {genericT}, + {}, + arena.addTypePack({genericT}), + arena.addTypePack({genericT}) + }); + + // (U) -> () + const TypeId genericUToNothingType = arena.addType(FunctionType{ + {genericU}, + {}, + arena.addTypePack({genericU}), + builtinTypes->emptyTypePack + }); + + // () -> T + const TypeId genericNothingToTType = arena.addType(FunctionType{ + {genericT}, + {}, + builtinTypes->emptyTypePack, + arena.addTypePack({genericT}) + }); + + // (A...) -> A... + const TypeId genericAsToAsType = arena.addType(FunctionType{ + {}, + {genericAs}, + genericAs, + genericAs + }); + + // (A...) -> number + const TypeId genericAsToNumberType = arena.addType(FunctionType{ + {}, + {genericAs}, + genericAs, + arena.addTypePack({builtinTypes->numberType}) + }); + + // (B...) -> B... + const TypeId genericBsToBsType = arena.addType(FunctionType{ + {}, + {genericBs}, + genericBs, + genericBs + }); + + // (B...) -> C... + const TypeId genericBsToCsType = arena.addType(FunctionType{ + {}, + {genericBs, genericCs}, + genericBs, + genericCs + }); + + // () -> A... + const TypeId genericNothingToAsType = arena.addType(FunctionType{ + {}, + {genericAs}, + builtinTypes->emptyTypePack, + genericAs + }); }; #define CHECK_IS_SUBTYPE(left, right) \ @@ -127,7 +254,7 @@ struct SubtypeFixture : Fixture { \ const auto& leftTy = (left); \ const auto& rightTy = (right); \ - SubtypingGraph result = isSubtype(leftTy, rightTy); \ + SubtypingResult result = isSubtype(leftTy, rightTy); \ CHECK_MESSAGE(result.isSubtype, "Expected " << leftTy << " <: " << rightTy); \ } while (0) @@ -136,7 +263,7 @@ struct SubtypeFixture : Fixture { \ const auto& leftTy = (left); \ const auto& rightTy = (right); \ - SubtypingGraph result = isSubtype(leftTy, rightTy); \ + SubtypingResult result = isSubtype(leftTy, rightTy); \ CHECK_MESSAGE(!result.isSubtype, "Expected " << leftTy << " anyType, builtinTypes->unknownType); + SubtypingResult result = isSubtype(builtinTypes->anyType, builtinTypes->unknownType); CHECK(!result.isSubtype); CHECK(result.isErrorSuppressing); } @@ -244,6 +371,11 @@ TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" | \"world\" <: number") CHECK_IS_NOT_SUBTYPE(helloOrWorldType, builtinTypes->numberType); } +TEST_CASE_FIXTURE(SubtypeFixture, "string stringType, helloOrHelloType); +} + TEST_CASE_FIXTURE(SubtypeFixture, "true <: boolean & true") { CHECK_IS_SUBTYPE(builtinTypes->trueType, booleanAndTrueType); @@ -349,4 +481,206 @@ TEST_CASE_FIXTURE(SubtypeFixture, "(number, string) -> string () -> T <: () -> number") +{ + CHECK_IS_SUBTYPE(genericNothingToTType, nothingToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> () <: (U) -> ()") +{ + CHECK_IS_SUBTYPE(genericTToNothingType, genericUToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> number () -> T") +{ + CHECK_IS_NOT_SUBTYPE(nothingToNumberType, genericNothingToTType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> () <: (number) -> ()") +{ + CHECK_IS_SUBTYPE(genericTToNothingType, numberToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> T <: (number) -> number") +{ + CHECK_IS_SUBTYPE(genericTToTType, numberToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> T string") +{ + CHECK_IS_NOT_SUBTYPE(genericTToTType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> () <: (U) -> ()") +{ + CHECK_IS_SUBTYPE(genericTToNothingType, genericUToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> () (T) -> ()") +{ + CHECK_IS_NOT_SUBTYPE(numberToNothingType, genericTToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... <: (number) -> number") +{ + CHECK_IS_SUBTYPE(genericAsToAsType, numberToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> number (A...) -> A...") +{ + CHECK_IS_NOT_SUBTYPE(numberToNumberType, genericAsToAsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... <: (B...) -> B...") +{ + CHECK_IS_SUBTYPE(genericAsToAsType, genericBsToBsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(B...) -> C... <: (A...) -> A...") +{ + CHECK_IS_SUBTYPE(genericBsToCsType, genericAsToAsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... (B...) -> C...") +{ + CHECK_IS_NOT_SUBTYPE(genericAsToAsType, genericBsToCsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> number <: (number) -> number") +{ + CHECK_IS_SUBTYPE(genericAsToNumberType, numberToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> number (A...) -> number") +{ + CHECK_IS_NOT_SUBTYPE(numberToNumberType, genericAsToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> number <: (...number) -> number") +{ + CHECK_IS_SUBTYPE(genericAsToNumberType, numbersToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(...number) -> number (A...) -> number") +{ + CHECK_IS_NOT_SUBTYPE(numbersToNumberType, genericAsToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> A... <: () -> ()") +{ + CHECK_IS_SUBTYPE(genericNothingToAsType, nothingToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> () () -> A...") +{ + CHECK_IS_NOT_SUBTYPE(nothingToNothingType, genericNothingToAsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... <: () -> ()") +{ + CHECK_IS_SUBTYPE(genericAsToAsType, nothingToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> () (A...) -> A...") +{ + CHECK_IS_NOT_SUBTYPE(nothingToNothingType, genericAsToAsType); +} + + +TEST_CASE_FIXTURE(SubtypeFixture, "{} <: {}") +{ + CHECK_IS_SUBTYPE(tbl({}), tbl({})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number} <: {}") +{ + CHECK_IS_SUBTYPE(tbl({{"x", builtinTypes->numberType}}), tbl({})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number} numberType}}), tbl({{"x", builtinTypes->stringType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number} numberType}}), tbl({{"x", builtinTypes->optionalNumberType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number?} optionalNumberType}}), tbl({{"x", builtinTypes->numberType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: (T) -> ()} <: {x: (U) -> ()}") +{ + CHECK_IS_SUBTYPE( + tbl({{"x", genericTToNothingType}}), + tbl({{"x", genericUToNothingType}}) + ); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> string} <: t2 where t2 = {trim: (t2) -> string}") +{ + TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + }); + + TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + }); + + CHECK_IS_SUBTYPE(t1, t2); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> string} t2}") +{ + TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + }); + + TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {ty}); + }); + + CHECK_IS_NOT_SUBTYPE(t1, t2); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> t1} string}") +{ + TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {ty}); + }); + + TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) + { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + }); + + CHECK_IS_NOT_SUBTYPE(t1, t2); +} + +/* + * (A) -> A <: (X) -> X + * A can be bound to X. + * + * (A) -> A (X) -> number + * A can be bound to X, but A number (A) -> A + * Only generics on the left side can be bound. + * number (A, B) -> boolean <: (X, X) -> boolean + * It is ok to bind both A and B to X. + * + * (A, A) -> boolean (X, Y) -> boolean + * A cannot be bound to both X and Y. + */ + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index b390a816..c285edf9 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -1017,4 +1017,54 @@ local y = x["Bar"] LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "cli_80596_simplify_degenerate_intersections") +{ + ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + type A = { + x: number?, + } + + type B = { + x: number?, + } + + type C = A & B + local obj: C = { + x = 3, + } + + local x: number = obj.x or 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cli_80596_simplify_more_realistic_intersections") +{ + ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + type A = { + x: number?, + y: string?, + } + + type B = { + x: number?, + z: string?, + } + + type C = A & B + local obj: C = { + x = 3, + } + + local x: number = obj.x or 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END();