Merge branch 'upstream' into merge

This commit is contained in:
Arseny Kapoulkine 2022-01-14 08:07:10 -08:00
commit e6dd6bd158
47 changed files with 2641 additions and 1261 deletions

View File

@ -97,6 +97,12 @@ struct ApplyTypeFunction : Substitution
TypePackId clean(TypePackId tp) override; TypePackId clean(TypePackId tp) override;
}; };
struct GenericTypeDefinitions
{
std::vector<GenericTypeDefinition> genericTypes;
std::vector<GenericTypePackDefinition> genericPacks;
};
// All TypeVars are retained via Environment::typeVars. All TypeIds // All TypeVars are retained via Environment::typeVars. All TypeIds
// within a program are borrowed pointers into this set. // within a program are borrowed pointers into this set.
struct TypeChecker struct TypeChecker
@ -146,7 +152,7 @@ struct TypeChecker
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr); ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr); ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr); ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType = std::nullopt);
TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector<std::pair<TypeId, TypeId>>& fieldTypes, TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector<std::pair<TypeId, TypeId>>& fieldTypes,
std::optional<TypeId> expectedType); std::optional<TypeId> expectedType);
@ -336,8 +342,8 @@ private:
const std::vector<TypePackId>& typePackParams, const Location& location); const std::vector<TypePackId>& typePackParams, const Location& location);
// Note: `scope` must be a fresh scope. // Note: `scope` must be a fresh scope.
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt, GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node,
const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames); const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericTypePack>& genericPackNames);
public: public:
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);

View File

@ -181,6 +181,18 @@ const T* get(const SingletonTypeVar* stv)
return nullptr; return nullptr;
} }
struct GenericTypeDefinition
{
TypeId ty;
std::optional<TypeId> defaultValue;
};
struct GenericTypePackDefinition
{
TypePackId tp;
std::optional<TypePackId> defaultValue;
};
struct FunctionArgument struct FunctionArgument
{ {
Name name; Name name;
@ -358,8 +370,8 @@ struct ClassTypeVar
struct TypeFun struct TypeFun
{ {
// These should all be generic // These should all be generic
std::vector<TypeId> typeParams; std::vector<GenericTypeDefinition> typeParams;
std::vector<TypePackId> typePackParams; std::vector<GenericTypePackDefinition> typePackParams;
/** The underlying type. /** The underlying type.
* *
@ -369,13 +381,13 @@ struct TypeFun
TypeId type; TypeId type;
TypeFun() = default; TypeFun() = default;
TypeFun(std::vector<TypeId> typeParams, TypeId type) TypeFun(std::vector<GenericTypeDefinition> typeParams, TypeId type)
: typeParams(std::move(typeParams)) : typeParams(std::move(typeParams))
, type(type) , type(type)
{ {
} }
TypeFun(std::vector<TypeId> typeParams, std::vector<TypePackId> typePackParams, TypeId type) TypeFun(std::vector<GenericTypeDefinition> typeParams, std::vector<GenericTypePackDefinition> typePackParams, TypeId type)
: typeParams(std::move(typeParams)) : typeParams(std::move(typeParams))
, typePackParams(std::move(typePackParams)) , typePackParams(std::move(typePackParams))
, type(type) , type(type)

View File

@ -13,11 +13,10 @@
#include <utility> #include <utility>
LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAG(LuauUseCommittingTxnLog)
LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false);
LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false);
LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false);
LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false);
static const std::unordered_set<std::string> kStatementStartingKeywords = { static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -291,51 +290,23 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
expectedType = follow(*it); expectedType = follow(*it);
} }
if (FFlag::LuauAutocompletePreferToCallFunctions) // We also want to suggest functions that return compatible result
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{ {
// We also want to suggest functions that return compatible result
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
auto [retHead, retTail] = flatten(ftv->retType);
if (!retHead.empty() && canUnify(retHead.front(), expectedType))
return TypeCorrectKind::CorrectFunctionResult;
// We might only have a variadic tail pack, check if the element is compatible
if (retTail)
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType))
return TypeCorrectKind::CorrectFunctionResult;
}
}
return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
}
else
{
if (canUnify(ty, expectedType))
return TypeCorrectKind::Correct;
// We also want to suggest functions that return compatible result
const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty);
if (!ftv)
return TypeCorrectKind::None;
auto [retHead, retTail] = flatten(ftv->retType); auto [retHead, retTail] = flatten(ftv->retType);
if (!retHead.empty()) if (!retHead.empty() && canUnify(retHead.front(), expectedType))
return canUnify(retHead.front(), expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; return TypeCorrectKind::CorrectFunctionResult;
// We might only have a variadic tail pack, check if the element is compatible // We might only have a variadic tail pack, check if the element is compatible
if (retTail) if (retTail)
{ {
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*retTail))) if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType))
return canUnify(vtp->ty, expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; return TypeCorrectKind::CorrectFunctionResult;
} }
return TypeCorrectKind::None;
} }
return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
} }
enum class PropIndexType enum class PropIndexType
@ -435,13 +406,28 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
auto indexIt = mtable->props.find("__index"); auto indexIt = mtable->props.find("__index");
if (indexIt != mtable->props.end()) if (indexIt != mtable->props.end())
{ {
if (get<TableTypeVar>(indexIt->second.type) || get<MetatableTypeVar>(indexIt->second.type)) if (FFlag::LuauMissingFollowACMetatables)
autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen);
else if (auto indexFunction = get<FunctionTypeVar>(indexIt->second.type))
{ {
std::optional<TypeId> indexFunctionResult = first(indexFunction->retType); TypeId followed = follow(indexIt->second.type);
if (indexFunctionResult) if (get<TableTypeVar>(followed) || get<MetatableTypeVar>(followed))
autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen);
else if (auto indexFunction = get<FunctionTypeVar>(followed))
{
std::optional<TypeId> indexFunctionResult = first(indexFunction->retType);
if (indexFunctionResult)
autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen);
}
}
else
{
if (get<TableTypeVar>(indexIt->second.type) || get<MetatableTypeVar>(indexIt->second.type))
autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen);
else if (auto indexFunction = get<FunctionTypeVar>(indexIt->second.type))
{
std::optional<TypeId> indexFunctionResult = first(indexFunction->retType);
if (indexFunctionResult)
autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen);
}
} }
} }
} }
@ -1224,7 +1210,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul
if (auto it = module.astTypes.find(node->asExpr())) if (auto it = module.astTypes.find(node->asExpr()))
autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result); autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result);
} }
else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) else if (autocompleteIfElseExpression(node, ancestry, position, result))
return; return;
else if (node->is<AstExprFunction>()) else if (node->is<AstExprFunction>())
return; return;
@ -1261,8 +1247,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul
TypeCorrectKind correctForFunction = TypeCorrectKind correctForFunction =
functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
if (FFlag::LuauIfElseExpressionAnalysisSupport) result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false};
result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false};
result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean};
result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean};
result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil};

View File

@ -190,24 +190,24 @@ struct ErrorConverter
{ {
name += "<"; name += "<";
bool first = true; bool first = true;
for (TypeId t : e.typeFun.typeParams) for (auto param : e.typeFun.typeParams)
{ {
if (first) if (first)
first = false; first = false;
else else
name += ", "; name += ", ";
name += toString(t); name += toString(param.ty);
} }
for (TypePackId t : e.typeFun.typePackParams) for (auto param : e.typeFun.typePackParams)
{ {
if (first) if (first)
first = false; first = false;
else else
name += ", "; name += ", ";
name += toString(t); name += toString(param.tp);
} }
name += ">"; name += ">";
@ -544,13 +544,13 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC
for (size_t i = 0; i < typeFun.typeParams.size(); ++i) for (size_t i = 0; i < typeFun.typeParams.size(); ++i)
{ {
if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) if (typeFun.typeParams[i].ty != rhs.typeFun.typeParams[i].ty)
return false; return false;
} }
for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) for (size_t i = 0; i < typeFun.typePackParams.size(); ++i)
{ {
if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) if (typeFun.typePackParams[i].tp != rhs.typeFun.typePackParams[i].tp)
return false; return false;
} }

View File

@ -96,24 +96,24 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo
{ {
stream << "<"; stream << "<";
bool first = true; bool first = true;
for (TypeId t : error.typeFun.typeParams) for (auto param : error.typeFun.typeParams)
{ {
if (first) if (first)
first = false; first = false;
else else
stream << ", "; stream << ", ";
stream << toString(t); stream << toString(param.ty);
} }
for (TypePackId t : error.typeFun.typePackParams) for (auto param : error.typeFun.typePackParams)
{ {
if (first) if (first)
first = false; first = false;
else else
stream << ", "; stream << ", ";
stream << toString(t); stream << toString(param.tp);
} }
stream << ">"; stream << ">";

View File

@ -5,6 +5,8 @@
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include "Luau/Common.h" #include "Luau/Common.h"
LUAU_FASTFLAG(LuauTypeAliasDefaults)
namespace Luau namespace Luau
{ {
@ -337,6 +339,42 @@ struct AstJsonEncoder : public AstVisitor
writeRaw("}"); writeRaw("}");
} }
void write(const AstGenericType& genericType)
{
if (FFlag::LuauTypeAliasDefaults)
{
writeRaw("{");
bool c = pushComma();
write("name", genericType.name);
if (genericType.defaultValue)
write("type", genericType.defaultValue);
popComma(c);
writeRaw("}");
}
else
{
write(genericType.name);
}
}
void write(const AstGenericTypePack& genericTypePack)
{
if (FFlag::LuauTypeAliasDefaults)
{
writeRaw("{");
bool c = pushComma();
write("name", genericTypePack.name);
if (genericTypePack.defaultValue)
write("type", genericTypePack.defaultValue);
popComma(c);
writeRaw("}");
}
else
{
write(genericTypePack.name);
}
}
void write(AstExprTable::Item::Kind kind) void write(AstExprTable::Item::Kind kind)
{ {
switch (kind) switch (kind)

View File

@ -14,6 +14,7 @@
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false)
LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false)
LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300)
LUAU_FASTFLAG(LuauTypeAliasDefaults)
namespace Luau namespace Luau
{ {
@ -447,11 +448,28 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
{ {
TypeFun result; TypeFun result;
for (TypeId ty : typeFun.typeParams)
result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
for (TypePackId tp : typeFun.typePackParams) for (auto param : typeFun.typeParams)
result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, cloneState)); {
TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState);
std::optional<TypeId> defaultValue;
if (FFlag::LuauTypeAliasDefaults && param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState);
result.typeParams.push_back({ty, defaultValue});
}
for (auto param : typeFun.typePackParams)
{
TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState);
std::optional<TypePackId> defaultValue;
if (FFlag::LuauTypeAliasDefaults && param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState);
result.typePackParams.push_back({tp, defaultValue});
}
result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState);

View File

@ -11,6 +11,7 @@
#include <stdexcept> #include <stdexcept>
LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions)
LUAU_FASTFLAG(LuauTypeAliasDefaults)
/* /*
* Prefix generic typenames with gen- * Prefix generic typenames with gen-
@ -209,6 +210,14 @@ struct StringifierState
result.name += s; result.name += s;
} }
void emit(const char* s)
{
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
return;
result.name += s;
}
}; };
struct TypeVarStringifier struct TypeVarStringifier
@ -280,13 +289,28 @@ struct TypeVarStringifier
else else
first = false; first = false;
if (!singleTp) if (FFlag::LuauTypeAliasDefaults)
state.emit("("); {
bool wrap = !singleTp && get<TypePack>(follow(tp));
stringify(tp); if (wrap)
state.emit("(");
if (!singleTp) stringify(tp);
state.emit(")");
if (wrap)
state.emit(")");
}
else
{
if (!singleTp)
state.emit("(");
stringify(tp);
if (!singleTp)
state.emit(")");
}
} }
if (types.size() || typePacks.size()) if (types.size() || typePacks.size())
@ -1086,7 +1110,7 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts)
return toString(const_cast<TypePackId>(&tp), std::move(opts)); return toString(const_cast<TypePackId>(&tp), std::move(opts));
} }
std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) std::string toStringNamedFunction_DEPRECATED(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts)
{ {
std::string s = prefix; std::string s = prefix;
@ -1175,6 +1199,77 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV
return s; return s;
} }
std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts)
{
if (!FFlag::LuauTypeAliasDefaults)
return toStringNamedFunction_DEPRECATED(prefix, ftv, opts);
ToStringResult result;
StringifierState state(opts, result, opts.nameMap);
TypeVarStringifier tvs{state};
state.emit(prefix);
if (!opts.hideNamedFunctionTypeParameters)
tvs.stringify(ftv.generics, ftv.genericPacks);
state.emit("(");
auto argPackIter = begin(ftv.argTypes);
auto argNameIter = ftv.argNames.begin();
bool first = true;
while (argPackIter != end(ftv.argTypes))
{
if (!first)
state.emit(", ");
first = false;
// We don't currently respect opts.functionTypeArguments. I don't think this function should.
if (argNameIter != ftv.argNames.end())
{
state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": ");
++argNameIter;
}
else
{
state.emit("_: ");
}
tvs.stringify(*argPackIter);
++argPackIter;
}
if (argPackIter.tail())
{
if (!first)
state.emit(", ");
state.emit("...: ");
if (auto vtp = get<VariadicTypePack>(*argPackIter.tail()))
tvs.stringify(vtp->ty);
else
tvs.stringify(*argPackIter.tail());
}
state.emit("): ");
size_t retSize = size(ftv.retType);
bool hasTail = !finite(ftv.retType);
bool wrap = get<TypePack>(follow(ftv.retType)) && (hasTail ? retSize != 0 : retSize != 1);
if (wrap)
state.emit("(");
tvs.stringify(ftv.retType);
if (wrap)
state.emit(")");
return result.name;
}
std::string dump(TypeId ty) std::string dump(TypeId ty)
{ {
ToStringOptions opts; ToStringOptions opts;

View File

@ -10,6 +10,8 @@
#include <limits> #include <limits>
#include <math.h> #include <math.h>
LUAU_FASTFLAG(LuauTypeAliasDefaults)
namespace namespace
{ {
bool isIdentifierStartChar(char c) bool isIdentifierStartChar(char c)
@ -793,14 +795,47 @@ struct Printer
for (auto o : a->generics) for (auto o : a->generics)
{ {
comma(); comma();
writer.identifier(o.value);
if (FFlag::LuauTypeAliasDefaults)
{
writer.advance(o.location.begin);
writer.identifier(o.name.value);
if (o.defaultValue)
{
writer.maybeSpace(o.defaultValue->location.begin, 2);
writer.symbol("=");
visualizeTypeAnnotation(*o.defaultValue);
}
}
else
{
writer.identifier(o.name.value);
}
} }
for (auto o : a->genericPacks) for (auto o : a->genericPacks)
{ {
comma(); comma();
writer.identifier(o.value);
writer.symbol("..."); if (FFlag::LuauTypeAliasDefaults)
{
writer.advance(o.location.begin);
writer.identifier(o.name.value);
writer.symbol("...");
if (o.defaultValue)
{
writer.maybeSpace(o.defaultValue->location.begin, 2);
writer.symbol("=");
visualizeTypePackAnnotation(*o.defaultValue, false);
}
}
else
{
writer.identifier(o.name.value);
writer.symbol("...");
}
} }
writer.symbol(">"); writer.symbol(">");
@ -846,12 +881,20 @@ struct Printer
for (const auto& o : func.generics) for (const auto& o : func.generics)
{ {
comma(); comma();
writer.identifier(o.value);
if (FFlag::LuauTypeAliasDefaults)
writer.advance(o.location.begin);
writer.identifier(o.name.value);
} }
for (const auto& o : func.genericPacks) for (const auto& o : func.genericPacks)
{ {
comma(); comma();
writer.identifier(o.value);
if (FFlag::LuauTypeAliasDefaults)
writer.advance(o.location.begin);
writer.identifier(o.name.value);
writer.symbol("..."); writer.symbol("...");
} }
writer.symbol(">"); writer.symbol(">");
@ -979,12 +1022,20 @@ struct Printer
for (const auto& o : a->generics) for (const auto& o : a->generics)
{ {
comma(); comma();
writer.identifier(o.value);
if (FFlag::LuauTypeAliasDefaults)
writer.advance(o.location.begin);
writer.identifier(o.name.value);
} }
for (const auto& o : a->genericPacks) for (const auto& o : a->genericPacks)
{ {
comma(); comma();
writer.identifier(o.value);
if (FFlag::LuauTypeAliasDefaults)
writer.advance(o.location.begin);
writer.identifier(o.name.value);
writer.symbol("..."); writer.symbol("...");
} }
writer.symbol(">"); writer.symbol(">");

View File

@ -212,24 +212,24 @@ public:
if (hasSeen(&ftv)) if (hasSeen(&ftv))
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Cycle>")); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Cycle>"));
AstArray<AstName> generics; AstArray<AstGenericType> generics;
generics.size = ftv.generics.size(); generics.size = ftv.generics.size();
generics.data = static_cast<AstName*>(allocator->allocate(sizeof(AstName) * generics.size)); generics.data = static_cast<AstGenericType*>(allocator->allocate(sizeof(AstGenericType) * generics.size));
size_t numGenerics = 0; size_t numGenerics = 0;
for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it)
{ {
if (auto gtv = get<GenericTypeVar>(*it)) if (auto gtv = get<GenericTypeVar>(*it))
generics.data[numGenerics++] = AstName(gtv->name.c_str()); generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr};
} }
AstArray<AstName> genericPacks; AstArray<AstGenericTypePack> genericPacks;
genericPacks.size = ftv.genericPacks.size(); genericPacks.size = ftv.genericPacks.size();
genericPacks.data = static_cast<AstName*>(allocator->allocate(sizeof(AstName) * genericPacks.size)); genericPacks.data = static_cast<AstGenericTypePack*>(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size));
size_t numGenericPacks = 0; size_t numGenericPacks = 0;
for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it)
{ {
if (auto gtv = get<GenericTypeVar>(*it)) if (auto gtv = get<GenericTypeVar>(*it))
genericPacks.data[numGenericPacks++] = AstName(gtv->name.c_str()); genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr};
} }
AstArray<AstType*> argTypes; AstArray<AstType*> argTypes;

View File

@ -15,8 +15,8 @@
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/TimeTrace.h" #include "Luau/TimeTrace.h"
#include <deque>
#include <algorithm> #include <algorithm>
#include <iterator>
LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false)
LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500)
@ -24,25 +24,30 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000)
LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false)
LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false)
LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false.
LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false)
LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false)
LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAG(LuauUseCommittingTxnLog)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false)
LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false)
LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false)
LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false)
LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false)
LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false)
LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false)
LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false)
LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false)
LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false)
LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false)
LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false)
LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false)
LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false)
LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false)
LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false)
LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false)
LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false)
LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false)
namespace Luau namespace Luau
@ -279,6 +284,14 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona
GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}});
} }
if (FFlag::LuauPerModuleUnificationCache)
{
// Clear unifier cache since it's keyed off internal types that get deallocated
// This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs.
unifierState.cachedUnify.clear();
unifierState.skipCacheForType.clear();
}
return std::move(currentModule); return std::move(currentModule);
} }
@ -1213,18 +1226,18 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
ScopePtr aliasScope = childScope(scope, typealias.location); ScopePtr aliasScope = childScope(scope, typealias.location);
aliasScope->level = scope->level.incr(); aliasScope->level = scope->level.incr();
for (TypeId ty : binding->typeParams) for (auto param : binding->typeParams)
{ {
auto generic = get<GenericTypeVar>(ty); auto generic = get<GenericTypeVar>(param.ty);
LUAU_ASSERT(generic); LUAU_ASSERT(generic);
aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, param.ty};
} }
for (TypePackId tp : binding->typePackParams) for (auto param : binding->typePackParams)
{ {
auto generic = get<GenericTypePack>(tp); auto generic = get<GenericTypePack>(param.tp);
LUAU_ASSERT(generic); LUAU_ASSERT(generic);
aliasScope->privateTypePackBindings[generic->name] = tp; aliasScope->privateTypePackBindings[generic->name] = param.tp;
} }
TypeId ty = resolveType(aliasScope, *typealias.type); TypeId ty = resolveType(aliasScope, *typealias.type);
@ -1233,9 +1246,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
// If the table is already named and we want to rename the type function, we have to bind new alias to a copy // If the table is already named and we want to rename the type function, we have to bind new alias to a copy
if (ttv->name) if (ttv->name)
{ {
bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(),
binding->typeParams.end(), [](auto&& itp, auto&& tp) {
return itp == tp.ty;
});
bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(),
binding->typePackParams.begin(), binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) {
return itpp == tpp.tp;
});
// Copy can be skipped if this is an identical alias // Copy can be skipped if this is an identical alias
if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || if (ttv->name != name || !sameTys || !sameTps)
ttv->instantiatedTypePackParams != binding->typePackParams)
{ {
// This is a shallow clone, original recursive links to self are not updated // This is a shallow clone, original recursive links to self are not updated
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state};
@ -1243,8 +1264,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.methodDefinitionLocations = ttv->methodDefinitionLocations;
clone.definitionModuleName = ttv->definitionModuleName; clone.definitionModuleName = ttv->definitionModuleName;
clone.name = name; clone.name = name;
clone.instantiatedTypeParams = binding->typeParams;
clone.instantiatedTypePackParams = binding->typePackParams; for (auto param : binding->typeParams)
clone.instantiatedTypeParams.push_back(param.ty);
for (auto param : binding->typePackParams)
clone.instantiatedTypePackParams.push_back(param.tp);
ty = addType(std::move(clone)); ty = addType(std::move(clone));
} }
@ -1252,8 +1277,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
else else
{ {
ttv->name = name; ttv->name = name;
ttv->instantiatedTypeParams = binding->typeParams;
ttv->instantiatedTypePackParams = binding->typePackParams; ttv->instantiatedTypeParams.clear();
for (auto param : binding->typeParams)
ttv->instantiatedTypeParams.push_back(param.ty);
ttv->instantiatedTypePackParams.clear();
for (auto param : binding->typePackParams)
ttv->instantiatedTypePackParams.push_back(param.tp);
} }
} }
else if (auto mtv = getMutable<MetatableTypeVar>(follow(ty))) else if (auto mtv = getMutable<MetatableTypeVar>(follow(ty)))
@ -1367,9 +1398,21 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo
auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks); auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks);
std::vector<TypeId> genericTys;
genericTys.reserve(generics.size());
std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) {
return el.ty;
});
std::vector<TypePackId> genericTps;
genericTps.reserve(genericPacks.size());
std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) {
return el.tp;
});
TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId argPack = resolveTypePack(funScope, global.params);
TypePackId retPack = resolveTypePack(funScope, global.retTypes); TypePackId retPack = resolveTypePack(funScope, global.retTypes);
TypeId fnType = addType(FunctionTypeVar{funScope->level, generics, genericPacks, argPack, retPack}); TypeId fnType = addType(FunctionTypeVar{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack});
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(fnType); FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(fnType);
ftv->argNames.reserve(global.paramNames.size); ftv->argNames.reserve(global.paramNames.size);
@ -1394,7 +1437,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr&
ExprResult<TypeId> result; ExprResult<TypeId> result;
if (auto a = expr.as<AstExprGroup>()) if (auto a = expr.as<AstExprGroup>())
result = checkExpr(scope, *a->expr); result = checkExpr(scope, *a->expr, FFlag::LuauGroupExpectedType ? expectedType : std::nullopt);
else if (expr.is<AstExprConstantNil>()) else if (expr.is<AstExprConstantNil>())
result = {nilType}; result = {nilType};
else if (const AstExprConstantBool* bexpr = expr.as<AstExprConstantBool>()) else if (const AstExprConstantBool* bexpr = expr.as<AstExprConstantBool>())
@ -1438,21 +1481,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr&
else if (auto a = expr.as<AstExprError>()) else if (auto a = expr.as<AstExprError>())
result = checkExpr(scope, *a); result = checkExpr(scope, *a);
else if (auto a = expr.as<AstExprIfElse>()) else if (auto a = expr.as<AstExprIfElse>())
{ result = checkExpr(scope, *a, FFlag::LuauIfElseExpectedType2 ? expectedType : std::nullopt);
if (FFlag::LuauIfElseExpressionAnalysisSupport)
{
result = checkExpr(scope, *a);
}
else
{
// Note: When the fast flag is disabled we can't skip the handling of AstExprIfElse
// because we would generate an ICE. We also can't use the default value
// of result, because it will lead to a compiler crash.
// Note: LuauIfElseExpressionBaseSupport can be used to disable parser support
// for if-else expressions which will mean this node type is never created.
result = {anyType};
}
}
else else
ice("Unhandled AstExpr?"); ice("Unhandled AstExpr?");
@ -1895,7 +1924,7 @@ TypeId TypeChecker::checkExprTable(
} }
} }
TableState state = (expr.items.size == 0 || isNonstrictMode()) ? TableState::Unsealed : TableState::Sealed; TableState state = (expr.items.size == 0 || isNonstrictMode() || FFlag::LuauUnsealedTableLiteral) ? TableState::Unsealed : TableState::Sealed;
TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state};
table.definitionModuleName = currentModuleName; table.definitionModuleName = currentModuleName;
return addType(table); return addType(table);
@ -2549,23 +2578,34 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr
return {errorRecoveryType(scope)}; return {errorRecoveryType(scope)};
} }
ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType)
{ {
ExprResult<TypeId> result = checkExpr(scope, *expr.condition); ExprResult<TypeId> result = checkExpr(scope, *expr.condition);
ScopePtr trueScope = childScope(scope, expr.trueExpr->location); ScopePtr trueScope = childScope(scope, expr.trueExpr->location);
reportErrors(resolve(result.predicates, trueScope, true)); reportErrors(resolve(result.predicates, trueScope, true));
ExprResult<TypeId> trueType = checkExpr(trueScope, *expr.trueExpr); ExprResult<TypeId> trueType = checkExpr(trueScope, *expr.trueExpr, expectedType);
ScopePtr falseScope = childScope(scope, expr.falseExpr->location); ScopePtr falseScope = childScope(scope, expr.falseExpr->location);
// Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope. // Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope.
resolve(result.predicates, falseScope, false); resolve(result.predicates, falseScope, false);
ExprResult<TypeId> falseType = checkExpr(falseScope, *expr.falseExpr); ExprResult<TypeId> falseType = checkExpr(falseScope, *expr.falseExpr, expectedType);
unify(falseType.type, trueType.type, expr.location); if (FFlag::LuauIfElseBranchTypeUnion)
{
if (falseType.type == trueType.type)
return {trueType.type};
// TODO: normalize(UnionTypeVar{{trueType, falseType}}) std::vector<TypeId> types = reduceUnion({trueType.type, falseType.type});
// For now both trueType and falseType must be the same type. return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})};
return {trueType.type}; }
else
{
unify(falseType.type, trueType.type, expr.location);
// TODO: normalize(UnionTypeVar{{trueType, falseType}})
// For now both trueType and falseType must be the same type.
return {trueType.type};
}
} }
TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr)
@ -3032,7 +3072,20 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
defn.varargLocation = expr.vararg ? std::make_optional(expr.varargLocation) : std::nullopt; defn.varargLocation = expr.vararg ? std::make_optional(expr.varargLocation) : std::nullopt;
defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0));
TypeId funTy = addType(FunctionTypeVar(funScope->level, generics, genericPacks, argPack, retPack, std::move(defn), bool(expr.self))); std::vector<TypeId> genericTys;
genericTys.reserve(generics.size());
std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) {
return el.ty;
});
std::vector<TypePackId> genericTps;
genericTps.reserve(genericPacks.size());
std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) {
return el.tp;
});
TypeId funTy =
addType(FunctionTypeVar(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self)));
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(funTy); FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(funTy);
@ -4848,11 +4901,38 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty())
return tf->type; return tf->type;
if (!lit->hasParameterList && !tf->typePackParams.empty()) bool hasDefaultTypes = false;
bool hasDefaultPacks = false;
bool parameterCountErrorReported = false;
if (FFlag::LuauTypeAliasDefaults)
{ {
reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) {
if (!FFlag::LuauErrorRecoveryType) return el.defaultValue.has_value();
return errorRecoveryType(scope); });
hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) {
return el.defaultValue.has_value();
});
if (!lit->hasParameterList)
{
if ((!tf->typeParams.empty() && !hasDefaultTypes) || (!tf->typePackParams.empty() && !hasDefaultPacks))
{
reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}});
parameterCountErrorReported = true;
if (!FFlag::LuauErrorRecoveryType)
return errorRecoveryType(scope);
}
}
}
else
{
if (!lit->hasParameterList && !tf->typePackParams.empty())
{
reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}});
if (!FFlag::LuauErrorRecoveryType)
return errorRecoveryType(scope);
}
} }
std::vector<TypeId> typeParams; std::vector<TypeId> typeParams;
@ -4892,14 +4972,89 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
if (typePackParams.empty() && !extraTypes.empty()) if (typePackParams.empty() && !extraTypes.empty())
typePackParams.push_back(addTypePack(extraTypes)); typePackParams.push_back(addTypePack(extraTypes));
if (FFlag::LuauTypeAliasDefaults)
{
size_t typesProvided = typeParams.size();
size_t typesRequired = tf->typeParams.size();
size_t packsProvided = typePackParams.size();
size_t packsRequired = tf->typePackParams.size();
bool notEnoughParameters =
(typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired);
bool hasDefaultParameters = hasDefaultTypes || hasDefaultPacks;
// Add default type and type pack parameters if that's required and it's possible
if (notEnoughParameters && hasDefaultParameters)
{
// 'applyTypeFunction' is used to substitute default types that reference previous generic types
applyTypeFunction.typeArguments.clear();
applyTypeFunction.typePackArguments.clear();
applyTypeFunction.currentModule = currentModule;
applyTypeFunction.level = scope->level;
applyTypeFunction.encounteredForwardedType = false;
for (size_t i = 0; i < typesProvided; ++i)
applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i];
if (typesProvided < typesRequired)
{
for (size_t i = typesProvided; i < typesRequired; ++i)
{
TypeId defaultTy = tf->typeParams[i].defaultValue.value_or(nullptr);
if (!defaultTy)
break;
std::optional<TypeId> maybeInstantiated = applyTypeFunction.substitute(defaultTy);
if (!maybeInstantiated.has_value())
{
reportError(annotation.location, UnificationTooComplex{});
maybeInstantiated = errorRecoveryType(scope);
}
applyTypeFunction.typeArguments[tf->typeParams[i].ty] = *maybeInstantiated;
typeParams.push_back(*maybeInstantiated);
}
}
for (size_t i = 0; i < packsProvided; ++i)
applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = typePackParams[i];
if (packsProvided < packsRequired)
{
for (size_t i = packsProvided; i < packsRequired; ++i)
{
TypePackId defaultTp = tf->typePackParams[i].defaultValue.value_or(nullptr);
if (!defaultTp)
break;
std::optional<TypePackId> maybeInstantiated = applyTypeFunction.substitute(defaultTp);
if (!maybeInstantiated.has_value())
{
reportError(annotation.location, UnificationTooComplex{});
maybeInstantiated = errorRecoveryTypePack(scope);
}
applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = *maybeInstantiated;
typePackParams.push_back(*maybeInstantiated);
}
}
}
}
// If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack
if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size())
typePackParams.push_back(addTypePack({})); typePackParams.push_back(addTypePack({}));
if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size())
{ {
reportError( if (!parameterCountErrorReported)
TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); reportError(
TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}});
if (FFlag::LuauErrorRecoveryType) if (FFlag::LuauErrorRecoveryType)
{ {
@ -4913,11 +5068,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
return errorRecoveryType(scope); return errorRecoveryType(scope);
} }
if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) if (FFlag::LuauRecursiveTypeParameterRestriction)
{ {
bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) {
return itp == tp.ty;
});
bool sameTps = std::equal(
typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) {
return itpp == tpp.tp;
});
// If the generic parameters and the type arguments are the same, we are about to // If the generic parameters and the type arguments are the same, we are about to
// perform an identity substitution, which we can just short-circuit. // perform an identity substitution, which we can just short-circuit.
return tf->type; if (sameTys && sameTps)
return tf->type;
} }
return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location);
@ -4948,7 +5112,19 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
TypePackId argTypes = resolveTypePack(funcScope, func->argTypes); TypePackId argTypes = resolveTypePack(funcScope, func->argTypes);
TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes); TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes);
TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(generics), std::move(genericPacks), argTypes, retTypes}); std::vector<TypeId> genericTys;
genericTys.reserve(generics.size());
std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) {
return el.ty;
});
std::vector<TypePackId> genericTps;
genericTps.reserve(genericPacks.size());
std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) {
return el.tp;
});
TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes});
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(fnType); FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(fnType);
@ -5137,11 +5313,11 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
applyTypeFunction.typeArguments.clear(); applyTypeFunction.typeArguments.clear();
for (size_t i = 0; i < tf.typeParams.size(); ++i) for (size_t i = 0; i < tf.typeParams.size(); ++i)
applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; applyTypeFunction.typeArguments[tf.typeParams[i].ty] = typeParams[i];
applyTypeFunction.typePackArguments.clear(); applyTypeFunction.typePackArguments.clear();
for (size_t i = 0; i < tf.typePackParams.size(); ++i) for (size_t i = 0; i < tf.typePackParams.size(); ++i)
applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; applyTypeFunction.typePackArguments[tf.typePackParams[i].tp] = typePackParams[i];
applyTypeFunction.currentModule = currentModule; applyTypeFunction.currentModule = currentModule;
applyTypeFunction.level = scope->level; applyTypeFunction.level = scope->level;
@ -5213,17 +5389,23 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
return instantiated; return instantiated;
} }
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt, GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node,
const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames) const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericTypePack>& genericPackNames)
{ {
LUAU_ASSERT(scope->parent); LUAU_ASSERT(scope->parent);
const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level;
std::vector<TypeId> generics; std::vector<GenericTypeDefinition> generics;
for (const AstName& generic : genericNames)
for (const AstGenericType& generic : genericNames)
{ {
Name n = generic.value; std::optional<TypeId> defaultValue;
if (FFlag::LuauTypeAliasDefaults && generic.defaultValue)
defaultValue = resolveType(scope, *generic.defaultValue);
Name n = generic.name.value;
// These generics are the only thing that will ever be added to scope, so we can be certain that // These generics are the only thing that will ever be added to scope, so we can be certain that
// a collision can only occur when two generic typevars have the same name. // a collision can only occur when two generic typevars have the same name.
@ -5246,14 +5428,20 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
g = addType(Unifiable::Generic{level, n}); g = addType(Unifiable::Generic{level, n});
} }
generics.push_back(g); generics.push_back({g, defaultValue});
scope->privateTypeBindings[n] = TypeFun{{}, g}; scope->privateTypeBindings[n] = TypeFun{{}, g};
} }
std::vector<TypePackId> genericPacks; std::vector<GenericTypePackDefinition> genericPacks;
for (const AstName& genericPack : genericPackNames)
for (const AstGenericTypePack& genericPack : genericPackNames)
{ {
Name n = genericPack.value; std::optional<TypePackId> defaultValue;
if (FFlag::LuauTypeAliasDefaults && genericPack.defaultValue)
defaultValue = resolveTypePack(scope, *genericPack.defaultValue);
Name n = genericPack.name.value;
// These generics are the only thing that will ever be added to scope, so we can be certain that // These generics are the only thing that will ever be added to scope, so we can be certain that
// a collision can only occur when two generic typevars have the same name. // a collision can only occur when two generic typevars have the same name.
@ -5276,7 +5464,7 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); g = addTypePack(TypePackVar{Unifiable::Generic{level, n}});
} }
genericPacks.push_back(g); genericPacks.push_back({g, defaultValue});
scope->privateTypePackBindings[n] = g; scope->privateTypePackBindings[n] = g;
} }

View File

@ -19,6 +19,7 @@
LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false)
LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false)
LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauErrorRecoveryType)
LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauFreezeArena)
@ -453,6 +454,9 @@ bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs)
static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs) static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs)
{ {
if (FFlag::LuauMetatableAreEqualRecursion && areSeen(seen, &lhs, &rhs))
return true;
return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable); return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable);
} }

View File

@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false)
LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauSingletonTypes)
LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAG(LuauProperTypeLevels);
LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false)
LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false)
LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false)
@ -1170,9 +1171,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
// If both are at the end, we're done // If both are at the end, we're done
if (!superIter.good() && !subIter.good()) if (!superIter.good() && !subIter.good())
{ {
if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail)
{
tryUnify_(*subTpv->tail, *superTpv->tail);
break;
}
const bool lFreeTail = superTpv->tail && log.getMutable<FreeTypePack>(log.follow(*superTpv->tail)) != nullptr; const bool lFreeTail = superTpv->tail && log.getMutable<FreeTypePack>(log.follow(*superTpv->tail)) != nullptr;
const bool rFreeTail = subTpv->tail && log.getMutable<FreeTypePack>(log.follow(*subTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && log.getMutable<FreeTypePack>(log.follow(*subTpv->tail)) != nullptr;
if (lFreeTail && rFreeTail) if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail)
tryUnify_(*subTpv->tail, *superTpv->tail); tryUnify_(*subTpv->tail, *superTpv->tail);
else if (lFreeTail) else if (lFreeTail)
tryUnify_(emptyTp, *superTpv->tail); tryUnify_(emptyTp, *superTpv->tail);
@ -1370,9 +1377,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
// If both are at the end, we're done // If both are at the end, we're done
if (!superIter.good() && !subIter.good()) if (!superIter.good() && !subIter.good())
{ {
if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail)
{
tryUnify_(*subTpv->tail, *superTpv->tail);
break;
}
const bool lFreeTail = superTpv->tail && get<FreeTypePack>(follow(*superTpv->tail)) != nullptr; const bool lFreeTail = superTpv->tail && get<FreeTypePack>(follow(*superTpv->tail)) != nullptr;
const bool rFreeTail = subTpv->tail && get<FreeTypePack>(follow(*subTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && get<FreeTypePack>(follow(*subTpv->tail)) != nullptr;
if (lFreeTail && rFreeTail) if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail)
tryUnify_(*subTpv->tail, *superTpv->tail); tryUnify_(*subTpv->tail, *superTpv->tail);
else if (lFreeTail) else if (lFreeTail)
tryUnify_(emptyTp, *superTpv->tail); tryUnify_(emptyTp, *superTpv->tail);

View File

@ -334,6 +334,20 @@ struct AstTypeList
using AstArgumentName = std::pair<AstName, Location>; // TODO: remove and replace when we get a common struct for this pair instead of AstName using AstArgumentName = std::pair<AstName, Location>; // TODO: remove and replace when we get a common struct for this pair instead of AstName
struct AstGenericType
{
AstName name;
Location location;
AstType* defaultValue = nullptr;
};
struct AstGenericTypePack
{
AstName name;
Location location;
AstTypePack* defaultValue = nullptr;
};
extern int gAstRttiIndex; extern int gAstRttiIndex;
template<typename T> template<typename T>
@ -569,15 +583,15 @@ class AstExprFunction : public AstExpr
public: public:
LUAU_RTTI(AstExprFunction) LUAU_RTTI(AstExprFunction)
AstExprFunction(const Location& location, const AstArray<AstName>& generics, const AstArray<AstName>& genericPacks, AstLocal* self, AstExprFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstArray<AstLocal*>& args, std::optional<Location> vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, AstLocal* self, const AstArray<AstLocal*>& args, std::optional<Location> vararg, AstStatBlock* body, size_t functionDepth,
std::optional<AstTypeList> returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, const AstName& debugname, std::optional<AstTypeList> returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false,
std::optional<Location> argLocation = std::nullopt); std::optional<Location> argLocation = std::nullopt);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstArray<AstName> generics; AstArray<AstGenericType> generics;
AstArray<AstName> genericPacks; AstArray<AstGenericTypePack> genericPacks;
AstLocal* self; AstLocal* self;
AstArray<AstLocal*> args; AstArray<AstLocal*> args;
bool hasReturnAnnotation; bool hasReturnAnnotation;
@ -942,14 +956,14 @@ class AstStatTypeAlias : public AstStat
public: public:
LUAU_RTTI(AstStatTypeAlias) LUAU_RTTI(AstStatTypeAlias)
AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, const AstArray<AstName>& genericPacks, AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics,
AstType* type, bool exported); const AstArray<AstGenericTypePack>& genericPacks, AstType* type, bool exported);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstName name; AstName name;
AstArray<AstName> generics; AstArray<AstGenericType> generics;
AstArray<AstName> genericPacks; AstArray<AstGenericTypePack> genericPacks;
AstType* type; AstType* type;
bool exported; bool exported;
}; };
@ -972,14 +986,15 @@ class AstStatDeclareFunction : public AstStat
public: public:
LUAU_RTTI(AstStatDeclareFunction) LUAU_RTTI(AstStatDeclareFunction)
AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstName>& generics, const AstArray<AstName>& genericPacks, AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics,
const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, const AstTypeList& retTypes); const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstName name; AstName name;
AstArray<AstName> generics; AstArray<AstGenericType> generics;
AstArray<AstName> genericPacks; AstArray<AstGenericTypePack> genericPacks;
AstTypeList params; AstTypeList params;
AstArray<AstArgumentName> paramNames; AstArray<AstArgumentName> paramNames;
AstTypeList retTypes; AstTypeList retTypes;
@ -1077,13 +1092,13 @@ class AstTypeFunction : public AstType
public: public:
LUAU_RTTI(AstTypeFunction) LUAU_RTTI(AstTypeFunction)
AstTypeFunction(const Location& location, const AstArray<AstName>& generics, const AstArray<AstName>& genericPacks, const AstTypeList& argTypes, AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes); const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstArray<AstName> generics; AstArray<AstGenericType> generics;
AstArray<AstName> genericPacks; AstArray<AstGenericTypePack> genericPacks;
AstTypeList argTypes; AstTypeList argTypes;
AstArray<std::optional<AstArgumentName>> argNames; AstArray<std::optional<AstArgumentName>> argNames;
AstTypeList returnTypes; AstTypeList returnTypes;

View File

@ -219,7 +219,7 @@ private:
AstTableIndexer* parseTableIndexerAnnotation(); AstTableIndexer* parseTableIndexerAnnotation();
AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack);
AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<AstName> generics, AstArray<AstName> genericPacks, AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks,
AstArray<AstType*>& params, AstArray<std::optional<AstArgumentName>>& paramNames, AstTypePack* varargAnnotation); AstArray<AstType*>& params, AstArray<std::optional<AstArgumentName>>& paramNames, AstTypePack* varargAnnotation);
AstType* parseTableTypeAnnotation(); AstType* parseTableTypeAnnotation();
@ -281,7 +281,7 @@ private:
Name parseIndexName(const char* context, const Position& previous); Name parseIndexName(const char* context, const Position& previous);
// `<' namelist `>' // `<' namelist `>'
std::pair<AstArray<AstName>, AstArray<AstName>> parseGenericTypeList(); std::pair<AstArray<AstGenericType>, AstArray<AstGenericTypePack>> parseGenericTypeList(bool withDefaultValues);
// `<' typeAnnotation[, ...] `>' // `<' typeAnnotation[, ...] `>'
AstArray<AstTypeOrPack> parseTypeParams(); AstArray<AstTypeOrPack> parseTypeParams();
@ -418,6 +418,8 @@ private:
std::vector<AstDeclaredClassProp> scratchDeclaredClassProps; std::vector<AstDeclaredClassProp> scratchDeclaredClassProps;
std::vector<AstExprTable::Item> scratchItem; std::vector<AstExprTable::Item> scratchItem;
std::vector<AstArgumentName> scratchArgName; std::vector<AstArgumentName> scratchArgName;
std::vector<AstGenericType> scratchGenericTypes;
std::vector<AstGenericTypePack> scratchGenericTypePacks;
std::vector<std::optional<AstArgumentName>> scratchOptArgName; std::vector<std::optional<AstArgumentName>> scratchOptArgName;
std::string scratchData; std::string scratchData;
}; };

View File

@ -158,9 +158,10 @@ void AstExprIndexExpr::visit(AstVisitor* visitor)
} }
} }
AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstName>& generics, const AstArray<AstName>& genericPacks, AstLocal* self, AstExprFunction::AstExprFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstArray<AstLocal*>& args, std::optional<Location> vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, AstLocal* self, const AstArray<AstLocal*>& args, std::optional<Location> vararg, AstStatBlock* body, size_t functionDepth,
std::optional<AstTypeList> returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, std::optional<Location> argLocation) const AstName& debugname, std::optional<AstTypeList> returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd,
std::optional<Location> argLocation)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, generics(generics) , generics(generics)
, genericPacks(genericPacks) , genericPacks(genericPacks)
@ -641,8 +642,8 @@ void AstStatLocalFunction::visit(AstVisitor* visitor)
func->visit(visitor); func->visit(visitor);
} }
AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstName>& generics, AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics,
const AstArray<AstName>& genericPacks, AstType* type, bool exported) const AstArray<AstGenericTypePack>& genericPacks, AstType* type, bool exported)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, name(name) , name(name)
, generics(generics) , generics(generics)
@ -655,7 +656,21 @@ AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name
void AstStatTypeAlias::visit(AstVisitor* visitor) void AstStatTypeAlias::visit(AstVisitor* visitor)
{ {
if (visitor->visit(this)) if (visitor->visit(this))
{
for (const AstGenericType& el : generics)
{
if (el.defaultValue)
el.defaultValue->visit(visitor);
}
for (const AstGenericTypePack& el : genericPacks)
{
if (el.defaultValue)
el.defaultValue->visit(visitor);
}
type->visit(visitor); type->visit(visitor);
}
} }
AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type)
@ -671,8 +686,9 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor)
type->visit(visitor); type->visit(visitor);
} }
AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstName>& generics, AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray<AstGenericType>& generics,
const AstArray<AstName>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames, const AstTypeList& retTypes) const AstArray<AstGenericTypePack>& genericPacks, const AstTypeList& params, const AstArray<AstArgumentName>& paramNames,
const AstTypeList& retTypes)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, name(name) , name(name)
, generics(generics) , generics(generics)
@ -778,7 +794,7 @@ void AstTypeTable::visit(AstVisitor* visitor)
} }
} }
AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstName>& generics, const AstArray<AstName>& genericPacks, AstTypeFunction::AstTypeFunction(const Location& location, const AstArray<AstGenericType>& generics, const AstArray<AstGenericTypePack>& genericPacks,
const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes) const AstTypeList& argTypes, const AstArray<std::optional<AstArgumentName>>& argNames, const AstTypeList& returnTypes)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, generics(generics) , generics(generics)

View File

@ -10,10 +10,10 @@
// See docs/SyntaxChanges.md for an explanation. // See docs/SyntaxChanges.md for an explanation.
LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000)
LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false)
LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false)
LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false)
LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false)
LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false)
LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false)
namespace Luau namespace Luau
{ {
@ -394,23 +394,13 @@ AstStat* Parser::parseIf()
if (lexer.current().type == Lexeme::ReservedElseif) if (lexer.current().type == Lexeme::ReservedElseif)
{ {
if (FFlag::LuauIfStatementRecursionGuard) unsigned int recursionCounterOld = recursionCounter;
{ incrementRecursionCounter("elseif");
unsigned int recursionCounterOld = recursionCounter; elseLocation = lexer.current().location;
incrementRecursionCounter("elseif"); elsebody = parseIf();
elseLocation = lexer.current().location; end = elsebody->location;
elsebody = parseIf(); hasEnd = elsebody->as<AstStatIf>()->hasEnd;
end = elsebody->location; recursionCounter = recursionCounterOld;
hasEnd = elsebody->as<AstStatIf>()->hasEnd;
recursionCounter = recursionCounterOld;
}
else
{
elseLocation = lexer.current().location;
elsebody = parseIf();
end = elsebody->location;
hasEnd = elsebody->as<AstStatIf>()->hasEnd;
}
} }
else else
{ {
@ -772,7 +762,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported)
if (!name) if (!name)
name = Name(nameError, lexer.current().location); name = Name(nameError, lexer.current().location);
auto [generics, genericPacks] = parseGenericTypeList(); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ FFlag::LuauParseTypeAliasDefaults);
expectAndConsume('=', "type alias"); expectAndConsume('=', "type alias");
@ -788,8 +778,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
Name fnName = parseName("function name"); Name fnName = parseName("function name");
// TODO: generic method declarations CLI-39909 // TODO: generic method declarations CLI-39909
AstArray<AstName> generics; AstArray<AstGenericType> generics;
AstArray<AstName> genericPacks; AstArray<AstGenericTypePack> genericPacks;
generics.size = 0; generics.size = 0;
generics.data = nullptr; generics.data = nullptr;
genericPacks.size = 0; genericPacks.size = 0;
@ -849,7 +839,7 @@ AstStat* Parser::parseDeclaration(const Location& start)
nextLexeme(); nextLexeme();
Name globalName = parseName("global function name"); Name globalName = parseName("global function name");
auto [generics, genericPacks] = parseGenericTypeList(); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false);
Lexeme matchParen = lexer.current(); Lexeme matchParen = lexer.current();
@ -991,7 +981,7 @@ std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
{ {
Location start = matchFunction.location; Location start = matchFunction.location;
auto [generics, genericPacks] = parseGenericTypeList(); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false);
Lexeme matchParen = lexer.current(); Lexeme matchParen = lexer.current();
expectAndConsume('(', "function"); expectAndConsume('(', "function");
@ -1228,8 +1218,8 @@ std::pair<Location, AstTypeList> Parser::parseReturnTypeAnnotation()
return {location, AstTypeList{copy(result), varargAnnotation}}; return {location, AstTypeList{copy(result), varargAnnotation}};
} }
AstArray<AstName> generics{nullptr, 0}; AstArray<AstGenericType> generics{nullptr, 0};
AstArray<AstName> genericPacks{nullptr, 0}; AstArray<AstGenericTypePack> genericPacks{nullptr, 0};
AstArray<AstType*> types = copy(result); AstArray<AstType*> types = copy(result);
AstArray<std::optional<AstArgumentName>> names = copy(resultNames); AstArray<std::optional<AstArgumentName>> names = copy(resultNames);
@ -1363,7 +1353,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
Lexeme begin = lexer.current(); Lexeme begin = lexer.current();
auto [generics, genericPacks] = parseGenericTypeList(); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false);
Lexeme parameterStart = lexer.current(); Lexeme parameterStart = lexer.current();
@ -1401,7 +1391,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}};
} }
AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<AstName> generics, AstArray<AstName> genericPacks, AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks,
AstArray<AstType*>& params, AstArray<std::optional<AstArgumentName>>& paramNames, AstTypePack* varargAnnotation) AstArray<AstType*>& params, AstArray<std::optional<AstArgumentName>>& paramNames, AstTypePack* varargAnnotation)
{ {
@ -1448,7 +1438,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
if (c == '|') if (c == '|')
{ {
nextLexeme(); nextLexeme();
parts.push_back(parseSimpleTypeAnnotation(false).type); parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type);
isUnion = true; isUnion = true;
} }
else if (c == '?') else if (c == '?')
@ -1461,7 +1451,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
else if (c == '&') else if (c == '&')
{ {
nextLexeme(); nextLexeme();
parts.push_back(parseSimpleTypeAnnotation(false).type); parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type);
isIntersection = true; isIntersection = true;
} }
else else
@ -1498,7 +1488,7 @@ AstTypeOrPack Parser::parseTypeOrPackAnnotation()
TempVector<AstType*> parts(scratchAnnotation); TempVector<AstType*> parts(scratchAnnotation);
auto [type, typePack] = parseSimpleTypeAnnotation(true); auto [type, typePack] = parseSimpleTypeAnnotation(/* allowPack= */ true);
if (typePack) if (typePack)
{ {
@ -1521,7 +1511,7 @@ AstType* Parser::parseTypeAnnotation()
Location begin = lexer.current().location; Location begin = lexer.current().location;
TempVector<AstType*> parts(scratchAnnotation); TempVector<AstType*> parts(scratchAnnotation);
parts.push_back(parseSimpleTypeAnnotation(false).type); parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type);
recursionCounter = oldRecursionCount; recursionCounter = oldRecursionCount;
@ -2121,7 +2111,7 @@ AstExpr* Parser::parseSimpleExpr()
{ {
return parseTableConstructor(); return parseTableConstructor();
} }
else if (FFlag::LuauIfElseExpressionBaseSupport && lexer.current().type == Lexeme::ReservedIf) else if (lexer.current().type == Lexeme::ReservedIf)
{ {
return parseIfElseExpr(); return parseIfElseExpr();
} }
@ -2341,10 +2331,10 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou
return Name(nameError, location); return Name(nameError, location);
} }
std::pair<AstArray<AstName>, AstArray<AstName>> Parser::parseGenericTypeList() std::pair<AstArray<AstGenericType>, AstArray<AstGenericTypePack>> Parser::parseGenericTypeList(bool withDefaultValues)
{ {
TempVector<AstName> names{scratchName}; TempVector<AstGenericType> names{scratchGenericTypes};
TempVector<AstName> namePacks{scratchPackName}; TempVector<AstGenericTypePack> namePacks{scratchGenericTypePacks};
if (lexer.current().type == '<') if (lexer.current().type == '<')
{ {
@ -2352,21 +2342,73 @@ std::pair<AstArray<AstName>, AstArray<AstName>> Parser::parseGenericTypeList()
nextLexeme(); nextLexeme();
bool seenPack = false; bool seenPack = false;
bool seenDefault = false;
while (true) while (true)
{ {
Location nameLocation = lexer.current().location;
AstName name = parseName().name; AstName name = parseName().name;
if (lexer.current().type == Lexeme::Dot3) if (lexer.current().type == Lexeme::Dot3 || (FFlag::LuauParseRecoverTypePackEllipsis && seenPack))
{ {
seenPack = true; seenPack = true;
nextLexeme();
namePacks.push_back(name); if (FFlag::LuauParseRecoverTypePackEllipsis && lexer.current().type != Lexeme::Dot3)
report(lexer.current().location, "Generic types come before generic type packs");
else
nextLexeme();
if (withDefaultValues && lexer.current().type == '=')
{
seenDefault = true;
nextLexeme();
Lexeme packBegin = lexer.current();
if (shouldParseTypePackAnnotation(lexer))
{
auto typePack = parseTypePackAnnotation();
namePacks.push_back({name, nameLocation, typePack});
}
else if (lexer.current().type == '(')
{
auto [type, typePack] = parseTypeOrPackAnnotation();
if (type)
report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type");
namePacks.push_back({name, nameLocation, typePack});
}
}
else
{
if (seenDefault)
report(lexer.current().location, "Expected default type pack after type pack name");
namePacks.push_back({name, nameLocation, nullptr});
}
} }
else else
{ {
if (seenPack) if (!FFlag::LuauParseRecoverTypePackEllipsis && seenPack)
report(lexer.current().location, "Generic types come before generic type packs"); report(lexer.current().location, "Generic types come before generic type packs");
names.push_back(name); if (withDefaultValues && lexer.current().type == '=')
{
seenDefault = true;
nextLexeme();
AstType* defaultType = parseTypeAnnotation();
names.push_back({name, nameLocation, defaultType});
}
else
{
if (seenDefault)
report(lexer.current().location, "Expected default type after type name");
names.push_back({name, nameLocation, nullptr});
}
} }
if (lexer.current().type == ',') if (lexer.current().type == ',')
@ -2378,8 +2420,8 @@ std::pair<AstArray<AstName>, AstArray<AstName>> Parser::parseGenericTypeList()
expectMatchAndConsume('>', begin); expectMatchAndConsume('>', begin);
} }
AstArray<AstName> generics = copy(names); AstArray<AstGenericType> generics = copy(names);
AstArray<AstName> genericPacks = copy(namePacks); AstArray<AstGenericTypePack> genericPacks = copy(namePacks);
return {generics, genericPacks}; return {generics, genericPacks};
} }

View File

@ -8,6 +8,8 @@
#include "FileUtils.h" #include "FileUtils.h"
LUAU_FASTFLAG(DebugLuauTimeTracing)
enum class ReportFormat enum class ReportFormat
{ {
Default, Default,
@ -105,6 +107,7 @@ static void displayHelp(const char* argv0)
printf("Available options:\n"); printf("Available options:\n");
printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n");
printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n");
printf(" --timetrace: record compiler time tracing information into trace.json\n");
} }
static int assertionHandler(const char* expr, const char* file, int line, const char* function) static int assertionHandler(const char* expr, const char* file, int line, const char* function)
@ -213,8 +216,18 @@ int main(int argc, char** argv)
format = ReportFormat::Gnu; format = ReportFormat::Gnu;
else if (strcmp(argv[i], "--annotate") == 0) else if (strcmp(argv[i], "--annotate") == 0)
annotate = true; annotate = true;
else if (strcmp(argv[i], "--timetrace") == 0)
FFlag::DebugLuauTimeTracing.value = true;
} }
#if !defined(LUAU_ENABLE_TIME_TRACE)
if (FFlag::DebugLuauTimeTracing)
{
printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n");
return 1;
}
#endif
Luau::FrontendOptions frontendOptions; Luau::FrontendOptions frontendOptions;
frontendOptions.retainFullTypeGraphs = annotate; frontendOptions.retainFullTypeGraphs = annotate;
@ -240,5 +253,8 @@ int main(int argc, char** argv)
fprintf(stderr, "%s: %s\n", pair.first.c_str(), pair.second.c_str()); fprintf(stderr, "%s: %s\n", pair.first.c_str(), pair.second.c_str());
} }
return (format == ReportFormat::Luacheck) ? 0 : failed; if (format == ReportFormat::Luacheck)
return 0;
else
return failed ? 1 : 0;
} }

View File

@ -19,6 +19,16 @@
#include <fcntl.h> #include <fcntl.h>
#endif #endif
LUAU_FASTFLAG(DebugLuauTimeTracing)
enum class CliMode
{
Unknown,
Repl,
Compile,
RunSourceFiles
};
enum class CompileFormat enum class CompileFormat
{ {
Text, Text,
@ -485,8 +495,10 @@ static void displayHelp(const char* argv0)
printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n");
printf("\n"); printf("\n");
printf("Available options:\n"); printf("Available options:\n");
printf(" -h, --help: Display this usage message.\n");
printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n");
printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n");
printf(" --timetrace: record compiler time tracing information into trace.json\n");
} }
static int assertionHandler(const char* expr, const char* file, int line, const char* function) static int assertionHandler(const char* expr, const char* file, int line, const char* function)
@ -503,71 +515,112 @@ int main(int argc, char** argv)
if (strncmp(flag->name, "Luau", 4) == 0) if (strncmp(flag->name, "Luau", 4) == 0)
flag->value = true; flag->value = true;
if (argc == 1) CliMode mode = CliMode::Unknown;
{ CompileFormat compileFormat{};
runRepl(); int profile = 0;
return 0; bool coverage = false;
}
if (argc >= 2 && strcmp(argv[1], "--help") == 0)
{
displayHelp(argv[0]);
return 0;
}
// Set the mode if the user has explicitly specified one.
int argStart = 1;
if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0)
{ {
CompileFormat format = CompileFormat::Text; argStart++;
mode = CliMode::Compile;
if (strcmp(argv[1], "--compile") == 0)
{
compileFormat = CompileFormat::Text;
}
else if (strcmp(argv[1], "--compile=binary") == 0)
{
compileFormat = CompileFormat::Binary;
}
else if (strcmp(argv[1], "--compile=text") == 0)
{
compileFormat = CompileFormat::Text;
}
else
{
fprintf(stdout, "Error: Unrecognized value for '--compile' specified.\n");
return -1;
}
}
if (strcmp(argv[1], "--compile=binary") == 0) for (int i = argStart; i < argc; i++)
format = CompileFormat::Binary; {
if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0)
{
displayHelp(argv[0]);
return 0;
}
else if (strcmp(argv[i], "--profile") == 0)
{
profile = 10000; // default to 10 KHz
}
else if (strncmp(argv[i], "--profile=", 10) == 0)
{
profile = atoi(argv[i] + 10);
}
else if (strcmp(argv[i], "--coverage") == 0)
{
coverage = true;
}
else if (strcmp(argv[i], "--timetrace") == 0)
{
FFlag::DebugLuauTimeTracing.value = true;
#if !defined(LUAU_ENABLE_TIME_TRACE)
printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n");
return 1;
#endif
}
else if (argv[i][0] == '-')
{
fprintf(stdout, "Error: Unrecognized option '%s'.\n\n", argv[i]);
displayHelp(argv[0]);
return 1;
}
}
const std::vector<std::string> files = getSourceFiles(argc, argv);
if (mode == CliMode::Unknown)
{
mode = files.empty() ? CliMode::Repl : CliMode::RunSourceFiles;
}
switch (mode)
{
case CliMode::Compile:
{
#ifdef _WIN32 #ifdef _WIN32
if (format == CompileFormat::Binary) if (compileFormat == CompileFormat::Binary)
_setmode(_fileno(stdout), _O_BINARY); _setmode(_fileno(stdout), _O_BINARY);
#endif #endif
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0; int failed = 0;
for (const std::string& path : files) for (const std::string& path : files)
failed += !compileFile(path.c_str(), format); failed += !compileFile(path.c_str(), compileFormat);
return failed; return failed ? 1 : 0;
} }
case CliMode::Repl:
{
runRepl();
return 0;
}
case CliMode::RunSourceFiles:
{ {
std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close); std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get(); lua_State* L = globalState.get();
setupState(L); setupState(L);
int profile = 0;
bool coverage = false;
for (int i = 1; i < argc; ++i)
{
if (argv[i][0] != '-')
continue;
if (strcmp(argv[i], "--profile") == 0)
profile = 10000; // default to 10 KHz
else if (strncmp(argv[i], "--profile=", 10) == 0)
profile = atoi(argv[i] + 10);
else if (strcmp(argv[i], "--coverage") == 0)
coverage = true;
}
if (profile) if (profile)
profilerStart(L, profile); profilerStart(L, profile);
if (coverage) if (coverage)
coverageInit(L); coverageInit(L);
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0; int failed = 0;
for (const std::string& path : files) for (const std::string& path : files)
@ -582,6 +635,10 @@ int main(int argc, char** argv)
if (coverage) if (coverage)
coverageDump("coverage.out"); coverageDump("coverage.out");
return failed; return failed ? 1 : 0;
}
case CliMode::Unknown:
default:
LUAU_ASSERT(!"Unhandled cli mode.");
} }
} }

197
Compiler/src/Builtins.cpp Normal file
View File

@ -0,0 +1,197 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Builtins.h"
#include "Luau/Bytecode.h"
#include "Luau/Compiler.h"
namespace Luau
{
namespace Compile
{
Builtin getBuiltin(AstExpr* node, const DenseHashMap<AstName, Global>& globals, const DenseHashMap<AstLocal*, Variable>& variables)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
{
const Variable* v = variables.find(expr->local);
return v && !v->written && v->init ? getBuiltin(v->init, globals, variables) : Builtin();
}
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
{
if (AstExprGlobal* object = expr->expr->as<AstExprGlobal>())
{
return getGlobalState(globals, object->name) == Global::Default ? Builtin{object->name, expr->index} : Builtin();
}
else
{
return Builtin();
}
}
else if (AstExprGlobal* expr = node->as<AstExprGlobal>())
{
return getGlobalState(globals, expr->name) == Global::Default ? Builtin{AstName(), expr->name} : Builtin();
}
else
{
return Builtin();
}
}
int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options)
{
if (builtin.empty())
return -1;
if (builtin.isGlobal("assert"))
return LBF_ASSERT;
if (builtin.isGlobal("type"))
return LBF_TYPE;
if (builtin.isGlobal("typeof"))
return LBF_TYPEOF;
if (builtin.isGlobal("rawset"))
return LBF_RAWSET;
if (builtin.isGlobal("rawget"))
return LBF_RAWGET;
if (builtin.isGlobal("rawequal"))
return LBF_RAWEQUAL;
if (builtin.isGlobal("unpack"))
return LBF_TABLE_UNPACK;
if (builtin.object == "math")
{
if (builtin.method == "abs")
return LBF_MATH_ABS;
if (builtin.method == "acos")
return LBF_MATH_ACOS;
if (builtin.method == "asin")
return LBF_MATH_ASIN;
if (builtin.method == "atan2")
return LBF_MATH_ATAN2;
if (builtin.method == "atan")
return LBF_MATH_ATAN;
if (builtin.method == "ceil")
return LBF_MATH_CEIL;
if (builtin.method == "cosh")
return LBF_MATH_COSH;
if (builtin.method == "cos")
return LBF_MATH_COS;
if (builtin.method == "deg")
return LBF_MATH_DEG;
if (builtin.method == "exp")
return LBF_MATH_EXP;
if (builtin.method == "floor")
return LBF_MATH_FLOOR;
if (builtin.method == "fmod")
return LBF_MATH_FMOD;
if (builtin.method == "frexp")
return LBF_MATH_FREXP;
if (builtin.method == "ldexp")
return LBF_MATH_LDEXP;
if (builtin.method == "log10")
return LBF_MATH_LOG10;
if (builtin.method == "log")
return LBF_MATH_LOG;
if (builtin.method == "max")
return LBF_MATH_MAX;
if (builtin.method == "min")
return LBF_MATH_MIN;
if (builtin.method == "modf")
return LBF_MATH_MODF;
if (builtin.method == "pow")
return LBF_MATH_POW;
if (builtin.method == "rad")
return LBF_MATH_RAD;
if (builtin.method == "sinh")
return LBF_MATH_SINH;
if (builtin.method == "sin")
return LBF_MATH_SIN;
if (builtin.method == "sqrt")
return LBF_MATH_SQRT;
if (builtin.method == "tanh")
return LBF_MATH_TANH;
if (builtin.method == "tan")
return LBF_MATH_TAN;
if (builtin.method == "clamp")
return LBF_MATH_CLAMP;
if (builtin.method == "sign")
return LBF_MATH_SIGN;
if (builtin.method == "round")
return LBF_MATH_ROUND;
}
if (builtin.object == "bit32")
{
if (builtin.method == "arshift")
return LBF_BIT32_ARSHIFT;
if (builtin.method == "band")
return LBF_BIT32_BAND;
if (builtin.method == "bnot")
return LBF_BIT32_BNOT;
if (builtin.method == "bor")
return LBF_BIT32_BOR;
if (builtin.method == "bxor")
return LBF_BIT32_BXOR;
if (builtin.method == "btest")
return LBF_BIT32_BTEST;
if (builtin.method == "extract")
return LBF_BIT32_EXTRACT;
if (builtin.method == "lrotate")
return LBF_BIT32_LROTATE;
if (builtin.method == "lshift")
return LBF_BIT32_LSHIFT;
if (builtin.method == "replace")
return LBF_BIT32_REPLACE;
if (builtin.method == "rrotate")
return LBF_BIT32_RROTATE;
if (builtin.method == "rshift")
return LBF_BIT32_RSHIFT;
if (builtin.method == "countlz")
return LBF_BIT32_COUNTLZ;
if (builtin.method == "countrz")
return LBF_BIT32_COUNTRZ;
}
if (builtin.object == "string")
{
if (builtin.method == "byte")
return LBF_STRING_BYTE;
if (builtin.method == "char")
return LBF_STRING_CHAR;
if (builtin.method == "len")
return LBF_STRING_LEN;
if (builtin.method == "sub")
return LBF_STRING_SUB;
}
if (builtin.object == "table")
{
if (builtin.method == "insert")
return LBF_TABLE_INSERT;
if (builtin.method == "unpack")
return LBF_TABLE_UNPACK;
}
if (options.vectorCtor)
{
if (options.vectorLib)
{
if (builtin.isMethod(options.vectorLib, options.vectorCtor))
return LBF_VECTOR;
}
else
{
if (builtin.isGlobal(options.vectorCtor))
return LBF_VECTOR;
}
}
return -1;
}
} // namespace Compile
} // namespace Luau

41
Compiler/src/Builtins.h Normal file
View File

@ -0,0 +1,41 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "ValueTracking.h"
namespace Luau
{
struct CompileOptions;
}
namespace Luau
{
namespace Compile
{
struct Builtin
{
AstName object;
AstName method;
bool empty() const
{
return object == AstName() && method == AstName();
}
bool isGlobal(const char* name) const
{
return object == AstName() && method == name;
}
bool isMethod(const char* table, const char* name) const
{
return object == table && method == name;
}
};
Builtin getBuiltin(AstExpr* node, const DenseHashMap<AstName, Global>& globals, const DenseHashMap<AstLocal*, Variable>& variables);
int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options);
} // namespace Compile
} // namespace Luau

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,394 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "ConstantFolding.h"
#include <math.h>
namespace Luau
{
namespace Compile
{
static bool constantsEqual(const Constant& la, const Constant& ra)
{
LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown);
switch (la.type)
{
case Constant::Type_Nil:
return ra.type == Constant::Type_Nil;
case Constant::Type_Boolean:
return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean;
case Constant::Type_Number:
return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber;
case Constant::Type_String:
return ra.type == Constant::Type_String && la.stringLength == ra.stringLength && memcmp(la.valueString, ra.valueString, la.stringLength) == 0;
default:
LUAU_ASSERT(!"Unexpected constant type in comparison");
return false;
}
}
static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg)
{
switch (op)
{
case AstExprUnary::Not:
if (arg.type != Constant::Type_Unknown)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = !arg.isTruthful();
}
break;
case AstExprUnary::Minus:
if (arg.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = -arg.valueNumber;
}
break;
case AstExprUnary::Len:
if (arg.type == Constant::Type_String)
{
result.type = Constant::Type_Number;
result.valueNumber = double(arg.stringLength);
}
break;
default:
LUAU_ASSERT(!"Unexpected unary operation");
}
}
static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra)
{
switch (op)
{
case AstExprBinary::Add:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber + ra.valueNumber;
}
break;
case AstExprBinary::Sub:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber - ra.valueNumber;
}
break;
case AstExprBinary::Mul:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber * ra.valueNumber;
}
break;
case AstExprBinary::Div:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber / ra.valueNumber;
}
break;
case AstExprBinary::Mod:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber;
}
break;
case AstExprBinary::Pow:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = pow(la.valueNumber, ra.valueNumber);
}
break;
case AstExprBinary::Concat:
break;
case AstExprBinary::CompareNe:
if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = !constantsEqual(la, ra);
}
break;
case AstExprBinary::CompareEq:
if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = constantsEqual(la, ra);
}
break;
case AstExprBinary::CompareLt:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = la.valueNumber < ra.valueNumber;
}
break;
case AstExprBinary::CompareLe:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = la.valueNumber <= ra.valueNumber;
}
break;
case AstExprBinary::CompareGt:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = la.valueNumber > ra.valueNumber;
}
break;
case AstExprBinary::CompareGe:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = la.valueNumber >= ra.valueNumber;
}
break;
case AstExprBinary::And:
if (la.type != Constant::Type_Unknown)
{
result = la.isTruthful() ? ra : la;
}
break;
case AstExprBinary::Or:
if (la.type != Constant::Type_Unknown)
{
result = la.isTruthful() ? la : ra;
}
break;
default:
LUAU_ASSERT(!"Unexpected binary operation");
}
}
struct ConstantVisitor : AstVisitor
{
DenseHashMap<AstExpr*, Constant>& constants;
DenseHashMap<AstLocal*, Variable>& variables;
DenseHashMap<AstLocal*, Constant> locals;
ConstantVisitor(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables)
: constants(constants)
, variables(variables)
, locals(nullptr)
{
}
Constant analyze(AstExpr* node)
{
Constant result;
result.type = Constant::Type_Unknown;
if (AstExprGroup* expr = node->as<AstExprGroup>())
{
result = analyze(expr->expr);
}
else if (node->is<AstExprConstantNil>())
{
result.type = Constant::Type_Nil;
}
else if (AstExprConstantBool* expr = node->as<AstExprConstantBool>())
{
result.type = Constant::Type_Boolean;
result.valueBoolean = expr->value;
}
else if (AstExprConstantNumber* expr = node->as<AstExprConstantNumber>())
{
result.type = Constant::Type_Number;
result.valueNumber = expr->value;
}
else if (AstExprConstantString* expr = node->as<AstExprConstantString>())
{
result.type = Constant::Type_String;
result.valueString = expr->value.data;
result.stringLength = unsigned(expr->value.size);
}
else if (AstExprLocal* expr = node->as<AstExprLocal>())
{
const Constant* l = locals.find(expr->local);
if (l)
result = *l;
}
else if (node->is<AstExprGlobal>())
{
// nope
}
else if (node->is<AstExprVarargs>())
{
// nope
}
else if (AstExprCall* expr = node->as<AstExprCall>())
{
analyze(expr->func);
for (size_t i = 0; i < expr->args.size; ++i)
analyze(expr->args.data[i]);
}
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
{
analyze(expr->expr);
}
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
{
analyze(expr->expr);
analyze(expr->index);
}
else if (AstExprFunction* expr = node->as<AstExprFunction>())
{
// this is necessary to propagate constant information in all child functions
expr->body->visit(this);
}
else if (AstExprTable* expr = node->as<AstExprTable>())
{
for (size_t i = 0; i < expr->items.size; ++i)
{
const AstExprTable::Item& item = expr->items.data[i];
if (item.key)
analyze(item.key);
analyze(item.value);
}
}
else if (AstExprUnary* expr = node->as<AstExprUnary>())
{
Constant arg = analyze(expr->expr);
if (arg.type != Constant::Type_Unknown)
foldUnary(result, expr->op, arg);
}
else if (AstExprBinary* expr = node->as<AstExprBinary>())
{
Constant la = analyze(expr->left);
Constant ra = analyze(expr->right);
if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown)
foldBinary(result, expr->op, la, ra);
}
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
{
Constant arg = analyze(expr->expr);
result = arg;
}
else if (AstExprIfElse* expr = node->as<AstExprIfElse>())
{
Constant cond = analyze(expr->condition);
Constant trueExpr = analyze(expr->trueExpr);
Constant falseExpr = analyze(expr->falseExpr);
if (cond.type != Constant::Type_Unknown)
result = cond.isTruthful() ? trueExpr : falseExpr;
}
else
{
LUAU_ASSERT(!"Unknown expression type");
}
if (result.type != Constant::Type_Unknown)
constants[node] = result;
return result;
}
bool visit(AstExpr* node) override
{
// note: we short-circuit the visitor traversal through any expression trees by returning false
// recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression
analyze(node);
return false;
}
bool visit(AstStatLocal* node) override
{
// all values that align wrt indexing are simple - we just match them 1-1
for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i)
{
Constant arg = analyze(node->values.data[i]);
if (arg.type != Constant::Type_Unknown)
{
// note: we rely on trackValues to have been run before us
Variable* v = variables.find(node->vars.data[i]);
LUAU_ASSERT(v);
if (!v->written)
{
locals[node->vars.data[i]] = arg;
v->constant = true;
}
}
}
if (node->vars.size > node->values.size)
{
// if we have trailing variables, then depending on whether the last value is capable of returning multiple values
// (aka call or varargs), we either don't know anything about these vars, or we know they're nil
AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr;
bool multRet = last && (last->is<AstExprCall>() || last->is<AstExprVarargs>());
if (!multRet)
{
for (size_t i = node->values.size; i < node->vars.size; ++i)
{
// note: we rely on trackValues to have been run before us
Variable* v = variables.find(node->vars.data[i]);
LUAU_ASSERT(v);
if (!v->written)
{
locals[node->vars.data[i]].type = Constant::Type_Nil;
v->constant = true;
}
}
}
}
else
{
// we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside
// them
for (size_t i = node->vars.size; i < node->values.size; ++i)
analyze(node->values.data[i]);
}
return false;
}
};
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables, AstNode* root)
{
ConstantVisitor visitor{constants, variables};
root->visit(&visitor);
}
} // namespace Compile
} // namespace Luau

View File

@ -0,0 +1,48 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "ValueTracking.h"
namespace Luau
{
namespace Compile
{
struct Constant
{
enum Type
{
Type_Unknown,
Type_Nil,
Type_Boolean,
Type_Number,
Type_String,
};
Type type = Type_Unknown;
unsigned int stringLength = 0;
union
{
bool valueBoolean;
double valueNumber;
char* valueString = nullptr; // length stored in stringLength
};
bool isTruthful() const
{
LUAU_ASSERT(type != Type_Unknown);
return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false);
}
AstArray<char> getString() const
{
LUAU_ASSERT(type == Type_String);
return {valueString, stringLength};
}
};
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables, AstNode* root);
} // namespace Compile
} // namespace Luau

129
Compiler/src/TableShape.cpp Normal file
View File

@ -0,0 +1,129 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "TableShape.h"
namespace Luau
{
namespace Compile
{
static AstExprTable* getTableHint(AstExpr* expr)
{
// unadorned table literal
if (AstExprTable* table = expr->as<AstExprTable>())
return table;
// setmetatable(table literal, ...)
if (AstExprCall* call = expr->as<AstExprCall>(); call && !call->self && call->args.size == 2)
if (AstExprGlobal* func = call->func->as<AstExprGlobal>(); func && func->name == "setmetatable")
if (AstExprTable* table = call->args.data[0]->as<AstExprTable>())
return table;
return nullptr;
}
struct ShapeVisitor : AstVisitor
{
struct Hasher
{
size_t operator()(const std::pair<AstExprTable*, AstName>& p) const
{
return std::hash<AstExprTable*>()(p.first) ^ std::hash<AstName>()(p.second);
}
};
DenseHashMap<AstExprTable*, TableShape>& shapes;
DenseHashMap<AstLocal*, AstExprTable*> tables;
DenseHashSet<std::pair<AstExprTable*, AstName>, Hasher> fields;
ShapeVisitor(DenseHashMap<AstExprTable*, TableShape>& shapes)
: shapes(shapes)
, tables(nullptr)
, fields(std::pair<AstExprTable*, AstName>())
{
}
void assignField(AstExpr* expr, AstName index)
{
if (AstExprLocal* lv = expr->as<AstExprLocal>())
{
if (AstExprTable** table = tables.find(lv->local))
{
std::pair<AstExprTable*, AstName> field = {*table, index};
if (!fields.contains(field))
{
fields.insert(field);
shapes[*table].hashSize += 1;
}
}
}
}
void assignField(AstExpr* expr, AstExpr* index)
{
AstExprLocal* lv = expr->as<AstExprLocal>();
AstExprConstantNumber* number = index->as<AstExprConstantNumber>();
if (lv && number)
{
if (AstExprTable** table = tables.find(lv->local))
{
TableShape& shape = shapes[*table];
if (number->value == double(shape.arraySize + 1))
shape.arraySize += 1;
}
}
}
void assign(AstExpr* var)
{
if (AstExprIndexName* index = var->as<AstExprIndexName>())
{
assignField(index->expr, index->index);
}
else if (AstExprIndexExpr* index = var->as<AstExprIndexExpr>())
{
assignField(index->expr, index->index);
}
}
bool visit(AstStatLocal* node) override
{
// track local -> table association so that we can update table size prediction in assignField
if (node->vars.size == 1 && node->values.size == 1)
if (AstExprTable* table = getTableHint(node->values.data[0]); table && table->items.size == 0)
tables[node->vars.data[0]] = table;
return true;
}
bool visit(AstStatAssign* node) override
{
for (size_t i = 0; i < node->vars.size; ++i)
assign(node->vars.data[i]);
for (size_t i = 0; i < node->values.size; ++i)
node->values.data[i]->visit(this);
return false;
}
bool visit(AstStatFunction* node) override
{
assign(node->name);
node->func->visit(this);
return false;
}
};
void predictTableShapes(DenseHashMap<AstExprTable*, TableShape>& shapes, AstNode* root)
{
ShapeVisitor visitor{shapes};
root->visit(&visitor);
}
} // namespace Compile
} // namespace Luau

21
Compiler/src/TableShape.h Normal file
View File

@ -0,0 +1,21 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/DenseHash.h"
namespace Luau
{
namespace Compile
{
struct TableShape
{
unsigned int arraySize = 0;
unsigned int hashSize = 0;
};
void predictTableShapes(DenseHashMap<AstExprTable*, TableShape>& shapes, AstNode* root);
} // namespace Compile
} // namespace Luau

View File

@ -0,0 +1,103 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "ValueTracking.h"
#include "Luau/Lexer.h"
namespace Luau
{
namespace Compile
{
struct ValueVisitor : AstVisitor
{
DenseHashMap<AstName, Global>& globals;
DenseHashMap<AstLocal*, Variable>& variables;
ValueVisitor(DenseHashMap<AstName, Global>& globals, DenseHashMap<AstLocal*, Variable>& variables)
: globals(globals)
, variables(variables)
{
}
void assign(AstExpr* var)
{
if (AstExprLocal* lv = var->as<AstExprLocal>())
{
variables[lv->local].written = true;
}
else if (AstExprGlobal* gv = var->as<AstExprGlobal>())
{
globals[gv->name] = Global::Written;
}
else
{
// we need to be able to track assignments in all expressions, including crazy ones like t[function() t = nil end] = 5
var->visit(this);
}
}
bool visit(AstStatLocal* node) override
{
for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i)
variables[node->vars.data[i]].init = node->values.data[i];
for (size_t i = node->values.size; i < node->vars.size; ++i)
variables[node->vars.data[i]].init = nullptr;
return true;
}
bool visit(AstStatAssign* node) override
{
for (size_t i = 0; i < node->vars.size; ++i)
assign(node->vars.data[i]);
for (size_t i = 0; i < node->values.size; ++i)
node->values.data[i]->visit(this);
return false;
}
bool visit(AstStatCompoundAssign* node) override
{
assign(node->var);
node->value->visit(this);
return false;
}
bool visit(AstStatLocalFunction* node) override
{
variables[node->name].init = node->func;
return true;
}
bool visit(AstStatFunction* node) override
{
assign(node->name);
node->func->visit(this);
return false;
}
};
void assignMutable(DenseHashMap<AstName, Global>& globals, const AstNameTable& names, const char** mutableGlobals)
{
if (AstName name = names.get("_G"); name.value)
globals[name] = Global::Mutable;
if (mutableGlobals)
for (const char** ptr = mutableGlobals; *ptr; ++ptr)
if (AstName name = names.get(*ptr); name.value)
globals[name] = Global::Mutable;
}
void trackValues(DenseHashMap<AstName, Global>& globals, DenseHashMap<AstLocal*, Variable>& variables, AstNode* root)
{
ValueVisitor visitor{globals, variables};
root->visit(&visitor);
}
} // namespace Compile
} // namespace Luau

View File

@ -0,0 +1,42 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/DenseHash.h"
namespace Luau
{
class AstNameTable;
}
namespace Luau
{
namespace Compile
{
enum class Global
{
Default = 0,
Mutable, // builtin that has contents unknown at compile time, blocks GETIMPORT for chains
Written, // written in the code which means we can't reason about the value
};
struct Variable
{
AstExpr* init = nullptr; // initial value of the variable; filled by trackValues
bool written = false; // is the variable ever assigned to? filled by trackValues
bool constant = false; // is the variable's value a compile-time constant? filled by constantFold
};
void assignMutable(DenseHashMap<AstName, Global>& globals, const AstNameTable& names, const char** mutableGlobals);
void trackValues(DenseHashMap<AstName, Global>& globals, DenseHashMap<AstLocal*, Variable>& variables, AstNode* root);
inline Global getGlobalState(const DenseHashMap<AstName, Global>& globals, AstName name)
{
const Global* it = globals.find(name);
return it ? *it : Global::Default;
}
} // namespace Compile
} // namespace Luau

View File

@ -29,7 +29,15 @@ target_sources(Luau.Compiler PRIVATE
Compiler/src/BytecodeBuilder.cpp Compiler/src/BytecodeBuilder.cpp
Compiler/src/Compiler.cpp Compiler/src/Compiler.cpp
Compiler/src/Builtins.cpp
Compiler/src/ConstantFolding.cpp
Compiler/src/TableShape.cpp
Compiler/src/ValueTracking.cpp
Compiler/src/lcode.cpp Compiler/src/lcode.cpp
Compiler/src/Builtins.h
Compiler/src/ConstantFolding.h
Compiler/src/TableShape.h
Compiler/src/ValueTracking.h
) )
# Luau.Analysis Sources # Luau.Analysis Sources

View File

@ -17,7 +17,6 @@
#include <string.h> #include <string.h>
LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false)
LUAU_FASTFLAG(LuauCoroutineClose) LUAU_FASTFLAG(LuauCoroutineClose)
/* /*
@ -545,11 +544,8 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
if (!oldactive) if (!oldactive)
resetbit(L->stackstate, THREAD_ACTIVEBIT); resetbit(L->stackstate, THREAD_ACTIVEBIT);
if (FFlag::LuauCcallRestoreFix) // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored.
{ L->nCcalls = oldnCcalls;
// Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored.
L->nCcalls = oldnCcalls;
}
// an error occurred, check if we have a protected error callback // an error occurred, check if we have a protected error callback
if (L->global->cb.debugprotectederror) if (L->global->cb.debugprotectederror)
@ -564,10 +560,6 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
StkId oldtop = restorestack(L, old_top); StkId oldtop = restorestack(L, old_top);
luaF_close(L, oldtop); /* close eventual pending closures */ luaF_close(L, oldtop); /* close eventual pending closures */
seterrorobj(L, status, oldtop); seterrorobj(L, status, oldtop);
if (!FFlag::LuauCcallRestoreFix)
{
L->nCcalls = oldnCcalls;
}
L->ci = restoreci(L, old_ci); L->ci = restoreci(L, old_ci);
L->base = L->ci->base; L->base = L->ci->base;
restore_stack_limit(L); restore_stack_limit(L);

View File

@ -6,6 +6,8 @@
#include "lmem.h" #include "lmem.h"
#include "lgc.h" #include "lgc.h"
LUAU_FASTFLAGVARIABLE(LuauNoDirectUpvalRemoval, false)
Proto* luaF_newproto(lua_State* L) Proto* luaF_newproto(lua_State* L)
{ {
Proto* f = luaM_new(L, Proto, sizeof(Proto), L->activememcat); Proto* f = luaM_new(L, Proto, sizeof(Proto), L->activememcat);
@ -113,14 +115,16 @@ void luaF_freeupval(lua_State* L, UpVal* uv)
void luaF_close(lua_State* L, StkId level) void luaF_close(lua_State* L, StkId level)
{ {
UpVal* uv; UpVal* uv;
global_State* g = L->global; global_State* g = L->global; // TODO: remove with FFlagLuauNoDirectUpvalRemoval
while (L->openupval != NULL && (uv = gco2uv(L->openupval))->v >= level) while (L->openupval != NULL && (uv = gco2uv(L->openupval))->v >= level)
{ {
GCObject* o = obj2gco(uv); GCObject* o = obj2gco(uv);
LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value);
L->openupval = uv->next; /* remove from `open' list */ L->openupval = uv->next; /* remove from `open' list */
if (isdead(g, o)) if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o))
{
luaF_freeupval(L, uv); /* free upvalue */ luaF_freeupval(L, uv); /* free upvalue */
}
else else
{ {
unlinkupval(uv); unlinkupval(uv);

View File

@ -8,8 +8,6 @@
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauStrPackUBCastFix, false)
/* macro to `unsign' a character */ /* macro to `unsign' a character */
#define uchar(c) ((unsigned char)(c)) #define uchar(c) ((unsigned char)(c))
@ -1406,20 +1404,10 @@ static int str_pack(lua_State* L)
} }
case Kuint: case Kuint:
{ /* unsigned integers */ { /* unsigned integers */
if (FFlag::LuauStrPackUBCastFix) long long n = (long long)luaL_checknumber(L, arg);
{ if (size < SZINT) /* need overflow check? */
long long n = (long long)luaL_checknumber(L, arg); luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow");
if (size < SZINT) /* need overflow check? */ packint(&b, (unsigned long long)n, h.islittle, size, 0);
luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow");
packint(&b, (unsigned long long)n, h.islittle, size, 0);
}
else
{
unsigned long long n = (unsigned long long)luaL_checknumber(L, arg);
if (size < SZINT) /* need overflow check? */
luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow");
packint(&b, n, h.islittle, size, 0);
}
break; break;
} }
case Kfloat: case Kfloat:

View File

@ -111,15 +111,10 @@ end
-- multiplies two matrices -- multiplies two matrices
function MMulti(M1, M2) function MMulti(M1, M2)
local M = {{},{},{},{}}; local M = {{},{},{},{}};
local i = 1; for i = 1,4 do
local j = 1; for j = 1,4 do
while i <= 4 do M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j];
j = 1;
while j <= 4 do
M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j]; j = j + 1
end end
i = i + 1
end end
return M; return M;
end end
@ -127,28 +122,27 @@ end
-- multiplies matrix with vector -- multiplies matrix with vector
function VMulti(M, V) function VMulti(M, V)
local Vect = {}; local Vect = {};
local i = 1; for i = 1,4 do
while i <= 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; i = i + 1 end Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4];
end
return Vect; return Vect;
end end
function VMulti2(M, V) function VMulti2(M, V)
local Vect = {}; local Vect = {};
local i = 1; for i = 1,3 do
while i < 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; i = i + 1 end Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3];
end
return Vect; return Vect;
end end
-- add to matrices -- add to matrices
function MAdd(M1, M2) function MAdd(M1, M2)
local M = {{},{},{},{}}; local M = {{},{},{},{}};
local i = 1; for i = 1,4 do
local j = 1; for j = 1,4 do
while i <= 4 do M[i][j] = M1[i][j] + M2[i][j];
j = 1; end
while j <= 4 do M[i][j] = M1[i][j] + M2[i][j]; j = j + 1 end
i = i + 1
end end
return M; return M;
end end

View File

@ -1938,7 +1938,6 @@ return target(b@1
TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses")
{ {
ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true);
ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true);
check(R"( check(R"(
local function bar(a: number) return -a end local function bar(a: number) return -a end
@ -1954,7 +1953,6 @@ local abc = b@1
TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses")
{ {
ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true);
ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true);
check(R"( check(R"(
local function foo() return 1 end local function foo() return 1 end
@ -2538,10 +2536,6 @@ TEST_CASE("autocomplete_documentation_symbols")
TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions")
{ {
ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true};
ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true};
{
check(R"( check(R"(
local temp = false local temp = false
local even = true; local even = true;
@ -2614,7 +2608,6 @@ a = if temp then even elseif true then temp else e@9
CHECK(ac.entryMap.count("then") == 0); CHECK(ac.entryMap.count("then") == 0);
CHECK(ac.entryMap.count("else") == 0); CHECK(ac.entryMap.count("else") == 0);
CHECK(ac.entryMap.count("elseif") == 0); CHECK(ac.entryMap.count("elseif") == 0);
}
} }
TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack")
@ -2681,4 +2674,58 @@ local r4 = t:bar1(@4)
CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None);
} }
TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_parameters")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
check(R"(
type A<T = @1> = () -> T
)");
auto ac = autocomplete('1');
CHECK(ac.entryMap.count("number"));
CHECK(ac.entryMap.count("string"));
}
TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_pack_parameters")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
check(R"(
type A<T... = ...@1> = () -> T
)");
auto ac = autocomplete('1');
CHECK(ac.entryMap.count("number"));
CHECK(ac.entryMap.count("string"));
}
TEST_CASE_FIXTURE(ACFixture, "autocomplete_oop_implicit_self")
{
ScopedFastFlag flag("LuauMissingFollowACMetatables", true);
check(R"(
--!strict
local Class = {}
Class.__index = Class
type Class = typeof(setmetatable({} :: { x: number }, Class))
function Class.new(x: number): Class
return setmetatable({x = x}, Class)
end
function Class.getx(self: Class)
return self.x
end
function test()
local c = Class.new(42)
local n = c:@1
print(n)
end
)");
auto ac = autocomplete('1');
CHECK(ac.entryMap.count("getx"));
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -603,9 +603,9 @@ RETURN R0 1
)"); )");
} }
TEST_CASE("EmptyTableHashSizePredictionOptimization") TEST_CASE("TableSizePredictionBasic")
{ {
const char* hashSizeSource = R"( CHECK_EQ("\n" + compileFunction0(R"(
local t = {} local t = {}
t.a = 1 t.a = 1
t.b = 1 t.b = 1
@ -616,36 +616,8 @@ t.f = 1
t.g = 1 t.g = 1
t.h = 1 t.h = 1
t.i = 1 t.i = 1
)"; )"),
R"(
const char* hashSizeSource2 = R"(
local t = {}
t.x = 1
t.x = 2
t.x = 3
t.x = 4
t.x = 5
t.x = 6
t.x = 7
t.x = 8
t.x = 9
)";
const char* arraySizeSource = R"(
local t = {}
t[1] = 1
t[2] = 1
t[3] = 1
t[4] = 1
t[5] = 1
t[6] = 1
t[7] = 1
t[8] = 1
t[9] = 1
t[10] = 1
)";
CHECK_EQ("\n" + compileFunction0(hashSizeSource), R"(
NEWTABLE R0 16 0 NEWTABLE R0 16 0
LOADN R1 1 LOADN R1 1
SETTABLEKS R1 R0 K0 SETTABLEKS R1 R0 K0
@ -668,7 +640,19 @@ SETTABLEKS R1 R0 K8
RETURN R0 0 RETURN R0 0
)"); )");
CHECK_EQ("\n" + compileFunction0(hashSizeSource2), R"( CHECK_EQ("\n" + compileFunction0(R"(
local t = {}
t.x = 1
t.x = 2
t.x = 3
t.x = 4
t.x = 5
t.x = 6
t.x = 7
t.x = 8
t.x = 9
)"),
R"(
NEWTABLE R0 1 0 NEWTABLE R0 1 0
LOADN R1 1 LOADN R1 1
SETTABLEKS R1 R0 K0 SETTABLEKS R1 R0 K0
@ -691,7 +675,20 @@ SETTABLEKS R1 R0 K0
RETURN R0 0 RETURN R0 0
)"); )");
CHECK_EQ("\n" + compileFunction0(arraySizeSource), R"( CHECK_EQ("\n" + compileFunction0(R"(
local t = {}
t[1] = 1
t[2] = 1
t[3] = 1
t[4] = 1
t[5] = 1
t[6] = 1
t[7] = 1
t[8] = 1
t[9] = 1
t[10] = 1
)"),
R"(
NEWTABLE R0 0 10 NEWTABLE R0 0 10
LOADN R1 1 LOADN R1 1
SETTABLEN R1 R0 1 SETTABLEN R1 R0 1
@ -717,6 +714,27 @@ RETURN R0 0
)"); )");
} }
TEST_CASE("TableSizePredictionObject")
{
CHECK_EQ("\n" + compileFunction(R"(
local t = {}
t.field = 1
function t:getfield()
return self.field
end
return t
)",
1),
R"(
NEWTABLE R0 2 0
LOADN R1 1
SETTABLEKS R1 R0 K0
DUPCLOSURE R1 K1
SETTABLEKS R1 R0 K2
RETURN R0 1
)");
}
TEST_CASE("TableSizePredictionSetMetatable") TEST_CASE("TableSizePredictionSetMetatable")
{ {
CHECK_EQ("\n" + compileFunction0(R"( CHECK_EQ("\n" + compileFunction0(R"(
@ -1031,9 +1049,6 @@ RETURN R0 1
TEST_CASE("IfElseExpression") TEST_CASE("IfElseExpression")
{ {
ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true};
ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true};
// codegen for a true constant condition // codegen for a true constant condition
CHECK_EQ("\n" + compileFunction0("return if true then 10 else 20"), R"( CHECK_EQ("\n" + compileFunction0("return if true then 10 else 20"), R"(
LOADN R0 10 LOADN R0 10
@ -3058,7 +3073,7 @@ RETURN R0 0
// table variants (indexed by string, number, variable) // table variants (indexed by string, number, variable)
CHECK_EQ("\n" + compileFunction0("local a = {} a.foo += 5"), R"( CHECK_EQ("\n" + compileFunction0("local a = {} a.foo += 5"), R"(
NEWTABLE R0 1 0 NEWTABLE R0 0 0
GETTABLEKS R1 R0 K0 GETTABLEKS R1 R0 K0
ADDK R1 R1 K1 ADDK R1 R1 K1
SETTABLEKS R1 R0 K0 SETTABLEKS R1 R0 K0
@ -3066,7 +3081,7 @@ RETURN R0 0
)"); )");
CHECK_EQ("\n" + compileFunction0("local a = {} a[1] += 5"), R"( CHECK_EQ("\n" + compileFunction0("local a = {} a[1] += 5"), R"(
NEWTABLE R0 0 1 NEWTABLE R0 0 0
GETTABLEN R1 R0 1 GETTABLEN R1 R0 1
ADDK R1 R1 K0 ADDK R1 R1 K0
SETTABLEN R1 R0 1 SETTABLEN R1 R0 1

View File

@ -366,15 +366,11 @@ TEST_CASE("PCall")
TEST_CASE("Pack") TEST_CASE("Pack")
{ {
ScopedFastFlag sff{"LuauStrPackUBCastFix", true};
runConformance("tpack.lua"); runConformance("tpack.lua");
} }
TEST_CASE("Vector") TEST_CASE("Vector")
{ {
ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true};
lua_CompileOptions copts = {}; lua_CompileOptions copts = {};
copts.optimizationLevel = 1; copts.optimizationLevel = 1;
copts.debugLevel = 1; copts.debugLevel = 1;
@ -861,15 +857,11 @@ TEST_CASE("ExceptionObject")
TEST_CASE("IfElseExpression") TEST_CASE("IfElseExpression")
{ {
ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true};
runConformance("ifelseexpr.lua"); runConformance("ifelseexpr.lua");
} }
TEST_CASE("TagMethodError") TEST_CASE("TagMethodError")
{ {
ScopedFastFlag sff{"LuauCcallRestoreFix", true};
runConformance("tmerror.lua", [](lua_State* L) { runConformance("tmerror.lua", [](lua_State* L) {
auto* cb = lua_callbacks(L); auto* cb = lua_callbacks(L);

View File

@ -191,7 +191,7 @@ ParseResult Fixture::tryParse(const std::string& source, const ParseOptions& par
return result; return result;
} }
ParseResult Fixture::matchParseError(const std::string& source, const std::string& message) ParseResult Fixture::matchParseError(const std::string& source, const std::string& message, std::optional<Location> location)
{ {
ParseOptions options; ParseOptions options;
options.allowDeclarationSyntax = true; options.allowDeclarationSyntax = true;
@ -203,6 +203,9 @@ ParseResult Fixture::matchParseError(const std::string& source, const std::strin
CHECK_EQ(result.errors.front().getMessage(), message); CHECK_EQ(result.errors.front().getMessage(), message);
if (location)
CHECK_EQ(result.errors.front().getLocation(), *location);
return result; return result;
} }

View File

@ -106,7 +106,7 @@ struct Fixture
/// Parse with all language extensions enabled /// Parse with all language extensions enabled
ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {});
ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {});
ParseResult matchParseError(const std::string& source, const std::string& message); ParseResult matchParseError(const std::string& source, const std::string& message, std::optional<Location> location = std::nullopt);
// Verify a parse error occurs and the parse error message has the specified prefix // Verify a parse error occurs and the parse error message has the specified prefix
ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix);

View File

@ -1255,7 +1255,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_type_group")
TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements")
{ {
ScopedFastInt sfis{"LuauRecursionLimit", 10}; ScopedFastInt sfis{"LuauRecursionLimit", 10};
ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true};
matchParseErrorPrefix( matchParseErrorPrefix(
"function f() if true then if true then if true then if true then if true then if true then if true then if true then if true " "function f() if true then if true then if true then if true then if true then if true then if true then if true then if true "
@ -1266,7 +1265,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements")
TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements")
{ {
ScopedFastInt sfis{"LuauRecursionLimit", 10}; ScopedFastInt sfis{"LuauRecursionLimit", 10};
ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true};
matchParseErrorPrefix( matchParseErrorPrefix(
"function f() if false then elseif false then elseif false then elseif false then elseif false then elseif false then elseif " "function f() if false then elseif false then elseif false then elseif false then elseif false then elseif false then elseif "
@ -1276,7 +1274,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements"
TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1")
{ {
ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true};
ScopedFastInt sfis{"LuauRecursionLimit", 10}; ScopedFastInt sfis{"LuauRecursionLimit", 10};
matchParseError("function f() return if true then 1 elseif true then 2 elseif true then 3 elseif true then 4 elseif true then 5 elseif true then " matchParseError("function f() return if true then 1 elseif true then 2 elseif true then 3 elseif true then 4 elseif true then 5 elseif true then "
@ -1286,7 +1283,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1
TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions2") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions2")
{ {
ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true};
ScopedFastInt sfis{"LuauRecursionLimit", 10}; ScopedFastInt sfis{"LuauRecursionLimit", 10};
matchParseError( matchParseError(
@ -1962,6 +1958,37 @@ TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments")
matchParseError("type MyFunc = (number) -> (d: number) <a, b, c> -> number", "Expected '->' when parsing function type, got '<'"); matchParseError("type MyFunc = (number) -> (d: number) <a, b, c> -> number", "Expected '->' when parsing function type, got '<'");
} }
TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis")
{
matchParseError("local a: <T>(number -> string", "Expected ')' (to close '(' at column 13), got '->'");
}
TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
AstStat* stat = parse(R"(
type A<T = string> = {}
type B<T... = ...number> = {}
type C<T..., U... = T...> = {}
type D<T..., U... = ()> = {}
type E<T... = (), U... = ()> = {}
type F<T... = (string), U... = ()> = (T...) -> U...
type G<T... = ...number, U... = (string, number, boolean)> = (U...) -> T...
)");
REQUIRE(stat != nullptr);
}
TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
matchParseError("type Y<T = number, U> = {}", "Expected default type after type name", Location{{0, 20}, {0, 21}});
matchParseError("type Y<T... = ...number, U...> = {}", "Expected default type pack after type pack name", Location{{0, 29}, {0, 30}});
matchParseError("type Y<T... = (string) -> number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}});
}
TEST_SUITE_END(); TEST_SUITE_END();
TEST_SUITE_BEGIN("ParseErrorRecovery"); TEST_SUITE_BEGIN("ParseErrorRecovery");
@ -2455,10 +2482,19 @@ do end
CHECK_EQ(1, result.errors.size()); CHECK_EQ(1, result.errors.size());
} }
TEST_CASE_FIXTURE(Fixture, "recover_expected_type_pack")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauParseRecoverTypePackEllipsis{"LuauParseRecoverTypePackEllipsis", true};
ParseResult result = tryParse(R"(
type Y<T..., U = T...> = (T...) -> U...
)");
CHECK_EQ(1, result.errors.size());
}
TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression")
{ {
ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true};
{ {
AstStat* stat = parse("return if true then 1 else 2"); AstStat* stat = parse("return if true then 1 else 2");
@ -2524,9 +2560,4 @@ type C<X...> = Packed<(number, X...)>
REQUIRE(stat != nullptr); REQUIRE(stat != nullptr);
} }
TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis")
{
matchParseError("local a: <T>(number -> string", "Expected ')' (to close '(' at column 13), got '->'");
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -338,6 +338,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed")
TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") TEST_CASE_FIXTURE(Fixture, "toStringDetailed2")
{ {
ScopedFastFlag sff{"LuauUnsealedTableLiteral", true};
CheckResult result = check(R"( CheckResult result = check(R"(
local base = {} local base = {}
function base:one() return 1 end function base:one() return 1 end
@ -353,7 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed2")
TypeId tType = requireType("inst"); TypeId tType = requireType("inst");
ToStringResult r = toStringDetailed(tType); ToStringResult r = toStringDetailed(tType);
CHECK_EQ("{ @metatable {| __index: { @metatable {| __index: base |}, child } |}, inst }", r.name); CHECK_EQ("{ @metatable { __index: { @metatable { __index: base }, child } }, inst }", r.name);
CHECK_EQ(0, r.nameMap.typeVars.size()); CHECK_EQ(0, r.nameMap.typeVars.size());
ToStringOptions opts; ToStringOptions opts;
@ -500,6 +502,24 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map")
CHECK_EQ("map<a, b>(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); CHECK_EQ("map<a, b>(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv));
} }
TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack")
{
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
local function f(a: number, b: string) end
local function test<T..., U...>(...: T...): U...
f(...)
return 1, 2, 3
end
)");
TypeId ty = requireType("test");
const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(ty));
CHECK_EQ("test<T..., U...>(...: T...): U...", toStringNamedFunction("test", *ftv));
}
TEST_CASE("toStringNamedFunction_unit_f") TEST_CASE("toStringNamedFunction_unit_f")
{ {
TypePackVar empty{TypePack{}}; TypePackVar empty{TypePack{}};

View File

@ -421,8 +421,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion")
TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else")
{ {
ScopedFastFlag luauIfElseExpressionBaseSupport("LuauIfElseExpressionBaseSupport", true);
std::string code = "local a = if 1 then 2 else 3"; std::string code = "local a = if 1 then 2 else 3";
CHECK_EQ(code, transpile(code).code); CHECK_EQ(code, transpile(code).code);
@ -641,4 +639,16 @@ TEST_CASE_FIXTURE(Fixture, "transpile_to_string")
CHECK_EQ("'hello'", toString(expr)); CHECK_EQ("'hello'", toString(expr));
} }
TEST_CASE_FIXTURE(Fixture, "transpile_type_alias_default_type_parameters")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
std::string code = R"(
type Packed<T = string, U = T, V... = ...boolean, W... = (T, U, V...)> = (T, U, V...)->(W...)
local a: Packed<number>
)";
CHECK_EQ(code, transpile(code, {}, true).code);
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -497,7 +497,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly")
CHECK(arrayTable->indexer); CHECK(arrayTable->indexer);
CHECK(isInArena(array.type, mod.interfaceTypes)); CHECK(isInArena(array.type, mod.interfaceTypes));
CHECK_EQ(array.typeParams[0], arrayTable->indexer->indexResultType); CHECK_EQ(array.typeParams[0].ty, arrayTable->indexer->indexResultType);
} }
TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions") TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions")

View File

@ -1031,9 +1031,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_n
TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression")
{ {
ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true};
ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true};
CheckResult result = check(R"( CheckResult result = check(R"(
function f(v:string?) function f(v:string?)
return if v then v else tostring(v) return if v then v else tostring(v)
@ -1048,9 +1045,6 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression")
TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression")
{ {
ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true};
ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true};
CheckResult result = check(R"( CheckResult result = check(R"(
function f(v:string?) function f(v:string?)
return if not v then tostring(v) else v return if not v then tostring(v) else v
@ -1065,9 +1059,6 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression")
TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression")
{ {
ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true};
ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true};
CheckResult result = check(R"( CheckResult result = check(R"(
function returnOne(x) function returnOne(x)
return 1 return 1
@ -1119,6 +1110,25 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_
CHECK_EQ("string", toString(requireTypeAtPosition({5, 30}))); CHECK_EQ("string", toString(requireTypeAtPosition({5, 30})));
} }
TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined2")
{
ScopedFastFlag sff{"LuauLValueAsKey", true};
CheckResult result = check(R"(
type T = { x: { y: number }? }
local function f(t: T?)
if t and t.x then
local foo = t.x.y
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("number", toString(requireTypeAtPosition({5, 32})));
}
TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string")
{ {
ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true};

View File

@ -360,6 +360,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes")
{ {
ScopedFastFlag sffs[] = { ScopedFastFlag sffs[] = {
{"LuauParseSingletonTypes", true}, {"LuauParseSingletonTypes", true},
{"LuauUnsealedTableLiteral", true},
}; };
CheckResult result = check(R"( CheckResult result = check(R"(
@ -369,7 +370,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes")
)"); )");
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", CHECK_EQ(R"(Table type '{ ["\n"]: number }' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')",
toString(result.errors[0])); toString(result.errors[0]));
} }
@ -423,4 +424,27 @@ caused by:
toString(result.errors[0])); toString(result.errors[0]));
} }
TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options")
{
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
{"LuauUnionHeuristic", true},
{"LuauExpectedTypesOfProperties", true},
{"LuauExtendedUnionMismatchError", true},
{"LuauIfElseExpectedType2", true},
{"LuauIfElseBranchTypeUnion", true},
};
CheckResult result = check(R"(
type Cat = { tag: 'cat', catfood: string }
type Dog = { tag: 'dog', dogfood: string }
type Animal = Cat | Dog
local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag = 'dog', dogfood = 'other' }
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -65,7 +65,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table")
TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table")
{ {
CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); CheckResult result = check("function mkt() return {prop=999} end local t = mkt() t.foo = 'bar'");
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeError& err = result.errors[0]; TypeError& err = result.errors[0];
@ -77,7 +77,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table")
CHECK_EQ(s, "{| prop: number |}"); CHECK_EQ(s, "{| prop: number |}");
CHECK_EQ(error->prop, "foo"); CHECK_EQ(error->prop, "foo");
CHECK_EQ(error->context, CannotExtendTable::Property); CHECK_EQ(error->context, CannotExtendTable::Property);
CHECK_EQ(err.location, (Location{Position{0, 24}, Position{0, 29}})); CHECK_EQ(err.location, (Location{Position{0, 59}, Position{0, 64}}));
} }
TEST_CASE_FIXTURE(Fixture, "dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table") TEST_CASE_FIXTURE(Fixture, "dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table")
@ -1155,7 +1155,8 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_mu
TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail") TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
local t = {x = 1} function mkt() return {x = 1} end
local t = mkt()
function t.m() end function t.m() end
)"); )");
@ -1165,13 +1166,38 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail
TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must_fail") TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must_fail")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
local t = {x = 1} function mkt() return {x = 1} end
local t = mkt()
function t:m() end function t:m() end
)"); )");
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
} }
TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok")
{
ScopedFastFlag sff{"LuauUnsealedTableLiteral", true};
CheckResult result = check(R"(
local t = {x = 1}
function t.m() end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_unsealed_table_is_ok")
{
ScopedFastFlag sff{"LuauUnsealedTableLiteral", true};
CheckResult result = check(R"(
local t = {x = 1}
function t:m() end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
// This unit test could be flaky if the fix has regressed. // This unit test could be flaky if the fix has regressed.
TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_crashing") TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_crashing")
{ {
@ -1439,8 +1465,13 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2")
CHECK_EQ("{| |}", toString(mp->subType)); CHECK_EQ("{| |}", toString(mp->subType));
} }
TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_indexer")
{ {
ScopedFastFlag sff[]{
{"LuauTableSubtypingVariance2", true},
{"LuauUnsealedTableLiteral", true},
};
CheckResult result = check(R"( CheckResult result = check(R"(
type StringToStringMap = { [string]: string } type StringToStringMap = { [string]: string }
local rt: StringToStringMap = { ["foo"] = 1 } local rt: StringToStringMap = { ["foo"] = 1 }
@ -1448,6 +1479,25 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer")
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
ToStringOptions o{/* exhaustive= */ true};
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o));
// Should t now have an indexer?
// It would if the assignment to rt was correctly typed.
CHECK_EQ("{ [string]: string, foo: number }", toString(tm->givenType, o));
}
TEST_CASE_FIXTURE(Fixture, "casting_sealed_tables_with_props_into_table_with_indexer")
{
CheckResult result = check(R"(
type StringToStringMap = { [string]: string }
function mkrt() return { ["foo"] = 1 } end
local rt: StringToStringMap = mkrt()
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
ToStringOptions o{/* exhaustive= */ true}; ToStringOptions o{/* exhaustive= */ true};
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]); TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm); REQUIRE(tm);
@ -1467,7 +1517,10 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2")
TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3")
{ {
ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; ScopedFastFlag sff[]{
{"LuauTableSubtypingVariance2", true},
{"LuauUnsealedTableLiteral", true},
};
CheckResult result = check(R"( CheckResult result = check(R"(
local function foo(a: {[string]: number, a: string}) end local function foo(a: {[string]: number, a: string}) end
@ -1480,7 +1533,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3")
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]); TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm); REQUIRE(tm);
CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o));
CHECK_EQ("{| a: number |}", toString(tm->givenType, o)); CHECK_EQ("{ a: number }", toString(tm->givenType, o));
} }
TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4")
@ -1536,8 +1589,11 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi
TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
local vec3 = {{x = 1, y = 2, z = 3}} function mkvec3() return {x = 1, y = 2, z = 3} end
local vec1 = {{x = 1}} function mkvec1() return {x = 1} end
local vec3 = {mkvec3()}
local vec1 = {mkvec1()}
vec1 = vec3 vec1 = vec3
)"); )");
@ -1620,7 +1676,8 @@ TEST_CASE_FIXTURE(Fixture, "reasonable_error_when_adding_a_nonexistent_property_
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
--!strict --!strict
local A = {"value"} function mkA() return {"value"} end
local A = mkA()
A.B = "Hello" A.B = "Hello"
)"); )");
@ -1668,7 +1725,8 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties")
--!strict --!strict
local function f() local function f()
local t = { x = 1 } local function mkt() return { x = 1 } end
local t = mkt()
function t.a() end function t.a() end
function t.b() end function t.b() end
@ -1995,7 +2053,10 @@ caused by:
TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop")
{ {
ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag sff[]{
{"LuauTableSubtypingVariance2", true},
{"LuauUnsealedTableLiteral", true},
};
CheckResult result = check(R"( CheckResult result = check(R"(
local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end });
@ -2010,7 +2071,7 @@ local c2: typeof(a2) = b2
LUAU_REQUIRE_ERROR_COUNT(2, result); LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1'
caused by: caused by:
Type '{| x: number, y: string |}' could not be converted into '{| x: number, y: number |}' Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }'
caused by: caused by:
Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); Property 'y' is not compatible. Type 'string' could not be converted into 'number')");
@ -2018,7 +2079,7 @@ caused by:
{ {
CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2'
caused by: caused by:
Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: <a>(a) -> () |}' Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: <a>(a) -> () }'
caused by: caused by:
Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '<a>(a) -> ()'; different number of generic type parameters)"); Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '<a>(a) -> ()'; different number of generic type parameters)");
} }
@ -2026,7 +2087,7 @@ caused by:
{ {
CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2'
caused by: caused by:
Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: <a>(a) -> () |}' Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: <a>(a) -> () }'
caused by: caused by:
Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '<a>(a) -> ()')"); Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '<a>(a) -> ()')");
} }
@ -2059,6 +2120,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error")
{"LuauPropertiesGetExpectedType", true}, {"LuauPropertiesGetExpectedType", true},
{"LuauExpectedTypesOfProperties", true}, {"LuauExpectedTypesOfProperties", true},
{"LuauTableSubtypingVariance2", true}, {"LuauTableSubtypingVariance2", true},
{"LuauUnsealedTableLiteral", true},
}; };
CheckResult result = check(R"( CheckResult result = check(R"(
@ -2077,7 +2139,7 @@ local y: number = tmp.p.y
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper'
caused by: caused by:
Property 'p' is not compatible. Table type '{| x: number, y: number |}' not compatible with type 'Super' because the former has extra field 'y')"); Property 'p' is not compatible. Table type '{ x: number, y: number }' not compatible with type 'Super' because the former has extra field 'y')");
} }
TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer")
@ -2103,7 +2165,10 @@ a.p = { x = 9 }
TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call")
{ {
ScopedFastFlag luauFixRecursiveMetatableCall{"LuauFixRecursiveMetatableCall", true}; ScopedFastFlag sff[]{
{"LuauFixRecursiveMetatableCall", true},
{"LuauUnsealedTableLiteral", true},
};
CheckResult result = check(R"( CheckResult result = check(R"(
local b local b
@ -2112,7 +2177,7 @@ b()
)"); )");
LUAU_REQUIRE_ERROR_COUNT(1, result); LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable {| __call: t1 |}, { } })"); CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })");
} }
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -4525,7 +4525,9 @@ f(function(x) print(x) end)
} }
TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument")
{ {
ScopedFastFlag sff{"LuauUnsealedTableLiteral", true};
CheckResult result = check(R"( CheckResult result = check(R"(
local function sum<a>(x: a, y: a, f: (a, a) -> a) return f(x, y) end local function sum<a>(x: a, y: a, f: (a, a) -> a) return f(x, y) end
return sum(2, 3, function(a, b) return a + b end) return sum(2, 3, function(a, b) return a + b end)
@ -4549,7 +4551,7 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e
)"); )");
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
REQUIRE_EQ("{| c: number, s: number |}", toString(requireType("r"))); REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r")));
} }
TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded")
@ -4689,6 +4691,18 @@ a = setmetatable(a, { __call = function(x) end })
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(Fixture, "infer_through_group_expr")
{
ScopedFastFlag luauGroupExpectedType{"LuauGroupExpectedType", true};
CheckResult result = check(R"(
local function f(a: (number, number) -> number) return a(1, 3) end
f(((function(a, b) return a + b end)))
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "refine_and_or") TEST_CASE_FIXTURE(Fixture, "refine_and_or")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
@ -4743,46 +4757,75 @@ local c: X<B>
TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1")
{ {
ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; CheckResult result = check(R"(local a = if true then "true" else "false")");
ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; LUAU_REQUIRE_NO_ERRORS(result);
TypeId aType = requireType("a");
{ CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String);
CheckResult result = check(R"(local a = if true then "true" else "false")");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId aType = requireType("a");
CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String);
}
} }
TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2")
{ {
ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; // Test expression containing elseif
ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; CheckResult result = check(R"(
local a = if false then "a" elseif false then "b" else "c"
)");
LUAU_REQUIRE_NO_ERRORS(result);
TypeId aType = requireType("a");
CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String);
}
TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union")
{
ScopedFastFlag sff3{"LuauIfElseBranchTypeUnion", true};
{ {
// Test expression containing elseif CheckResult result = check(R"(local a: number? = if true then 42 else nil)");
CheckResult result = check(R"(
local a = if false then "a" elseif false then "b" else "c"
)");
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
TypeId aType = requireType("a"); CHECK_EQ(toString(requireType("a"), {true}), "number?");
CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String);
} }
} }
TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_1")
{ {
ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true};
ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true};
{ CheckResult result = check(R"(
CheckResult result = check(R"(local a = if true then "true" else 42)"); type X = {number | string}
// We currently require both true/false expressions to unify to the same type. However, we do intend to lift local a: X = if true then {"1", 2, 3} else {4, 5, 6}
// this restriction in the future. )");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeId aType = requireType("a"); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); CHECK_EQ(toString(requireType("a"), {true}), "{number | string}");
} }
TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_2")
{
ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true};
ScopedFastFlag luauIfElseBranchTypeUnion{ "LuauIfElseBranchTypeUnion", true };
CheckResult result = check(R"(
local a: number? = if true then 1 else nil
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_3")
{
ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true};
CheckResult result = check(R"(
local function times<T>(n: any, f: () -> T)
local result: {T} = {}
local res = f()
table.insert(result, if true then res else n)
return result
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(Fixture, "type_error_addition") TEST_CASE_FIXTURE(Fixture, "type_error_addition")
@ -5039,4 +5082,51 @@ end
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(Fixture, "table_oop")
{
CheckResult result = check(R"(
--!strict
local Class = {}
Class.__index = Class
type Class = typeof(setmetatable({} :: { x: number }, Class))
function Class.new(x: number): Class
return setmetatable({x = x}, Class)
end
function Class.getx(self: Class)
return self.x
end
function test()
local c = Class.new(42)
local n = c:getx()
local nn = c.x
print(string.format("%d %d", n, nn))
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash")
{
ScopedFastFlag luauMetatableAreEqualRecursion{"LuauMetatableAreEqualRecursion", true};
CheckResult result = check(R"(
local function getIt()
local y
y = setmetatable({}, y)
return y
end
local a = getIt()
local b = getIt()
local c = a or b
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -621,4 +621,328 @@ type Other = Packed<number, string>
CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T..., U...>' expects 2 type pack arguments, but only 1 is specified"); CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed<T..., U...>' expects 2 type pack arguments, but only 1 is specified");
} }
TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_explicit")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type Y<T, U = string> = { a: T, b: U }
local a: Y<number, number> = { a = 2, b = 3 }
local b: Y<number> = { a = 2, b = "s" }
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<number, number>");
CHECK_EQ(toString(requireType("b")), "Y<number, string>");
result = check(R"(
type Y<T = string> = { a: T }
local a: Y<number> = { a = 2 }
local b: Y<> = { a = "s" }
local c: Y = { a = "s" }
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<number>");
CHECK_EQ(toString(requireType("b")), "Y<string>");
CHECK_EQ(toString(requireType("c")), "Y<string>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_self")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type Y<T, U = T> = { a: T, b: U }
local a: Y<number> = { a = 2, b = 3 }
local b: Y<string> = { a = "h", b = "s" }
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<number, number>");
CHECK_EQ(toString(requireType("b")), "Y<string, string>");
result = check(R"(
type Y<T, U = (T, T) -> string> = { a: T, b: U }
local a: Y<number>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<number, (number, number) -> string>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_chained")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type Y<T, U = T, V = U> = { a: T, b: U, c: V }
local a: Y<number>
local b: Y<number, string>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<number, number, number>");
CHECK_EQ(toString(requireType("b")), "Y<number, string, string>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_explicit")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type Y<T... = (string, number)> = { a: (T...) -> () }
local a: Y<>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<string, number>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_ty")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type Y<T, U... = ...T> = { a: T, b: (U...) -> T }
local a: Y<number>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<number, ...number>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_tp")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type Y<T..., U... = T...> = { a: (T...) -> U... }
local a: Y<number, string>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string)>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_chained_tp")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type Y<T..., U... = T..., V... = U...> = { a: (T...) -> U..., b: (T...) -> V... }
local a: Y<number, string>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string), (number, string)>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_default_mixed_self")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type Y<T, U = T, V... = ...number, W... = (T, U, V...)> = { a: (T, U, V...) -> W... }
local a: Y<number>
local b: Y<number, string>
local c: Y<number, string, ...boolean>
local d: Y<number, string, ...boolean, ...() -> ()>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "Y<number, number, ...number, (number, number, ...number)>");
CHECK_EQ(toString(requireType("b")), "Y<number, string, ...number, (number, string, ...number)>");
CHECK_EQ(toString(requireType("c")), "Y<number, string, ...boolean, (number, string, ...boolean)>");
CHECK_EQ(toString(requireType("d")), "Y<number, string, ...boolean, ...() -> ()>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type Y<T = T> = { a: T }
local a: Y = { a = 2 }
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'");
result = check(R"(
type Y<T... = T...> = { a: (T...) -> () }
local a: Y<>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'");
result = check(R"(
type Y<T = string, U... = ...string> = { a: (T) -> U... }
local a: Y<...number>
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Generic type 'Y<T, U...>' expects at least 1 type argument, but none are specified");
result = check(R"(
type Packed<T> = (T) -> T
local a: Packed
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Type parameter list is required");
result = check(R"(
type Y<T, U = T, V> = { a: T }
local a: Y<number>
)");
LUAU_REQUIRE_ERRORS(result);
result = check(R"(
type Y<T..., U... = T..., V...> = { a: T }
local a: Y<...number>
)");
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "type_alias_default_export")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
fileResolver.source["Module/Types"] = R"(
export type A<T, U = string> = { a: T, b: U }
export type B<T, U = T> = { a: T, b: U }
export type C<T, U = (T, T) -> string> = { a: T, b: U }
export type D<T, U = T, V = U> = { a: T, b: U, c: V }
export type E<T... = (string, number)> = { a: (T...) -> () }
export type F<T, U... = ...T> = { a: T, b: (U...) -> T }
export type G<T..., U... = ()> = { b: (U...) -> T... }
export type H<T... = ()> = { b: (T...) -> T... }
return {}
)";
CheckResult resultTypes = frontend.check("Module/Types");
LUAU_REQUIRE_NO_ERRORS(resultTypes);
fileResolver.source["Module/Users"] = R"(
local Types = require(script.Parent.Types)
local a: Types.A<number>
local b: Types.B<number>
local c: Types.C<number>
local d: Types.D<number>
local e: Types.E<>
local eVoid: Types.E<()>
local f: Types.F<number>
local g: Types.G<...number>
local h: Types.H<>
)";
CheckResult resultUsers = frontend.check("Module/Users");
LUAU_REQUIRE_NO_ERRORS(resultUsers);
CHECK_EQ(toString(requireType("Module/Users", "a")), "A<number, string>");
CHECK_EQ(toString(requireType("Module/Users", "b")), "B<number, number>");
CHECK_EQ(toString(requireType("Module/Users", "c")), "C<number, (number, number) -> string>");
CHECK_EQ(toString(requireType("Module/Users", "d")), "D<number, number, number>");
CHECK_EQ(toString(requireType("Module/Users", "e")), "E<string, number>");
CHECK_EQ(toString(requireType("Module/Users", "eVoid")), "E<>");
CHECK_EQ(toString(requireType("Module/Users", "f")), "F<number, ...number>");
CHECK_EQ(toString(requireType("Module/Users", "g")), "G<...number, ()>");
CHECK_EQ(toString(requireType("Module/Users", "h")), "H<>");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_skip_brackets")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type Y<T... = ...string> = (T...) -> number
local a: Y
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "(...string) -> number");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_confusing_types")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type A<T, U = T, V... = ...any, W... = V...> = (T, V...) -> (U, W...)
type B = A<string, (number)>
type C = A<string, (number), (boolean)>
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("B"), {true}), "(string, ...any) -> (number, ...any)");
CHECK_EQ(toString(*lookupType("C"), {true}), "(string, boolean) -> (number, boolean)");
}
TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_recursive_type")
{
ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true};
ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true};
CheckResult result = check(R"(
type F<K = string, V = (K) -> ()> = (K) -> V
type R = { m: F<R> }
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(*lookupType("R"), {true}), "t1 where t1 = {| m: (t1) -> (t1) -> () |}");
}
TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check")
{
ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true};
ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true};
CheckResult result = check(R"(
local a: () -> (number, ...string)
local b: () -> (number, ...boolean)
a = b
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), R"(Type '() -> (number, ...boolean)' could not be converted into '() -> (number, ...string)'
caused by:
Type 'boolean' could not be converted into 'string')");
}
TEST_SUITE_END(); TEST_SUITE_END();