Merge branch 'upstream' into merge

This commit is contained in:
Aaron Weiss 2024-03-08 16:05:03 -08:00
commit 1ebdfe093a
48 changed files with 2365 additions and 277 deletions

View File

@ -25,19 +25,21 @@ TypeId makeOption(NotNull<BuiltinTypes> builtinTypes, TypeArena& arena, TypeId t
/** Small utility function for building up type definitions from C++. /** Small utility function for building up type definitions from C++.
*/ */
TypeId makeFunction( // Monomorphic TypeId makeFunction( // Monomorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes); TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes,
bool checked = false);
TypeId makeFunction( // Polymorphic TypeId makeFunction( // Polymorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, std::initializer_list<TypePackId> genericPacks, TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes); std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes, bool checked = false);
TypeId makeFunction( // Monomorphic TypeId makeFunction( // Monomorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes); std::initializer_list<TypeId> retTypes, bool checked = false);
TypeId makeFunction( // Polymorphic TypeId makeFunction( // Polymorphic
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, std::initializer_list<TypePackId> genericPacks, TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, std::initializer_list<TypePackId> genericPacks,
std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes); std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes,
bool checked = false);
void attachMagicFunction(TypeId ty, MagicFunction fn); void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn);

View File

@ -341,6 +341,13 @@ struct UninhabitedTypeFamily
bool operator==(const UninhabitedTypeFamily& rhs) const; bool operator==(const UninhabitedTypeFamily& rhs) const;
}; };
struct ExplicitFunctionAnnotationRecommended
{
std::vector<std::pair<std::string, TypeId>> recommendedArgs;
TypeId recommendedReturn;
bool operator==(const ExplicitFunctionAnnotationRecommended& rhs) const;
};
struct UninhabitedTypePackFamily struct UninhabitedTypePackFamily
{ {
TypePackId tp; TypePackId tp;
@ -416,14 +423,15 @@ struct UnexpectedTypePackInSubtyping
bool operator==(const UnexpectedTypePackInSubtyping& rhs) const; bool operator==(const UnexpectedTypePackInSubtyping& rhs) const;
}; };
using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods, using TypeErrorData =
DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire, Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods, DuplicateTypeDefinition,
IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire, IncorrectGenericParameterCount, SyntaxError,
CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError, CannotCallNonFunction, ExtraInformation,
DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning, DuplicateGenericParameter,
TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe, UninhabitedTypeFamily, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, TypesAreUnrelated,
UninhabitedTypePackFamily, WhereClauseNeeded, PackWhereClauseNeeded, CheckedFunctionCallError, NonStrictFunctionDefinitionError, NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe, UninhabitedTypeFamily, UninhabitedTypePackFamily,
PropertyAccessViolation, CheckedFunctionIncorrectArgs, UnexpectedTypeInSubtyping, UnexpectedTypePackInSubtyping>; WhereClauseNeeded, PackWhereClauseNeeded, CheckedFunctionCallError, NonStrictFunctionDefinitionError, PropertyAccessViolation,
CheckedFunctionIncorrectArgs, UnexpectedTypeInSubtyping, UnexpectedTypePackInSubtyping, ExplicitFunctionAnnotationRecommended>;
struct TypeErrorSummary struct TypeErrorSummary
{ {

View File

@ -307,6 +307,9 @@ struct NormalizedType
bool hasTables() const; bool hasTables() const;
bool hasFunctions() const; bool hasFunctions() const;
bool hasTyvars() const; bool hasTyvars() const;
bool isFalsy() const;
bool isTruthy() const;
}; };

View File

@ -67,4 +67,36 @@ private:
void add(Analysis analysis, TypeId ty, ErrorVec&& errors); void add(Analysis analysis, TypeId ty, ErrorVec&& errors);
}; };
struct SolveResult
{
enum OverloadCallResult {
Ok,
CodeTooComplex,
OccursCheckFailed,
NoMatchingOverload,
};
OverloadCallResult result;
std::optional<TypePackId> typePackId; // nullopt if result != Ok
TypeId overloadToUse = nullptr;
TypeId inferredTy = nullptr;
DenseHashMap<TypeId, std::vector<TypeId>> expandedFreeTypes{nullptr};
};
// Helper utility, presently used for binary operator type families.
//
// Given a function and a set of arguments, select a suitable overload.
SolveResult solveFunctionCall(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
const Location& location,
TypeId fn,
TypePackId argsPack
);
} // namespace Luau } // namespace Luau

View File

@ -219,23 +219,7 @@ private:
template<typename T, typename Container> template<typename T, typename Container>
TypeId makeAggregateType(const Container& container, TypeId orElse); TypeId makeAggregateType(const Container& container, TypeId orElse);
std::pair<TypeId, ErrorVec> handleTypeFamilyReductionResult(const TypeFamilyInstanceType* familyInstance);
std::pair<TypeId, ErrorVec> handleTypeFamilyReductionResult(const TypeFamilyInstanceType* familyInstance)
{
TypeFamilyContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}};
TypeId family = arena->addType(*familyInstance);
std::string familyString = toString(family);
FamilyGraphReductionResult result = reduceFamilies(family, {}, context, true);
ErrorVec errors;
if (result.blockedTypes.size() != 0 || result.blockedPacks.size() != 0)
{
errors.push_back(TypeError{{}, UninhabitedTypeFamily{family}});
return {builtinTypes->neverType, errors};
}
if (result.reducedTypes.contains(family))
return {family, errors};
return {builtinTypes->neverType, errors};
}
[[noreturn]] void unexpected(TypeId ty); [[noreturn]] void unexpected(TypeId ty);
[[noreturn]] void unexpected(TypePackId tp); [[noreturn]] void unexpected(TypePackId tp);

View File

@ -57,6 +57,8 @@ struct TypeFamilyContext
, constraint(nullptr) , constraint(nullptr)
{ {
} }
NotNull<Constraint> pushConstraint(ConstraintV&& c);
}; };
/// Represents a reduction result, which may have successfully reduced the type, /// Represents a reduction result, which may have successfully reduced the type,
/// may have concretely failed to reduce the type, or may simply be stuck /// may have concretely failed to reduce the type, or may simply be stuck

View File

@ -0,0 +1,81 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/VecDeque.h"
#include "Luau/DenseHash.h"
#include "Luau/TypeFamily.h"
#include "Luau/Type.h"
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include "Luau/Normalize.h"
#include "Luau/TypeFwd.h"
#include "Luau/VisitType.h"
#include "Luau/NotNull.h"
namespace Luau
{
struct TypeFamilyReductionGuessResult
{
std::vector<std::pair<std::string, TypeId>> guessedFunctionAnnotations;
TypeId guessedReturnType;
bool shouldRecommendAnnotation = true;
};
// An Inference result for a type family is a list of types corresponding to the guessed argument types, followed by a type for the result
struct TypeFamilyInferenceResult
{
std::vector<TypeId> operandInference;
TypeId familyResultInference;
};
struct TypeFamilyReductionGuesser
{
// Tracks our hypothesis about what a type family reduces to
DenseHashMap<TypeId, TypeId> familyReducesTo{nullptr};
// Tracks our constraints on type family operands
DenseHashMap<TypeId, TypeId> substitutable{nullptr};
// List of instances to try progress
VecDeque<TypeId> toInfer;
DenseHashSet<TypeId> cyclicInstances{nullptr};
// Utilities
NotNull<BuiltinTypes> builtins;
NotNull<Normalizer> normalizer;
TypeFamilyReductionGuesser(NotNull<BuiltinTypes> builtins, NotNull<Normalizer> normalizer);
TypeFamilyReductionGuessResult guessTypeFamilyReductionForFunction(const AstExprFunction& expr, const FunctionType* ftv, TypeId retTy);
private:
std::optional<TypeId> guessType(TypeId arg);
void dumpGuesses();
bool isNumericBinopFamily(const TypeFamilyInstanceType& instance);
bool isComparisonFamily(const TypeFamilyInstanceType& instance);
bool isOrAndFamily(const TypeFamilyInstanceType& instance);
bool isNotFamily(const TypeFamilyInstanceType& instance);
bool isLenFamily(const TypeFamilyInstanceType& instance);
bool isUnaryMinus(const TypeFamilyInstanceType& instance);
// Operand is assignable if it looks like a cyclic family instance, or a generic type
bool operandIsAssignable(TypeId ty);
std::optional<TypeId> tryAssignOperandType(TypeId ty);
const NormalizedType* normalize(TypeId ty);
void step();
void infer();
bool done();
bool isFunctionGenericsSaturated(const FunctionType& ftv, DenseHashSet<TypeId>& instanceArgs);
void inferTypeFamilySubstitutions(TypeId ty, const TypeFamilyInstanceType* instance);
TypeFamilyInferenceResult inferNumericBinopFamily(const TypeFamilyInstanceType* instance);
TypeFamilyInferenceResult inferComparisonFamily(const TypeFamilyInstanceType* instance);
TypeFamilyInferenceResult inferOrAndFamily(const TypeFamilyInstanceType* instance);
TypeFamilyInferenceResult inferNotFamily(const TypeFamilyInstanceType* instance);
TypeFamilyInferenceResult inferLenFamily(const TypeFamilyInstanceType* instance);
TypeFamilyInferenceResult inferUnaryMinusFamily(const TypeFamilyInstanceType* instance);
};
} // namespace Luau

View File

@ -86,6 +86,10 @@ private:
*/ */
TypeId mkIntersection(TypeId left, TypeId right); TypeId mkIntersection(TypeId left, TypeId right);
// Returns true if needle occurs within haystack already. ie if we bound
// needle to haystack, would a cyclic type result?
OccursCheckResult occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);
// Returns true if needle occurs within haystack already. ie if we bound // Returns true if needle occurs within haystack already. ie if we bound
// needle to haystack, would a cyclic TypePack result? // needle to haystack, would a cyclic TypePack result?
OccursCheckResult occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack); OccursCheckResult occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);

View File

@ -25,6 +25,7 @@
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauSetMetatableOnUnionsOfTables, false); LUAU_FASTFLAGVARIABLE(LuauSetMetatableOnUnionsOfTables, false);
LUAU_FASTFLAGVARIABLE(LuauMakeStringMethodsChecked, false);
namespace Luau namespace Luau
{ {
@ -62,26 +63,26 @@ TypeId makeOption(NotNull<BuiltinTypes> builtinTypes, TypeArena& arena, TypeId t
} }
TypeId makeFunction( TypeId makeFunction(
TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes) TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes, bool checked)
{ {
return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes); return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes, checked);
} }
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes) std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<TypeId> retTypes, bool checked)
{ {
return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes); return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes, checked);
} }
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes, TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> paramTypes,
std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes) std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes, bool checked)
{ {
return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes); return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes, checked);
} }
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics, TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes) std::initializer_list<TypeId> retTypes, bool checked)
{ {
std::vector<TypeId> params; std::vector<TypeId> params;
if (selfType) if (selfType)
@ -108,6 +109,8 @@ TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initi
ftv.argNames.push_back(std::nullopt); ftv.argNames.push_back(std::nullopt);
} }
ftv.isCheckedFunction = checked;
return arena.addType(std::move(ftv)); return arena.addType(std::move(ftv));
} }
@ -289,17 +292,10 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
// declare function assert<T>(value: T, errorMessage: string?): intersect<T, ~(false?)> // declare function assert<T>(value: T, errorMessage: string?): intersect<T, ~(false?)>
TypeId genericT = arena.addType(GenericType{"T"}); TypeId genericT = arena.addType(GenericType{"T"});
TypeId refinedTy = arena.addType(TypeFamilyInstanceType{ TypeId refinedTy = arena.addType(TypeFamilyInstanceType{
NotNull{&kBuiltinTypeFamilies.intersectFamily}, NotNull{&kBuiltinTypeFamilies.intersectFamily}, {genericT, arena.addType(NegationType{builtinTypes->falsyType})}, {}});
{genericT, arena.addType(NegationType{builtinTypes->falsyType})},
{}
});
TypeId assertTy = arena.addType(FunctionType{ TypeId assertTy = arena.addType(FunctionType{
{genericT}, {genericT}, {}, arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}), arena.addTypePack(TypePack{{refinedTy}})});
{},
arena.addTypePack(TypePack{{genericT, builtinTypes->optionalStringType}}),
arena.addTypePack(TypePack{{refinedTy}})
});
addGlobalBinding(globals, "assert", assertTy, "@luau"); addGlobalBinding(globals, "assert", assertTy, "@luau");
} }
@ -773,72 +769,158 @@ TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
const TypePackId anyTypePack = builtinTypes->anyTypePack; const TypePackId anyTypePack = builtinTypes->anyTypePack;
const TypePackId variadicTailPack = FFlag::DebugLuauDeferredConstraintResolution ? builtinTypes->unknownTypePack : anyTypePack; const TypePackId variadicTailPack = FFlag::DebugLuauDeferredConstraintResolution ? builtinTypes->unknownTypePack : anyTypePack;
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypePackId emptyPack = arena->addTypePack({}); const TypePackId emptyPack = arena->addTypePack({});
const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}});
const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}});
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType});
const TypeId replArgType = if (FFlag::LuauMakeStringMethodsChecked)
arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), {
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); formatFTV.magicFunction = &magicFunctionFormat;
const TypeId gmatchFunc = formatFTV.isCheckedFunction = true;
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); const TypeId formatFn = arena->addType(formatFTV);
attachMagicFunction(gmatchFunc, magicFunctionGmatch); attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
const TypeId matchFunc = arena->addType(
FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true);
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
TableType::Props stringLib = { const TypeId replArgType = arena->addType(
{"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
{"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ false)}});
{"find", {findFunc}}, const TypeId gsubFunc =
{"format", {formatFn}}, // FIXME makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false);
{"gmatch", {gmatchFunc}}, const TypeId gmatchFunc = makeFunction(
{"gsub", {gsubFunc}}, *arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true);
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, attachMagicFunction(gmatchFunc, magicFunctionGmatch);
{"lower", {stringToStringType}}, attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"unpack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
variadicTailPack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string"); FunctionType matchFuncTy{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})};
matchFuncTy.isCheckedFunction = true;
const TypeId matchFunc = arena->addType(matchFuncTy);
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); FunctionType findFuncTy{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})};
findFuncTy.isCheckedFunction = true;
const TypeId findFunc = arena->addType(findFuncTy);
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
if (TableType* ttv = getMutable<TableType>(tableType)) // string.byte : string -> number? -> number? -> ...number
ttv->name = "typeof(string)"; FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList};
stringDotByte.isCheckedFunction = true;
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); // string.char : .... number -> string
FunctionType stringDotChar{numberVariadicList, arena->addTypePack({stringType})};
stringDotChar.isCheckedFunction = true;
// string.unpack : string -> string -> number? -> ...any
FunctionType stringDotUnpack{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
variadicTailPack,
};
stringDotUnpack.isCheckedFunction = true;
TableType::Props stringLib = {
{"byte", {arena->addType(stringDotByte)}},
{"char", {arena->addType(stringDotChar)}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType}, /* checked */ true)}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType}, /* checked */ true)}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})},
/* checked */ true)}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType}, /* checked */ true)}},
{"unpack", {arena->addType(stringDotUnpack)}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
else
{
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType});
const TypeId replArgType = arena->addType(
UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}});
const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType});
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})});
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
const TypeId matchFunc = arena->addType(FunctionType{
arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
TableType::Props stringLib = {
{"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, variadicTailPack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"unpack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
variadicTailPack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect( static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(

View File

@ -290,8 +290,8 @@ std::optional<TypeId> ConstraintGenerator::lookup(const ScopePtr& scope, DefId d
{ {
if (auto found = scope->lookup(def)) if (auto found = scope->lookup(def))
return *found; return *found;
else if (phi->operands.size() == 1) else if (!prototype && phi->operands.size() == 1)
return lookup(scope, phi->operands[0], prototype); return lookup(scope, phi->operands.at(0), prototype);
else if (!prototype) else if (!prototype)
return std::nullopt; return std::nullopt;
@ -963,7 +963,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f
DefId def = dfg->getDef(function->name); DefId def = dfg->getDef(function->name);
std::optional<TypeId> existingFunctionTy = lookup(scope, def); std::optional<TypeId> existingFunctionTy = lookup(scope, def);
if (sigFullyDefined && existingFunctionTy && get<BlockedType>(*existingFunctionTy)) if (existingFunctionTy && (sigFullyDefined || function->name->is<AstExprLocal>()) && get<BlockedType>(*existingFunctionTy))
asMutable(*existingFunctionTy)->ty.emplace<BoundType>(sig.signature); asMutable(*existingFunctionTy)->ty.emplace<BoundType>(sig.signature);
if (AstExprLocal* localName = function->name->as<AstExprLocal>()) if (AstExprLocal* localName = function->name->as<AstExprLocal>())
@ -2537,7 +2537,7 @@ TypeId ConstraintGenerator::updateProperty(const ScopePtr& scope, AstExpr* expr,
std::vector<std::string> segmentStrings(begin(segments), end(segments)); std::vector<std::string> segmentStrings(begin(segments), end(segments));
TypeId updatedType = arena->addType(BlockedType{}); TypeId updatedType = arena->addType(BlockedType{});
addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), assignedTy}); auto setC = addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), assignedTy});
TypeId prevSegmentTy = updatedType; TypeId prevSegmentTy = updatedType;
for (size_t i = 0; i < segments.size(); ++i) for (size_t i = 0; i < segments.size(); ++i)
@ -2545,7 +2545,8 @@ TypeId ConstraintGenerator::updateProperty(const ScopePtr& scope, AstExpr* expr,
TypeId segmentTy = arena->addType(BlockedType{}); TypeId segmentTy = arena->addType(BlockedType{});
module->astTypes[exprs[i]] = segmentTy; module->astTypes[exprs[i]] = segmentTy;
ValueContext ctx = i == segments.size() - 1 ? ValueContext::LValue : ValueContext::RValue; ValueContext ctx = i == segments.size() - 1 ? ValueContext::LValue : ValueContext::RValue;
addConstraint(scope, expr->location, HasPropConstraint{segmentTy, prevSegmentTy, segments[i], ctx}); auto hasC = addConstraint(scope, expr->location, HasPropConstraint{segmentTy, prevSegmentTy, segments[i], ctx});
setC->dependencies.push_back(hasC);
prevSegmentTy = segmentTy; prevSegmentTy = segmentTy;
} }

View File

@ -1068,7 +1068,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
} }
OverloadResolver resolver{ OverloadResolver resolver{
builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, c.callSite->location}; builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location};
auto [status, overload] = resolver.selectOverload(fn, argsPack); auto [status, overload] = resolver.selectOverload(fn, argsPack);
TypeId overloadToUse = fn; TypeId overloadToUse = fn;
if (status == OverloadResolver::Analysis::Ok) if (status == OverloadResolver::Analysis::Ok)
@ -2184,14 +2184,15 @@ bool ConstraintSolver::block_(BlockedConstraintId target, NotNull<const Constrai
{ {
// If a set is not present for the target, construct a new DenseHashSet for it, // If a set is not present for the target, construct a new DenseHashSet for it,
// else grab the address of the existing set. // else grab the address of the existing set.
NotNull<DenseHashSet<const Constraint*>> blockVec{&blocked.try_emplace(target, nullptr).first->second}; auto [iter, inserted] = blocked.try_emplace(target, nullptr);
auto& [key, blockVec] = *iter;
if (blockVec->find(constraint)) if (blockVec.find(constraint))
return false; return false;
blockVec->insert(constraint); blockVec.insert(constraint);
auto& count = blockedConstraints[constraint]; size_t& count = blockedConstraints[constraint];
count += 1; count += 1;
return true; return true;

View File

@ -509,6 +509,26 @@ struct ErrorConverter
return "Type family instance " + Luau::toString(e.ty) + " is uninhabited"; return "Type family instance " + Luau::toString(e.ty) + " is uninhabited";
} }
std::string operator()(const ExplicitFunctionAnnotationRecommended& r) const
{
std::string toReturn = toString(r.recommendedReturn);
std::string argAnnotations;
for (auto [arg, type] : r.recommendedArgs)
{
argAnnotations += arg + ": " + toString(type) + ", ";
}
if (argAnnotations.length() >= 2)
{
argAnnotations.pop_back();
argAnnotations.pop_back();
}
if (argAnnotations.empty())
return "Consider annotating the return with " + toReturn;
return "Consider placing the following annotations on the arguments: " + argAnnotations + " or instead annotating the return as " + toReturn;
}
std::string operator()(const UninhabitedTypePackFamily& e) const std::string operator()(const UninhabitedTypePackFamily& e) const
{ {
return "Type pack family instance " + Luau::toString(e.tp) + " is uninhabited"; return "Type pack family instance " + Luau::toString(e.tp) + " is uninhabited";
@ -883,6 +903,12 @@ bool UninhabitedTypeFamily::operator==(const UninhabitedTypeFamily& rhs) const
return ty == rhs.ty; return ty == rhs.ty;
} }
bool ExplicitFunctionAnnotationRecommended::operator==(const ExplicitFunctionAnnotationRecommended& rhs) const
{
return recommendedReturn == rhs.recommendedReturn && recommendedArgs == rhs.recommendedArgs;
}
bool UninhabitedTypePackFamily::operator==(const UninhabitedTypePackFamily& rhs) const bool UninhabitedTypePackFamily::operator==(const UninhabitedTypePackFamily& rhs) const
{ {
return tp == rhs.tp; return tp == rhs.tp;
@ -1084,6 +1110,12 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
e.ty = clone(e.ty); e.ty = clone(e.ty);
else if constexpr (std::is_same_v<T, UninhabitedTypeFamily>) else if constexpr (std::is_same_v<T, UninhabitedTypeFamily>)
e.ty = clone(e.ty); e.ty = clone(e.ty);
else if constexpr (std::is_same_v<T, ExplicitFunctionAnnotationRecommended>)
{
e.recommendedReturn = clone(e.recommendedReturn);
for (auto [_, t] : e.recommendedArgs)
t = clone(t);
}
else if constexpr (std::is_same_v<T, UninhabitedTypePackFamily>) else if constexpr (std::is_same_v<T, UninhabitedTypePackFamily>)
e.tp = clone(e.tp); e.tp = clone(e.tp);
else if constexpr (std::is_same_v<T, WhereClauseNeeded>) else if constexpr (std::is_same_v<T, WhereClauseNeeded>)

View File

@ -195,6 +195,15 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "DynamicPropertyLookupOnClassesUnsafe { " << toString(err.ty) << " }"; stream << "DynamicPropertyLookupOnClassesUnsafe { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, UninhabitedTypeFamily>) else if constexpr (std::is_same_v<T, UninhabitedTypeFamily>)
stream << "UninhabitedTypeFamily { " << toString(err.ty) << " }"; stream << "UninhabitedTypeFamily { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, ExplicitFunctionAnnotationRecommended>)
{
std::string recArgs = "[";
for (auto [s, t] : err.recommendedArgs)
recArgs += " " + s + ": " + toString(t);
recArgs += " ]";
stream << "ExplicitFunctionAnnotationRecommended { recommmendedReturn = '" + toString(err.recommendedReturn) +
"', recommmendedArgs = " + recArgs + "}";
}
else if constexpr (std::is_same_v<T, UninhabitedTypePackFamily>) else if constexpr (std::is_same_v<T, UninhabitedTypePackFamily>)
stream << "UninhabitedTypePackFamily { " << toString(err.tp) << " }"; stream << "UninhabitedTypePackFamily { " << toString(err.tp) << " }";
else if constexpr (std::is_same_v<T, WhereClauseNeeded>) else if constexpr (std::is_same_v<T, WhereClauseNeeded>)

View File

@ -542,11 +542,11 @@ struct NonStrictTypeChecker
} }
} }
} }
// For a checked function, these gotta be the same size
std::string functionName = getFunctionNameAsString(*call->func).value_or(""); std::string functionName = getFunctionNameAsString(*call->func).value_or("");
if (call->args.size != argTypes.size()) if (call->args.size > argTypes.size())
{ {
// We are passing more arguments than we expect, so we should error
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location);
return fresh; return fresh;
} }
@ -572,6 +572,20 @@ struct NonStrictTypeChecker
if (auto runTimeFailureType = willRunTimeError(arg, fresh)) if (auto runTimeFailureType = willRunTimeError(arg, fresh))
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location); reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location);
} }
if (call->args.size < argTypes.size())
{
// We are passing fewer arguments than we expect
// so we need to ensure that the rest of the args are optional.
bool remainingArgsOptional = true;
for (size_t i = call->args.size; i < argTypes.size(); i++)
remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]);
if (!remainingArgsOptional)
{
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location);
return fresh;
}
}
} }
} }

View File

@ -394,6 +394,25 @@ bool NormalizedType::hasTyvars() const
return !tyvars.empty(); return !tyvars.empty();
} }
bool NormalizedType::isFalsy() const
{
bool hasAFalse = false;
if (auto singleton = get<SingletonType>(booleans))
{
if (auto bs = singleton->variant.get_if<BooleanSingleton>())
hasAFalse = !bs->value;
}
return (hasAFalse || hasNils()) && (!hasTops() && !hasClasses() && !hasErrors() && !hasNumbers() && !hasStrings() && !hasThreads() &&
!hasBuffers() && !hasTables() && !hasFunctions() && !hasTyvars());
}
bool NormalizedType::isTruthy() const
{
return !isFalsy();
}
static bool isShallowInhabited(const NormalizedType& norm) static bool isShallowInhabited(const NormalizedType& norm)
{ {
// This test is just a shallow check, for example it returns `true` for `{ p : never }` // This test is just a shallow check, for example it returns `true` for `{ p : never }`

View File

@ -1,12 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/OverloadResolution.h" #include "Luau/OverloadResolution.h"
#include "Luau/Instantiation2.h"
#include "Luau/Subtyping.h" #include "Luau/Subtyping.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeFamily.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/TypeFamily.h" #include "Luau/Unifier2.h"
namespace Luau namespace Luau
{ {
@ -26,19 +28,28 @@ OverloadResolver::OverloadResolver(NotNull<BuiltinTypes> builtinTypes, NotNull<T
std::pair<OverloadResolver::Analysis, TypeId> OverloadResolver::selectOverload(TypeId ty, TypePackId argsPack) std::pair<OverloadResolver::Analysis, TypeId> OverloadResolver::selectOverload(TypeId ty, TypePackId argsPack)
{ {
auto tryOne = [&](TypeId f) {
if (auto ftv = get<FunctionType>(f))
{
SubtypingResult r = subtyping.isSubtype(argsPack, ftv->argTypes);
if (r.isSubtype)
return true;
}
return false;
};
TypeId t = follow(ty); TypeId t = follow(ty);
if (tryOne(ty))
return {Analysis::Ok, ty};
if (auto it = get<IntersectionType>(t)) if (auto it = get<IntersectionType>(t))
{ {
for (TypeId component : it) for (TypeId component : it)
{ {
if (auto ftv = get<FunctionType>(component)) if (tryOne(component))
{ return {Analysis::Ok, component};
SubtypingResult r = subtyping.isSubtype(argsPack, ftv->argTypes);
if (r.isSubtype)
return {Analysis::Ok, component};
}
else
continue;
} }
} }
@ -348,4 +359,63 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors)
} }
SolveResult solveFunctionCall(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter,
NotNull<TypeCheckLimits> limits,
NotNull<Scope> scope,
const Location& location,
TypeId fn,
TypePackId argsPack
)
{
OverloadResolver resolver{
builtinTypes, NotNull{arena}, normalizer, scope, iceReporter, limits, location};
auto [status, overload] = resolver.selectOverload(fn, argsPack);
TypeId overloadToUse = fn;
if (status == OverloadResolver::Analysis::Ok)
overloadToUse = overload;
else if (get<AnyType>(fn) || get<FreeType>(fn))
{
// Nothing. Let's keep going
}
else
return {SolveResult::NoMatchingOverload};
TypePackId resultPack = arena->freshTypePack(scope);
TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, scope.get(), argsPack, resultPack});
Unifier2 u2{NotNull{arena}, builtinTypes, scope, iceReporter};
const bool occursCheckPassed = u2.unify(overloadToUse, inferredTy);
if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty())
{
Instantiation2 instantiation{arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions)};
std::optional<TypePackId> subst = instantiation.substitute(resultPack);
if (!subst)
return {SolveResult::CodeTooComplex};
else
resultPack = *subst;
}
if (!occursCheckPassed)
return {SolveResult::OccursCheckFailed};
SolveResult result;
result.result = SolveResult::Ok;
result.typePackId = resultPack;
LUAU_ASSERT(overloadToUse);
result.overloadToUse = overloadToUse;
result.inferredTy = inferredTy;
result.expandedFreeTypes = std::move(u2.expandedFreeTypes);
return result;
}
} // namespace Luau } // namespace Luau

View File

@ -44,8 +44,9 @@ std::optional<std::pair<TypeId, Scope*>> Scope::lookupEx(DefId def)
while (true) while (true)
{ {
TypeId* it = s->lvalueTypes.find(def); if (TypeId* it = s->lvalueTypes.find(def))
if (it) return std::pair{*it, s};
else if (TypeId* it = s->rvalueRefinements.find(def))
return std::pair{*it, s}; return std::pair{*it, s};
if (s->parent) if (s->parent)

View File

@ -1128,6 +1128,9 @@ TypeId TypeSimplifier::intersect(TypeId left, TypeId right)
left = simplify(left); left = simplify(left);
right = simplify(right); right = simplify(right);
if (left == right)
return left;
if (get<AnyType>(left) && get<ErrorType>(right)) if (get<AnyType>(left) && get<ErrorType>(right))
return right; return right;
if (get<AnyType>(right) && get<ErrorType>(left)) if (get<AnyType>(right) && get<ErrorType>(left))

View File

@ -480,7 +480,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
// tested as though it were its upper bounds. We do not yet support bounded // tested as though it were its upper bounds. We do not yet support bounded
// generics, so the upper bound is always unknown. // generics, so the upper bound is always unknown.
if (auto subGeneric = get<GenericType>(subTy); subGeneric && subsumes(subGeneric->scope, scope)) if (auto subGeneric = get<GenericType>(subTy); subGeneric && subsumes(subGeneric->scope, scope))
return isCovariantWith(env, builtinTypes->unknownType, superTy); return isCovariantWith(env, builtinTypes->neverType, superTy);
if (auto superGeneric = get<GenericType>(superTy); superGeneric && subsumes(superGeneric->scope, scope)) if (auto superGeneric = get<GenericType>(superTy); superGeneric && subsumes(superGeneric->scope, scope))
return isCovariantWith(env, subTy, builtinTypes->unknownType); return isCovariantWith(env, subTy, builtinTypes->unknownType);
@ -1611,4 +1611,21 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse)
return arena->addType(T{std::vector<TypeId>(begin(container), end(container))}); return arena->addType(T{std::vector<TypeId>(begin(container), end(container))});
} }
std::pair<TypeId, ErrorVec> Subtyping::handleTypeFamilyReductionResult(const TypeFamilyInstanceType* familyInstance)
{
TypeFamilyContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}};
TypeId family = arena->addType(*familyInstance);
std::string familyString = toString(family);
FamilyGraphReductionResult result = reduceFamilies(family, {}, context, true);
ErrorVec errors;
if (result.blockedTypes.size() != 0 || result.blockedPacks.size() != 0)
{
errors.push_back(TypeError{{}, UninhabitedTypeFamily{family}});
return {builtinTypes->neverType, errors};
}
if (result.reducedTypes.contains(family))
return {family, errors};
return {builtinTypes->neverType, errors};
}
} // namespace Luau } // namespace Luau

View File

@ -20,7 +20,6 @@
#include <string> #include <string>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauToStringPrettifyLocation, false)
LUAU_FASTFLAGVARIABLE(LuauToStringSimpleCompositeTypesSingleLine, false) LUAU_FASTFLAGVARIABLE(LuauToStringSimpleCompositeTypesSingleLine, false)
/* /*
@ -1879,15 +1878,8 @@ std::string toString(const Position& position)
std::string toString(const Location& location, int offset, bool useBegin) std::string toString(const Location& location, int offset, bool useBegin)
{ {
if (FFlag::LuauToStringPrettifyLocation) return "(" + std::to_string(location.begin.line + offset) + ", " + std::to_string(location.begin.column + offset) + ") - (" +
{ std::to_string(location.end.line + offset) + ", " + std::to_string(location.end.column + offset) + ")";
return "(" + std::to_string(location.begin.line + offset) + ", " + std::to_string(location.begin.column + offset) + ") - (" +
std::to_string(location.end.line + offset) + ", " + std::to_string(location.end.column + offset) + ")";
}
else
{
return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }";
}
} }
std::string toString(const TypeOrPack& tyOrTp, ToStringOptions& opts) std::string toString(const TypeOrPack& tyOrTp, ToStringOptions& opts)

View File

@ -17,6 +17,7 @@
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeFamily.h" #include "Luau/TypeFamily.h"
#include "Luau/TypeFamilyReductionGuesser.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypePath.h" #include "Luau/TypePath.h"
@ -25,6 +26,8 @@
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
#include <algorithm> #include <algorithm>
#include <iostream>
#include <ostream>
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
@ -36,6 +39,7 @@ namespace Luau
using PrintLineProc = void (*)(const std::string&); using PrintLineProc = void (*)(const std::string&);
extern PrintLineProc luauPrintLine; extern PrintLineProc luauPrintLine;
/* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. /* Push a scope onto the end of a stack for the lifetime of the StackPusher instance.
* TypeChecker2 uses this to maintain knowledge about which scope encloses every * TypeChecker2 uses this to maintain knowledge about which scope encloses every
* given AstNode. * given AstNode.
@ -1271,13 +1275,13 @@ struct TypeChecker2
{ {
switch (shouldSuppressErrors(NotNull{&normalizer}, fnTy)) switch (shouldSuppressErrors(NotNull{&normalizer}, fnTy))
{ {
case ErrorSuppression::Suppress: case ErrorSuppression::Suppress:
break; break;
case ErrorSuppression::NormalizationFailed: case ErrorSuppression::NormalizationFailed:
reportError(NormalizationTooComplex{}, call->func->location); reportError(NormalizationTooComplex{}, call->func->location);
// fallthrough intentional // fallthrough intentional
case ErrorSuppression::DoNotSuppress: case ErrorSuppression::DoNotSuppress:
reportError(OptionalValueAccess{fnTy}, call->func->location); reportError(OptionalValueAccess{fnTy}, call->func->location);
} }
return; return;
} }
@ -1528,6 +1532,7 @@ struct TypeChecker2
functionDeclStack.push_back(inferredFnTy); functionDeclStack.push_back(inferredFnTy);
const NormalizedType* normalizedFnTy = normalizer.normalize(inferredFnTy); const NormalizedType* normalizedFnTy = normalizer.normalize(inferredFnTy);
const FunctionType* inferredFtv = get<FunctionType>(normalizedFnTy->functions.parts.front());
if (!normalizedFnTy) if (!normalizedFnTy)
{ {
reportError(CodeTooComplex{}, fn->location); reportError(CodeTooComplex{}, fn->location);
@ -1622,6 +1627,19 @@ struct TypeChecker2
if (fn->returnAnnotation) if (fn->returnAnnotation)
visit(*fn->returnAnnotation); visit(*fn->returnAnnotation);
// If the function type has a family annotation, we need to see if we can suggest an annotation
TypeFamilyReductionGuesser guesser{builtinTypes, NotNull{&normalizer}};
for (TypeId retTy : inferredFtv->retTypes)
{
if (get<TypeFamilyInstanceType>(follow(retTy)))
{
TypeFamilyReductionGuessResult result = guesser.guessTypeFamilyReductionForFunction(*fn, inferredFtv, retTy);
if (result.shouldRecommendAnnotation)
reportError(
ExplicitFunctionAnnotationRecommended{std::move(result.guessedFunctionAnnotations), result.guessedReturnType}, fn->location);
}
}
functionDeclStack.pop_back(); functionDeclStack.pop_back();
} }

View File

@ -8,6 +8,8 @@
#include "Luau/Instantiation.h" #include "Luau/Instantiation.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/OverloadResolution.h"
#include "Luau/Set.h"
#include "Luau/Simplify.h" #include "Luau/Simplify.h"
#include "Luau/Substitution.h" #include "Luau/Substitution.h"
#include "Luau/Subtyping.h" #include "Luau/Subtyping.h"
@ -19,7 +21,6 @@
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h" #include "Luau/Unifier2.h"
#include "Luau/VecDeque.h" #include "Luau/VecDeque.h"
#include "Luau/Set.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000);
@ -514,6 +515,18 @@ TypeFamilyReductionResult<TypeId> unmFamilyFn(
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
} }
NotNull<Constraint> TypeFamilyContext::pushConstraint(ConstraintV&& c)
{
NotNull<Constraint> newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c));
// Every constraint that is blocked on the current constraint must also be
// blocked on this new one.
if (constraint)
solver->inheritBlocks(NotNull{constraint}, newConstraint);
return newConstraint;
}
TypeFamilyReductionResult<TypeId> numericBinopFamilyFn(TypeId instance, const std::vector<TypeId>& typeParams, TypeFamilyReductionResult<TypeId> numericBinopFamilyFn(TypeId instance, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& packParams, NotNull<TypeFamilyContext> ctx, const std::string metamethod) const std::vector<TypePackId>& packParams, NotNull<TypeFamilyContext> ctx, const std::string metamethod)
{ {
@ -526,6 +539,8 @@ TypeFamilyReductionResult<TypeId> numericBinopFamilyFn(TypeId instance, const st
TypeId lhsTy = follow(typeParams.at(0)); TypeId lhsTy = follow(typeParams.at(0));
TypeId rhsTy = follow(typeParams.at(1)); TypeId rhsTy = follow(typeParams.at(1));
const Location location = ctx->constraint ? ctx->constraint->location : Location{};
// check to see if both operand types are resolved enough, and wait to reduce if not // check to see if both operand types are resolved enough, and wait to reduce if not
if (isPending(lhsTy, ctx->solver)) if (isPending(lhsTy, ctx->solver))
return {std::nullopt, false, {lhsTy}, {}}; return {std::nullopt, false, {lhsTy}, {}};
@ -555,11 +570,11 @@ TypeFamilyReductionResult<TypeId> numericBinopFamilyFn(TypeId instance, const st
// the necessary state to do that, even if we intend to just eat the errors. // the necessary state to do that, even if we intend to just eat the errors.
ErrorVec dummy; ErrorVec dummy;
std::optional<TypeId> mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, metamethod, Location{}); std::optional<TypeId> mmType = findMetatableEntry(ctx->builtins, dummy, lhsTy, metamethod, location);
bool reversed = false; bool reversed = false;
if (!mmType) if (!mmType)
{ {
mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, Location{}); mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, location);
reversed = true; reversed = true;
} }
@ -570,33 +585,26 @@ TypeFamilyReductionResult<TypeId> numericBinopFamilyFn(TypeId instance, const st
if (isPending(*mmType, ctx->solver)) if (isPending(*mmType, ctx->solver))
return {std::nullopt, false, {*mmType}, {}}; return {std::nullopt, false, {*mmType}, {}};
const FunctionType* mmFtv = get<FunctionType>(*mmType); TypePackId argPack = ctx->arena->addTypePack({lhsTy, rhsTy});
if (!mmFtv) SolveResult solveResult;
return {std::nullopt, true, {}, {}};
std::optional<TypeId> instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType);
if (!instantiatedMmType)
return {std::nullopt, true, {}, {}};
const FunctionType* instantiatedMmFtv = get<FunctionType>(*instantiatedMmType);
if (!instantiatedMmFtv)
return {ctx->builtins->errorRecoveryType(), false, {}, {}};
std::vector<TypeId> inferredArgs;
if (!reversed) if (!reversed)
inferredArgs = {lhsTy, rhsTy}; solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack);
else else
inferredArgs = {rhsTy, lhsTy}; {
TypePack* p = getMutable<TypePack>(argPack);
std::swap(p->head.front(), p->head.back());
solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack);
}
TypePackId inferredArgPack = ctx->arena->addTypePack(std::move(inferredArgs)); if (!solveResult.typePackId.has_value())
Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice};
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed
if (std::optional<TypeId> ret = first(instantiatedMmFtv->retTypes))
return {*ret, false, {}, {}};
else
return {std::nullopt, true, {}, {}}; return {std::nullopt, true, {}, {}};
TypePack extracted = extendTypePack(*ctx->arena, ctx->builtins, *solveResult.typePackId, 1);
if (extracted.head.empty())
return {std::nullopt, true, {}, {}};
return {extracted.head.front(), false, {}, {}};
} }
TypeFamilyReductionResult<TypeId> addFamilyFn( TypeFamilyReductionResult<TypeId> addFamilyFn(
@ -855,6 +863,11 @@ static TypeFamilyReductionResult<TypeId> comparisonFamilyFn(TypeId instance, con
TypeId lhsTy = follow(typeParams.at(0)); TypeId lhsTy = follow(typeParams.at(0));
TypeId rhsTy = follow(typeParams.at(1)); TypeId rhsTy = follow(typeParams.at(1));
if (isPending(lhsTy, ctx->solver))
return {std::nullopt, false, {lhsTy}, {}};
else if (isPending(rhsTy, ctx->solver))
return {std::nullopt, false, {rhsTy}, {}};
// Algebra Reduction Rules for comparison family functions // Algebra Reduction Rules for comparison family functions
// Note that comparing to never tells you nothing about the other operand // Note that comparing to never tells you nothing about the other operand
// lt< 'a , never> -> continue // lt< 'a , never> -> continue
@ -875,12 +888,12 @@ static TypeFamilyReductionResult<TypeId> comparisonFamilyFn(TypeId instance, con
asMutable(rhsTy)->ty.emplace<BoundType>(ctx->builtins->numberType); asMutable(rhsTy)->ty.emplace<BoundType>(ctx->builtins->numberType);
else if (lhsFree && get<NeverType>(rhsTy) == nullptr) else if (lhsFree && get<NeverType>(rhsTy) == nullptr)
{ {
auto c1 = ctx->solver->pushConstraint(ctx->scope, {}, EqualityConstraint{lhsTy, rhsTy}); auto c1 = ctx->pushConstraint(EqualityConstraint{lhsTy, rhsTy});
const_cast<Constraint*>(ctx->constraint)->dependencies.emplace_back(c1); const_cast<Constraint*>(ctx->constraint)->dependencies.emplace_back(c1);
} }
else if (rhsFree && get<NeverType>(lhsTy) == nullptr) else if (rhsFree && get<NeverType>(lhsTy) == nullptr)
{ {
auto c1 = ctx->solver->pushConstraint(ctx->scope, {}, EqualityConstraint{rhsTy, lhsTy}); auto c1 = ctx->pushConstraint(EqualityConstraint{rhsTy, lhsTy});
const_cast<Constraint*>(ctx->constraint)->dependencies.emplace_back(c1); const_cast<Constraint*>(ctx->constraint)->dependencies.emplace_back(c1);
} }
} }
@ -890,10 +903,6 @@ static TypeFamilyReductionResult<TypeId> comparisonFamilyFn(TypeId instance, con
rhsTy = follow(rhsTy); rhsTy = follow(rhsTy);
// check to see if both operand types are resolved enough, and wait to reduce if not // check to see if both operand types are resolved enough, and wait to reduce if not
if (isPending(lhsTy, ctx->solver))
return {std::nullopt, false, {lhsTy}, {}};
else if (isPending(rhsTy, ctx->solver))
return {std::nullopt, false, {rhsTy}, {}};
const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy); const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy);
const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy); const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy);

View File

@ -0,0 +1,409 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeFamilyReductionGuesser.h"
#include "Luau/DenseHash.h"
#include "Luau/Normalize.h"
#include "Luau/TypeFamily.h"
#include "Luau/Type.h"
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include "Luau/VecDeque.h"
#include "Luau/VisitType.h"
#include <iostream>
#include <ostream>
namespace Luau
{
struct InstanceCollector2 : TypeOnceVisitor
{
VecDeque<TypeId> tys;
VecDeque<TypePackId> tps;
DenseHashSet<TypeId> cyclicInstance{nullptr};
DenseHashSet<TypeId> instanceArguments{nullptr};
bool visit(TypeId ty, const TypeFamilyInstanceType& it) override
{
// TypeOnceVisitor performs a depth-first traversal in the absence of
// cycles. This means that by pushing to the front of the queue, we will
// try to reduce deeper instances first if we start with the first thing
// in the queue. Consider Add<Add<Add<number, number>, number>, number>:
// we want to reduce the innermost Add<number, number> instantiation
// first.
tys.push_front(ty);
for (auto t : it.typeArguments)
instanceArguments.insert(follow(t));
return true;
}
void cycle(TypeId ty) override
{
/// Detected cyclic type pack
TypeId t = follow(ty);
if (get<TypeFamilyInstanceType>(t))
cyclicInstance.insert(t);
}
bool visit(TypeId ty, const ClassType&) override
{
return false;
}
bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override
{
// TypeOnceVisitor performs a depth-first traversal in the absence of
// cycles. This means that by pushing to the front of the queue, we will
// try to reduce deeper instances first if we start with the first thing
// in the queue. Consider Add<Add<Add<number, number>, number>, number>:
// we want to reduce the innermost Add<number, number> instantiation
// first.
tps.push_front(tp);
return true;
}
};
TypeFamilyReductionGuesser::TypeFamilyReductionGuesser(NotNull<BuiltinTypes> builtins, NotNull<Normalizer> normalizer)
: builtins(builtins)
, normalizer(normalizer)
{
}
bool TypeFamilyReductionGuesser::isFunctionGenericsSaturated(const FunctionType& ftv, DenseHashSet<TypeId>& argsUsed)
{
bool sameSize = ftv.generics.size() == argsUsed.size();
bool allGenericsAppear = true;
for (auto gt : ftv.generics)
allGenericsAppear = allGenericsAppear || argsUsed.contains(gt);
return sameSize && allGenericsAppear;
}
void TypeFamilyReductionGuesser::dumpGuesses()
{
for (auto [tf, t] : familyReducesTo)
printf("Type family %s ~~> %s\n", toString(tf).c_str(), toString(t).c_str());
for (auto [t, t_] : substitutable)
printf("Substitute %s for %s\n", toString(t).c_str(), toString(t_).c_str());
}
TypeFamilyReductionGuessResult TypeFamilyReductionGuesser::guessTypeFamilyReductionForFunction(
const AstExprFunction& expr, const FunctionType* ftv, TypeId retTy)
{
InstanceCollector2 collector;
collector.traverse(retTy);
toInfer = std::move(collector.tys);
cyclicInstances = std::move(collector.cyclicInstance);
if (isFunctionGenericsSaturated(*ftv, collector.instanceArguments))
return TypeFamilyReductionGuessResult{{}, nullptr, false};
infer();
std::vector<std::pair<std::string, TypeId>> results;
std::vector<TypeId> args;
for (TypeId t : ftv->argTypes)
args.push_back(t);
// Submit a guess for arg types
for (size_t i = 0; i < expr.args.size; i++)
{
TypeId argTy;
AstLocal* local = expr.args.data[i];
if (i >= args.size())
continue;
argTy = args[i];
std::optional<TypeId> guessedType = guessType(argTy);
if (!guessedType.has_value())
continue;
TypeId guess = follow(*guessedType);
if (get<TypeFamilyInstanceType>(guess))
continue;
results.push_back({local->name.value, guess});
}
// Submit a guess for return types
TypeId recommendedAnnotation;
std::optional<TypeId> guessedReturnType = guessType(retTy);
if (!guessedReturnType.has_value())
recommendedAnnotation = builtins->unknownType;
else
recommendedAnnotation = follow(*guessedReturnType);
if (auto t = get<TypeFamilyInstanceType>(recommendedAnnotation))
recommendedAnnotation = builtins->unknownType;
toInfer.clear();
cyclicInstances.clear();
familyReducesTo.clear();
substitutable.clear();
return TypeFamilyReductionGuessResult{results, recommendedAnnotation};
}
std::optional<TypeId> TypeFamilyReductionGuesser::guessType(TypeId arg)
{
TypeId t = follow(arg);
if (substitutable.contains(t))
{
TypeId subst = follow(substitutable[t]);
if (subst == t || substitutable.contains(subst))
return subst;
else if (!get<TypeFamilyInstanceType>(subst))
return subst;
else
return guessType(subst);
}
if (get<TypeFamilyInstanceType>(t))
{
if (familyReducesTo.contains(t))
return familyReducesTo[t];
}
return {};
}
bool TypeFamilyReductionGuesser::isNumericBinopFamily(const TypeFamilyInstanceType& instance)
{
return instance.family->name == "add" || instance.family->name == "sub" || instance.family->name == "mul" || instance.family->name == "div" ||
instance.family->name == "idiv" || instance.family->name == "pow" || instance.family->name == "mod";
}
bool TypeFamilyReductionGuesser::isComparisonFamily(const TypeFamilyInstanceType& instance)
{
return instance.family->name == "lt" || instance.family->name == "le" || instance.family->name == "eq";
}
bool TypeFamilyReductionGuesser::isOrAndFamily(const TypeFamilyInstanceType& instance)
{
return instance.family->name == "or" || instance.family->name == "and";
}
bool TypeFamilyReductionGuesser::isNotFamily(const TypeFamilyInstanceType& instance)
{
return instance.family->name == "not";
}
bool TypeFamilyReductionGuesser::isLenFamily(const TypeFamilyInstanceType& instance)
{
return instance.family->name == "len";
}
bool TypeFamilyReductionGuesser::isUnaryMinus(const TypeFamilyInstanceType& instance)
{
return instance.family->name == "unm";
}
// Operand is assignable if it looks like a cyclic family instance, or a generic type
bool TypeFamilyReductionGuesser::operandIsAssignable(TypeId ty)
{
if (get<TypeFamilyInstanceType>(ty))
return true;
if (get<GenericType>(ty))
return true;
if (cyclicInstances.contains(ty))
return true;
return false;
}
const NormalizedType* TypeFamilyReductionGuesser::normalize(TypeId ty)
{
return normalizer->normalize(ty);
}
std::optional<TypeId> TypeFamilyReductionGuesser::tryAssignOperandType(TypeId ty)
{
// Because we collect innermost instances first, if we see a typefamily instance as an operand,
// We try to check if we guessed a type for it
if (auto tfit = get<TypeFamilyInstanceType>(ty))
{
if (familyReducesTo.contains(ty))
return {familyReducesTo[ty]};
}
// If ty is a generic, we need to check if we inferred a substitution
if (auto gt = get<GenericType>(ty))
{
if (substitutable.contains(ty))
return {substitutable[ty]};
}
// If we cannot substitute a type for this value, we return an empty optional
return {};
}
void TypeFamilyReductionGuesser::step()
{
TypeId t = toInfer.front();
toInfer.pop_front();
t = follow(t);
if (auto tf = get<TypeFamilyInstanceType>(t))
inferTypeFamilySubstitutions(t, tf);
}
void TypeFamilyReductionGuesser::infer()
{
while (!done())
step();
}
bool TypeFamilyReductionGuesser::done()
{
return toInfer.empty();
}
void TypeFamilyReductionGuesser::inferTypeFamilySubstitutions(TypeId ty, const TypeFamilyInstanceType* instance)
{
TypeFamilyInferenceResult result;
LUAU_ASSERT(instance);
// TODO: Make an inexhaustive version of this warn in the compiler?
if (isNumericBinopFamily(*instance))
result = inferNumericBinopFamily(instance);
else if (isComparisonFamily(*instance))
result = inferComparisonFamily(instance);
else if (isOrAndFamily(*instance))
result = inferOrAndFamily(instance);
else if (isNotFamily(*instance))
result = inferNotFamily(instance);
else if (isLenFamily(*instance))
result = inferLenFamily(instance);
else if (isUnaryMinus(*instance))
result = inferUnaryMinusFamily(instance);
else
result = {{}, builtins->unknownType};
TypeId resultInference = follow(result.familyResultInference);
if (!familyReducesTo.contains(resultInference))
familyReducesTo[ty] = resultInference;
for (size_t i = 0; i < instance->typeArguments.size(); i++)
{
if (i < result.operandInference.size())
{
TypeId arg = follow(instance->typeArguments[i]);
TypeId inference = follow(result.operandInference[i]);
if (auto tfit = get<TypeFamilyInstanceType>(arg))
{
if (!familyReducesTo.contains(arg))
familyReducesTo.try_insert(arg, inference);
}
else if (auto gt = get<GenericType>(arg))
substitutable[arg] = inference;
}
}
}
TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferNumericBinopFamily(const TypeFamilyInstanceType* instance)
{
LUAU_ASSERT(instance->typeArguments.size() == 2);
TypeFamilyInferenceResult defaultNumericBinopInference{{builtins->numberType, builtins->numberType}, builtins->numberType};
return defaultNumericBinopInference;
}
TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferComparisonFamily(const TypeFamilyInstanceType* instance)
{
LUAU_ASSERT(instance->typeArguments.size() == 2);
// Comparison families are lt/le/eq.
// Heuristic: these are type functions from t -> t -> bool
TypeId lhsTy = follow(instance->typeArguments[0]);
TypeId rhsTy = follow(instance->typeArguments[1]);
auto comparisonInference = [&](TypeId op) -> TypeFamilyInferenceResult {
return TypeFamilyInferenceResult{{op, op}, builtins->booleanType};
};
if (std::optional<TypeId> ty = tryAssignOperandType(lhsTy))
lhsTy = follow(*ty);
if (std::optional<TypeId> ty = tryAssignOperandType(rhsTy))
rhsTy = follow(*ty);
if (operandIsAssignable(lhsTy) && !operandIsAssignable(rhsTy))
return comparisonInference(rhsTy);
if (operandIsAssignable(rhsTy) && !operandIsAssignable(lhsTy))
return comparisonInference(lhsTy);
return comparisonInference(builtins->numberType);
}
TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferOrAndFamily(const TypeFamilyInstanceType* instance)
{
LUAU_ASSERT(instance->typeArguments.size() == 2);
TypeId lhsTy = follow(instance->typeArguments[0]);
TypeId rhsTy = follow(instance->typeArguments[1]);
if (std::optional<TypeId> ty = tryAssignOperandType(lhsTy))
lhsTy = follow(*ty);
if (std::optional<TypeId> ty = tryAssignOperandType(rhsTy))
rhsTy = follow(*ty);
TypeFamilyInferenceResult defaultAndOrInference{{builtins->unknownType, builtins->unknownType}, builtins->booleanType};
const NormalizedType* lty = normalize(lhsTy);
const NormalizedType* rty = normalize(lhsTy);
bool lhsTruthy = lty ? lty->isTruthy() : false;
bool rhsTruthy = rty ? rty->isTruthy() : false;
// If at the end, we still don't have good substitutions, return the default type
if (instance->family->name == "or")
{
if (operandIsAssignable(lhsTy) && operandIsAssignable(rhsTy))
return defaultAndOrInference;
if (operandIsAssignable(lhsTy))
return TypeFamilyInferenceResult{{builtins->unknownType, rhsTy}, rhsTy};
if (operandIsAssignable(rhsTy))
return TypeFamilyInferenceResult{{lhsTy, builtins->unknownType}, lhsTy};
if (lhsTruthy)
return {{lhsTy, rhsTy}, lhsTy};
if (rhsTruthy)
return {{builtins->unknownType, rhsTy}, rhsTy};
}
if (instance->family->name == "and")
{
if (operandIsAssignable(lhsTy) && operandIsAssignable(rhsTy))
return defaultAndOrInference;
if (operandIsAssignable(lhsTy))
return TypeFamilyInferenceResult{{}, rhsTy};
if (operandIsAssignable(rhsTy))
return TypeFamilyInferenceResult{{}, lhsTy};
if (lhsTruthy)
return {{lhsTy, rhsTy}, rhsTy};
else
return {{lhsTy, rhsTy}, lhsTy};
}
return defaultAndOrInference;
}
TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferNotFamily(const TypeFamilyInstanceType* instance)
{
LUAU_ASSERT(instance->typeArguments.size() == 1);
TypeId opTy = follow(instance->typeArguments[0]);
if (std::optional<TypeId> ty = tryAssignOperandType(opTy))
opTy = follow(*ty);
return {{opTy}, builtins->booleanType};
}
TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferLenFamily(const TypeFamilyInstanceType* instance)
{
LUAU_ASSERT(instance->typeArguments.size() == 1);
TypeId opTy = follow(instance->typeArguments[0]);
if (std::optional<TypeId> ty = tryAssignOperandType(opTy))
opTy = follow(*ty);
return {{opTy}, builtins->numberType};
}
TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferUnaryMinusFamily(const TypeFamilyInstanceType* instance)
{
LUAU_ASSERT(instance->typeArguments.size() == 1);
TypeId opTy = follow(instance->typeArguments[0]);
if (std::optional<TypeId> ty = tryAssignOperandType(opTy))
opTy = follow(*ty);
if (isNumber(opTy))
return {{builtins->numberType}, builtins->numberType};
return {{builtins->unknownType}, builtins->numberType};
}
} // namespace Luau

View File

@ -106,6 +106,18 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
if (subFree && superFree) if (subFree && superFree)
{ {
DenseHashSet<TypeId> seen{nullptr};
if (OccursCheckResult::Fail == occursCheck(seen, subTy, superTy))
{
asMutable(subTy)->ty.emplace<BoundType>(builtinTypes->errorRecoveryType());
return false;
}
else if (OccursCheckResult::Fail == occursCheck(seen, superTy, subTy))
{
asMutable(subTy)->ty.emplace<BoundType>(builtinTypes->errorRecoveryType());
return false;
}
superFree->lowerBound = mkUnion(subFree->lowerBound, superFree->lowerBound); superFree->lowerBound = mkUnion(subFree->lowerBound, superFree->lowerBound);
superFree->upperBound = mkIntersection(subFree->upperBound, superFree->upperBound); superFree->upperBound = mkIntersection(subFree->upperBound, superFree->upperBound);
asMutable(subTy)->ty.emplace<BoundType>(superTy); asMutable(subTy)->ty.emplace<BoundType>(superTy);
@ -821,6 +833,53 @@ TypeId Unifier2::mkIntersection(TypeId left, TypeId right)
return simplifyIntersection(builtinTypes, arena, left, right).result; return simplifyIntersection(builtinTypes, arena, left, right).result;
} }
OccursCheckResult Unifier2::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack)
{
RecursionLimiter _ra(&recursionCount, recursionLimit);
OccursCheckResult occurrence = OccursCheckResult::Pass;
auto check = [&](TypeId ty) {
if (occursCheck(seen, needle, ty) == OccursCheckResult::Fail)
occurrence = OccursCheckResult::Fail;
};
needle = follow(needle);
haystack = follow(haystack);
if (seen.find(haystack))
return OccursCheckResult::Pass;
seen.insert(haystack);
if (get<ErrorType>(needle))
return OccursCheckResult::Pass;
if (!get<FreeType>(needle))
ice->ice("Expected needle to be free");
if (needle == haystack)
return OccursCheckResult::Fail;
if (auto haystackFree = get<FreeType>(haystack))
{
check(haystackFree->lowerBound);
check(haystackFree->upperBound);
}
else if (auto ut = get<UnionType>(haystack))
{
for (TypeId ty : ut->options)
check(ty);
}
else if (auto it = get<IntersectionType>(haystack))
{
for (TypeId ty : it->parts)
check(ty);
}
return occurrence;
}
OccursCheckResult Unifier2::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack) OccursCheckResult Unifier2::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack)
{ {
needle = follow(needle); needle = follow(needle);

View File

@ -398,6 +398,7 @@ enum class IrCmd : uint8_t
// A, B: tag // A, B: tag
// C: block/vmexit/undef // C: block/vmexit/undef
// In final x64 lowering, A can also be Rn // In final x64 lowering, A can also be Rn
// When DebugLuauAbortingChecks flag is enabled, A can also be Rn
// When undef is specified instead of a block, execution is aborted on check failure // When undef is specified instead of a block, execution is aborted on check failure
CHECK_TAG, CHECK_TAG,

View File

@ -4,6 +4,8 @@
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/IrData.h" #include "Luau/IrData.h"
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores2)
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -186,7 +188,15 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
visitor.def(inst.b); visitor.def(inst.b);
break; break;
case IrCmd::FALLBACK_FORGPREP: case IrCmd::FALLBACK_FORGPREP:
visitor.use(inst.b); if (FFlag::LuauCodegenRemoveDeadStores2)
{
// This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use
visitor.useRange(vmRegOp(inst.b), 3);
}
else
{
visitor.use(inst.b);
}
visitor.defRange(vmRegOp(inst.b), 3); visitor.defRange(vmRegOp(inst.b), 3);
break; break;
@ -204,6 +214,11 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
visitor.use(inst.a); visitor.use(inst.a);
break; break;
// After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG Rn, tag, block instructions are generated
case IrCmd::CHECK_TAG:
visitor.maybeUse(inst.a);
break;
default: default:
// All instructions which reference registers have to be handled explicitly // All instructions which reference registers have to be handled explicitly
CODEGEN_ASSERT(inst.a.kind != IrOpKind::VmReg); CODEGEN_ASSERT(inst.a.kind != IrOpKind::VmReg);

View File

@ -0,0 +1,16 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/IrData.h"
namespace Luau
{
namespace CodeGen
{
struct IrBuilder;
void markDeadStoresInBlockChains(IrBuilder& build);
} // namespace CodeGen
} // namespace Luau

View File

@ -27,6 +27,10 @@ const size_t kPageSize = sysconf(_SC_PAGESIZE);
#endif #endif
#endif #endif
#ifdef __APPLE__
extern "C" void sys_icache_invalidate(void* start, size_t len);
#endif
static size_t alignToPageSize(size_t size) static size_t alignToPageSize(size_t size)
{ {
return (size + kPageSize - 1) & ~(kPageSize - 1); return (size + kPageSize - 1) & ~(kPageSize - 1);
@ -98,7 +102,11 @@ static void makePagesExecutable(uint8_t* mem, size_t size)
static void flushInstructionCache(uint8_t* mem, size_t size) static void flushInstructionCache(uint8_t* mem, size_t size)
{ {
#ifdef __APPLE__
sys_icache_invalidate(mem, size);
#else
__builtin___clear_cache((char*)mem, (char*)mem + size); __builtin___clear_cache((char*)mem, (char*)mem + size);
#endif
} }
#endif #endif

View File

@ -8,6 +8,7 @@
#include "Luau/IrDump.h" #include "Luau/IrDump.h"
#include "Luau/IrUtils.h" #include "Luau/IrUtils.h"
#include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeConstProp.h"
#include "Luau/OptimizeDeadStore.h"
#include "Luau/OptimizeFinalX64.h" #include "Luau/OptimizeFinalX64.h"
#include "EmitCommon.h" #include "EmitCommon.h"
@ -26,6 +27,7 @@ LUAU_FASTFLAG(DebugCodegenSkipNumbering)
LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsInstructionLimit)
LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit)
LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit)
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores2)
namespace Luau namespace Luau
{ {
@ -309,6 +311,9 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers&
stats->blockLinearizationStats.constPropInstructionCount += constPropInstructionCount; stats->blockLinearizationStats.constPropInstructionCount += constPropInstructionCount;
} }
} }
if (FFlag::LuauCodegenRemoveDeadStores2)
markDeadStoresInBlockChains(ir);
} }
std::vector<uint32_t> sortedBlocks = getSortedBlockOrder(ir.function); std::vector<uint32_t> sortedBlocks = getSortedBlockOrder(ir.function);

View File

@ -17,6 +17,7 @@
LUAU_FASTFLAG(LuauCodegenVectorTag2) LUAU_FASTFLAG(LuauCodegenVectorTag2)
LUAU_FASTFLAGVARIABLE(LuauCodegenVectorOptAnd, false) LUAU_FASTFLAGVARIABLE(LuauCodegenVectorOptAnd, false)
LUAU_FASTFLAGVARIABLE(LuauCodegenSmallerUnm, false)
namespace Luau namespace Luau
{ {
@ -542,18 +543,24 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a});
RegisterX64 src = regOp(inst.a); if (FFlag::LuauCodegenSmallerUnm)
if (inst.regX64 == src)
{ {
build.vxorpd(inst.regX64, inst.regX64, build.f64(-0.0)); build.vxorpd(inst.regX64, regOp(inst.a), build.f64(-0.0));
} }
else else
{ {
build.vmovsd(inst.regX64, src, src); RegisterX64 src = regOp(inst.a);
build.vxorpd(inst.regX64, inst.regX64, build.f64(-0.0));
}
if (inst.regX64 == src)
{
build.vxorpd(inst.regX64, inst.regX64, build.f64(-0.0));
}
else
{
build.vmovsd(inst.regX64, src, src);
build.vxorpd(inst.regX64, inst.regX64, build.f64(-0.0));
}
}
break; break;
} }
case IrCmd::FLOOR_NUM: case IrCmd::FLOOR_NUM:
@ -604,13 +611,26 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
ScopedRegX64 tmp1{regs}; if (FFlag::LuauCodegenVectorOptAnd)
ScopedRegX64 tmp2{regs}; {
ScopedRegX64 tmp1{regs};
ScopedRegX64 tmp2{regs};
RegisterX64 tmpa = vecOp(inst.a, tmp1); RegisterX64 tmpa = vecOp(inst.a, tmp1);
RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2);
build.vaddps(inst.regX64, tmpa, tmpb); build.vaddps(inst.regX64, tmpa, tmpb);
}
else
{
ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
// Fourth component is the tag number which is interpreted as a denormal and has to be filtered out
build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp());
build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp());
build.vaddps(inst.regX64, tmp1.reg, tmp2.reg);
}
if (!FFlag::LuauCodegenVectorTag2) if (!FFlag::LuauCodegenVectorTag2)
build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp()); build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp());
@ -620,13 +640,27 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
ScopedRegX64 tmp1{regs}; if (FFlag::LuauCodegenVectorOptAnd)
ScopedRegX64 tmp2{regs}; {
ScopedRegX64 tmp1{regs};
ScopedRegX64 tmp2{regs};
RegisterX64 tmpa = vecOp(inst.a, tmp1); RegisterX64 tmpa = vecOp(inst.a, tmp1);
RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2);
build.vsubps(inst.regX64, tmpa, tmpb);
}
else
{
ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
// Fourth component is the tag number which is interpreted as a denormal and has to be filtered out
build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp());
build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp());
build.vsubps(inst.regX64, tmp1.reg, tmp2.reg);
}
build.vsubps(inst.regX64, tmpa, tmpb);
if (!FFlag::LuauCodegenVectorTag2) if (!FFlag::LuauCodegenVectorTag2)
build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp()); build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp());
break; break;
@ -635,13 +669,27 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
ScopedRegX64 tmp1{regs}; if (FFlag::LuauCodegenVectorOptAnd)
ScopedRegX64 tmp2{regs}; {
ScopedRegX64 tmp1{regs};
ScopedRegX64 tmp2{regs};
RegisterX64 tmpa = vecOp(inst.a, tmp1); RegisterX64 tmpa = vecOp(inst.a, tmp1);
RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2);
build.vmulps(inst.regX64, tmpa, tmpb);
}
else
{
ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
// Fourth component is the tag number which is interpreted as a denormal and has to be filtered out
build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp());
build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp());
build.vmulps(inst.regX64, tmp1.reg, tmp2.reg);
}
build.vmulps(inst.regX64, tmpa, tmpb);
if (!FFlag::LuauCodegenVectorTag2) if (!FFlag::LuauCodegenVectorTag2)
build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp()); build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp());
break; break;
@ -650,13 +698,27 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
ScopedRegX64 tmp1{regs}; if (FFlag::LuauCodegenVectorOptAnd)
ScopedRegX64 tmp2{regs}; {
ScopedRegX64 tmp1{regs};
ScopedRegX64 tmp2{regs};
RegisterX64 tmpa = vecOp(inst.a, tmp1); RegisterX64 tmpa = vecOp(inst.a, tmp1);
RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2);
build.vdivps(inst.regX64, tmpa, tmpb);
}
else
{
ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
// Fourth component is the tag number which is interpreted as a denormal and has to be filtered out
build.vandps(tmp1.reg, regOp(inst.a), vectorAndMaskOp());
build.vandps(tmp2.reg, regOp(inst.b), vectorAndMaskOp());
build.vdivps(inst.regX64, tmp1.reg, tmp2.reg);
}
build.vdivps(inst.regX64, tmpa, tmpb);
if (!FFlag::LuauCodegenVectorTag2) if (!FFlag::LuauCodegenVectorTag2)
build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3); build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3);
break; break;
@ -665,16 +727,23 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a}); inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a});
RegisterX64 src = regOp(inst.a); if (FFlag::LuauCodegenSmallerUnm)
if (inst.regX64 == src)
{ {
build.vxorpd(inst.regX64, inst.regX64, build.f32x4(-0.0, -0.0, -0.0, -0.0)); build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0));
} }
else else
{ {
build.vmovsd(inst.regX64, src, src); RegisterX64 src = regOp(inst.a);
build.vxorpd(inst.regX64, inst.regX64, build.f32x4(-0.0, -0.0, -0.0, -0.0));
if (inst.regX64 == src)
{
build.vxorpd(inst.regX64, inst.regX64, build.f32x4(-0.0, -0.0, -0.0, -0.0));
}
else
{
build.vmovsd(inst.regX64, src, src);
build.vxorpd(inst.regX64, inst.regX64, build.f32x4(-0.0, -0.0, -0.0, -0.0));
}
} }
if (!FFlag::LuauCodegenVectorTag2) if (!FFlag::LuauCodegenVectorTag2)
@ -2299,6 +2368,7 @@ OperandX64 IrLoweringX64::vectorAndMaskOp()
OperandX64 IrLoweringX64::vectorOrMaskOp() OperandX64 IrLoweringX64::vectorOrMaskOp()
{ {
CODEGEN_ASSERT(!FFlag::LuauCodegenVectorTag2); CODEGEN_ASSERT(!FFlag::LuauCodegenVectorTag2);
if (vectorOrMask.base == noreg) if (vectorOrMask.base == noreg)
vectorOrMask = build.u32x4(0, 0, 0, LUA_TVECTOR); vectorOrMask = build.u32x4(0, 0, 0, LUA_TVECTOR);

View File

@ -0,0 +1,530 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/OptimizeDeadStore.h"
#include "Luau/IrBuilder.h"
#include "Luau/IrVisitUseDef.h"
#include "Luau/IrUtils.h"
#include <array>
#include "lobject.h"
LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores2, false)
LUAU_FASTFLAG(LuauCodegenVectorTag2)
// TODO: optimization can be improved by knowing which registers are live in at each VM exit
namespace Luau
{
namespace CodeGen
{
// Luau value structure reminder:
// [ TValue ]
// [ Value ][ Extra ][ Tag ]
// Storing individual components will not kill any previous TValue stores
// Storing TValue will kill any full store or a component store ('extra' excluded because it's rare)
struct StoreRegInfo
{
// Indices of the last unused store instructions
uint32_t tagInstIdx = ~0u;
uint32_t valueInstIdx = ~0u;
uint32_t tvalueInstIdx = ~0u;
// This register might contain a GC object
bool maybeGco = false;
};
struct RemoveDeadStoreState
{
RemoveDeadStoreState(IrFunction& function)
: function(function)
{
maxReg = function.proto ? function.proto->maxstacksize : 255;
}
void killTagStore(StoreRegInfo& regInfo)
{
if (regInfo.tagInstIdx != ~0u)
{
kill(function, function.instructions[regInfo.tagInstIdx]);
regInfo.tagInstIdx = ~0u;
regInfo.maybeGco = false;
}
}
void killValueStore(StoreRegInfo& regInfo)
{
if (regInfo.valueInstIdx != ~0u)
{
kill(function, function.instructions[regInfo.valueInstIdx]);
regInfo.valueInstIdx = ~0u;
regInfo.maybeGco = false;
}
}
void killTValueStore(StoreRegInfo& regInfo)
{
if (regInfo.tvalueInstIdx != ~0u)
{
kill(function, function.instructions[regInfo.tvalueInstIdx]);
regInfo.tvalueInstIdx = ~0u;
regInfo.maybeGco = false;
}
}
// When a register value is being defined, it kills previous stores
void defReg(uint8_t reg)
{
StoreRegInfo& regInfo = info[reg];
// Stores to captured registers are not removed since we don't track their uses outside of function
if (function.cfg.captured.regs.test(reg))
return;
killTagStore(regInfo);
killValueStore(regInfo);
killTValueStore(regInfo);
}
// When a register value is being used, we forget about the last store location to not kill them
void useReg(uint8_t reg)
{
info[reg] = StoreRegInfo{};
}
// When checking control flow, such as exit to fallback blocks:
// For VM exits, we keep all stores because we don't have information on what registers are live at the start of the VM assist
// For regular blocks, we check which registers are expected to be live at entry (if we have CFG information available)
void checkLiveIns(IrOp op)
{
if (op.kind == IrOpKind::VmExit)
{
clear();
}
else if (op.kind == IrOpKind::Block)
{
if (op.index < function.cfg.in.size())
{
const RegisterSet& in = function.cfg.in[op.index];
for (int i = 0; i <= maxReg; i++)
{
if (in.regs.test(i) || (in.varargSeq && i >= in.varargStart))
useReg(i);
}
}
else
{
clear();
}
}
else if (op.kind == IrOpKind::Undef)
{
// Nothing to do for a debug abort
}
else
{
CODEGEN_ASSERT(!"unexpected jump target type");
}
}
// When checking block terminators, any registers that are not live out can be removed by saying that a new value is being 'defined'
void checkLiveOuts(const IrBlock& block)
{
uint32_t index = function.getBlockIndex(block);
if (index < function.cfg.out.size())
{
const RegisterSet& out = function.cfg.out[index];
for (int i = 0; i <= maxReg; i++)
{
bool isOut = out.regs.test(i) || (out.varargSeq && i >= out.varargStart);
if (!isOut)
defReg(i);
}
}
}
// Common instruction visitor handling
void defVarargs(uint8_t varargStart)
{
for (int i = varargStart; i <= maxReg; i++)
defReg(uint8_t(i));
}
void useVarargs(uint8_t varargStart)
{
for (int i = varargStart; i <= maxReg; i++)
useReg(uint8_t(i));
}
void def(IrOp op, int offset = 0)
{
defReg(vmRegOp(op) + offset);
}
void use(IrOp op, int offset = 0)
{
useReg(vmRegOp(op) + offset);
}
void maybeDef(IrOp op)
{
if (op.kind == IrOpKind::VmReg)
defReg(vmRegOp(op));
}
void maybeUse(IrOp op)
{
if (op.kind == IrOpKind::VmReg)
useReg(vmRegOp(op));
}
void defRange(int start, int count)
{
if (count == -1)
{
defVarargs(start);
}
else
{
for (int i = start; i < start + count; i++)
defReg(i);
}
}
void useRange(int start, int count)
{
if (count == -1)
{
useVarargs(start);
}
else
{
for (int i = start; i < start + count; i++)
useReg(i);
}
}
// Required for a full visitor interface
void capture(int reg) {}
// Full clear of the tracked information
void clear()
{
for (int i = 0; i <= maxReg; i++)
info[i] = StoreRegInfo();
hasGcoToClear = false;
}
// Partial clear of information about registers that might contain a GC object
// This is used by instructions that might perform a GC assist and GC needs all pointers to be pinned to stack
void flushGcoRegs()
{
for (int i = 0; i <= maxReg; i++)
{
if (info[i].maybeGco)
info[i] = StoreRegInfo();
}
hasGcoToClear = false;
}
IrFunction& function;
std::array<StoreRegInfo, 256> info;
int maxReg = 255;
// Some of the registers contain values which might be a GC object
bool hasGcoToClear = false;
};
static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, IrFunction& function, IrBlock& block, IrInst& inst, uint32_t index)
{
switch (inst.cmd)
{
case IrCmd::STORE_TAG:
if (inst.a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(inst.a);
if (function.cfg.captured.regs.test(reg))
return;
StoreRegInfo& regInfo = state.info[reg];
state.killTagStore(regInfo);
uint8_t tag = function.tagOp(inst.b);
regInfo.tagInstIdx = index;
regInfo.maybeGco = isGCO(tag);
state.hasGcoToClear |= regInfo.maybeGco;
}
break;
case IrCmd::STORE_EXTRA:
// To simplify, extra field store is preserved along with all other stores made so far
if (inst.a.kind == IrOpKind::VmReg)
{
state.useReg(vmRegOp(inst.a));
}
break;
case IrCmd::STORE_POINTER:
if (inst.a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(inst.a);
if (function.cfg.captured.regs.test(reg))
return;
StoreRegInfo& regInfo = state.info[reg];
state.killValueStore(regInfo);
regInfo.valueInstIdx = index;
regInfo.maybeGco = true;
state.hasGcoToClear = true;
}
break;
case IrCmd::STORE_DOUBLE:
case IrCmd::STORE_INT:
case IrCmd::STORE_VECTOR:
if (inst.a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(inst.a);
if (function.cfg.captured.regs.test(reg))
return;
StoreRegInfo& regInfo = state.info[reg];
state.killValueStore(regInfo);
regInfo.valueInstIdx = index;
}
break;
case IrCmd::STORE_TVALUE:
if (inst.a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(inst.a);
if (function.cfg.captured.regs.test(reg))
return;
StoreRegInfo& regInfo = state.info[reg];
state.killTagStore(regInfo);
state.killValueStore(regInfo);
state.killTValueStore(regInfo);
regInfo.tvalueInstIdx = index;
regInfo.maybeGco = true;
// If the argument is a vector, it's not a GC object
// Note that for known boolean/number/GCO, we already optimize into STORE_SPLIT_TVALUE form
// TODO: this can be removed if TAG_VECTOR+STORE_TVALUE is replaced with STORE_SPLIT_TVALUE
if (IrInst* arg = function.asInstOp(inst.b))
{
if (FFlag::LuauCodegenVectorTag2)
{
if (arg->cmd == IrCmd::TAG_VECTOR)
regInfo.maybeGco = false;
}
else
{
if (arg->cmd == IrCmd::ADD_VEC || arg->cmd == IrCmd::SUB_VEC || arg->cmd == IrCmd::MUL_VEC || arg->cmd == IrCmd::DIV_VEC ||
arg->cmd == IrCmd::UNM_VEC)
regInfo.maybeGco = false;
}
}
state.hasGcoToClear |= regInfo.maybeGco;
}
break;
case IrCmd::STORE_SPLIT_TVALUE:
if (inst.a.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(inst.a);
if (function.cfg.captured.regs.test(reg))
return;
StoreRegInfo& regInfo = state.info[reg];
state.killTagStore(regInfo);
state.killValueStore(regInfo);
state.killTValueStore(regInfo);
regInfo.tvalueInstIdx = index;
regInfo.maybeGco = isGCO(function.tagOp(inst.b));
state.hasGcoToClear |= regInfo.maybeGco;
}
break;
// Guard checks can jump to a block which might be using some or all the values we stored
case IrCmd::CHECK_TAG:
// After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG might use a VM register
visitVmRegDefsUses(state, function, inst);
state.checkLiveIns(inst.c);
break;
case IrCmd::TRY_NUM_TO_INDEX:
state.checkLiveIns(inst.b);
break;
case IrCmd::TRY_CALL_FASTGETTM:
state.checkLiveIns(inst.c);
break;
case IrCmd::CHECK_FASTCALL_RES:
state.checkLiveIns(inst.b);
break;
case IrCmd::CHECK_TRUTHY:
state.checkLiveIns(inst.c);
break;
case IrCmd::CHECK_READONLY:
state.checkLiveIns(inst.b);
break;
case IrCmd::CHECK_NO_METATABLE:
state.checkLiveIns(inst.b);
break;
case IrCmd::CHECK_SAFE_ENV:
state.checkLiveIns(inst.a);
break;
case IrCmd::CHECK_ARRAY_SIZE:
state.checkLiveIns(inst.c);
break;
case IrCmd::CHECK_SLOT_MATCH:
state.checkLiveIns(inst.c);
break;
case IrCmd::CHECK_NODE_NO_NEXT:
state.checkLiveIns(inst.b);
break;
case IrCmd::CHECK_NODE_VALUE:
state.checkLiveIns(inst.b);
break;
case IrCmd::CHECK_BUFFER_LEN:
state.checkLiveIns(inst.d);
break;
case IrCmd::JUMP:
// Ideally, we would be able to remove stores to registers that are not live out from a block
// But during chain optimizations, we rely on data stored in the predecessor even when it's not an explicit live out
break;
case IrCmd::RETURN:
visitVmRegDefsUses(state, function, inst);
// At the end of a function, we can kill stores to registers that are not live out
state.checkLiveOuts(block);
break;
case IrCmd::ADJUST_STACK_TO_REG:
// visitVmRegDefsUses considers adjustment as the fast call register definition point, but for dead store removal, we count the actual writes
break;
// This group of instructions can trigger GC assist internally
// For GC to work correctly, all values containing a GCO have to be stored on stack - otherwise a live reference might be missed
case IrCmd::CMP_ANY:
case IrCmd::DO_ARITH:
case IrCmd::DO_LEN:
case IrCmd::GET_TABLE:
case IrCmd::SET_TABLE:
case IrCmd::GET_IMPORT:
case IrCmd::CONCAT:
case IrCmd::INTERRUPT:
case IrCmd::CHECK_GC:
case IrCmd::CALL:
case IrCmd::FORGLOOP_FALLBACK:
case IrCmd::FALLBACK_GETGLOBAL:
case IrCmd::FALLBACK_SETGLOBAL:
case IrCmd::FALLBACK_GETTABLEKS:
case IrCmd::FALLBACK_SETTABLEKS:
case IrCmd::FALLBACK_NAMECALL:
case IrCmd::FALLBACK_DUPCLOSURE:
case IrCmd::FALLBACK_FORGPREP:
if (state.hasGcoToClear)
state.flushGcoRegs();
visitVmRegDefsUses(state, function, inst);
break;
default:
// Guards have to be covered explicitly
CODEGEN_ASSERT(!isNonTerminatingJump(inst.cmd));
visitVmRegDefsUses(state, function, inst);
break;
}
}
static void markDeadStoresInBlock(IrBuilder& build, IrBlock& block, RemoveDeadStoreState& state)
{
IrFunction& function = build.function;
for (uint32_t index = block.start; index <= block.finish; index++)
{
CODEGEN_ASSERT(index < function.instructions.size());
IrInst& inst = function.instructions[index];
markDeadStoresInInst(state, build, function, block, inst, index);
}
}
static void markDeadStoresInBlockChain(IrBuilder& build, std::vector<uint8_t>& visited, IrBlock* block)
{
IrFunction& function = build.function;
RemoveDeadStoreState state{function};
while (block)
{
uint32_t blockIdx = function.getBlockIndex(*block);
CODEGEN_ASSERT(!visited[blockIdx]);
visited[blockIdx] = true;
markDeadStoresInBlock(build, *block, state);
IrInst& termInst = function.instructions[block->finish];
IrBlock* nextBlock = nullptr;
// Unconditional jump into a block with a single user (current block) allows us to continue optimization
// with the information we have gathered so far (unless we have already visited that block earlier)
if (termInst.cmd == IrCmd::JUMP && termInst.a.kind == IrOpKind::Block)
{
IrBlock& target = function.blockOp(termInst.a);
uint32_t targetIdx = function.getBlockIndex(target);
if (target.useCount == 1 && !visited[targetIdx] && target.kind != IrBlockKind::Fallback)
nextBlock = &target;
}
block = nextBlock;
}
}
void markDeadStoresInBlockChains(IrBuilder& build)
{
IrFunction& function = build.function;
std::vector<uint8_t> visited(function.blocks.size(), false);
for (IrBlock& block : function.blocks)
{
if (block.kind == IrBlockKind::Fallback || block.kind == IrBlockKind::Dead)
continue;
if (visited[function.getBlockIndex(block)])
continue;
markDeadStoresInBlockChain(build, visited, &block);
}
}
} // namespace CodeGen
} // namespace Luau

View File

@ -88,6 +88,7 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/include/Luau/Label.h CodeGen/include/Luau/Label.h
CodeGen/include/Luau/OperandX64.h CodeGen/include/Luau/OperandX64.h
CodeGen/include/Luau/OptimizeConstProp.h CodeGen/include/Luau/OptimizeConstProp.h
CodeGen/include/Luau/OptimizeDeadStore.h
CodeGen/include/Luau/OptimizeFinalX64.h CodeGen/include/Luau/OptimizeFinalX64.h
CodeGen/include/Luau/RegisterA64.h CodeGen/include/Luau/RegisterA64.h
CodeGen/include/Luau/RegisterX64.h CodeGen/include/Luau/RegisterX64.h
@ -125,6 +126,7 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/src/lcodegen.cpp CodeGen/src/lcodegen.cpp
CodeGen/src/NativeState.cpp CodeGen/src/NativeState.cpp
CodeGen/src/OptimizeConstProp.cpp CodeGen/src/OptimizeConstProp.cpp
CodeGen/src/OptimizeDeadStore.cpp
CodeGen/src/OptimizeFinalX64.cpp CodeGen/src/OptimizeFinalX64.cpp
CodeGen/src/UnwindBuilderDwarf2.cpp CodeGen/src/UnwindBuilderDwarf2.cpp
CodeGen/src/UnwindBuilderWin.cpp CodeGen/src/UnwindBuilderWin.cpp
@ -210,6 +212,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/TypeCheckLimits.h Analysis/include/Luau/TypeCheckLimits.h
Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypedAllocator.h
Analysis/include/Luau/TypeFamily.h Analysis/include/Luau/TypeFamily.h
Analysis/include/Luau/TypeFamilyReductionGuesser.h
Analysis/include/Luau/TypeFwd.h Analysis/include/Luau/TypeFwd.h
Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypeInfer.h
Analysis/include/Luau/TypeOrPack.h Analysis/include/Luau/TypeOrPack.h
@ -271,6 +274,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/TypeChecker2.cpp Analysis/src/TypeChecker2.cpp
Analysis/src/TypedAllocator.cpp Analysis/src/TypedAllocator.cpp
Analysis/src/TypeFamily.cpp Analysis/src/TypeFamily.cpp
Analysis/src/TypeFamilyReductionGuesser.cpp
Analysis/src/TypeInfer.cpp Analysis/src/TypeInfer.cpp
Analysis/src/TypeOrPack.cpp Analysis/src/TypeOrPack.cpp
Analysis/src/TypePack.cpp Analysis/src/TypePack.cpp

View File

@ -260,7 +260,10 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int
return page; return page;
} }
static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, uint8_t sizeClass, bool storeMetadata) // this is part of a cold path in newblock and newgcoblock
// it is marked as noinline to prevent it from being inlined into those functions
// if it is inlined, then the compiler may determine those functions are "too big" to be profitably inlined, which results in reduced performance
LUAU_NOINLINE static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, uint8_t sizeClass, bool storeMetadata)
{ {
if (FFlag::LuauExtendedSizeClasses) if (FFlag::LuauExtendedSizeClasses)
{ {

View File

@ -103,6 +103,12 @@ ClassFixture::ClassFixture()
}; };
getMutable<TableType>(vector2MetaType)->props = { getMutable<TableType>(vector2MetaType)->props = {
{"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}},
{"__mul", {
arena.addType(IntersectionType{{
makeFunction(arena, vector2InstanceType, {vector2InstanceType}, {vector2InstanceType}),
makeFunction(arena, vector2InstanceType, {builtinTypes->numberType}, {vector2InstanceType}),
}})
}}
}; };
globals.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; globals.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType};
addGlobalBinding(globals, "Vector2", vector2Type, "@test"); addGlobalBinding(globals, "Vector2", vector2Type, "@test");

View File

@ -4,6 +4,7 @@
#include "Luau/IrDump.h" #include "Luau/IrDump.h"
#include "Luau/IrUtils.h" #include "Luau/IrUtils.h"
#include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeConstProp.h"
#include "Luau/OptimizeDeadStore.h"
#include "Luau/OptimizeFinalX64.h" #include "Luau/OptimizeFinalX64.h"
#include "ScopedFlags.h" #include "ScopedFlags.h"
@ -15,6 +16,10 @@ LUAU_FASTFLAG(LuauCodegenVectorTag2)
using namespace Luau::CodeGen; using namespace Luau::CodeGen;
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores2)
LUAU_FASTFLAG(DebugLuauAbortingChecks)
class IrBuilderFixture class IrBuilderFixture
{ {
public: public:
@ -2538,6 +2543,8 @@ bb_0: ; useCount: 0
TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation") TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation")
{ {
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp block = build.block(IrBlockKind::Internal); IrOp block = build.block(IrBlockKind::Internal);
IrOp followup = build.block(IrBlockKind::Internal); IrOp followup = build.block(IrBlockKind::Internal);
@ -2560,7 +2567,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation")
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0: bb_0:
; successors: bb_1 ; successors: bb_1
; in regs: R0, R1 ; in regs: R0, R1, R2, R3
; out regs: R1, R2, R3 ; out regs: R1, R2, R3
%0 = LOAD_POINTER R0 %0 = LOAD_POINTER R0
CHECK_READONLY %0, exit(1) CHECK_READONLY %0, exit(1)
@ -2884,6 +2891,65 @@ bb_1:
)"); )");
} }
TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepImplicitUse")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp entry = build.block(IrBlockKind::Internal);
IrOp direct = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Internal);
IrOp exit = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(10.0));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.constDouble(1.0));
IrOp tag = build.inst(IrCmd::LOAD_TAG, build.vmReg(0));
build.inst(IrCmd::JUMP_EQ_TAG, tag, build.constTag(tnumber), direct, fallback);
build.beginBlock(direct);
build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1));
build.beginBlock(fallback);
build.inst(IrCmd::FALLBACK_FORGPREP, build.constUint(0), build.vmReg(1), exit);
build.beginBlock(exit);
build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(3));
updateUseCounts(build.function);
computeCfgInfo(build.function);
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0:
; successors: bb_1, bb_2
; in regs: R0
; out regs: R0, R1, R2, R3
STORE_DOUBLE R1, 1
STORE_DOUBLE R2, 10
STORE_DOUBLE R3, 1
%3 = LOAD_TAG R0
JUMP_EQ_TAG %3, tnumber, bb_1, bb_2
bb_1:
; predecessors: bb_0
; in regs: R0
RETURN R0, 1i
bb_2:
; predecessors: bb_0
; successors: bb_3
; in regs: R1, R2, R3
; out regs: R1, R2, R3
FALLBACK_FORGPREP 0u, R1, bb_3
bb_3:
; predecessors: bb_2
; in regs: R1, R2, R3
RETURN R1, 3i
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable")
{ {
IrOp entry = build.block(IrBlockKind::Internal); IrOp entry = build.block(IrBlockKind::Internal);
@ -3333,6 +3399,358 @@ bb_1:
TEST_SUITE_END(); TEST_SUITE_END();
TEST_SUITE_BEGIN("DeadStoreRemoval");
TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDoubleStore")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp entry = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(2.0)); // Should remove previous store
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(1.0));
build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4)); // Should remove previous store of different type
build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnil));
build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); // Should remove previous store
build.inst(IrCmd::STORE_TAG, build.vmReg(4), build.constTag(tnil));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(4), build.constDouble(1.0));
build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(4), build.constTag(tnumber), build.constDouble(2.0)); // Should remove two previous stores
IrOp someTv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0));
build.inst(IrCmd::STORE_TAG, build.vmReg(5), build.constTag(tnil));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.constDouble(1.0));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(5), someTv); // Should remove two previous stores
build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(5));
updateUseCounts(build.function);
computeCfgInfo(build.function);
constPropInBlockChains(build, true);
markDeadStoresInBlockChains(build);
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0:
; in regs: R0
STORE_DOUBLE R1, 2
STORE_INT R2, 4i
STORE_TAG R3, tnumber
STORE_SPLIT_TVALUE R4, tnumber, 2
%9 = LOAD_TVALUE R0
STORE_TVALUE R5, %9
RETURN R1, 5i
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturn")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp entry = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0));
build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4));
build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber));
build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(4), build.constTag(tnumber), build.constDouble(2.0));
IrOp someTv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(5), someTv);
build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1));
updateUseCounts(build.function);
computeCfgInfo(build.function);
constPropInBlockChains(build, true);
markDeadStoresInBlockChains(build);
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0:
; in regs: R0
RETURN R0, 1i
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse1")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp entry = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
IrOp somePtr = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0));
build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtr);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable));
build.inst(IrCmd::CALL, build.vmReg(2), build.constInt(0), build.constInt(1));
build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(1));
updateUseCounts(build.function);
computeCfgInfo(build.function);
constPropInBlockChains(build, true);
markDeadStoresInBlockChains(build);
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0:
; in regs: R0, R2
%0 = LOAD_POINTER R0
STORE_POINTER R1, %0
STORE_TAG R1, ttable
CALL R2, 0i, 1i
RETURN R2, 1i
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse2")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp entry = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
IrOp somePtrA = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0));
build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtrA);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable));
build.inst(IrCmd::CALL, build.vmReg(2), build.constInt(0), build.constInt(1));
IrOp somePtrB = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2));
build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtrB);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable));
build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(1));
updateUseCounts(build.function);
computeCfgInfo(build.function);
constPropInBlockChains(build, true);
markDeadStoresInBlockChains(build);
// Stores to pointers can be safely removed at 'return' point, but have to preserved for any GC assist trigger (such as a call)
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0:
; in regs: R0, R2
%0 = LOAD_POINTER R0
STORE_POINTER R1, %0
STORE_TAG R1, ttable
CALL R2, 0i, 1i
RETURN R2, 1i
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse3")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp entry = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
IrOp somePtrA = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0));
build.inst(IrCmd::STORE_POINTER, build.vmReg(1), somePtrA);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable));
IrOp someTv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(2));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), someTv);
build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1));
updateUseCounts(build.function);
computeCfgInfo(build.function);
constPropInBlockChains(build, true);
markDeadStoresInBlockChains(build);
// Stores to pointers can be safely removed if there are no potential implicit uses by any GC assists
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0:
; in regs: R0, R2
%3 = LOAD_TVALUE R2
STORE_TVALUE R1, %3
RETURN R1, 1i
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "IgnoreFastcallAdjustment")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp entry = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(-1.0));
build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(1), build.constInt(1));
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0));
build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1));
updateUseCounts(build.function);
computeCfgInfo(build.function);
constPropInBlockChains(build, true);
markDeadStoresInBlockChains(build);
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0:
STORE_TAG R1, tnumber
ADJUST_STACK_TO_REG R1, 1i
STORE_DOUBLE R1, 1
RETURN R1, 1i
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "JumpImplicitLiveOut")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp entry = build.block(IrBlockKind::Internal);
IrOp next = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0));
build.inst(IrCmd::JUMP, next);
build.beginBlock(next);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0));
build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1));
updateUseCounts(build.function);
computeCfgInfo(build.function);
constPropInBlockChains(build, true);
markDeadStoresInBlockChains(build);
// Even though bb_0 doesn't have R1 as a live out, chain optimization used the knowledge of those writes happening to optimize duplicate stores
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0:
; successors: bb_1
STORE_TAG R1, tnumber
STORE_DOUBLE R1, 1
JUMP bb_1
bb_1:
; predecessors: bb_0
RETURN R1, 1i
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "KeepCapturedRegisterStores")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp entry = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
build.inst(IrCmd::CAPTURE, build.vmReg(1), build.constUint(1));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0));
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber));
build.inst(IrCmd::DO_ARITH, build.vmReg(0), build.vmReg(2), build.vmReg(3), build.constInt(0));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(-1.0));
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber));
build.inst(IrCmd::DO_ARITH, build.vmReg(1), build.vmReg(4), build.vmReg(5), build.constInt(0));
build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2));
updateUseCounts(build.function);
computeCfgInfo(build.function);
constPropInBlockChains(build, true);
markDeadStoresInBlockChains(build);
// Captured registers may be modified from called user functions (plain or hidden in metamethods)
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
; captured regs: R1
bb_0:
; in regs: R1, R2, R3, R4, R5
CAPTURE R1, 1u
STORE_DOUBLE R1, 1
STORE_TAG R1, tnumber
DO_ARITH R0, R2, R3, 0i
STORE_DOUBLE R1, -1
STORE_TAG R1, tnumber
DO_ARITH R1, R4, R5, 0i
RETURN R0, 2i
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "AbortingChecksRequireStores")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
ScopedFastFlag debugLuauAbortingChecks{FFlag::DebugLuauAbortingChecks, true};
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5));
build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)));
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5));
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnil));
build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(6));
updateUseCounts(build.function);
computeCfgInfo(build.function);
constPropInBlockChains(build, true);
markDeadStoresInBlockChains(build);
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0:
; in regs: R1, R4
STORE_TAG R0, tnumber
STORE_DOUBLE R2, 0.5
STORE_TAG R3, tnumber
STORE_DOUBLE R5, 0.5
CHECK_TAG R0, tnumber, undef
STORE_TAG R0, tnil
RETURN R0, 6i
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "PartialOverFullValue")
{
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
IrOp entry = build.block(IrBlockKind::Internal);
build.beginBlock(entry);
build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(0), build.constTag(tnumber), build.constDouble(1.0));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(4.0));
build.inst(
IrCmd::STORE_SPLIT_TVALUE, build.vmReg(0), build.constTag(ttable), build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32)));
build.inst(IrCmd::STORE_POINTER, build.vmReg(0), build.inst(IrCmd::NEW_TABLE, build.constUint(8), build.constUint(16)));
build.inst(IrCmd::STORE_POINTER, build.vmReg(0), build.inst(IrCmd::NEW_TABLE, build.constUint(4), build.constUint(8)));
build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(0), build.constTag(tnumber), build.constDouble(1.0));
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tstring));
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(ttable));
build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1));
updateUseCounts(build.function);
computeCfgInfo(build.function);
markDeadStoresInBlockChains(build);
CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"(
bb_0:
STORE_SPLIT_TVALUE R0, tnumber, 1
STORE_TAG R0, ttable
RETURN R0, 1i
)");
}
TEST_SUITE_END();
TEST_SUITE_BEGIN("Dump"); TEST_SUITE_BEGIN("Dump");
TEST_CASE_FIXTURE(IrBuilderFixture, "ToDot") TEST_CASE_FIXTURE(IrBuilderFixture, "ToDot")

View File

@ -13,6 +13,7 @@
#include <memory> #include <memory>
LUAU_FASTFLAG(LuauCodegenVectorTag2) LUAU_FASTFLAG(LuauCodegenVectorTag2)
LUAU_FASTFLAG(LuauCodegenRemoveDeadStores2)
static std::string getCodegenAssembly(const char* source) static std::string getCodegenAssembly(const char* source)
{ {
@ -90,6 +91,8 @@ bb_bytecode_1:
TEST_CASE("VectorComponentRead") TEST_CASE("VectorComponentRead")
{ {
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
CHECK_EQ("\n" + getCodegenAssembly(R"( CHECK_EQ("\n" + getCodegenAssembly(R"(
local function compsum(a: vector) local function compsum(a: vector)
return a.X + a.Y + a.Z return a.X + a.Y + a.Z
@ -104,16 +107,9 @@ bb_2:
JUMP bb_bytecode_1 JUMP bb_bytecode_1
bb_bytecode_1: bb_bytecode_1:
%6 = LOAD_FLOAT R0, 0i %6 = LOAD_FLOAT R0, 0i
STORE_DOUBLE R3, %6
STORE_TAG R3, tnumber
%11 = LOAD_FLOAT R0, 4i %11 = LOAD_FLOAT R0, 4i
STORE_DOUBLE R4, %11
STORE_TAG R4, tnumber
%20 = ADD_NUM %6, %11 %20 = ADD_NUM %6, %11
STORE_DOUBLE R2, %20
STORE_TAG R2, tnumber
%25 = LOAD_FLOAT R0, 8i %25 = LOAD_FLOAT R0, 8i
STORE_DOUBLE R3, %25
%34 = ADD_NUM %20, %25 %34 = ADD_NUM %20, %25
STORE_DOUBLE R1, %34 STORE_DOUBLE R1, %34
STORE_TAG R1, tnumber STORE_TAG R1, tnumber
@ -179,6 +175,7 @@ bb_bytecode_1:
TEST_CASE("VectorSubMulDiv") TEST_CASE("VectorSubMulDiv")
{ {
ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true};
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
CHECK_EQ("\n" + getCodegenAssembly(R"( CHECK_EQ("\n" + getCodegenAssembly(R"(
local function vec3combo(a: vector, b: vector, c: vector, d: vector) local function vec3combo(a: vector, b: vector, c: vector, d: vector)
@ -199,13 +196,9 @@ bb_bytecode_1:
%14 = LOAD_TVALUE R0 %14 = LOAD_TVALUE R0
%15 = LOAD_TVALUE R1 %15 = LOAD_TVALUE R1
%16 = MUL_VEC %14, %15 %16 = MUL_VEC %14, %15
%17 = TAG_VECTOR %16
STORE_TVALUE R5, %17
%23 = LOAD_TVALUE R2 %23 = LOAD_TVALUE R2
%24 = LOAD_TVALUE R3 %24 = LOAD_TVALUE R3
%25 = DIV_VEC %23, %24 %25 = DIV_VEC %23, %24
%26 = TAG_VECTOR %25
STORE_TVALUE R6, %26
%34 = SUB_VEC %16, %25 %34 = SUB_VEC %16, %25
%35 = TAG_VECTOR %34 %35 = TAG_VECTOR %34
STORE_TVALUE R4, %35 STORE_TVALUE R4, %35
@ -217,6 +210,7 @@ bb_bytecode_1:
TEST_CASE("VectorSubMulDiv2") TEST_CASE("VectorSubMulDiv2")
{ {
ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true};
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
CHECK_EQ("\n" + getCodegenAssembly(R"( CHECK_EQ("\n" + getCodegenAssembly(R"(
local function vec3combo(a: vector) local function vec3combo(a: vector)
@ -234,14 +228,8 @@ bb_2:
bb_bytecode_1: bb_bytecode_1:
%8 = LOAD_TVALUE R0 %8 = LOAD_TVALUE R0
%10 = MUL_VEC %8, %8 %10 = MUL_VEC %8, %8
%11 = TAG_VECTOR %10
STORE_TVALUE R1, %11
%19 = SUB_VEC %10, %10 %19 = SUB_VEC %10, %10
%20 = TAG_VECTOR %19
STORE_TVALUE R3, %20
%28 = ADD_VEC %10, %10 %28 = ADD_VEC %10, %10
%29 = TAG_VECTOR %28
STORE_TVALUE R4, %29
%37 = DIV_VEC %19, %28 %37 = DIV_VEC %19, %28
%38 = TAG_VECTOR %37 %38 = TAG_VECTOR %37
STORE_TVALUE R2, %38 STORE_TVALUE R2, %38
@ -253,6 +241,7 @@ bb_bytecode_1:
TEST_CASE("VectorMulDivMixed") TEST_CASE("VectorMulDivMixed")
{ {
ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true};
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
CHECK_EQ("\n" + getCodegenAssembly(R"( CHECK_EQ("\n" + getCodegenAssembly(R"(
local function vec3combo(a: vector, b: vector, c: vector, d: vector) local function vec3combo(a: vector, b: vector, c: vector, d: vector)
@ -273,31 +262,17 @@ bb_bytecode_1:
%12 = LOAD_TVALUE R0 %12 = LOAD_TVALUE R0
%13 = NUM_TO_VEC 2 %13 = NUM_TO_VEC 2
%14 = MUL_VEC %12, %13 %14 = MUL_VEC %12, %13
%15 = TAG_VECTOR %14
STORE_TVALUE R7, %15
%19 = LOAD_TVALUE R1 %19 = LOAD_TVALUE R1
%20 = NUM_TO_VEC 4 %20 = NUM_TO_VEC 4
%21 = DIV_VEC %19, %20 %21 = DIV_VEC %19, %20
%22 = TAG_VECTOR %21
STORE_TVALUE R8, %22
%30 = ADD_VEC %14, %21 %30 = ADD_VEC %14, %21
%31 = TAG_VECTOR %30
STORE_TVALUE R6, %31
STORE_DOUBLE R8, 0.5
STORE_TAG R8, tnumber
%40 = NUM_TO_VEC 0.5 %40 = NUM_TO_VEC 0.5
%41 = LOAD_TVALUE R2 %41 = LOAD_TVALUE R2
%42 = MUL_VEC %40, %41 %42 = MUL_VEC %40, %41
%43 = TAG_VECTOR %42
STORE_TVALUE R7, %43
%51 = ADD_VEC %30, %42 %51 = ADD_VEC %30, %42
%52 = TAG_VECTOR %51
STORE_TVALUE R5, %52
%56 = NUM_TO_VEC 40 %56 = NUM_TO_VEC 40
%57 = LOAD_TVALUE R3 %57 = LOAD_TVALUE R3
%58 = DIV_VEC %56, %57 %58 = DIV_VEC %56, %57
%59 = TAG_VECTOR %58
STORE_TVALUE R6, %59
%67 = ADD_VEC %51, %58 %67 = ADD_VEC %51, %58
%68 = TAG_VECTOR %67 %68 = TAG_VECTOR %67
STORE_TVALUE R4, %68 STORE_TVALUE R4, %68
@ -308,6 +283,8 @@ bb_bytecode_1:
TEST_CASE("ExtraMathMemoryOperands") TEST_CASE("ExtraMathMemoryOperands")
{ {
ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores2, true};
CHECK_EQ("\n" + getCodegenAssembly(R"( CHECK_EQ("\n" + getCodegenAssembly(R"(
local function foo(a: number, b: number, c: number, d: number, e: number) local function foo(a: number, b: number, c: number, d: number, e: number)
return math.floor(a) + math.ceil(b) + math.round(c) + math.sqrt(d) + math.abs(e) return math.floor(a) + math.ceil(b) + math.round(c) + math.sqrt(d) + math.abs(e)
@ -327,26 +304,13 @@ bb_2:
bb_bytecode_1: bb_bytecode_1:
CHECK_SAFE_ENV exit(1) CHECK_SAFE_ENV exit(1)
%16 = FLOOR_NUM R0 %16 = FLOOR_NUM R0
STORE_DOUBLE R9, %16
STORE_TAG R9, tnumber
%23 = CEIL_NUM R1 %23 = CEIL_NUM R1
STORE_DOUBLE R10, %23
STORE_TAG R10, tnumber
%32 = ADD_NUM %16, %23 %32 = ADD_NUM %16, %23
STORE_DOUBLE R8, %32
STORE_TAG R8, tnumber
%39 = ROUND_NUM R2 %39 = ROUND_NUM R2
STORE_DOUBLE R9, %39
%48 = ADD_NUM %32, %39 %48 = ADD_NUM %32, %39
STORE_DOUBLE R7, %48
STORE_TAG R7, tnumber
%55 = SQRT_NUM R3 %55 = SQRT_NUM R3
STORE_DOUBLE R8, %55
%64 = ADD_NUM %48, %55 %64 = ADD_NUM %48, %55
STORE_DOUBLE R6, %64
STORE_TAG R6, tnumber
%71 = ABS_NUM R4 %71 = ABS_NUM R4
STORE_DOUBLE R7, %71
%80 = ADD_NUM %64, %71 %80 = ADD_NUM %64, %71
STORE_DOUBLE R5, %80 STORE_DOUBLE R5, %80
STORE_TAG R5, tnumber STORE_TAG R5, tnumber

View File

@ -89,6 +89,9 @@ declare function @checked optionalArg(x: string?) : number
declare foo: { declare foo: {
bar: @checked (number) -> number, bar: @checked (number) -> number,
} }
declare function @checked optionalArgsAtTheEnd1(x: string, y: number?, z: number?) : number
declare function @checked optionalArgsAtTheEnd2(x: string, y: number?, z: string) : number
)BUILTIN_SRC"; )BUILTIN_SRC";
}; };
@ -474,4 +477,32 @@ abs(3, "hi");
CHECK_EQ("foo.bar", r2->functionName); CHECK_EQ("foo.bar", r2->functionName);
} }
TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "optionals_in_checked_function_can_be_omitted")
{
CheckResult result = checkNonStrict(R"(
optionalArgsAtTheEnd1("a")
optionalArgsAtTheEnd1("a", 3)
optionalArgsAtTheEnd1("a", nil, 3)
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "optionals_in_checked_function_in_middle_cannot_be_omitted")
{
CheckResult result = checkNonStrict(R"(
optionalArgsAtTheEnd2("a", "a") -- error
optionalArgsAtTheEnd2("a", nil, "b")
optionalArgsAtTheEnd2("a", 3, "b")
optionalArgsAtTheEnd2("a", "b", "c") -- error
)");
LUAU_REQUIRE_ERROR_COUNT(3, result);
NONSTRICT_REQUIRE_CHECKED_ERR(Position(1, 27), "optionalArgsAtTheEnd2", result);
NONSTRICT_REQUIRE_CHECKED_ERR(Position(4, 27), "optionalArgsAtTheEnd2", result);
auto r1 = get<CheckedFunctionIncorrectArgs>(result.errors[2]);
LUAU_ASSERT(r1);
CHECK_EQ(3, r1->expected);
CHECK_EQ(2, r1->actual);
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -6,7 +6,7 @@
#include <string.h> #include <string.h>
template<typename T> template<typename T>
struct ScopedFValue struct [[nodiscard]] ScopedFValue
{ {
private: private:
Luau::FValue<T>* value = nullptr; Luau::FValue<T>* value = nullptr;

View File

@ -563,4 +563,10 @@ TEST_CASE_FIXTURE(SimplifyFixture, "free_type_bound_by_any_with_any")
CHECK("'a | *error-type*" == intersectStr(anyTy, freeTy)); CHECK("'a | *error-type*" == intersectStr(anyTy, freeTy));
} }
TEST_CASE_FIXTURE(SimplifyFixture, "bound_intersected_by_itself_should_be_itself")
{
TypeId blocked = arena->addType(BlockedType{});
CHECK(toString(blocked) == intersectStr(blocked, blocked));
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -552,6 +552,29 @@ TEST_CASE_FIXTURE(ClassFixture, "keyof_type_family_common_subset_if_union_of_dif
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(ClassFixture, "vector2_multiply_is_overloaded")
{
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
CheckResult result = check(R"(
local v = Vector2.New(1, 2)
local v2 = v * 1.5
local v3 = v * v
local v4 = v * "Hello" -- line 5
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(5 == result.errors[0].location.begin.line);
CHECK(5 == result.errors[0].location.end.line);
CHECK("Vector2" == toString(requireType("v2")));
CHECK("Vector2" == toString(requireType("v3")));
CHECK("mul<Vector2, string>" == toString(requireType("v4")));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_rfc_example") TEST_CASE_FIXTURE(BuiltinsFixture, "keyof_rfc_example")
{ {
if (!FFlag::DebugLuauDeferredConstraintResolution) if (!FFlag::DebugLuauDeferredConstraintResolution)

View File

@ -2309,4 +2309,60 @@ end
CHECK_EQ("(number) -> boolean", toString(requireType("odd"))); CHECK_EQ("(number) -> boolean", toString(requireType("odd")));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type")
{
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
CheckResult result = check(R"(
function fib(n)
return n < 2 and 1 or fib(n-1) + fib(n-2)
end
)");
LUAU_REQUIRE_ERRORS(result);
auto err = get<ExplicitFunctionAnnotationRecommended>(result.errors.back());
LUAU_ASSERT(err);
CHECK("false | number" == toString(err->recommendedReturn));
CHECK(err->recommendedArgs.size() == 0);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type")
{
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
CheckResult result = check(R"(
function fib(n, u)
return (n or u) and (n < u and n + fib(n,u))
end
)");
LUAU_REQUIRE_ERRORS(result);
auto err = get<ExplicitFunctionAnnotationRecommended>(result.errors.back());
LUAU_ASSERT(err);
CHECK("number" == toString(err->recommendedReturn));
CHECK(err->recommendedArgs.size() == 2);
CHECK("number" == toString(err->recommendedArgs[0].second));
CHECK("number" == toString(err->recommendedArgs[1].second));
}
TEST_CASE_FIXTURE(Fixture, "local_function_fwd_decl_doesnt_crash")
{
CheckResult result = check(R"(
local foo
local function bar()
foo()
end
function foo()
end
bar()
)");
// This test verifies that an ICE doesn't occur, so the bulk of the test is
// just from running check above.
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -460,6 +460,19 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_while")
CHECK_EQ(us->name, "a"); CHECK_EQ(us->name, "a");
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "trivial_ipairs_usage")
{
CheckResult result = check(R"(
local next, t, s = ipairs({1, 2, 3})
)");
LUAU_REQUIRE_NO_ERRORS(result);
REQUIRE_EQ("({number}, number) -> (number?, number)", toString(requireType("next")));
REQUIRE_EQ("{number}", toString(requireType("t")));
REQUIRE_EQ("number", toString(requireType("s")));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_produces_integral_indices") TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_produces_integral_indices")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(

View File

@ -502,4 +502,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "promise_type_error_too_complex" * doctest::t
LUAU_REQUIRE_ERRORS(result); LUAU_REQUIRE_ERRORS(result);
} }
TEST_CASE_FIXTURE(Fixture, "method_should_not_create_cyclic_type")
{
ScopedFastFlag sff(FFlag::DebugLuauDeferredConstraintResolution, true);
CheckResult result = check(R"(
local Component = {}
function Component:__resolveUpdate(incomingState)
local oldState = self.state
incomingState = oldState
self.state = incomingState
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -1474,4 +1474,22 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compare_singleton_string_to_string")
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "no_infinite_expansion_of_free_type" * doctest::timeout(1.0))
{
ScopedFastFlag sff(FFlag::DebugLuauDeferredConstraintResolution, true);
check(R"(
local tooltip = {}
function tooltip:Show()
local playerGui = self.Player:FindFirstChild("PlayerGui")
for _,c in ipairs(playerGui:GetChildren()) do
if c:IsA("ScreenGui") and c.DisplayOrder > self.Gui.DisplayOrder then
end
end
end
)");
// just type-checking this code is enough
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -2022,4 +2022,39 @@ end
CHECK("string" == toString(t)); CHECK("string" == toString(t));
} }
TEST_CASE_FIXTURE(RefinementClassFixture, "mutate_prop_of_some_refined_symbol")
{
CheckResult result = check(R"(
local function instances(): {Instance} error("") end
local function vec3(x, y, z): Vector3 error("") end
for _, object in ipairs(instances()) do
if object:IsA("Part") then
object.Position = vec3(1, 2, 3)
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(RefinementClassFixture, "mutate_prop_of_some_refined_symbol_2")
{
CheckResult result = check(R"(
type Result<T, E> = never
| { tag: "ok", value: T }
| { tag: "err", error: E }
local function results(): {Result<number, string>} error("") end
for _, res in ipairs(results()) do
if res.tag == "ok" then
res.value = 7
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -4171,17 +4171,9 @@ TEST_CASE_FIXTURE(Fixture, "table_writes_introduce_write_properties")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK("<a, b, c...>({{ read Character: a }}, { Character: t1 }) -> () " CHECK("<a, b...>({{ read Character: t1 }}, { Character: t1 }) -> () "
"where " "where "
"t1 = a & { read FindFirstChild: (t1, string) -> (b, c...) }" == toString(requireType("oc"))); "t1 = { read FindFirstChild: (t1, string) -> (a, b...) }" == toString(requireType("oc")));
// We currently get
// <a, b, c...>({{ read Character: a }}, { Character: t1 }) -> () where t1 = { read FindFirstChild: (t1, string) -> (b, c...) }
// But we'd like to see
// <a, b...>({{ read Character: t1 }}, { Character: t1 }) -> () where t1 = { read FindFirstChild: (t1, string) -> (a, b...) }
// The type of speaker.Character should be the same as player[1].Character
} }
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -407,4 +407,24 @@ end
bufferbounds(0) bufferbounds(0)
function deadStoreChecks1()
local a = 1.0
local b = 0.0
local function update()
b += a
for i = 1, 100 do print(`{b} is {b}`) end
end
update()
a = 10
update()
a = 100
update()
return b
end
assert(deadStoreChecks1() == 111)
return('OK') return('OK')

View File

@ -11,7 +11,6 @@ BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_th
BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy
BuiltinTests.bad_select_should_not_crash BuiltinTests.bad_select_should_not_crash
BuiltinTests.coroutine_resume_anything_goes BuiltinTests.coroutine_resume_anything_goes
BuiltinTests.global_singleton_types_are_sealed
BuiltinTests.gmatch_capture_types BuiltinTests.gmatch_capture_types
BuiltinTests.gmatch_capture_types2 BuiltinTests.gmatch_capture_types2
BuiltinTests.gmatch_capture_types_balanced_escaped_parens BuiltinTests.gmatch_capture_types_balanced_escaped_parens
@ -20,26 +19,18 @@ BuiltinTests.gmatch_capture_types_parens_in_sets_are_ignored
BuiltinTests.gmatch_capture_types_set_containing_lbracket BuiltinTests.gmatch_capture_types_set_containing_lbracket
BuiltinTests.gmatch_definition BuiltinTests.gmatch_definition
BuiltinTests.ipairs_iterator_should_infer_types_and_type_check BuiltinTests.ipairs_iterator_should_infer_types_and_type_check
BuiltinTests.next_iterator_should_infer_types_and_type_check
BuiltinTests.os_time_takes_optional_date_table BuiltinTests.os_time_takes_optional_date_table
BuiltinTests.pairs_iterator_should_infer_types_and_type_check
BuiltinTests.select_slightly_out_of_range BuiltinTests.select_slightly_out_of_range
BuiltinTests.select_way_out_of_range BuiltinTests.select_way_out_of_range
BuiltinTests.select_with_variadic_typepack_tail_and_string_head BuiltinTests.select_with_variadic_typepack_tail_and_string_head
BuiltinTests.set_metatable_needs_arguments BuiltinTests.set_metatable_needs_arguments
BuiltinTests.setmetatable_should_not_mutate_persisted_types BuiltinTests.setmetatable_should_not_mutate_persisted_types
BuiltinTests.sort
BuiltinTests.sort_with_bad_predicate BuiltinTests.sort_with_bad_predicate
BuiltinTests.sort_with_predicate
BuiltinTests.string_format_as_method BuiltinTests.string_format_as_method
BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_correctly_ordered_types
BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_report_all_type_errors_at_correct_positions
BuiltinTests.string_format_use_correct_argument2 BuiltinTests.string_format_use_correct_argument2
BuiltinTests.table_concat_returns_string
BuiltinTests.table_dot_remove_optionally_returns_generic
BuiltinTests.table_freeze_is_generic BuiltinTests.table_freeze_is_generic
BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload
BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload
BuiltinTests.tonumber_returns_optional_number_type BuiltinTests.tonumber_returns_optional_number_type
ControlFlowAnalysis.if_not_x_break_elif_not_y_break ControlFlowAnalysis.if_not_x_break_elif_not_y_break
ControlFlowAnalysis.if_not_x_break_elif_not_y_continue ControlFlowAnalysis.if_not_x_break_elif_not_y_continue
@ -98,12 +89,10 @@ GenericsTests.generic_type_pack_unification1
GenericsTests.generic_type_pack_unification2 GenericsTests.generic_type_pack_unification2
GenericsTests.generic_type_pack_unification3 GenericsTests.generic_type_pack_unification3
GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments
GenericsTests.hof_subtype_instantiation_regression
GenericsTests.infer_generic_function_function_argument GenericsTests.infer_generic_function_function_argument
GenericsTests.infer_generic_function_function_argument_2 GenericsTests.infer_generic_function_function_argument_2
GenericsTests.infer_generic_function_function_argument_3 GenericsTests.infer_generic_function_function_argument_3
GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_function_function_argument_overloaded
GenericsTests.infer_generic_lib_function_function_argument
GenericsTests.instantiated_function_argument_names GenericsTests.instantiated_function_argument_names
GenericsTests.mutable_state_polymorphism GenericsTests.mutable_state_polymorphism
GenericsTests.no_stack_overflow_from_quantifying GenericsTests.no_stack_overflow_from_quantifying
@ -408,7 +397,6 @@ TypeInferFunctions.infer_anonymous_function_arguments
TypeInferFunctions.infer_anonymous_function_arguments_outside_call TypeInferFunctions.infer_anonymous_function_arguments_outside_call
TypeInferFunctions.infer_generic_function_function_argument TypeInferFunctions.infer_generic_function_function_argument
TypeInferFunctions.infer_generic_function_function_argument_overloaded TypeInferFunctions.infer_generic_function_function_argument_overloaded
TypeInferFunctions.infer_generic_lib_function_function_argument
TypeInferFunctions.infer_return_type_from_selected_overload TypeInferFunctions.infer_return_type_from_selected_overload
TypeInferFunctions.infer_return_value_type TypeInferFunctions.infer_return_value_type
TypeInferFunctions.inferred_higher_order_functions_are_quantified_at_the_right_time3 TypeInferFunctions.inferred_higher_order_functions_are_quantified_at_the_right_time3
@ -432,7 +420,6 @@ TypeInferFunctions.too_many_arguments_error_location
TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_in_parentheses
TypeInferFunctions.too_many_return_values_no_function TypeInferFunctions.too_many_return_values_no_function
TypeInferLoops.cli_68448_iterators_need_not_accept_nil TypeInferLoops.cli_68448_iterators_need_not_accept_nil
TypeInferLoops.dcr_iteration_explore_raycast_minimization
TypeInferLoops.dcr_iteration_fragmented_keys TypeInferLoops.dcr_iteration_fragmented_keys
TypeInferLoops.dcr_iteration_on_never_gives_never TypeInferLoops.dcr_iteration_on_never_gives_never
TypeInferLoops.dcr_xpath_candidates TypeInferLoops.dcr_xpath_candidates
@ -443,7 +430,6 @@ TypeInferLoops.for_in_loop_on_error
TypeInferLoops.for_in_loop_on_non_function TypeInferLoops.for_in_loop_on_non_function
TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_loop_with_next
TypeInferLoops.for_in_with_an_iterator_of_type_any TypeInferLoops.for_in_with_an_iterator_of_type_any
TypeInferLoops.for_in_with_generic_next
TypeInferLoops.for_loop TypeInferLoops.for_loop
TypeInferLoops.ipairs_produces_integral_indices TypeInferLoops.ipairs_produces_integral_indices
TypeInferLoops.iterate_over_free_table TypeInferLoops.iterate_over_free_table
@ -486,7 +472,6 @@ TypeInferOperators.error_on_invalid_operand_types_to_relational_operators2
TypeInferOperators.luau_polyfill_is_array TypeInferOperators.luau_polyfill_is_array
TypeInferOperators.mm_comparisons_must_return_a_boolean TypeInferOperators.mm_comparisons_must_return_a_boolean
TypeInferOperators.operator_eq_verifies_types_do_intersect TypeInferOperators.operator_eq_verifies_types_do_intersect
TypeInferOperators.reducing_and
TypeInferOperators.refine_and_or TypeInferOperators.refine_and_or
TypeInferOperators.reworked_and TypeInferOperators.reworked_and
TypeInferOperators.reworked_or TypeInferOperators.reworked_or