diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index d38976ef..dfe373a5 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -42,6 +42,21 @@ struct ExprOrLocal { return expr ? expr->location : (local ? local->location : std::optional{}); } + std::optional getName() + { + if (expr) + { + if (AstName name = getIdentifier(expr); name.value) + { + return name; + } + } + else if (local) + { + return local->name; + } + return std::nullopt; + } private: AstExpr* expr = nullptr; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 2e41674b..1bf0473c 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -13,6 +13,8 @@ #include #include +LUAU_FASTFLAG(LuauPrepopulateUnionOptionsBeforeAllocation) + namespace Luau { @@ -58,6 +60,12 @@ struct TypeArena template TypeId addType(T tv) { + if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation) + { + if constexpr (std::is_same_v) + LUAU_ASSERT(tv.options.size() >= 2); + } + return addTV(TypeVar(std::move(tv))); } diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index aa090014..b843509d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -135,7 +135,8 @@ struct TypeChecker void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); - ExprResult checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt); + ExprResult checkExpr( + const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt, bool forceSingleton = false); ExprResult checkExpr(const ScopePtr& scope, const AstExprLocal& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); @@ -160,14 +161,12 @@ struct TypeChecker // Returns the type of the lvalue. TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr); - // Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding). - // Note: the binding may be null. - // TODO: remove second return value with FFlagLuauUpdateFunctionNameBinding - std::pair checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); + // Returns the type of the lvalue. + TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, @@ -322,8 +321,6 @@ private: return addTV(TypeVar(tv)); } - TypeId addType(const UnionTypeVar& utv); - TypeId addTV(TypeVar&& tv); TypePackId addTypePack(TypePackVar&& tp); @@ -349,6 +346,8 @@ public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); private: + void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate); + std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 3f5e26d6..11dc9377 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -111,16 +111,16 @@ struct PrimitiveTypeVar // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Types for true and false -struct BoolSingleton +struct BooleanSingleton { bool value; - bool operator==(const BoolSingleton& rhs) const + bool operator==(const BooleanSingleton& rhs) const { return value == rhs.value; } - bool operator!=(const BoolSingleton& rhs) const + bool operator!=(const BooleanSingleton& rhs) const { return !(*this == rhs); } @@ -145,7 +145,7 @@ struct StringSingleton // No type for float singletons, partly because === isn't any equalivalence on floats // (NaN != NaN). -using SingletonVariant = Luau::Variant; +using SingletonVariant = Luau::Variant; struct SingletonTypeVar { diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index a3be739a..1b1671c0 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -85,6 +85,13 @@ public: Unifier makeChildUnifier(); + // A utility function that appends the given error to the unifier's error log. + // This allows setting a breakpoint wherever the unifier reports an error. + void reportError(TypeError error) + { + errors.push_back(error); + } + private: bool isNonstrictMode() const; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 7a801f97..85099e12 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,9 +14,9 @@ LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); +LUAU_FASTFLAGVARIABLE(PreferToCallFunctionsForIntersects, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -194,8 +194,6 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) { - LUAU_ASSERT(FFlag::LuauAutocompleteFirstArg); - auto expr = node->asExpr(); if (!expr) return std::nullopt; @@ -266,43 +264,63 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } }; - TypeId expectedType; + auto typeAtPosition = findExpectedTypeAt(module, node, position); - if (FFlag::LuauAutocompleteFirstArg) + if (!typeAtPosition) + return TypeCorrectKind::None; + + TypeId expectedType = follow(*typeAtPosition); + + if (FFlag::PreferToCallFunctionsForIntersects) { - auto typeAtPosition = findExpectedTypeAt(module, node, position); + auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) { + auto [retHead, retTail] = flatten(ftv->retType); - if (!typeAtPosition) - return TypeCorrectKind::None; + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + return true; - expectedType = follow(*typeAtPosition); + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return true; + } + + return false; + }; + + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + for (TypeId id : itv->parts) + { + if (const FunctionTypeVar* ftv = get(id); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + } + } } else { - auto expr = node->asExpr(); - if (!expr) - return TypeCorrectKind::None; - - auto it = module.astExpectedTypes.find(expr); - if (!it) - return TypeCorrectKind::None; - - expectedType = follow(*it); - } - - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty)) - { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty)) { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + auto [retHead, retTail] = flatten(ftv->retType); + + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return TypeCorrectKind::CorrectFunctionResult; + + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return TypeCorrectKind::CorrectFunctionResult; + } } } @@ -741,29 +759,12 @@ std::optional returnFirstNonnullOptionOfType(const UnionTypeVar* utv) static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) { - TypeId expectedType; + auto typeAtPosition = findExpectedTypeAt(module, node, position); - if (FFlag::LuauAutocompleteFirstArg) - { - auto typeAtPosition = findExpectedTypeAt(module, node, position); + if (!typeAtPosition) + return std::nullopt; - if (!typeAtPosition) - return std::nullopt; - - expectedType = follow(*typeAtPosition); - } - else - { - auto expr = node->asExpr(); - if (!expr) - return std::nullopt; - - auto it = module.astExpectedTypes.find(expr); - if (!it) - return std::nullopt; - - expectedType = follow(*it); - } + TypeId expectedType = follow(*typeAtPosition); if (get(expectedType)) return true; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index fe4b6529..9001b19d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) namespace Luau { @@ -102,8 +101,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - if (FFlag::LuauPersistDefinitionFileTypes) - persist(globalTy); + persist(globalTy); } for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) @@ -113,8 +111,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; - if (FFlag::LuauPersistDefinitionFileTypes) - persist(globalTy.type); + persist(globalTy.type); } return LoadDefinitionFileResult{true, parseResult, checkedModule}; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 9f352f4b..4fdff8f7 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -16,6 +16,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypeAliasDefaults) +LUAU_FASTFLAGVARIABLE(LuauPrepopulateUnionOptionsBeforeAllocation, false) + namespace Luau { @@ -377,14 +379,28 @@ void TypeCloner::operator()(const AnyTypeVar& t) void TypeCloner::operator()(const UnionTypeVar& t) { - TypeId result = dest.addType(UnionTypeVar{}); - seenTypes[typeId] = result; + if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation) + { + std::vector options; + options.reserve(t.options.size()); - UnionTypeVar* option = getMutable(result); - LUAU_ASSERT(option != nullptr); + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - for (TypeId ty : t.options) - option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + TypeId result = dest.addType(UnionTypeVar{std::move(options)}); + seenTypes[typeId] = result; + } + else + { + TypeId result = dest.addType(UnionTypeVar{}); + seenTypes[typeId] = result; + + UnionTypeVar* option = getMutable(result); + LUAU_ASSERT(option != nullptr); + + for (TypeId ty : t.options) + option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + } } void TypeCloner::operator()(const IntersectionTypeVar& t) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 4b898d3a..5e79b841 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAG(LuauTypeAliasDefaults) /* @@ -374,7 +373,7 @@ struct TypeVarStringifier void operator()(TypeId, const SingletonTypeVar& stv) { - if (const BoolSingleton* bs = Luau::get(&stv)) + if (const BooleanSingleton* bs = Luau::get(&stv)) state.emit(bs->value ? "true" : "false"); else if (const StringSingleton* ss = Luau::get(&stv)) { @@ -617,9 +616,7 @@ struct TypeVarStringifier std::string saved = std::move(state.result.name); - bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions - ? !state.cycleNames.count(el) && (get(el) || get(el)) - : get(el) || get(el); + bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -675,9 +672,7 @@ struct TypeVarStringifier std::string saved = std::move(state.result.name); - bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions - ? !state.cycleNames.count(el) && (get(el) || get(el)) - : get(el) || get(el); + bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); if (needParens) state.emit("("); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 2ec02093..2208213f 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -97,7 +97,7 @@ public: AstType* operator()(const SingletonTypeVar& stv) { - if (const BoolSingleton* bs = get(&stv)) + if (const BooleanSingleton* bs = get(&stv)) return allocator->alloc(Location(), bs->value); else if (const StringSingleton* ss = get(&stv)) { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e2d8a4fb..23fcc2d5 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -26,8 +26,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) -LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) @@ -37,6 +35,7 @@ LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) @@ -46,10 +45,8 @@ LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) -LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) -LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau { @@ -1139,33 +1136,25 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { - auto [leftType, leftTypeBinding] = checkLValueBinding(scope, *function.name); + TypeId leftType = checkLValueBinding(scope, *function.name); checkFunctionBody(funScope, ty, *function.func); unify(ty, leftType, function.location); - if (FFlag::LuauUpdateFunctionNameBinding) - { - LUAU_ASSERT(function.name->is() || function.name->is()); + LUAU_ASSERT(function.name->is() || function.name->is()); - if (auto exprIndexName = function.name->as()) + if (auto exprIndexName = function.name->as()) + { + if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) { - if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) + if (auto ttv = getMutableTableType(*typeIt)) { - if (auto ttv = getMutableTableType(*typeIt)) - { - if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) - it->second.type = follow(quantify(funScope, leftType, function.name->location)); - } + if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) + it->second.type = follow(quantify(funScope, leftType, function.name->location)); } } } - else - { - if (leftTypeBinding) - *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); - } } } @@ -1426,7 +1415,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType) +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) @@ -1443,14 +1432,14 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) { - if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType)))) result = {singletonType(bexpr->value)}; else result = {booleanType}; } else if (const AstExprConstantString* sexpr = expr.as()) { - if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType)))) result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; else result = {stringType}; @@ -1488,15 +1477,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result.type = follow(result.type); - if (FFlag::LuauStoreMatchingOverloadFnType) - { - if (!currentModule->astTypes.find(&expr)) - currentModule->astTypes[&expr] = result.type; - } - else - { + if (!currentModule->astTypes.find(&expr)) currentModule->astTypes[&expr] = result.type; - } if (expectedType) currentModule->astExpectedTypes[&expr] = *expectedType; @@ -2242,7 +2224,6 @@ TypeId TypeChecker::checkRelationalOperation( state.log.commit(); } - bool needsMetamethod = !isEquality; TypeId leftType = follow(lhsType); @@ -2250,10 +2231,11 @@ TypeId TypeChecker::checkRelationalOperation( { reportErrors(state.errors); - const PrimitiveTypeVar* ptv = get(leftType); - if (!isEquality && state.errors.empty() && (get(leftType) || (ptv && ptv->type == PrimitiveTypeVar::Boolean))) + if (!isEquality && state.errors.empty() && (get(leftType) || isBoolean(leftType))) + { reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())}); + } return booleanType; } @@ -2501,7 +2483,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi ExprResult rhs = checkExpr(innerScope, *expr.right); - return {checkBinaryOperation(innerScope, expr, lhs.type, rhs.type), {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + return {checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type), + {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } else if (expr.op == AstExprBinary::Or) { @@ -2513,7 +2496,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi ExprResult rhs = checkExpr(innerScope, *expr.right); // Because of C++, I'm not sure if lhs.predicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates); + TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type, lhs.predicates); return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) @@ -2521,8 +2504,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left); - ExprResult rhs = checkExpr(scope, *expr.right); + ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions); + ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions); PredicateVec predicates; @@ -2621,11 +2604,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIf TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) { - auto [ty, binding] = checkLValueBinding(scope, expr); - return ty; + return checkLValueBinding(scope, expr); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) return checkLValueBinding(scope, *a); @@ -2639,22 +2621,22 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { for (AstExpr* expr : a->expressions) checkExpr(scope, *expr); - return {errorRecoveryType(scope), nullptr}; + return errorRecoveryType(scope); } else ice("Unexpected AST node in checkLValue", expr.location); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) { if (std::optional ty = scope->lookup(expr.local)) - return {*ty, nullptr}; + return *ty; reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); - return {errorRecoveryType(scope), nullptr}; + return errorRecoveryType(scope); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) { Name name = expr.name.value; ScopePtr moduleScope = currentModule->getModuleScope(); @@ -2662,7 +2644,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto it = moduleScope->bindings.find(expr.name); if (it != moduleScope->bindings.end()) - return std::pair(it->second.typeId, &it->second.typeId); + return it->second.typeId; TypeId result = freshType(scope); Binding& binding = moduleScope->bindings[expr.name]; @@ -2673,15 +2655,15 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!isNonstrictMode()) reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - return std::pair(result, &binding.typeId); + return result; } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) { TypeId lhs = checkExpr(scope, *expr.expr).type; if (get(lhs) || get(lhs)) - return std::pair(lhs, nullptr); + return lhs; tablify(lhs); @@ -2694,7 +2676,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto& it = lhsTable->props.find(name); if (it != lhsTable->props.end()) { - return std::pair(it->second.type, &it->second.type); + return it->second.type; } else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { @@ -2702,7 +2684,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope Property& property = lhsTable->props[name]; property.type = theType; property.location = expr.indexLocation; - return std::pair(theType, &property.type); + return theType; } else if (auto indexer = lhsTable->indexer) { @@ -2720,17 +2702,17 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope else if (FFlag::LuauUseCommittingTxnLog) state.log.commit(); - return std::pair(retType, nullptr); + return retType; } else if (lhsTable->state == TableState::Sealed) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } else { reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } } else if (const ClassTypeVar* lhsClass = get(lhs)) @@ -2739,29 +2721,29 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } - return std::pair(prop->type, nullptr); + return prop->type; } else if (get(lhs)) { if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, false)) - return std::pair(*ty, nullptr); + return *ty; // If intersection has a table part, report that it cannot be extended just as a sealed table if (isTableIntersection(lhs)) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } } reportError(TypeError{expr.location, NotATable{lhs}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) { TypeId exprType = checkExpr(scope, *expr.expr).type; tablify(exprType); @@ -2771,7 +2753,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope TypeId indexType = checkExpr(scope, *expr.index).type; if (get(exprType) || get(exprType)) - return std::pair(exprType, nullptr); + return exprType; AstExprConstantString* value = expr.index->as(); @@ -2783,9 +2765,9 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } - return std::pair(prop->type, nullptr); + return prop->type; } } @@ -2794,7 +2776,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { reportError(TypeError{expr.expr->location, NotATable{exprType}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } if (value) @@ -2802,7 +2784,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto& it = exprTable->props.find(value->value.data); if (it != exprTable->props.end()) { - return std::pair(it->second.type, &it->second.type); + return it->second.type; } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { @@ -2810,7 +2792,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope Property& property = exprTable->props[value->value.data]; property.type = resultType; property.location = expr.index->location; - return std::pair(resultType, &property.type); + return resultType; } } @@ -2818,18 +2800,18 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { const TableIndexer& indexer = *exprTable->indexer; unify(indexType, indexer.indexType, expr.index->location); - return std::pair(indexer.indexResultType, nullptr); + return indexer.indexResultType; } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { TypeId resultType = freshType(scope); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; - return std::pair(resultType, nullptr); + return resultType; } else { TypeId resultType = freshType(scope); - return std::pair(resultType, nullptr); + return resultType; } } @@ -3326,7 +3308,7 @@ void TypeChecker::checkArgumentList( } // ok else { - state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); return; } ++paramIter; @@ -3348,7 +3330,7 @@ void TypeChecker::checkArgumentList( Location location = state.location; if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } TypePackId tail = state.log.follow(*paramIter.tail()); @@ -3405,7 +3387,7 @@ void TypeChecker::checkArgumentList( if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; // TODO: Better error message? - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } } @@ -3520,7 +3502,7 @@ void TypeChecker::checkArgumentList( } // ok else { - state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); return; } ++paramIter; @@ -3540,7 +3522,7 @@ void TypeChecker::checkArgumentList( Location location = state.location; if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } TypePackId tail = *paramIter.tail(); @@ -3606,7 +3588,7 @@ void TypeChecker::checkArgumentList( if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; // TODO: Better error message? - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } } @@ -3825,22 +3807,11 @@ std::optional> TypeChecker::checkCallOverload(const Scope metaArgLocations = *argLocations; metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); - if (FFlag::LuauFixRecursiveMetatableCall) - { - fn = instantiate(scope, *ty, expr.func->location); + fn = instantiate(scope, *ty, expr.func->location); - argPack = metaCallArgPack; - args = metaCallArgs; - argLocations = &metaArgLocations; - } - else - { - TypeId fn = *ty; - fn = instantiate(scope, fn, expr.func->location); - - return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, &metaArgLocations, argListResult, - overloadsThatMatchArgCount, overloadsThatDont, errors); - } + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; } } @@ -3932,8 +3903,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope } } - if (FFlag::LuauStoreMatchingOverloadFnType) - currentModule->astOverloadResolvedTypes[&expr] = fn; + currentModule->astOverloadResolvedTypes[&expr] = fn; // We select this overload return {{retPack}}; @@ -4776,7 +4746,7 @@ TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::singletonType(bool value) { // TODO: cache singleton types - return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value}))); + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BooleanSingleton{value}))); } TypeId TypeChecker::singletonType(std::string value) @@ -4813,13 +4783,6 @@ std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predic return std::nullopt; } -TypeId TypeChecker::addType(const UnionTypeVar& utv) -{ - LUAU_ASSERT(utv.options.size() > 1); - - return addTV(TypeVar(utv)); -} - TypeId TypeChecker::addTV(TypeVar&& tv) { return currentModule->internalTypes.addType(std::move(tv)); @@ -5347,54 +5310,35 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId instantiated = *maybeInstantiated; - if (FFlag::LuauCloneCorrectlyBeforeMutatingTableType) + // TODO: CLI-46926 it's not a good idea to rename the type here + TypeId target = follow(instantiated); + bool needsClone = follow(tf.type) == target; + TableTypeVar* ttv = getMutableTableType(target); + + if (ttv && needsClone) { - // TODO: CLI-46926 it's not a good idea to rename the type here - TypeId target = follow(instantiated); - bool needsClone = follow(tf.type) == target; - TableTypeVar* ttv = getMutableTableType(target); - - if (ttv && needsClone) + // Substitution::clone is a shallow clone. If this is a metatable type, we + // want to mutate its table, so we need to explicitly clone that table as + // well. If we don't, we will mutate another module's type surface and cause + // a use-after-free. + if (get(target)) { - // Substitution::clone is a shallow clone. If this is a metatable type, we - // want to mutate its table, so we need to explicitly clone that table as - // well. If we don't, we will mutate another module's type surface and cause - // a use-after-free. - if (get(target)) - { - instantiated = applyTypeFunction.clone(tf.type); - MetatableTypeVar* mtv = getMutable(instantiated); - mtv->table = applyTypeFunction.clone(mtv->table); - ttv = getMutable(mtv->table); - } - if (get(target)) - { - instantiated = applyTypeFunction.clone(tf.type); - ttv = getMutable(instantiated); - } + instantiated = applyTypeFunction.clone(tf.type); + MetatableTypeVar* mtv = getMutable(instantiated); + mtv->table = applyTypeFunction.clone(mtv->table); + ttv = getMutable(mtv->table); } - - if (ttv) + if (get(target)) { - ttv->instantiatedTypeParams = typeParams; - ttv->instantiatedTypePackParams = typePackParams; + instantiated = applyTypeFunction.clone(tf.type); + ttv = getMutable(instantiated); } } - else - { - if (TableTypeVar* ttv = getMutableTableType(instantiated)) - { - if (follow(tf.type) == instantiated) - { - // This can happen if a type alias has generics that it does not use at all. - // ex type FooBar = { a: number } - instantiated = applyTypeFunction.clone(tf.type); - ttv = getMutableTableType(instantiated); - } - ttv->instantiatedTypeParams = typeParams; - ttv->instantiatedTypePackParams = typePackParams; - } + if (ttv) + { + ttv->instantiatedTypeParams = typeParams; + ttv->instantiatedTypePackParams = typePackParams; } return instantiated; @@ -5482,6 +5426,85 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st return {generics, genericPacks}; } +void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) +{ + LUAU_ASSERT(FFlag::LuauDiscriminableUnions); + + const LValue* target = &lvalue; + std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. + + auto ty = resolveLValue(scope, *target); + if (!ty) + return; // Do nothing. An error was already reported. + + // If the provided lvalue is a local or global, then that's without a doubt the target. + // However, if there is a base lvalue, then we'll want that to be the target iff the base is a union type. + if (auto base = baseof(lvalue)) + { + std::optional baseTy = resolveLValue(scope, *base); + if (baseTy && get(follow(*baseTy))) + { + ty = baseTy; + target = base; + key = lvalue; + } + } + + // If we do not have a key, it means we're not trying to discriminate anything, so it's a simple matter of just filtering for a subset. + if (!key) + { + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, *target, *result); + else + addRefinement(refis, *target, errorRecoveryType(scope)); + + return; + } + + // Otherwise, we'll want to walk each option of ty, get its index type, and filter that. + auto utv = get(follow(*ty)); + LUAU_ASSERT(utv); + + std::unordered_set viableTargetOptions; + std::unordered_set viableChildOptions; // There may be additional refinements that apply. We add those here too. + + for (TypeId option : utv) + { + std::optional discriminantTy; + if (auto field = Luau::get(*key)) // need to fully qualify Luau::get because of ADL. + discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), false); + else + LUAU_ASSERT(!"Unhandled LValue alternative?"); + + if (!discriminantTy) + return; // Do nothing. An error was already reported, as per usual. + + if (std::optional result = filterMap(*discriminantTy, predicate)) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } + } + + auto intoType = [this](const std::unordered_set& s) -> std::optional { + if (s.empty()) + return std::nullopt; + + // TODO: allocate UnionTypeVar and just normalize. + std::vector options(s.begin(), s.end()); + if (options.size() == 1) + return options[0]; + + return addType(UnionTypeVar{std::move(options)}); + }; + + if (std::optional viableTargetType = intoType(viableTargetOptions)) + addRefinement(refis, *target, *viableTargetType); + + if (std::optional viableChildType = intoType(viableChildOptions)) + addRefinement(refis, lvalue, *viableChildType); +} + std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue) { if (!FFlag::LuauLValueAsKey) @@ -5645,18 +5668,29 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, Refi return std::nullopt; }; - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (!ty) - return; + if (FFlag::LuauDiscriminableUnions) + { + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (ty && fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); - // This is a hack. :( - // Without this, the expression 'a or b' might refine 'b' to be falsy. - // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. - if (fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); + refineLValue(truthyP.lvalue, refis, scope, predicate); + } + else + { + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (!ty) + return; - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, truthyP.lvalue, *result); + // This is a hack. :( + // Without this, the expression 'a or b' might refine 'b' to be falsy. + // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. + if (fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); + + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, truthyP.lvalue, *result); + } } void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) @@ -5746,16 +5780,23 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return res; }; - std::optional ty = resolveLValue(refis, scope, isaP.lvalue); - if (!ty) - return; - - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, isaP.lvalue, *result); + if (FFlag::LuauDiscriminableUnions) + { + refineLValue(isaP.lvalue, refis, scope, predicate); + } else { - addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); - errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); + std::optional ty = resolveLValue(refis, scope, isaP.lvalue); + if (!ty) + return; + + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, isaP.lvalue, *result); + else + { + addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); + errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); + } } } @@ -5814,21 +5855,30 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) { - if (std::optional result = filterMap(*ty, it->second(sense))) - addRefinement(refis, typeguardP.lvalue, *result); + if (FFlag::LuauDiscriminableUnions) + { + refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); + return; + } else { - addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - if (sense) - errVec.push_back( - TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); - } + if (std::optional result = filterMap(*ty, it->second(sense))) + addRefinement(refis, typeguardP.lvalue, *result); + else + { + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); + if (sense) + errVec.push_back( + TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); + } - return; + return; + } } auto fail = [&](const TypeErrorData& err) { - errVec.push_back(TypeError{typeguardP.location, err}); + if (!FFlag::LuauDiscriminableUnions) + errVec.push_back(TypeError{typeguardP.location, err}); addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); }; @@ -5853,55 +5903,85 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. - auto options = [](TypeId ty) -> std::vector { if (auto utv = get(follow(ty))) return std::vector(begin(utv), end(utv)); return {ty}; }; - if (FFlag::LuauWeakEqConstraint) + if (FFlag::LuauDiscriminableUnions) { - if (!sense && isNil(eqP.type)) - resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); - - return; - } - - if (FFlag::LuauEqConstraint) - { - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; - - std::vector lhs = options(*ty); std::vector rhs = options(eqP.type); - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) - { - addRefinement(refis, eqP.lvalue, eqP.type); - return; - } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - std::unordered_set set; - for (TypeId left : lhs) - { - for (TypeId right : rhs) + auto predicate = [&](TypeId option) -> std::optional { + if (sense && isUndecidable(option)) + return FFlag::LuauWeakEqConstraint ? option : eqP.type; + + if (!sense && isNil(eqP.type)) + return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; + + if (maybeSingleton(eqP.type)) { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); + // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. + if (!sense || canUnify(eqP.type, option, eqP.location).empty()) + return sense ? eqP.type : option; + + return std::nullopt; } + + return option; + }; + + refineLValue(eqP.lvalue, refis, scope, predicate); + } + else + { + if (FFlag::LuauWeakEqConstraint) + { + if (!sense && isNil(eqP.type)) + resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); + + return; } - if (set.empty()) - return; + if (FFlag::LuauEqConstraint) + { + std::optional ty = resolveLValue(refis, scope, eqP.lvalue); + if (!ty) + return; - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); + std::vector lhs = options(*ty); + std::vector rhs = options(eqP.type); + + if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) + { + addRefinement(refis, eqP.lvalue, eqP.type); + return; + } + else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + + std::unordered_set set; + for (TypeId left : lhs) + { + for (TypeId right : rhs) + { + // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. + if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) + set.insert(left); + } + } + + if (set.empty()) + return; + + std::vector viable(set.begin(), set.end()); + TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); + addRefinement(refis, eqP.lvalue, result); + } } } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index df5d76ed..5b162b31 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -18,14 +18,15 @@ #include #include +LUAU_FASTFLAG(DebugLuauFreezeArena) + LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauLengthOnCompositeType) LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) -LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) +LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) -LUAU_FASTFLAG(DebugLuauFreezeArena) namespace Luau { @@ -144,7 +145,20 @@ bool isNil(TypeId ty) bool isBoolean(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::Boolean); + if (FFlag::LuauRefactorTypeVarQuestions) + { + if (isPrim(ty, PrimitiveTypeVar::Boolean) || get(get(follow(ty)))) + return true; + + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isBoolean); + + return false; + } + else + { + return isPrim(ty, PrimitiveTypeVar::Boolean); + } } bool isNumber(TypeId ty) @@ -154,7 +168,20 @@ bool isNumber(TypeId ty) bool isString(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::String); + if (FFlag::LuauRefactorTypeVarQuestions) + { + if (isPrim(ty, PrimitiveTypeVar::String) || get(get(follow(ty)))) + return true; + + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isString); + + return false; + } + else + { + return isPrim(ty, PrimitiveTypeVar::String); + } } bool isThread(TypeId ty) @@ -167,37 +194,45 @@ bool isOptional(TypeId ty) if (isNil(ty)) return true; - if (!get(follow(ty))) - return false; - - std::unordered_set seen; - std::deque queue{ty}; - while (!queue.empty()) + if (FFlag::LuauRefactorTypeVarQuestions) { - TypeId current = follow(queue.front()); - queue.pop_front(); + auto utv = get(follow(ty)); + if (!utv) + return false; - if (seen.count(current)) - continue; - - seen.insert(current); - - if (isNil(current)) - return true; - - if (auto u = get(current)) + return std::any_of(begin(utv), end(utv), isNil); + } + else + { + std::unordered_set seen; + std::deque queue{ty}; + while (!queue.empty()) { - for (TypeId option : u->options) - { - if (isNil(option)) - return true; + TypeId current = follow(queue.front()); + queue.pop_front(); - queue.push_back(option); + if (seen.count(current)) + continue; + + seen.insert(current); + + if (isNil(current)) + return true; + + if (auto u = get(current)) + { + for (TypeId option : u->options) + { + if (isNil(option)) + return true; + + queue.push_back(option); + } } } - } - return false; + return false; + } } bool isTableIntersection(TypeId ty) @@ -228,13 +263,27 @@ std::optional getMetatable(TypeId type) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) return classType->metatable; - else if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) + else if (FFlag::LuauRefactorTypeVarQuestions) { - LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); - return primitiveType->metatable; + if (isString(type)) + { + auto ptv = get(getSingletonTypes().stringType); + LUAU_ASSERT(ptv && ptv->metatable); + return ptv->metatable; + } + else + return std::nullopt; } else - return std::nullopt; + { + if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) + { + LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); + return primitiveType->metatable; + } + else + return std::nullopt; + } } const TableTypeVar* getTableType(TypeId type) @@ -696,7 +745,7 @@ TypeId SingletonTypes::makeStringMetatable() {"reverse", {stringToStringType}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalString}, {}, + {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}})})}}, {"pack", {arena->addType(FunctionTypeVar{ arena->addTypePack(TypePack{{stringType}, anyTypePack}), @@ -1108,30 +1157,14 @@ static Tags* getTags(TypeId ty) void attachTag(TypeId ty, const std::string& tagName) { - if (!FFlag::LuauRefactorTagging) - { - if (auto ftv = getMutable(ty)) - { - ftv->tags.emplace_back(tagName); - } - else - { - LUAU_ASSERT(!"Got a non functional type"); - } - } + if (auto tags = getTags(ty)) + tags->push_back(tagName); else - { - if (auto tags = getTags(ty)) - tags->push_back(tagName); - else - LUAU_ASSERT(!"This TypeId does not support tags"); - } + LUAU_ASSERT(!"This TypeId does not support tags"); } void attachTag(Property& prop, const std::string& tagName) { - LUAU_ASSERT(FFlag::LuauRefactorTagging); - prop.tags.push_back(tagName); } @@ -1140,7 +1173,6 @@ void attachTag(Property& prop, const std::string& tagName) // Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it. bool hasTag(const Tags& tags, const std::string& tagName) { - LUAU_ASSERT(FFlag::LuauRefactorTagging); return std::find(tags.begin(), tags.end(), tagName) != tags.end(); } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 2bd9cf83..17d9bf58 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -17,15 +17,11 @@ LUAU_FASTFLAGVARIABLE(LuauCommittingTxnLogFreeTpPromote, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); -LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) -LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) -LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) -LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) namespace Luau { @@ -229,8 +225,6 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { - LUAU_ASSERT(FFlag::LuauExtendedUnionMismatchError); - type = follow(type); if (auto ttv = get(type)) @@ -291,7 +285,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { - errors.push_back(TypeError{location, UnificationTooComplex{}}); + reportError(TypeError{location, UnificationTooComplex{}}); return; } @@ -403,7 +397,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (subGeneric && !subGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); + reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}}); return; } @@ -448,7 +442,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superGeneric && !superGeneric->level.subsumes(subFree->level)) { // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); + reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}}); return; } @@ -561,13 +555,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } if (unificationTooComplex) - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); else if (failed) { if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } } else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) @@ -582,50 +576,44 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool bool foundHeuristic = false; size_t startIndex = 0; - if (FFlag::LuauUnionHeuristic) + if (const std::string* subName = getName(subTy)) { - if (const std::string* subName = getName(subTy)) + for (size_t i = 0; i < uv->options.size(); ++i) { - for (size_t i = 0; i < uv->options.size(); ++i) + const std::string* optionName = getName(uv->options[i]); + if (optionName && *optionName == *subName) { - const std::string* optionName = getName(uv->options[i]); - if (optionName && *optionName == *subName) - { - foundHeuristic = true; - startIndex = i; - break; - } + foundHeuristic = true; + startIndex = i; + break; } } + } - if (FFlag::LuauExtendedUnionMismatchError) + if (auto subMatchTag = getTableMatchTag(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) { - if (auto subMatchTag = getTableMatchTag(subTy)) + auto optionMatchTag = getTableMatchTag(uv->options[i]); + if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) { - for (size_t i = 0; i < uv->options.size(); ++i) - { - auto optionMatchTag = getTableMatchTag(uv->options[i]); - if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) - { - foundHeuristic = true; - startIndex = i; - break; - } - } + foundHeuristic = true; + startIndex = i; + break; } } + } - if (!foundHeuristic && cacheEnabled) + if (!foundHeuristic && cacheEnabled) + { + for (size_t i = 0; i < uv->options.size(); ++i) { - for (size_t i = 0; i < uv->options.size(); ++i) - { - TypeId type = uv->options[i]; + TypeId type = uv->options[i]; - if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) - { - startIndex = i; - break; - } + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) + { + startIndex = i; + break; } } } @@ -650,7 +638,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { unificationTooComplex = e; } - else if (FFlag::LuauExtendedUnionMismatchError && !isNil(type)) + else if (!isNil(type)) { failedOptionCount++; @@ -664,15 +652,15 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (unificationTooComplex) { - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); } else if (!found) { - if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption) - errors.push_back( + if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + reportError( TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); } } else if (const IntersectionTypeVar* uv = @@ -702,9 +690,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } if (unificationTooComplex) - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); else if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } else if (const IntersectionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) @@ -754,10 +742,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } if (unificationTooComplex) - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); else if (!found) { - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || @@ -801,7 +789,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool tryUnifyWithClass(subTy, superTy, /*reversed*/ true); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); if (FFlag::LuauUseCommittingTxnLog) log.popSeen(superTy, subTy); @@ -1067,7 +1055,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { - errors.push_back(TypeError{location, UnificationTooComplex{}}); + reportError(TypeError{location, UnificationTooComplex{}}); return; } @@ -1166,7 +1154,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal { tryUnify_(*subIter, *superIter); - if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + if (!errors.empty() && !firstPackErrorPos) firstPackErrorPos = loopCount; superIter.advance(); @@ -1251,7 +1239,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); - errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); while (superIter.good()) { @@ -1272,7 +1260,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); } } else @@ -1372,7 +1360,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal { tryUnify_(*subIter, *superIter); - if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + if (!errors.empty() && !firstPackErrorPos) firstPackErrorPos = loopCount; superIter.advance(); @@ -1459,7 +1447,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); - errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); while (superIter.good()) { @@ -1480,7 +1468,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); } } } @@ -1493,7 +1481,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) ice("passed non primitive types to unifyPrimitives"); if (superPrim->type != subPrim->type) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) @@ -1508,13 +1496,13 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) if (superSingleton && *superSingleton == *subSingleton) return; - if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) return; if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) @@ -1536,10 +1524,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); - if (FFlag::LuauExtendedFunctionMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); - else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); } size_t numGenericPacks = superFunction->genericPacks.size(); @@ -1547,10 +1532,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); - if (FFlag::LuauExtendedFunctionMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); - else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); } for (size_t i = 0; i < numGenerics; i++) @@ -1567,48 +1549,35 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { Unifier innerState = makeChildUnifier(); - if (FFlag::LuauExtendedFunctionMismatchError) + innerState.ctx = CountMismatch::Arg; + innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); + + bool reported = !innerState.errors.empty(); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + reportError(*e); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + reportError( + TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + + innerState.ctx = CountMismatch::Result; + innerState.tryUnify_(subFunction->retType, superFunction->retType); + + if (!reported) { - innerState.ctx = CountMismatch::Arg; - innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - - bool reported = !innerState.errors.empty(); - if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); + reportError(*e); + else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - errors.push_back( - TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + reportError( + TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), innerState.errors.front()}}); else if (!innerState.errors.empty()) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); - - innerState.ctx = CountMismatch::Result; - innerState.tryUnify_(subFunction->retType, superFunction->retType); - - if (!reported) - { - if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); - else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); - else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - errors.push_back( - TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); - else if (!innerState.errors.empty()) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); - } - } - else - { - ctx = CountMismatch::Arg; - innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - - ctx = CountMismatch::Result; - innerState.tryUnify_(subFunction->retType, superFunction->retType); - - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); } if (FFlag::LuauUseCommittingTxnLog) @@ -1716,7 +1685,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } } @@ -1734,7 +1703,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } } @@ -1957,13 +1926,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } @@ -2051,7 +2020,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt return tryUnifySealedTables(subTy, superTy, isIntersection); else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Generic) || (superTable->state == TableState::Generic && subTable->state == TableState::Sealed)) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); else if ((superTable->state == TableState::Free) != (subTable->state == TableState::Free)) // one table is free and the other is not { TypeId freeTypeId = subTable->state == TableState::Free ? subTy : superTy; @@ -2090,7 +2059,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt { const auto& r = subTable->props.find(name); if (r == subTable->props.end()) - errors.push_back(TypeError{location, UnknownProperty{subTy, name}}); + reportError(TypeError{location, UnknownProperty{subTy, name}}); else tryUnify_(r->second.type, prop.type); } @@ -2113,7 +2082,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt } } else - errors.push_back(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); + reportError(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); } } else if (superTable->state == TableState::Sealed) @@ -2194,7 +2163,7 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) } } else - errors.push_back(TypeError{location, UnknownProperty{subTy, freeName}}); + reportError(TypeError{location, UnknownProperty{subTy, freeName}}); } } @@ -2268,7 +2237,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } } @@ -2284,7 +2253,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec missingPropertiesInSuper.push_back(it.first); - innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } else { @@ -2299,7 +2268,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (oldErrorSize != innerState.errors.size() && !errorReported) { errorReported = true; - errors.push_back(innerState.errors.back()); + reportError(innerState.errors.back()); } } else @@ -2340,7 +2309,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec } } else - innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } else { @@ -2369,7 +2338,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec } } else - innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } } @@ -2386,7 +2355,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } @@ -2413,7 +2382,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!extraPropertiesInSub.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); return; } } @@ -2437,9 +2406,9 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); + reportError(*e); else if (!innerState.errors.empty()) - errors.push_back( + reportError( TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); if (FFlag::LuauUseCommittingTxnLog) @@ -2470,7 +2439,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) case TableState::Sealed: case TableState::Unsealed: case TableState::Generic: - errors.push_back(mismatchError); + reportError(mismatchError); } } else if (FFlag::LuauUseCommittingTxnLog ? (log.getMutable(subTy) || log.getMutable(subTy)) @@ -2479,7 +2448,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) } else { - errors.push_back(mismatchError); + reportError(mismatchError); } } @@ -2491,9 +2460,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) auto fail = [&]() { if (!reversed) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); else - errors.push_back(TypeError{location, TypeMismatch{subTy, superTy}}); + reportError(TypeError{location, TypeMismatch{subTy, superTy}}); }; const ClassTypeVar* superClass = get(superTy); @@ -2538,7 +2507,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (!classProp) { ok = false; - errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); + reportError(TypeError{location, UnknownProperty{superTy, propName}}); } else { @@ -2577,7 +2546,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; - errors.push_back(TypeError{location, GenericError{msg}}); + reportError(TypeError{location, GenericError{msg}}); } if (!ok) @@ -2695,7 +2664,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else if (get(tail)) { - errors.push_back(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); + reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); } else if (get(tail)) { @@ -2709,7 +2678,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify variadic packs"}}); + reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}}); } } @@ -2886,7 +2855,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); log.replace(needle, *getSingletonTypes().errorRecoveryType()); return; @@ -2894,17 +2863,6 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (log.getMutable(haystack)) return; - else if (auto a = log.getMutable(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (TypePackIterator it(a->argTypes, &log); it != end(a->argTypes); ++it) - check(*it); - - for (TypePackIterator it(a->retType, &log); it != end(a->retType); ++it) - check(*it); - } - } else if (auto a = log.getMutable(haystack)) { for (TypeId ty : a->options) @@ -2934,7 +2892,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); DEPRECATED_log(needle); *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); return; @@ -2942,17 +2900,6 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (get(haystack)) return; - else if (auto a = get(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (TypeId ty : a->argTypes) - check(ty); - - for (TypeId ty : a->retType) - check(ty); - } - } else if (auto a = get(haystack)) { for (TypeId ty : a->options) @@ -2988,7 +2935,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (log.getMutable(needle)) return; - if (!get(needle)) + if (!log.getMutable(needle)) ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); @@ -2997,32 +2944,18 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); return; } - if (auto a = get(haystack)) + if (auto a = get(haystack); a && a->tail) { - for (const auto& ty : a->head) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - if (auto f = log.getMutable(log.follow(ty))) - { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); - } - } - } - - if (a->tail) - { - haystack = follow(*a->tail); - continue; - } + haystack = log.follow(*a->tail); + continue; } + break; } } @@ -3048,31 +2981,17 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); DEPRECATED_log(needle); *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); } - if (auto a = get(haystack)) + if (auto a = get(haystack); a && a->tail) { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (const auto& ty : a->head) - { - if (auto f = get(follow(ty))) - { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); - } - } - } - - if (a->tail) - { - haystack = follow(*a->tail); - continue; - } + haystack = follow(*a->tail); + continue; } + break; } } @@ -3094,17 +3013,17 @@ bool Unifier::isNonstrictMode() const void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType) { if (auto e = hasUnificationTooComplex(innerErrors)) - errors.push_back(*e); + reportError(*e); else if (!innerErrors.empty()) - errors.push_back(TypeError{location, TypeMismatch{wantedType, givenType}}); + reportError(TypeError{location, TypeMismatch{wantedType, givenType}}); } void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) { if (auto e = hasUnificationTooComplex(innerErrors)) - errors.push_back(*e); + reportError(*e); else if (!innerErrors.empty()) - errors.push_back( + reportError( TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front()}}); } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index e0dc3e0f..10cf17d2 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -43,14 +43,14 @@ static void report(ReportFormat format, const char* name, const Luau::Location& } } -static void reportError(ReportFormat format, const Luau::TypeError& error) +static void reportError(const Luau::Frontend& frontend, ReportFormat format, const Luau::TypeError& error) { - const char* name = error.moduleName.c_str(); + std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(error.moduleName); if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) - report(format, name, error.location, "SyntaxError", syntaxError->message.c_str()); + report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); else - report(format, name, error.location, "TypeError", Luau::toString(error).c_str()); + report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str()); } static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) @@ -72,14 +72,15 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat } for (auto& error : cr.errors) - reportError(format, error); + reportError(frontend, format, error); Luau::LintResult lr = frontend.lint(name); + std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); for (auto& error : lr.errors) - reportWarning(format, name, error); + reportWarning(format, humanReadableName.c_str(), error); for (auto& warning : lr.warnings) - reportWarning(format, name, warning); + reportWarning(format, humanReadableName.c_str(), warning); if (annotate) { @@ -120,11 +121,25 @@ struct CliFileResolver : Luau::FileResolver { std::optional readSource(const Luau::ModuleName& name) override { - std::optional source = readFile(name); + Luau::SourceCode::Type sourceType; + std::optional source = std::nullopt; + + // If the module name is "-", then read source from stdin + if (name == "-") + { + source = readStdin(); + sourceType = Luau::SourceCode::Script; + } + else + { + source = readFile(name); + sourceType = Luau::SourceCode::Module; + } + if (!source) return std::nullopt; - return Luau::SourceCode{*source, Luau::SourceCode::Module}; + return Luau::SourceCode{*source, sourceType}; } std::optional resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override @@ -143,6 +158,13 @@ struct CliFileResolver : Luau::FileResolver return std::nullopt; } + + std::string getHumanReadableModuleName(const Luau::ModuleName& name) const override + { + if (name == "-") + return "stdin"; + return name; + } }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index cb993dfe..c6807022 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -74,6 +74,21 @@ std::optional readFile(const std::string& name) return result; } +std::optional readStdin() +{ + std::string result; + char buffer[4096] = { }; + + while (fgets(buffer, sizeof(buffer), stdin) != nullptr) + result.append(buffer); + + // If eof was not reached for stdin, then a read error occurred + if (!feof(stdin)) + return std::nullopt; + + return result; +} + template static void joinPaths(std::basic_string& str, const Ch* lhs, const Ch* rhs) { @@ -190,7 +205,10 @@ bool traverseDirectory(const std::string& path, const std::function getSourceFiles(int argc, char** argv) for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') + // Treat '-' as a special file whose source is read from stdin + // All other arguments that start with '-' are skipped + if (argv[i][0] == '-' && argv[i][1] != '\0') continue; if (isDirectory(argv[i])) diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index da11f512..97471cdc 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -7,6 +7,7 @@ #include std::optional readFile(const std::string& name); +std::optional readStdin(); bool isDirectory(const std::string& path); bool traverseDirectory(const std::string& path, const std::function& callback); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index e5042152..ab0f0ed0 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -158,7 +158,7 @@ static int lua_collectgarbage(lua_State* L) luaL_error(L, "collectgarbage must be called with 'count' or 'collect'"); } -static void setupState(lua_State* L) +void setupState(lua_State* L) { luaL_openlibs(L); @@ -176,7 +176,7 @@ static void setupState(lua_State* L) luaL_sandbox(L); } -static std::string runCode(lua_State* L, const std::string& source) +std::string runCode(lua_State* L, const std::string& source) { std::string bytecode = Luau::compile(source, copts()); @@ -206,7 +206,13 @@ static std::string runCode(lua_State* L, const std::string& source) if (n) { luaL_checkstack(T, LUA_MINSTACK, "too many results to print"); - lua_getglobal(T, "print"); + lua_getglobal(T, "_PRETTYPRINT"); + // If _PRETTYPRINT is nil, then use the standard print function instead + if (lua_isnil(T, -1)) + { + lua_pop(T, 1); + lua_getglobal(T, "print"); + } lua_insert(T, 1); lua_pcall(T, n, 0, 0); } @@ -545,7 +551,7 @@ static int assertionHandler(const char* expr, const char* file, int line, const return 1; } -int main(int argc, char** argv) +int replMain(int argc, char** argv) { Luau::assertHandler() = assertionHandler; @@ -696,7 +702,6 @@ int main(int argc, char** argv) case CliMode::Unknown: default: LUAU_ASSERT(!"Unhandled cli mode."); + return 1; } } - - diff --git a/CLI/Repl.h b/CLI/Repl.h new file mode 100644 index 00000000..11a077ae --- /dev/null +++ b/CLI/Repl.h @@ -0,0 +1,12 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "lua.h" + +#include + +// Note: These are internal functions which are being exposed in a header +// so they can be included by unit tests. +int replMain(int argc, char** argv); +void setupState(lua_State* L); +std::string runCode(lua_State* L, const std::string& source); diff --git a/CLI/ReplEntry.cpp b/CLI/ReplEntry.cpp new file mode 100644 index 00000000..b3131712 --- /dev/null +++ b/CLI/ReplEntry.cpp @@ -0,0 +1,10 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Repl.h" + + + +int main(int argc, char** argv) +{ + return replMain(argc, argv); +} \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 77cf47e8..b9f7a9e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,7 @@ endif() if(LUAU_BUILD_TESTS) add_executable(Luau.UnitTest) add_executable(Luau.Conformance) + add_executable(Luau.CLI.Test) endif() if(LUAU_BUILD_WEB) @@ -109,6 +110,17 @@ if(LUAU_BUILD_TESTS) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Conformance PRIVATE extern) target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.VM) + + target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS}) + target_include_directories(Luau.CLI.Test PRIVATE extern CLI) + target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.VM) + if(UNIX) + find_library(LIBPTHREAD pthread) + if (LIBPTHREAD) + target_link_libraries(Luau.CLI.Test PRIVATE pthread) + endif() + endif() + endif() if(LUAU_BUILD_WEB) diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index d9694d7d..679712f6 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -472,6 +472,9 @@ enum LuauBuiltinFunction // bit32.count LBF_BIT32_COUNTLZ, LBF_BIT32_COUNTRZ, + + // select(_, ...) + LBF_SELECT_VARARG, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index e344eb91..a907271c 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,6 +4,8 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +LUAU_FASTFLAGVARIABLE(LuauCompileSelectBuiltin, false) + namespace Luau { namespace Compile @@ -62,6 +64,9 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) if (builtin.isGlobal("unpack")) return LBF_TABLE_UNPACK; + if (FFlag::LuauCompileSelectBuiltin && builtin.isGlobal("select")) + return LBF_SELECT_VARARG; + if (builtin.object == "math") { if (builtin.method == "abs") diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9758c4a9..7da85244 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -15,6 +15,9 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauCompileTableIndexOpt, false) +LUAU_FASTFLAG(LuauCompileSelectBuiltin) + namespace Luau { @@ -261,6 +264,122 @@ struct Compiler bytecode.emitABC(LOP_GETVARARGS, target, multRet ? 0 : uint8_t(targetCount + 1), 0); } + void compileExprSelectVararg(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs) + { + LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin); + LUAU_ASSERT(targetCount == 1); + LUAU_ASSERT(!expr->self); + LUAU_ASSERT(expr->args.size == 2 && expr->args.data[1]->is()); + + AstExpr* arg = expr->args.data[0]; + + uint8_t argreg; + + if (isExprLocalReg(arg)) + argreg = getLocal(arg->as()->local); + else + { + argreg = uint8_t(regs + 1); + compileExprTempTop(arg, argreg); + } + + size_t fastcallLabel = bytecode.emitLabel(); + + bytecode.emitABC(LOP_FASTCALL1, LBF_SELECT_VARARG, argreg, 0); + + // note, these instructions are normally not executed and are used as a fallback for FASTCALL + // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten + compileExprTemp(expr->func, regs); + + bytecode.emitABC(LOP_GETVARARGS, uint8_t(regs + 2), 0, 0); + + size_t callLabel = bytecode.emitLabel(); + if (!bytecode.patchSkipC(fastcallLabel, callLabel)) + CompileError::raise(expr->func->location, "Exceeded jump distance limit; simplify the code to compile"); + + // note, this is always multCall (last argument is variadic) + bytecode.emitABC(LOP_CALL, regs, 0, multRet ? 0 : uint8_t(targetCount + 1)); + + // if we didn't output results directly to target, we need to move them + if (!targetTop) + { + for (size_t i = 0; i < targetCount; ++i) + bytecode.emitABC(LOP_MOVE, uint8_t(target + i), uint8_t(regs + i), 0); + } + } + + void compileExprFastcallN(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs, int bfid) + { + LUAU_ASSERT(!expr->self); + LUAU_ASSERT(expr->args.size <= 2); + + LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2; + + uint32_t args[2] = {}; + + for (size_t i = 0; i < expr->args.size; ++i) + { + if (i > 0) + { + if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0) + { + opc = LOP_FASTCALL2K; + args[i] = cid; + break; + } + } + + if (isExprLocalReg(expr->args.data[i])) + args[i] = getLocal(expr->args.data[i]->as()->local); + else + { + args[i] = uint8_t(regs + 1 + i); + compileExprTempTop(expr->args.data[i], uint8_t(args[i])); + } + } + + size_t fastcallLabel = bytecode.emitLabel(); + + bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); + if (opc != LOP_FASTCALL1) + bytecode.emitAux(args[1]); + + // Set up a traditional Lua stack for the subsequent LOP_CALL. + // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for + // these FASTCALL variants. + for (size_t i = 0; i < expr->args.size; ++i) + { + if (i > 0 && opc == LOP_FASTCALL2K) + { + emitLoadK(uint8_t(regs + 1 + i), args[i]); + break; + } + + if (args[i] != regs + 1 + i) + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); + } + + // note, these instructions are normally not executed and are used as a fallback for FASTCALL + // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten + compileExprTemp(expr->func, regs); + + size_t callLabel = bytecode.emitLabel(); + + // FASTCALL will skip over the instructions needed to compute function and jump over CALL which must immediately follow the instruction + // sequence after FASTCALL + if (!bytecode.patchSkipC(fastcallLabel, callLabel)) + CompileError::raise(expr->func->location, "Exceeded jump distance limit; simplify the code to compile"); + + bytecode.emitABC(LOP_CALL, regs, uint8_t(expr->args.size + 1), multRet ? 0 : uint8_t(targetCount + 1)); + + // if we didn't output results directly to target, we need to move them + if (!targetTop) + { + for (size_t i = 0; i < targetCount; ++i) + bytecode.emitABC(LOP_MOVE, uint8_t(target + i), uint8_t(regs + i), 0); + } + } + void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) { LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop); @@ -284,6 +403,25 @@ struct Compiler bfid = getBuiltinFunctionId(builtin, options); } + if (bfid == LBF_SELECT_VARARG) + { + LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin); + // Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly + // note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases + if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is()) + return compileExprSelectVararg(expr, target, targetCount, targetTop, multRet, regs); + else + bfid = -1; + } + + // Optimization: for 1/2 argument fast calls use specialized opcodes + if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) + { + AstExpr* last = expr->args.data[expr->args.size - 1]; + if (!last->is() && !last->is()) + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + } + if (expr->self) { AstExprIndexName* fi = expr->func->as(); @@ -309,24 +447,13 @@ struct Compiler compileExprTempTop(expr->func, regs); } - // Note: if the last argument is ExprVararg or ExprCall, we need to route that directly to the called function preserving the # of args bool multCall = false; - bool skipArgs = false; - if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) - { - AstExpr* last = expr->args.data[expr->args.size - 1]; - skipArgs = !(last->is() || last->is()); - } - - if (!skipArgs) - { - for (size_t i = 0; i < expr->args.size; ++i) - if (i + 1 == expr->args.size) - multCall = compileExprTempMultRet(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); - else - compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); - } + for (size_t i = 0; i < expr->args.size; ++i) + if (i + 1 == expr->args.size) + multCall = compileExprTempMultRet(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); + else + compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); setDebugLineEnd(expr->func); @@ -347,59 +474,8 @@ struct Compiler } else if (bfid >= 0) { - size_t fastcallLabel; - - if (skipArgs) - { - LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2; - - uint32_t args[2] = {}; - for (size_t i = 0; i < expr->args.size; ++i) - { - if (i > 0) - { - if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0) - { - opc = LOP_FASTCALL2K; - args[i] = cid; - break; - } - } - - if (isExprLocalReg(expr->args.data[i])) - args[i] = getLocal(expr->args.data[i]->as()->local); - else - { - args[i] = uint8_t(regs + 1 + i); - compileExprTempTop(expr->args.data[i], uint8_t(args[i])); - } - } - - fastcallLabel = bytecode.emitLabel(); - bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); - if (opc != LOP_FASTCALL1) - bytecode.emitAux(args[1]); - - // Set up a traditional Lua stack for the subsequent LOP_CALL. - // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for - // these FASTCALL variants. - for (size_t i = 0; i < expr->args.size; ++i) - { - if (i > 0 && opc == LOP_FASTCALL2K) - { - emitLoadK(uint8_t(regs + 1 + i), args[i]); - break; - } - - if (args[i] != regs + 1 + i) - bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); - } - } - else - { - fastcallLabel = bytecode.emitLabel(); - bytecode.emitABC(LOP_FASTCALL, uint8_t(bfid), 0, 0); - } + size_t fastcallLabel = bytecode.emitLabel(); + bytecode.emitABC(LOP_FASTCALL, uint8_t(bfid), 0, 0); // note, these instructions are normally not executed and are used as a fallback for FASTCALL // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten @@ -1101,9 +1177,20 @@ struct Compiler for (size_t i = 0; i < expr->items.size; ++i) { const AstExprTable::Item& item = expr->items.data[i]; - AstExprConstantNumber* ckey = item.key->as(); + LUAU_ASSERT(item.key); // no list portion => all items have keys - indexSize += (ckey && ckey->value == double(indexSize + 1)); + if (FFlag::LuauCompileTableIndexOpt) + { + const Constant* ckey = constants.find(item.key); + + indexSize += (ckey && ckey->type == Constant::Type_Number && ckey->valueNumber == double(indexSize + 1)); + } + else + { + AstExprConstantNumber* ckey = item.key->as(); + + indexSize += (ckey && ckey->value == double(indexSize + 1)); + } } // we only perform the optimization if we don't have any other []-keys @@ -1200,37 +1287,47 @@ struct Compiler arrayChunkCurrent = 0; } - // items with a key are set one by one via SETTABLE/SETTABLEKS + // items with a key are set one by one via SETTABLE/SETTABLEKS/SETTABLEN if (key) { RegScope rsi(this); - // Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax - if (AstExprConstantString* ckey = key->as()) + if (FFlag::LuauCompileTableIndexOpt) { - BytecodeBuilder::StringRef cname = sref(ckey->value); - int32_t cid = bytecode.addConstantString(cname); - if (cid < 0) - CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - + LValue lv = compileLValueIndex(reg, key, rsi); uint8_t rv = compileExprAuto(value, rsi); - bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname))); - bytecode.emitAux(cid); - } - else if (AstExprConstantNumber* ckey = key->as(); - ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value) - { - uint8_t rv = compileExprAuto(value, rsi); - - bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1)); + compileAssign(lv, rv); } else { - uint8_t rk = compileExprAuto(key, rsi); - uint8_t rv = compileExprAuto(value, rsi); + // Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax + if (AstExprConstantString* ckey = key->as()) + { + BytecodeBuilder::StringRef cname = sref(ckey->value); + int32_t cid = bytecode.addConstantString(cname); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - bytecode.emitABC(LOP_SETTABLE, rv, reg, rk); + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname))); + bytecode.emitAux(cid); + } + else if (AstExprConstantNumber* ckey = key->as(); + ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value) + { + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1)); + } + else + { + uint8_t rk = compileExprAuto(key, rsi); + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLE, rv, reg, rk); + } } } // items without a key are set using SETLIST so that we can initialize large arrays quickly @@ -1339,6 +1436,9 @@ struct Compiler uint8_t rt = compileExprAuto(expr->expr, rs); uint8_t i = uint8_t(int(cv->valueNumber) - 1); + if (FFlag::LuauCompileTableIndexOpt) + setDebugLine(expr->index); + bytecode.emitABC(LOP_GETTABLEN, target, rt, i); } else if (cv && cv->type == Constant::Type_String) @@ -1350,6 +1450,9 @@ struct Compiler if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + if (FFlag::LuauCompileTableIndexOpt) + setDebugLine(expr->index); + bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); } @@ -1657,6 +1760,40 @@ struct Compiler Location location; }; + LValue compileLValueIndex(uint8_t reg, AstExpr* index, RegScope& rs) + { + const Constant* cv = constants.find(index); + + if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && + double(int(cv->valueNumber)) == cv->valueNumber) + { + LValue result = {LValue::Kind_IndexNumber}; + result.reg = reg; + result.number = uint8_t(int(cv->valueNumber) - 1); + result.location = index->location; + + return result; + } + else if (cv && cv->type == Constant::Type_String) + { + LValue result = {LValue::Kind_IndexName}; + result.reg = reg; + result.name = sref(cv->getString()); + result.location = index->location; + + return result; + } + else + { + LValue result = {LValue::Kind_IndexExpr}; + result.reg = reg; + result.index = compileExprAuto(index, rs); + result.location = index->location; + + return result; + } + } + LValue compileLValue(AstExpr* node, RegScope& rs) { setDebugLine(node); @@ -1699,36 +1836,9 @@ struct Compiler } else if (AstExprIndexExpr* expr = node->as()) { - const Constant* cv = constants.find(expr->index); + uint8_t reg = compileExprAuto(expr->expr, rs); - if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && - double(int(cv->valueNumber)) == cv->valueNumber) - { - LValue result = {LValue::Kind_IndexNumber}; - result.reg = compileExprAuto(expr->expr, rs); - result.number = uint8_t(int(cv->valueNumber) - 1); - result.location = node->location; - - return result; - } - else if (cv && cv->type == Constant::Type_String) - { - LValue result = {LValue::Kind_IndexName}; - result.reg = compileExprAuto(expr->expr, rs); - result.name = sref(cv->getString()); - result.location = node->location; - - return result; - } - else - { - LValue result = {LValue::Kind_IndexExpr}; - result.reg = compileExprAuto(expr->expr, rs); - result.index = compileExprAuto(expr->index, rs); - result.location = node->location; - - return result; - } + return compileLValueIndex(reg, expr->index, rs); } else { @@ -1740,6 +1850,9 @@ struct Compiler void compileLValueUse(const LValue& lv, uint8_t reg, bool set) { + if (FFlag::LuauCompileTableIndexOpt) + setDebugLine(lv.location); + switch (lv.kind) { case LValue::Kind_Local: diff --git a/Makefile b/Makefile index b144cac6..638c4c63 100644 --- a/Makefile +++ b/Makefile @@ -23,11 +23,11 @@ VM_SOURCES=$(wildcard VM/src/*.cpp) VM_OBJECTS=$(VM_SOURCES:%=$(BUILD)/%.o) VM_TARGET=$(BUILD)/libluauvm.a -TESTS_SOURCES=$(wildcard tests/*.cpp) +TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp +REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/ReplEntry.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau @@ -90,11 +90,12 @@ $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -IAst/include $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -IVM/include -$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -Iextern +$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICLI -Iextern $(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IVM/include -Iextern $(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include -Iextern $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include +$(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a build/libprotobuf-mutator/external.protobuf/lib/libprotobuf.a diff --git a/Sources.cmake b/Sources.cmake index bafe7594..22e7af22 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -176,7 +176,8 @@ if(TARGET Luau.Repl.CLI) CLI/FileUtils.cpp CLI/Profiler.h CLI/Profiler.cpp - CLI/Repl.cpp) + CLI/Repl.cpp + CLI/ReplEntry.cpp) endif() if(TARGET Luau.Analyze.CLI) @@ -243,6 +244,21 @@ if(TARGET Luau.Conformance) tests/main.cpp) endif() +if(TARGET Luau.CLI.Test) + # Luau.CLI.Test Sources + target_sources(Luau.CLI.Test PRIVATE + CLI/Coverage.h + CLI/Coverage.cpp + CLI/FileUtils.h + CLI/FileUtils.cpp + CLI/Profiler.h + CLI/Profiler.cpp + CLI/Repl.cpp + + tests/Repl.test.cpp + tests/main.cpp) +endif() + if(TARGET Luau.Web) # Luau.Web Sources target_sources(Luau.Web PRIVATE diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index d5416285..5cffba63 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,6 +14,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauGcForwardMetatableBarrier, false) + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -869,7 +871,16 @@ int lua_setmetatable(lua_State* L, int objindex) luaG_runerror(L, "Attempt to modify a readonly table"); hvalue(obj)->metatable = mt; if (mt) - luaC_objbarriert(L, hvalue(obj), mt); + { + if (FFlag::LuauGcForwardMetatableBarrier) + { + luaC_objbarrier(L, hvalue(obj), mt); + } + else + { + luaC_objbarriert(L, hvalue(obj), mt); + } + } break; } case LUA_TUSERDATA: diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 34e9ebc1..ecc14e87 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1087,6 +1087,34 @@ static int luauF_countrz(lua_State* L, StkId res, TValue* arg0, int nresults, St return -1; } +static int luauF_select(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams == 1 && nresults == 1) + { + int n = cast_int(L->base - L->ci->func) - clvalue(L->ci->func)->l.p->numparams - 1; + + if (ttisnumber(arg0)) + { + int i = int(nvalue(arg0)); + + // i >= 1 && i <= n + if (unsigned(i - 1) <= unsigned(n)) + { + setobj2s(L, res, L->base - n + (i - 1)); + return 1; + } + // note: for now we don't handle negative case (wrap around) and defer to fallback + } + else if (ttisstring(arg0) && *svalue(arg0) == '#') + { + setnvalue(res, double(n)); + return 1; + } + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1156,4 +1184,6 @@ luau_FastFunction luauF_table[256] = { luauF_countlz, luauF_countrz, + + luauF_select, }; diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index abcde779..19222861 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,8 +5,6 @@ #include "lstate.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false) - #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -235,9 +233,6 @@ static int coyieldable(lua_State* L) static int coclose(lua_State* L) { - if (!FFlag::LuauCoroutineClose) - luaL_error(L, "coroutine.close is not enabled"); - lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 581506a8..a3982bc6 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,8 +17,6 @@ #include -LUAU_FASTFLAG(LuauCoroutineClose) - /* ** {====================================================== ** Error-recovery functions @@ -300,7 +298,7 @@ static void resume(lua_State* L, void* ud) { // start coroutine LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base); - if (FFlag::LuauCoroutineClose && firstArg == L->base) + if (firstArg == L->base) luaG_runerror(L, "cannot resume dead coroutine"); if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA) diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 50859b1e..82ac0009 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -93,10 +93,8 @@ static void finishGcCycleStats(global_State* g) g->gcstats.lastcycle = g->gcstats.currcycle; g->gcstats.currcycle = GCCycleStats(); - g->gcstats.cyclestatsacc.markitems += g->gcstats.lastcycle.markitems; g->gcstats.cyclestatsacc.marktime += g->gcstats.lastcycle.marktime; g->gcstats.cyclestatsacc.atomictime += g->gcstats.lastcycle.atomictime; - g->gcstats.cyclestatsacc.sweepitems += g->gcstats.lastcycle.sweepitems; g->gcstats.cyclestatsacc.sweeptime += g->gcstats.lastcycle.sweeptime; } @@ -492,23 +490,22 @@ static void freeobj(lua_State* L, GCObject* o, lua_Page* page) } } -#define sweepwholelist(L, p, tc) sweeplist(L, p, SIZE_MAX, tc) +#define sweepwholelist(L, p) sweeplist(L, p, SIZE_MAX) -static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* traversedcount) +static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count) { LUAU_ASSERT(!FFlag::LuauGcPagedSweep); GCObject* curr; global_State* g = L->global; int deadmask = otherwhite(g); - size_t startcount = count; LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); /* make sure we never sweep fixed objects */ while ((curr = *p) != NULL && count-- > 0) { int alive = (curr->gch.marked ^ WHITEBITS) & deadmask; if (curr->gch.tt == LUA_TTHREAD) { - sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval, traversedcount); /* sweep open upvalues */ + sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval); /* sweep open upvalues */ lua_State* th = gco2th(curr); @@ -534,10 +531,6 @@ static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* tr } } - // if we didn't reach the end of the list it means that we've stopped because the count dropped below zero - if (traversedcount) - *traversedcount += startcount - (curr ? count + 1 : count); - return p; } @@ -721,8 +714,6 @@ static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) int alive = (gco->gch.marked ^ WHITEBITS) & deadmask; - g->gcstats.currcycle.sweepitems++; - if (gco->gch.tt == LUA_TTHREAD) { lua_State* th = gco2th(gco); @@ -770,11 +761,11 @@ static int sweepgcopage(lua_State* L, lua_Page* page) { // if the last block was removed, page would be removed as well if (--busyBlocks == 0) - return (pos - start) / blockSize + 1; + return int(pos - start) / blockSize + 1; } } - return (end - start) / blockSize; + return int(end - start) / blockSize; } static size_t gcstep(lua_State* L, size_t limit) @@ -793,8 +784,6 @@ static size_t gcstep(lua_State* L, size_t limit) { while (g->gray && cost < limit) { - g->gcstats.currcycle.markitems++; - cost += propagatemark(g); } @@ -812,8 +801,6 @@ static size_t gcstep(lua_State* L, size_t limit) { while (g->gray && cost < limit) { - g->gcstats.currcycle.markitems++; - cost += propagatemark(g); } @@ -842,10 +829,8 @@ static size_t gcstep(lua_State* L, size_t limit) while (g->sweepstrgc < g->strt.size && cost < limit) { - size_t traversedcount = 0; - sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++], &traversedcount); + sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++]); - g->gcstats.currcycle.sweepitems += traversedcount; cost += GC_SWEEPCOST; } @@ -855,12 +840,10 @@ static size_t gcstep(lua_State* L, size_t limit) // sweep string buffer list and preserve used string count uint32_t nuse = L->global->strt.nuse; - size_t traversedcount = 0; - sweepwholelist(L, (GCObject**)&g->strbufgc, &traversedcount); + sweepwholelist(L, (GCObject**)&g->strbufgc); L->global->strt.nuse = nuse; - g->gcstats.currcycle.sweepitems += traversedcount; g->gcstate = GCSsweep; // end sweep-string phase } break; @@ -893,10 +876,8 @@ static size_t gcstep(lua_State* L, size_t limit) { while (*g->sweepgc && cost < limit) { - size_t traversedcount = 0; - g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); + g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX); - g->gcstats.currcycle.sweepitems += traversedcount; cost += GC_SWEEPMAX * GC_SWEEPCOST; } diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 4455fec5..528d0944 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -113,6 +113,7 @@ luaC_barrierf(L, obj2gco(p), obj2gco(o)); \ } +// TODO: remove with FFlagLuauGcForwardMetatableBarrier #define luaC_objbarriert(L, t, o) \ { \ if (isblack(obj2gco(t)) && iswhite(obj2gco(o))) \ diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 6d3b7772..e1dbce50 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -200,7 +200,7 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int global_State* g = L->global; - LUAU_ASSERT(pageSize - offsetof(lua_Page, data) >= blockSize * blockCount); + LUAU_ASSERT(pageSize - int(offsetof(lua_Page, data)) >= blockSize * blockCount); lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, pageSize); if (!page) @@ -376,7 +376,7 @@ static void* luaM_newgcoblock(lua_State* L, int sizeClass) LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); - LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass]); + LUAU_ASSERT(page->blockSize == kSizeClassConfig.sizeOfClass[sizeClass]); void* block; @@ -520,7 +520,7 @@ GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) } else { - lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + nsize, nsize, 1); + lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + int(nsize), int(nsize), 1); block = &page->data; ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 080f0024..0708b71f 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -96,9 +96,6 @@ struct GCCycleStats double sweeptime = 0.0; - size_t markitems = 0; - size_t sweepitems = 0; - size_t assistwork = 0; size_t explicitwork = 0; diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 41b553b5..292625b0 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -44,10 +44,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "prop") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") { - ScopedFastFlag sffs[] = { - {"LuauPersistDefinitionFileTypes", true}, - }; - loadDefinition(R"( declare function Connect(fn: (string) -> ()) )"); @@ -63,8 +59,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") { - ScopedFastFlag sffs{"LuauStoreMatchingOverloadFnType", true}; - loadDefinition(R"( declare foo: ((string) -> number) & ((number) -> string) )"); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 211e1be1..e8e3b315 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2626,7 +2626,6 @@ local a: A<(number, s@1> TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompleteFirstArg("LuauAutocompleteFirstArg", true); check(R"( local function foo1() return 1 end @@ -2728,4 +2727,39 @@ end CHECK(ac.entryMap.count("getx")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauRefactorTypeVarQuestions", true}, + }; + + check(R"( + --!strict + local foo: "hello" | "bye" = "hello" + foo:@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("format")); +} + +TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag preferToCallFunctionsForIntersects("PreferToCallFunctionsForIntersects", true); + + check(R"( +local bar: ((number) -> number) & (number, number) -> number) +local abc = b@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("bar")); + CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 3b0d677d..4a28bdde 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -603,6 +603,37 @@ RETURN R0 1 )"); } +TEST_CASE("TableLiteralsIndexConstant") +{ + ScopedFastFlag sff("LuauCompileTableIndexOpt", true); + + // validate that we use SETTTABLEKS for constant variable keys + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = "key", "value" + return {[a] = 42, [b] = 0} +)"), R"( +NEWTABLE R0 2 0 +LOADN R1 42 +SETTABLEKS R1 R0 K0 +LOADN R1 0 +SETTABLEKS R1 R0 K1 +RETURN R0 1 +)"); + + // validate that we use SETTABLEN for constant variable keys *and* that we predict array size + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = 1, 2 + return {[a] = 42, [b] = 0} +)"), R"( +NEWTABLE R0 0 2 +LOADN R1 42 +SETTABLEN R1 R0 1 +LOADN R1 0 +SETTABLEN R1 R0 2 +RETURN R0 1 +)"); +} + TEST_CASE("TableSizePredictionBasic") { CHECK_EQ("\n" + compileFunction0(R"( @@ -2450,6 +2481,37 @@ return )"); } +TEST_CASE("DebugLineInfoAssignment") +{ + ScopedFastFlag sff("LuauCompileTableIndexOpt", true); + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( + local a = { b = { c = { d = 3 } } } + +a +["b"] +["c"] +["d"] = 4 +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: DUPTABLE R0 1 +2: DUPTABLE R1 3 +2: DUPTABLE R2 5 +2: LOADN R3 3 +2: SETTABLEKS R3 R2 K4 +2: SETTABLEKS R2 R1 K2 +2: SETTABLEKS R1 R0 K0 +5: GETTABLEKS R2 R0 K0 +6: GETTABLEKS R1 R2 K2 +7: LOADN R2 4 +7: SETTABLEKS R2 R1 K4 +8: RETURN R0 0 +)"); +} + TEST_CASE("DebugSource") { const char* source = R"( @@ -2763,6 +2825,75 @@ RETURN R1 -1 )"); } +TEST_CASE("FastcallSelect") +{ + ScopedFastFlag sff("LuauCompileSelectBuiltin", true); + + // select(_, ...) compiles to a builtin call + CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"( +LOADK R1 K0 +FASTCALL1 57 R1 +3 +GETIMPORT R0 2 +GETVARARGS R2 -1 +CALL R0 -1 1 +RETURN R0 1 +)"); + + // more complex example: select inside a for loop bound + select from a iterator + CHECK_EQ("\n" + compileFunction0(R"( +local sum = 0 +for i=1, select('#', ...) do + sum += select(i, ...) +end +return sum +)"), R"( +LOADN R0 0 +LOADN R3 1 +LOADK R5 K0 +FASTCALL1 57 R5 +3 +GETIMPORT R4 2 +GETVARARGS R6 -1 +CALL R4 -1 1 +MOVE R1 R4 +LOADN R2 1 +FORNPREP R1 +7 +FASTCALL1 57 R3 +3 +GETIMPORT R4 2 +GETVARARGS R6 -1 +CALL R4 -1 1 +ADD R0 R0 R4 +FORNLOOP R1 -7 +RETURN R0 1 +)"); + + // currently we assume a single value return to avoid dealing with stack resizing + CHECK_EQ("\n" + compileFunction0("return select('#', ...)"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +GETVARARGS R2 -1 +CALL R0 -1 -1 +RETURN R0 -1 +)"); + + // note that select with a non-variadic second argument doesn't get optimized + CHECK_EQ("\n" + compileFunction0("return select('#')"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +CALL R0 1 -1 +RETURN R0 -1 +)"); + + // note that select with a non-variadic second argument doesn't get optimized + CHECK_EQ("\n" + compileFunction0("return select('#', foo())"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +GETIMPORT R2 4 +CALL R2 0 -1 +CALL R0 -1 -1 +RETURN R0 -1 +)"); +} + TEST_CASE("LotsOfParameters") { const char* source = R"( diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 5222af33..914b881f 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -331,8 +331,6 @@ TEST_CASE("UTF8") TEST_CASE("Coroutine") { - ScopedFastFlag sff("LuauCoroutineClose", true); - runConformance("coroutine.lua"); } diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 405f26e0..ea1a08fe 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -956,7 +956,6 @@ TEST_CASE("no_use_after_free_with_type_fun_instantiation") { // This flag forces this test to crash if there's a UAF in this code. ScopedFastFlag sff_DebugLuauFreezeArena("DebugLuauFreezeArena", true); - ScopedFastFlag sff_LuauCloneCorrectlyBeforeMutatingTableType("LuauCloneCorrectlyBeforeMutatingTableType", true); FrontendFixture fix; diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index ac81005c..90831ee9 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2000,6 +2000,73 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); } +TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +{ + { + AstStat* stat = parse("return if true then 1 else 2"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr = str->list.data[0]->as(); + REQUIRE(ifElseExpr != nullptr); + } + + { + AstStat* stat = parse("return if true then 1 elseif true then 2 else 3"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr1 = str->list.data[0]->as(); + REQUIRE(ifElseExpr1 != nullptr); + auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); + REQUIRE(ifElseExpr2 != nullptr); + } + + // Use "else if" as opposed to elseif + { + AstStat* stat = parse("return if true then 1 else if true then 2 else 3"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr1 = str->list.data[0]->as(); + REQUIRE(ifElseExpr1 != nullptr); + auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); + REQUIRE(ifElseExpr2 != nullptr); + } + + // Use an if-else expression as the conditional expression of an if-else expression + { + AstStat* stat = parse("return if if true then false else true then 1 else 2"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr = str->list.data[0]->as(); + REQUIRE(ifElseExpr != nullptr); + auto* nestedIfElseExpr = ifElseExpr->condition->as(); + REQUIRE(nestedIfElseExpr != nullptr); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") +{ + AstStat* stat = parse(R"( +type Packed = () -> T... + +type A = Packed +type B = Packed<...number> +type C = Packed<(number, X...)> + )"); + REQUIRE(stat != nullptr); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -2504,71 +2571,4 @@ type Y = (T...) -> U... CHECK_EQ(1, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") -{ - { - AstStat* stat = parse("return if true then 1 else 2"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr = str->list.data[0]->as(); - REQUIRE(ifElseExpr != nullptr); - } - - { - AstStat* stat = parse("return if true then 1 elseif true then 2 else 3"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr1 = str->list.data[0]->as(); - REQUIRE(ifElseExpr1 != nullptr); - auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); - REQUIRE(ifElseExpr2 != nullptr); - } - - // Use "else if" as opposed to elseif - { - AstStat* stat = parse("return if true then 1 else if true then 2 else 3"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr1 = str->list.data[0]->as(); - REQUIRE(ifElseExpr1 != nullptr); - auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); - REQUIRE(ifElseExpr2 != nullptr); - } - - // Use an if-else expression as the conditional expression of an if-else expression - { - AstStat* stat = parse("return if if true then false else true then 1 else 2"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr = str->list.data[0]->as(); - REQUIRE(ifElseExpr != nullptr); - auto* nestedIfElseExpr = ifElseExpr->condition->as(); - REQUIRE(nestedIfElseExpr != nullptr); - } -} - -TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") -{ - AstStat* stat = parse(R"( -type Packed = () -> T... - -type A = Packed -type B = Packed<...number> -type C = Packed<(number, X...)> - )"); - REQUIRE(stat != nullptr); -} - TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp new file mode 100644 index 00000000..f660bcd3 --- /dev/null +++ b/tests/Repl.test.cpp @@ -0,0 +1,117 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Repl.h" + +#include "doctest.h" + +#include +#include +#include +#include + + +class ReplFixture +{ +public: + ReplFixture() + : luaState(luaL_newstate(), lua_close) + { + L = luaState.get(); + setupState(L); + luaL_sandboxthread(L); + + std::string result = runCode(L, prettyPrintSource); + } + + // Returns all of the output captured from the pretty printer + std::string getCapturedOutput() + { + lua_getglobal(L, "capturedoutput"); + const char* str = lua_tolstring(L, -1, nullptr); + std::string result(str); + lua_pop(L, 1); + return result; + } + lua_State* L; + +private: + std::unique_ptr luaState; + + // This is a simplicitic and incomplete pretty printer. + // It is included here to test that the pretty printer hook is being called. + // More elaborate tests to ensure correct output can be added if we introduce + // a more feature rich pretty printer. + std::string prettyPrintSource = R"( +-- Accumulate pretty printer output in `capturedoutput` +capturedoutput = "" + +function arraytostring(arr) + local strings = {} + table.foreachi(arr, function(k,v) table.insert(strings, pptostring(v)) end ) + return "{" .. table.concat(strings, ", ") .. "}" +end + +function pptostring(x) + if type(x) == "table" then + -- Just assume array-like tables for now. + return arraytostring(x) + elseif type(x) == "string" then + return '"' .. x .. '"' + else + return tostring(x) + end +end + +-- Note: Instead of calling print, the pretty printer just stores the output +-- in `capturedoutput` so we can check for the correct results. +function _PRETTYPRINT(...) + local args = table.pack(...) + local strings = {} + for i=1, args.n do + local item = args[i] + local str = pptostring(item, customoptions) + if i == 1 then + capturedoutput = capturedoutput .. str + else + capturedoutput = capturedoutput .. "\t" .. str + end + end +end +)"; +}; + +TEST_SUITE_BEGIN("ReplPrettyPrint"); + +TEST_CASE_FIXTURE(ReplFixture, "AdditionStatement") +{ + runCode(L, "return 30 + 12"); + CHECK(getCapturedOutput() == "42"); +} + +TEST_CASE_FIXTURE(ReplFixture, "TableLiteral") +{ + runCode(L, "return {1, 2, 3, 4}"); + CHECK(getCapturedOutput() == "{1, 2, 3, 4}"); +} + +TEST_CASE_FIXTURE(ReplFixture, "StringLiteral") +{ + runCode(L, "return 'str'"); + CHECK(getCapturedOutput() == "\"str\""); +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithStringLiterals") +{ + runCode(L, "return {1, 'two', 3, 'four'}"); + CHECK(getCapturedOutput() == "{1, \"two\", 3, \"four\"}"); +} + +TEST_CASE_FIXTURE(ReplFixture, "MultipleArguments") +{ + runCode(L, "return 3, 'three'"); + CHECK(getCapturedOutput() == "3\t\"three\""); +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 445ee532..bbb26291 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -435,8 +435,6 @@ TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_T TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( type F = ((() -> number)?) -> F? local function f(p) return f end @@ -450,8 +448,6 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union" TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_intersection") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( function f() return f end local a: ((number) -> ()) & typeof(f) diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 822bd727..76ab23b3 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -11,8 +11,6 @@ TEST_SUITE_BEGIN("TypeAliases"); TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( type F = () -> F? local function f() @@ -194,8 +192,6 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( type A = () -> (number, B) type B = () -> (string, A) diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 114679e3..a7f27551 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,8 +9,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) - TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -656,11 +654,7 @@ local d: D = c LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauExtendedFunctionMismatchError) - CHECK_EQ( - toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack") @@ -675,11 +669,8 @@ local d: D = c LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauExtendedFunctionMismatchError) - CHECK_EQ(toString(result.errors[0]), - R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); + CHECK_EQ(toString(result.errors[0]), + R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index e6d3d4d4..47c13be9 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -271,6 +271,32 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b } +// Also belongs in TypeInfer.refinements.test.cpp. +// Just needs to fully support equality refinement. Which is annoying without type states. +TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") +{ + ScopedFastFlag sff{"LuauDiscriminableUnions", true}; + + CheckResult result = check(R"( + type T = {x: string, y: number} | {x: nil, y: nil} + + local function f(t: T) + if t.x ~= nil then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{| x: string, y: number |}", toString(requireTypeAtPosition({5, 28}))); + + // Should be {| x: nil, y: nil |} + CHECK_EQ("{| x: nil, y: nil |} | {| x: string, y: number |}", toString(requireTypeAtPosition({7, 28}))); +} + TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; @@ -590,8 +616,6 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; - // Mutability in type function application right now can create strange recursive types // TODO: instantiation right now is problematic, in this example should either leave the Table type alone // or it should rename the type to 'Self' so that the result will be 'Self' diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index d76b920b..f346ddfd 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -6,11 +6,77 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauDiscriminableUnions) LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; +namespace +{ +std::optional> magicFunctionInstanceIsA( + TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +{ + if (expr.args.size != 1) + return std::nullopt; + + auto index = expr.func->as(); + auto str = expr.args.data[0]->as(); + if (!index || !str) + return std::nullopt; + + std::optional lvalue = tryGetLValue(*index->expr); + std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); + if (!lvalue || !tfun) + return std::nullopt; + + unfreeze(typeChecker.globalTypes); + TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); + freeze(typeChecker.globalTypes); + return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; +} + +struct RefinementClassFixture : Fixture +{ + RefinementClassFixture() + { + TypeArena& arena = typeChecker.globalTypes; + + unfreeze(arena); + TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); + getMutable(vec3)->props = { + {"X", Property{typeChecker.numberType}}, + {"Y", Property{typeChecker.numberType}}, + {"Z", Property{typeChecker.numberType}}, + }; + + TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); + + TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); + TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); + TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); + getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + + getMutable(inst)->props = { + {"Name", Property{typeChecker.stringType}}, + {"IsA", Property{isA}}, + }; + + TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); + TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); + getMutable(part)->props = { + {"Position", Property{vec3}}, + }; + + typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; + typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; + typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; + typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; + freeze(typeChecker.globalTypes); + } +}; +} // namespace + TEST_SUITE_BEGIN("RefinementTest"); TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint") @@ -196,8 +262,18 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); + if (FFlag::LuauDiscriminableUnions) + { + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -237,7 +313,6 @@ TEST_CASE_FIXTURE(Fixture, "impossible_type_narrow_is_not_an_error") TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") { - CheckResult result = check(R"( local t: {x: number?} = {x = 1} @@ -254,7 +329,6 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") { - CheckResult result = check(R"( local t: {x: {y: string}?} = {x = {y = "hello!"}} @@ -360,7 +434,10 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauSingletonTypes", true}, + }; CheckResult result = check(R"( local function f(a: (string | number)?) @@ -374,16 +451,8 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == "hello" - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "string"); // a == "hello" - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello" + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" } TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") @@ -416,7 +485,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; + ScopedFastFlag sff{"LuauDiscriminableUnions", true}; + ScopedFastFlag sff2{"LuauWeakEqConstraint", true}; CheckResult result = check(R"( local function f(a, b: string?) @@ -428,16 +498,8 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "string?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b - } + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b } TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") @@ -527,9 +589,17 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") end )"); - // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); + if (FFlag::LuauDiscriminableUnions) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); + } + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); } @@ -614,214 +684,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" } -namespace -{ -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) -{ - if (expr.args.size != 1) - return std::nullopt; - - auto index = expr.func->as(); - auto str = expr.args.data[0]->as(); - if (!index || !str) - return std::nullopt; - - std::optional lvalue = tryGetLValue(*index->expr); - std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); - if (!lvalue || !tfun) - return std::nullopt; - - unfreeze(typeChecker.globalTypes); - TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); - freeze(typeChecker.globalTypes); - return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; -} - -struct RefinementClassFixture : Fixture -{ - RefinementClassFixture() - { - TypeArena& arena = typeChecker.globalTypes; - - unfreeze(arena); - TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); - getMutable(vec3)->props = { - {"X", Property{typeChecker.numberType}}, - {"Y", Property{typeChecker.numberType}}, - {"Z", Property{typeChecker.numberType}}, - }; - - TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); - - TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); - TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); - TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); - getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - - getMutable(inst)->props = { - {"Name", Property{typeChecker.stringType}}, - {"IsA", Property{isA}}, - }; - - TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); - TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); - getMutable(part)->props = { - {"Position", Property{vec3}}, - }; - - typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; - typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; - typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; - typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; - freeze(typeChecker.globalTypes); - } -}; -} // namespace - -TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") -{ - CheckResult result = check(R"( - local function f(vec) - local X, Y, Z = vec.X, vec.Y, vec.Z - - if type(vec) == "vector" then - local foo = vec - elseif typeof(vec) == "Instance" then - local foo = vec - else - local foo = vec - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); - else - CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); - - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" - else - CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") -{ - CheckResult result = check(R"( - local function f(x: Instance | Vector3) - if typeof(x) == "Vector3" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") -{ - CheckResult result = check(R"( - local function f(x: string | number | Instance | Vector3) - if type(x) == "userdata" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Instance | Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") -{ - CheckResult result = check(R"( - local function f(x: Part | Folder | string) - if typeof(x) == "Instance" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") -{ - CheckResult result = check(R"( - local function f(x: Part | Folder | Instance | string | Vector3 | any) - if typeof(x) == "Instance" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") -{ - CheckResult result = check(R"( - --!nonstrict - - local function f(x) - if typeof(x) == "Instance" and x:IsA("Folder") then - local foo = x - elseif typeof(x) == "table" then - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") -{ - CheckResult result = check(R"( - local function f(x: Part | Folder | string) - if typeof(x) ~= "Instance" or not x:IsA("Part") then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder | string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); -} - TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") { CheckResult result = check(R"( @@ -1145,4 +1007,259 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") +{ + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type T = {tag: "missing", x: nil} | {tag: "exists", x: string} + + local function f(t: T) + if t.x then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "discriminate_tag") +{ + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Cat = {tag: "Cat", name: string, catfood: string} + type Dog = {tag: "Dog", name: string, dogfood: string} + type Animal = Cat | Dog + + local function f(animal: Animal) + if animal.tag == "Cat" then + local cat: Cat = animal + elseif animal.tag == "Dog" then + local dog: Dog = animal + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); +} + +TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") +{ + ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; + + CheckResult result = check(R"( + type T = { [string]: { prop: number }? } + local t: T = {} + + if t["hello"] then + local foo = t["hello"].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") +{ + CheckResult result = check(R"( + local function len(a: {any}) + return a and #a or nil + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") +{ + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type T = {tag: "Part", x: Part} | {tag: "Folder", x: Folder} + + local function f(t: T) + if t.x:IsA("Part") then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") +{ + CheckResult result = check(R"( + local function f(vec) + local X, Y, Z = vec.X, vec.Y, vec.Z + + if type(vec) == "vector" then + local foo = vec + elseif typeof(vec) == "Instance" then + local foo = vec + else + local foo = vec + end + end + )"); + + if (FFlag::LuauDiscriminableUnions) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); + else + CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + } + + CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" + + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" + + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + else + CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") +{ + CheckResult result = check(R"( + local function f(x: Instance | Vector3) + if typeof(x) == "Vector3" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") +{ + CheckResult result = check(R"( + local function f(x: string | number | Instance | Vector3) + if type(x) == "userdata" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Instance | Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | string) + if typeof(x) == "Instance" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | Instance | string | Vector3 | any) + if typeof(x) == "Instance" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") +{ + CheckResult result = check(R"( + --!nonstrict + + local function f(x) + if typeof(x) == "Instance" and x:IsA("Folder") then + local foo = x + elseif typeof(x) == "table" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | string) + if typeof(x) ~= "Instance" or not x:IsA("Part") then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 94cfb643..df365fda 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -379,9 +379,7 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauUnionHeuristic", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauExtendedUnionMismatchError", true}, }; CheckResult result = check(R"( @@ -404,9 +402,7 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauUnionHeuristic", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauExtendedUnionMismatchError", true}, }; CheckResult result = check(R"( @@ -429,9 +425,7 @@ TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauUnionHeuristic", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauExtendedUnionMismatchError", true}, {"LuauIfElseExpectedType2", true}, {"LuauIfElseBranchTypeUnion", true}, }; diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 644efed7..48310921 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -12,8 +12,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) - TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -2075,22 +2073,11 @@ caused by: caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); - if (FFlag::LuauExtendedFunctionMismatchError) - { - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); - } - else - { - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' -caused by: - Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' -caused by: - Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); - } } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") @@ -2166,7 +2153,6 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") { ScopedFastFlag sff[]{ - {"LuauFixRecursiveMetatableCall", true}, {"LuauUnsealedTableLiteral", true}, }; diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 7ee5253c..c9b30e1a 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -16,7 +16,6 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) LUAU_FASTFLAG(LuauEqConstraint) -LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) using namespace Luau; @@ -959,8 +958,6 @@ TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( function f() return f @@ -973,8 +970,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( function f(g) return f(f) @@ -1699,8 +1694,6 @@ TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( --!strict local s @@ -1711,8 +1704,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") TEST_CASE_FIXTURE(Fixture, "occurs_check_does_not_recurse_forever_if_asked_to_traverse_a_cyclic_type") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( --!strict function u(t, w) @@ -3326,11 +3317,12 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable") +TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs") { CheckResult result = check(R"( - local x - print((x == true and (x .. "y")) .. 1) + local function f(x) + return x .. "y" + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -3340,13 +3332,14 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") { CheckResult result = check(R"( - local x - print("foo" .. x) + local function f(x) + return "foo" .. x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("(string) -> string", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") @@ -4374,8 +4367,6 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") TEST_CASE_FIXTURE(Fixture, "record_matching_overload") { - ScopedFastFlag sffs("LuauStoreMatchingOverloadFnType", true); - CheckResult result = check(R"( type Overload = ((string) -> string) & ((number) -> number) local abc: Overload @@ -4475,17 +4466,10 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauExtendedFunctionMismatchError) - { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' caused by: Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); - } - else - { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number')", toString(result.errors[0])); - } + toString(result.errors[0])); // Infer from variadic packs into elements result = check(R"( @@ -4618,17 +4602,9 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauExtendedFunctionMismatchError) - { - CHECK_EQ( - "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " - "parameters", - toString(result.errors[0])); - } - else - { - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); - } + CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") @@ -4741,8 +4717,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { - ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; - CheckResult result = check(R"( type A = { x: number } local a: A = { x = 1 } @@ -4965,8 +4939,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_ TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> string type B = (number) -> string @@ -4983,8 +4955,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> string type B = (number, string) -> string @@ -5001,8 +4971,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> (number) type B = (number, number) -> (number, boolean) @@ -5019,8 +4987,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> string type B = (number, number) -> number @@ -5037,8 +5003,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> (number, string) type B = (number, number) -> (number, boolean) @@ -5069,8 +5033,6 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") { - ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true}; - CheckResult result = check(R"( local t = {} diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index d4878d14..079870f5 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -931,7 +931,6 @@ type R = { m: F } TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") { ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true}; - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; CheckResult result = check(R"( local a: () -> (number, ...string) diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index b54ba996..759794e6 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -464,8 +464,6 @@ local a: XYZ = { w = 4 } TEST_CASE_FIXTURE(Fixture, "error_detailed_optional") { - ScopedFastFlag luauExtendedUnionMismatchError{"LuauExtendedUnionMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 2e0d149e..329e7b1f 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -268,8 +268,6 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TEST_CASE("tagging_tables") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypeVar ttv{TableTypeVar{}}; CHECK(!Luau::hasTag(&ttv, "foo")); Luau::attachTag(&ttv, "foo"); @@ -278,8 +276,6 @@ TEST_CASE("tagging_tables") TEST_CASE("tagging_classes") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; CHECK(!Luau::hasTag(&base, "foo")); Luau::attachTag(&base, "foo"); @@ -288,8 +284,6 @@ TEST_CASE("tagging_classes") TEST_CASE("tagging_subclasses") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}}; @@ -307,8 +301,6 @@ TEST_CASE("tagging_subclasses") TEST_CASE("tagging_functions") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypePackVar empty{TypePack{}}; TypeVar ftv{FunctionTypeVar{&empty, &empty}}; CHECK(!Luau::hasTag(&ftv, "foo")); @@ -318,8 +310,6 @@ TEST_CASE("tagging_functions") TEST_CASE("tagging_props") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - Property prop{}; CHECK(!Luau::hasTag(prop, "foo")); Luau::attachTag(prop, "foo"); @@ -370,4 +360,66 @@ local b: (T, T, T) -> T CHECK_EQ(count, 1); } +TEST_CASE("isString_on_string_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + CHECK(isString(&helloString)); +} + +TEST_CASE("isString_on_unions_of_various_string_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; + TypeVar union_{UnionTypeVar{{&helloString, &byeString}}}; + + CHECK(isString(&union_)); +} + +TEST_CASE("proof_that_isString_uses_all_of") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; + TypeVar booleanType{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}}; + TypeVar union_{UnionTypeVar{{&helloString, &byeString, &booleanType}}}; + + CHECK(!isString(&union_)); +} + +TEST_CASE("isBoolean_on_boolean_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + CHECK(isBoolean(&trueBool)); +} + +TEST_CASE("isBoolean_on_unions_of_true_or_false_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; + TypeVar union_{UnionTypeVar{{&trueBool, &falseBool}}}; + + CHECK(isBoolean(&union_)); +} + +TEST_CASE("proof_that_isBoolean_uses_all_of") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; + TypeVar stringType{PrimitiveTypeVar{PrimitiveTypeVar::String}}; + TypeVar union_{UnionTypeVar{{&trueBool, &falseBool, &stringType}}}; + + CHECK(!isBoolean(&union_)); +} + TEST_SUITE_END();