Sync to upstream/release/551 (#727)

* https://github.com/Roblox/luau/pull/719
* Improved `Failed to unify type packs` error message to be reported as
`Type pack 'X' could not be converted into 'Y'`
* https://github.com/Roblox/luau/pull/722
* 1% reduction in executed instruction count by removing a check in fast
call dispatch
* Additional fixes to reported error location of OOM errors in VM
* Improve `math.sqrt`, `math.floor` and `math.ceil` performance on
additional compilers and platforms (1-2% geomean improvement including
8-9% on math-cordic)
* All thrown exceptions by Luau analysis are derived from
`Luau::InternalCompilerError`
* When a call site has fewer arguments than required, error now reports
the location of the function name instead of the argument to the
function
* https://github.com/Roblox/luau/pull/724
* Fixed https://github.com/Roblox/luau/issues/725

Co-authored-by: Arseny Kapoulkine <arseny.kapoulkine@gmail.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
This commit is contained in:
vegorov-rbx 2022-10-28 03:37:29 -07:00 committed by GitHub
parent ad7d006fb2
commit a6b5051edc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
65 changed files with 3058 additions and 881 deletions

View File

@ -23,6 +23,30 @@ using ScopePtr = std::shared_ptr<Scope>;
struct DcrLogger; struct DcrLogger;
struct Inference
{
TypeId ty = nullptr;
Inference() = default;
explicit Inference(TypeId ty)
: ty(ty)
{
}
};
struct InferencePack
{
TypePackId tp = nullptr;
InferencePack() = default;
explicit InferencePack(TypePackId tp)
: tp(tp)
{
}
};
struct ConstraintGraphBuilder struct ConstraintGraphBuilder
{ {
// A list of all the scopes in the module. This vector holds ownership of the // A list of all the scopes in the module. This vector holds ownership of the
@ -130,8 +154,10 @@ struct ConstraintGraphBuilder
void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
void visit(const ScopePtr& scope, AstStatError* error); void visit(const ScopePtr& scope, AstStatError* error);
TypePackId checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<TypeId>& expectedTypes = {}); InferencePack checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<TypeId>& expectedTypes = {});
TypePackId checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector<TypeId>& expectedTypes = {}); InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector<TypeId>& expectedTypes = {});
InferencePack checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector<TypeId>& expectedTypes);
/** /**
* Checks an expression that is expected to evaluate to one type. * Checks an expression that is expected to evaluate to one type.
@ -141,18 +167,19 @@ struct ConstraintGraphBuilder
* surrounding context. Used to implement bidirectional type checking. * surrounding context. Used to implement bidirectional type checking.
* @return the type of the expression. * @return the type of the expression.
*/ */
TypeId check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}); Inference check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {});
TypeId check(const ScopePtr& scope, AstExprLocal* local); Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType);
TypeId check(const ScopePtr& scope, AstExprGlobal* global); Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType);
TypeId check(const ScopePtr& scope, AstExprIndexName* indexName); Inference check(const ScopePtr& scope, AstExprLocal* local);
TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); Inference check(const ScopePtr& scope, AstExprGlobal* global);
TypeId check(const ScopePtr& scope, AstExprUnary* unary); Inference check(const ScopePtr& scope, AstExprIndexName* indexName);
TypeId check_(const ScopePtr& scope, AstExprUnary* unary); Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
TypeId check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprUnary* unary);
TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
TypeId check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs); TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);
@ -202,7 +229,7 @@ struct ConstraintGraphBuilder
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(const ScopePtr& scope, AstArray<AstGenericType> generics); std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(const ScopePtr& scope, AstArray<AstGenericType> generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(const ScopePtr& scope, AstArray<AstGenericTypePack> packs); std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(const ScopePtr& scope, AstArray<AstGenericTypePack> packs);
TypeId flattenPack(const ScopePtr& scope, Location location, TypePackId tp); Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack);
void reportError(Location location, TypeErrorData err); void reportError(Location location, TypeErrorData err);
void reportCodeTooComplex(Location location); void reportCodeTooComplex(Location location);

View File

@ -7,6 +7,8 @@
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
LUAU_FASTFLAG(LuauIceExceptionInheritanceChange)
namespace Luau namespace Luau
{ {
struct TypeError; struct TypeError;
@ -302,12 +304,20 @@ struct NormalizationTooComplex
} }
}; };
struct TypePackMismatch
{
TypePackId wantedTp;
TypePackId givenTp;
bool operator==(const TypePackMismatch& rhs) const;
};
using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods, using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods,
DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire, DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire,
IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError, IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError,
CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning, CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning,
DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty,
TypesAreUnrelated, NormalizationTooComplex>; TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch>;
struct TypeError struct TypeError
{ {
@ -374,6 +384,10 @@ struct InternalErrorReporter
class InternalCompilerError : public std::exception class InternalCompilerError : public std::exception
{ {
public: public:
explicit InternalCompilerError(const std::string& message)
: message(message)
{
}
explicit InternalCompilerError(const std::string& message, const std::string& moduleName) explicit InternalCompilerError(const std::string& message, const std::string& moduleName)
: message(message) : message(message)
, moduleName(moduleName) , moduleName(moduleName)
@ -388,8 +402,14 @@ public:
virtual const char* what() const throw(); virtual const char* what() const throw();
const std::string message; const std::string message;
const std::string moduleName; const std::optional<std::string> moduleName;
const std::optional<Location> location; const std::optional<Location> location;
}; };
// These two function overloads only exist to facilitate fast flagging a change to InternalCompilerError
// Both functions can be removed when FFlagLuauIceExceptionInheritanceChange is removed and calling code
// can directly throw InternalCompilerError.
[[noreturn]] void throwRuntimeError(const std::string& message);
[[noreturn]] void throwRuntimeError(const std::string& message, const std::string& moduleName);
} // namespace Luau } // namespace Luau

View File

@ -106,9 +106,68 @@ struct std::equal_to<const Luau::TypeIds*>
namespace Luau namespace Luau
{ {
// A normalized string type is either `string` (represented by `nullopt`) /** A normalized string type is either `string` (represented by `nullopt`) or a
// or a union of string singletons. * union of string singletons.
using NormalizedStringType = std::optional<std::map<std::string, TypeId>>; *
* When FFlagLuauNegatedStringSingletons is unset, the representation is as
* follows:
*
* * The `string` data type is represented by the option `singletons` having the
* value `std::nullopt`.
* * The type `never` is represented by `singletons` being populated with an
* empty map.
* * A union of string singletons is represented by a map populated by the names
* and TypeIds of the singletons contained therein.
*
* When FFlagLuauNegatedStringSingletons is set, the representation is as
* follows:
*
* * A union of string singletons is finite and includes the singletons named by
* the `singletons` field.
* * An intersection of negated string singletons is cofinite and includes the
* singletons excluded by the `singletons` field. It is implied that cofinite
* values are exclusions from `string` itself.
* * The `string` data type is a cofinite set minus zero elements.
* * The `never` data type is a finite set plus zero elements.
*/
struct NormalizedStringType
{
// When false, this type represents a union of singleton string types.
// eg "a" | "b" | "c"
//
// When true, this type represents string intersected with negated string
// singleton types.
// eg string & ~"a" & ~"b" & ...
bool isCofinite = false;
// TODO: This field cannot be nullopt when FFlagLuauNegatedStringSingletons
// is set. When clipping that flag, we can remove the wrapping optional.
std::optional<std::map<std::string, TypeId>> singletons;
void resetToString();
void resetToNever();
bool isNever() const;
bool isString() const;
/// Returns true if the string has finite domain.
///
/// Important subtlety: This method returns true for `never`. The empty set
/// is indeed an empty set.
bool isUnion() const;
/// Returns true if the string has infinite domain.
bool isIntersection() const;
bool includes(const std::string& str) const;
static const NormalizedStringType never;
NormalizedStringType() = default;
NormalizedStringType(bool isCofinite, std::optional<std::map<std::string, TypeId>> singletons);
};
bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr);
// A normalized function type is either `never` (represented by `nullopt`) // A normalized function type is either `never` (represented by `nullopt`)
// or an intersection of function types. // or an intersection of function types.
@ -157,7 +216,7 @@ struct NormalizedType
// The string part of the type. // The string part of the type.
// This may be the `string` type, or a union of singletons. // This may be the `string` type, or a union of singletons.
NormalizedStringType strings = std::map<std::string, TypeId>{}; NormalizedStringType strings;
// The thread part of the type. // The thread part of the type.
// This type is either never or thread. // This type is either never or thread.
@ -231,8 +290,14 @@ public:
bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1);
// ------- Negations
NormalizedType negateNormal(const NormalizedType& here);
TypeIds negateAll(const TypeIds& theres);
TypeId negate(TypeId there);
void subtractPrimitive(NormalizedType& here, TypeId ty);
void subtractSingleton(NormalizedType& here, TypeId ty);
// ------- Normalizing intersections // ------- Normalizing intersections
void intersectTysWithTy(TypeIds& here, TypeId there);
TypeId intersectionOfTops(TypeId here, TypeId there); TypeId intersectionOfTops(TypeId here, TypeId there);
TypeId intersectionOfBools(TypeId here, TypeId there); TypeId intersectionOfBools(TypeId here, TypeId there);
void intersectClasses(TypeIds& heres, const TypeIds& theres); void intersectClasses(TypeIds& heres, const TypeIds& theres);

View File

@ -2,6 +2,7 @@
#pragma once #pragma once
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Error.h"
#include <stdexcept> #include <stdexcept>
#include <exception> #include <exception>
@ -9,10 +10,20 @@
namespace Luau namespace Luau
{ {
struct RecursionLimitException : public std::exception struct RecursionLimitException : public InternalCompilerError
{
RecursionLimitException()
: InternalCompilerError("Internal recursion counter limit exceeded")
{
LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange);
}
};
struct RecursionLimitException_DEPRECATED : public std::exception
{ {
const char* what() const noexcept const char* what() const noexcept
{ {
LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange);
return "Internal recursion counter limit exceeded"; return "Internal recursion counter limit exceeded";
} }
}; };
@ -42,7 +53,14 @@ struct RecursionLimiter : RecursionCounter
{ {
if (limit > 0 && *count > limit) if (limit > 0 && *count > limit)
{ {
throw RecursionLimitException(); if (FFlag::LuauIceExceptionInheritanceChange)
{
throw RecursionLimitException();
}
else
{
throw RecursionLimitException_DEPRECATED();
}
} }
} }
}; };

View File

@ -117,6 +117,8 @@ inline std::string toStringNamedFunction(const std::string& funcName, const Func
return toStringNamedFunction(funcName, ftv, opts); return toStringNamedFunction(funcName, ftv, opts);
} }
std::optional<std::string> getFunctionNameAsString(const AstExpr& expr);
// It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class
// These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression
std::string dump(TypeId ty); std::string dump(TypeId ty);

View File

@ -48,7 +48,17 @@ struct HashBoolNamePair
size_t operator()(const std::pair<bool, Name>& pair) const; size_t operator()(const std::pair<bool, Name>& pair) const;
}; };
class TimeLimitError : public std::exception class TimeLimitError : public InternalCompilerError
{
public:
explicit TimeLimitError(const std::string& moduleName)
: InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName)
{
LUAU_ASSERT(FFlag::LuauIceExceptionInheritanceChange);
}
};
class TimeLimitError_DEPRECATED : public std::exception
{ {
public: public:
virtual const char* what() const throw(); virtual const char* what() const throw();
@ -236,6 +246,7 @@ public:
[[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message); [[noreturn]] void ice(const std::string& message);
[[noreturn]] void throwTimeLimitError();
ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0);
ScopePtr childScope(const ScopePtr& parent, const Location& location); ScopePtr childScope(const ScopePtr& parent, const Location& location);

View File

@ -96,6 +96,8 @@ private:
void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy);
void tryUnifyNegationWithType(TypeId subTy, TypeId superTy);
TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args); TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args);

View File

@ -11,6 +11,8 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAGVARIABLE(LuauCheckOverloadedDocSymbol, false)
namespace Luau namespace Luau
{ {
@ -430,6 +432,8 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos)
static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol( static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol(
const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional<DocumentationSymbol> documentationSymbol) const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional<DocumentationSymbol> documentationSymbol)
{ {
LUAU_ASSERT(FFlag::LuauCheckOverloadedDocSymbol);
if (!documentationSymbol) if (!documentationSymbol)
return std::nullopt; return std::nullopt;
@ -466,7 +470,38 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
if (std::optional<Binding> binding = findBindingAtPosition(module, source, position)) if (std::optional<Binding> binding = findBindingAtPosition(module, source, position))
{ {
return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol); if (FFlag::LuauCheckOverloadedDocSymbol)
{
return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol);
}
else
{
if (binding->documentationSymbol)
{
// This might be an overloaded function binding.
if (get<IntersectionTypeVar>(follow(binding->typeId)))
{
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = *it;
}
}
if (matchingOverload)
{
std::string overloadSymbol = *binding->documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
}
}
}
return binding->documentationSymbol;
}
} }
if (targetExpr) if (targetExpr)
@ -480,14 +515,20 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
{ {
if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end())
{ {
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); if (FFlag::LuauCheckOverloadedDocSymbol)
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
else
return propIt->second.documentationSymbol;
} }
} }
else if (const ClassTypeVar* ctv = get<ClassTypeVar>(parentTy)) else if (const ClassTypeVar* ctv = get<ClassTypeVar>(parentTy))
{ {
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
{ {
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); if (FFlag::LuauCheckOverloadedDocSymbol)
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
else
return propIt->second.documentationSymbol;
} }
} }
} }

View File

@ -263,7 +263,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
if (hasAnnotation) if (hasAnnotation)
expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes)); expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes));
TypePackId exprPack = checkPack(scope, value, expectedTypes); TypePackId exprPack = checkPack(scope, value, expectedTypes).tp;
if (i < local->vars.size) if (i < local->vars.size)
{ {
@ -292,7 +292,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
if (hasAnnotation) if (hasAnnotation)
expectedType = varTypes.at(i); expectedType = varTypes.at(i);
TypeId exprType = check(scope, value, expectedType); TypeId exprType = check(scope, value, expectedType).ty;
if (i < varTypes.size()) if (i < varTypes.size())
{ {
if (varTypes[i]) if (varTypes[i])
@ -350,7 +350,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_)
if (!expr) if (!expr)
return; return;
TypeId t = check(scope, expr); TypeId t = check(scope, expr).ty;
addConstraint(scope, expr->location, SubtypeConstraint{t, singletonTypes->numberType}); addConstraint(scope, expr->location, SubtypeConstraint{t, singletonTypes->numberType});
}; };
@ -368,7 +368,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn)
{ {
ScopePtr loopScope = childScope(forIn, scope); ScopePtr loopScope = childScope(forIn, scope);
TypePackId iterator = checkPack(scope, forIn->values); TypePackId iterator = checkPack(scope, forIn->values).tp;
std::vector<TypeId> variableTypes; std::vector<TypeId> variableTypes;
variableTypes.reserve(forIn->vars.size); variableTypes.reserve(forIn->vars.size);
@ -489,7 +489,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct
} }
else if (AstExprIndexName* indexName = function->name->as<AstExprIndexName>()) else if (AstExprIndexName* indexName = function->name->as<AstExprIndexName>())
{ {
TypeId containingTableType = check(scope, indexName->expr); TypeId containingTableType = check(scope, indexName->expr).ty;
functionType = arena->addType(BlockedTypeVar{}); functionType = arena->addType(BlockedTypeVar{});
@ -531,7 +531,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret)
for (TypeId ty : scope->returnType) for (TypeId ty : scope->returnType)
expectedTypes.push_back(ty); expectedTypes.push_back(ty);
TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes); TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp;
addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType});
} }
@ -545,7 +545,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign)
{ {
TypePackId varPackId = checkLValues(scope, assign->vars); TypePackId varPackId = checkLValues(scope, assign->vars);
TypePackId valuePack = checkPack(scope, assign->values); TypePackId valuePack = checkPack(scope, assign->values).tp;
addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId});
} }
@ -732,7 +732,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global)
{ {
std::vector<std::pair<Name, GenericTypeDefinition>> generics = createGenerics(scope, global->generics); std::vector<std::pair<Name, GenericTypeDefinition>> generics = createGenerics(scope, global->generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> genericPacks = createGenericPacks(scope, global->genericPacks); std::vector<std::pair<Name, GenericTypePackDefinition>> genericPacks = createGenericPacks(scope, global->genericPacks);
@ -779,7 +778,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error)
check(scope, expr); check(scope, expr);
} }
TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<TypeId>& expectedTypes) InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<TypeId>& expectedTypes)
{ {
std::vector<TypeId> head; std::vector<TypeId> head;
std::optional<TypePackId> tail; std::optional<TypePackId> tail;
@ -792,201 +791,180 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray<Ast
std::optional<TypeId> expectedType; std::optional<TypeId> expectedType;
if (i < expectedTypes.size()) if (i < expectedTypes.size())
expectedType = expectedTypes[i]; expectedType = expectedTypes[i];
head.push_back(check(scope, expr)); head.push_back(check(scope, expr).ty);
} }
else else
{ {
std::vector<TypeId> expectedTailTypes; std::vector<TypeId> expectedTailTypes;
if (i < expectedTypes.size()) if (i < expectedTypes.size())
expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes));
tail = checkPack(scope, expr, expectedTailTypes); tail = checkPack(scope, expr, expectedTailTypes).tp;
} }
} }
if (head.empty() && tail) if (head.empty() && tail)
return *tail; return InferencePack{*tail};
else else
return arena->addTypePack(TypePack{std::move(head), tail}); return InferencePack{arena->addTypePack(TypePack{std::move(head), tail})};
} }
TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector<TypeId>& expectedTypes) InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector<TypeId>& expectedTypes)
{ {
RecursionCounter counter{&recursionCount}; RecursionCounter counter{&recursionCount};
if (recursionCount >= FInt::LuauCheckRecursionLimit) if (recursionCount >= FInt::LuauCheckRecursionLimit)
{ {
reportCodeTooComplex(expr->location); reportCodeTooComplex(expr->location);
return singletonTypes->errorRecoveryTypePack(); return InferencePack{singletonTypes->errorRecoveryTypePack()};
} }
TypePackId result = nullptr; InferencePack result;
if (AstExprCall* call = expr->as<AstExprCall>()) if (AstExprCall* call = expr->as<AstExprCall>())
{ result = {checkPack(scope, call, expectedTypes)};
TypeId fnType = check(scope, call->func);
const size_t constraintIndex = scope->constraints.size();
const size_t scopeIndex = scopes.size();
std::vector<TypeId> args;
for (AstExpr* arg : call->args)
{
args.push_back(check(scope, arg));
}
// TODO self
if (matchSetmetatable(*call))
{
LUAU_ASSERT(args.size() == 2);
TypeId target = args[0];
TypeId mt = args[1];
MetatableTypeVar mtv{target, mt};
TypeId resultTy = arena->addType(mtv);
result = arena->addTypePack({resultTy});
}
else
{
const size_t constraintEndIndex = scope->constraints.size();
const size_t scopeEndIndex = scopes.size();
astOriginalCallTypes[call->func] = fnType;
TypeId instantiatedType = arena->addType(BlockedTypeVar{});
// TODO: How do expectedTypes play into this? Do they?
TypePackId rets = arena->addTypePack(BlockedTypePack{});
TypePackId argPack = arena->addTypePack(TypePack{args, {}});
FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets);
TypeId inferredFnType = arena->addType(ftv);
scope->unqueuedConstraints.push_back(
std::make_unique<Constraint>(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType}));
NotNull<const Constraint> ic(scope->unqueuedConstraints.back().get());
scope->unqueuedConstraints.push_back(
std::make_unique<Constraint>(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType}));
NotNull<Constraint> sc(scope->unqueuedConstraints.back().get());
// We force constraints produced by checking function arguments to wait
// until after we have resolved the constraint on the function itself.
// This ensures, for instance, that we start inferring the contents of
// lambdas under the assumption that their arguments and return types
// will be compatible with the enclosing function call.
for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci)
scope->constraints[ci]->dependencies.push_back(sc);
for (size_t si = scopeIndex; si < scopeEndIndex; ++si)
{
for (auto& c : scopes[si].second->constraints)
{
c->dependencies.push_back(sc);
}
}
addConstraint(scope, call->func->location,
FunctionCallConstraint{
{ic, sc},
fnType,
argPack,
rets,
call,
});
result = rets;
}
}
else if (AstExprVarargs* varargs = expr->as<AstExprVarargs>()) else if (AstExprVarargs* varargs = expr->as<AstExprVarargs>())
{ {
if (scope->varargPack) if (scope->varargPack)
result = *scope->varargPack; result = InferencePack{*scope->varargPack};
else else
result = singletonTypes->errorRecoveryTypePack(); result = InferencePack{singletonTypes->errorRecoveryTypePack()};
} }
else else
{ {
std::optional<TypeId> expectedType; std::optional<TypeId> expectedType;
if (!expectedTypes.empty()) if (!expectedTypes.empty())
expectedType = expectedTypes[0]; expectedType = expectedTypes[0];
TypeId t = check(scope, expr, expectedType); TypeId t = check(scope, expr, expectedType).ty;
result = arena->addTypePack({t}); result = InferencePack{arena->addTypePack({t})};
} }
LUAU_ASSERT(result); LUAU_ASSERT(result.tp);
astTypePacks[expr] = result; astTypePacks[expr] = result.tp;
return result; return result;
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType) InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector<TypeId>& expectedTypes)
{
TypeId fnType = check(scope, call->func).ty;
const size_t constraintIndex = scope->constraints.size();
const size_t scopeIndex = scopes.size();
std::vector<TypeId> args;
for (AstExpr* arg : call->args)
{
args.push_back(check(scope, arg).ty);
}
// TODO self
if (matchSetmetatable(*call))
{
LUAU_ASSERT(args.size() == 2);
TypeId target = args[0];
TypeId mt = args[1];
AstExpr* targetExpr = call->args.data[0];
MetatableTypeVar mtv{target, mt};
TypeId resultTy = arena->addType(mtv);
if (AstExprLocal* targetLocal = targetExpr->as<AstExprLocal>())
scope->bindings[targetLocal->local].typeId = resultTy;
return InferencePack{arena->addTypePack({resultTy})};
}
else
{
const size_t constraintEndIndex = scope->constraints.size();
const size_t scopeEndIndex = scopes.size();
astOriginalCallTypes[call->func] = fnType;
TypeId instantiatedType = arena->addType(BlockedTypeVar{});
// TODO: How do expectedTypes play into this? Do they?
TypePackId rets = arena->addTypePack(BlockedTypePack{});
TypePackId argPack = arena->addTypePack(TypePack{args, {}});
FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets);
TypeId inferredFnType = arena->addType(ftv);
scope->unqueuedConstraints.push_back(
std::make_unique<Constraint>(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType}));
NotNull<const Constraint> ic(scope->unqueuedConstraints.back().get());
scope->unqueuedConstraints.push_back(
std::make_unique<Constraint>(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType}));
NotNull<Constraint> sc(scope->unqueuedConstraints.back().get());
// We force constraints produced by checking function arguments to wait
// until after we have resolved the constraint on the function itself.
// This ensures, for instance, that we start inferring the contents of
// lambdas under the assumption that their arguments and return types
// will be compatible with the enclosing function call.
for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci)
scope->constraints[ci]->dependencies.push_back(sc);
for (size_t si = scopeIndex; si < scopeEndIndex; ++si)
{
for (auto& c : scopes[si].second->constraints)
{
c->dependencies.push_back(sc);
}
}
addConstraint(scope, call->func->location,
FunctionCallConstraint{
{ic, sc},
fnType,
argPack,
rets,
call,
});
return InferencePack{rets};
}
}
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType)
{ {
RecursionCounter counter{&recursionCount}; RecursionCounter counter{&recursionCount};
if (recursionCount >= FInt::LuauCheckRecursionLimit) if (recursionCount >= FInt::LuauCheckRecursionLimit)
{ {
reportCodeTooComplex(expr->location); reportCodeTooComplex(expr->location);
return singletonTypes->errorRecoveryType(); return Inference{singletonTypes->errorRecoveryType()};
} }
TypeId result = nullptr; Inference result;
if (auto group = expr->as<AstExprGroup>()) if (auto group = expr->as<AstExprGroup>())
result = check(scope, group->expr, expectedType); result = check(scope, group->expr, expectedType);
else if (auto stringExpr = expr->as<AstExprConstantString>()) else if (auto stringExpr = expr->as<AstExprConstantString>())
{ result = check(scope, stringExpr, expectedType);
if (expectedType)
{
const TypeId expectedTy = follow(*expectedType);
if (get<BlockedTypeVar>(expectedTy) || get<PendingExpansionTypeVar>(expectedTy))
{
result = arena->addType(BlockedTypeVar{});
TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringExpr->value.data, stringExpr->value.size)}));
addConstraint(scope, expr->location, PrimitiveTypeConstraint{result, expectedTy, singletonType, singletonTypes->stringType});
}
else if (maybeSingleton(expectedTy))
result = arena->addType(SingletonTypeVar{StringSingleton{std::string{stringExpr->value.data, stringExpr->value.size}}});
else
result = singletonTypes->stringType;
}
else
result = singletonTypes->stringType;
}
else if (expr->is<AstExprConstantNumber>()) else if (expr->is<AstExprConstantNumber>())
result = singletonTypes->numberType; result = Inference{singletonTypes->numberType};
else if (auto boolExpr = expr->as<AstExprConstantBool>()) else if (auto boolExpr = expr->as<AstExprConstantBool>())
{ result = check(scope, boolExpr, expectedType);
if (expectedType)
{
const TypeId expectedTy = follow(*expectedType);
const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType;
if (get<BlockedTypeVar>(expectedTy) || get<PendingExpansionTypeVar>(expectedTy))
{
result = arena->addType(BlockedTypeVar{});
addConstraint(scope, expr->location, PrimitiveTypeConstraint{result, expectedTy, singletonType, singletonTypes->booleanType});
}
else if (maybeSingleton(expectedTy))
result = singletonType;
else
result = singletonTypes->booleanType;
}
else
result = singletonTypes->booleanType;
}
else if (expr->is<AstExprConstantNil>()) else if (expr->is<AstExprConstantNil>())
result = singletonTypes->nilType; result = Inference{singletonTypes->nilType};
else if (auto local = expr->as<AstExprLocal>()) else if (auto local = expr->as<AstExprLocal>())
result = check(scope, local); result = check(scope, local);
else if (auto global = expr->as<AstExprGlobal>()) else if (auto global = expr->as<AstExprGlobal>())
result = check(scope, global); result = check(scope, global);
else if (expr->is<AstExprVarargs>()) else if (expr->is<AstExprVarargs>())
result = flattenPack(scope, expr->location, checkPack(scope, expr)); result = flattenPack(scope, expr->location, checkPack(scope, expr));
else if (expr->is<AstExprCall>()) else if (auto call = expr->as<AstExprCall>())
result = flattenPack(scope, expr->location, checkPack(scope, expr)); // TODO: needs predicates too {
std::vector<TypeId> expectedTypes;
if (expectedType)
expectedTypes.push_back(*expectedType);
result = flattenPack(scope, expr->location, checkPack(scope, call, expectedTypes)); // TODO: needs predicates too
}
else if (auto a = expr->as<AstExprFunction>()) else if (auto a = expr->as<AstExprFunction>())
{ {
FunctionSignature sig = checkFunctionSignature(scope, a); FunctionSignature sig = checkFunctionSignature(scope, a);
checkFunctionBody(sig.bodyScope, a); checkFunctionBody(sig.bodyScope, a);
return sig.signature; return Inference{sig.signature};
} }
else if (auto indexName = expr->as<AstExprIndexName>()) else if (auto indexName = expr->as<AstExprIndexName>())
result = check(scope, indexName); result = check(scope, indexName);
@ -1008,20 +986,63 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::
for (AstExpr* subExpr : err->expressions) for (AstExpr* subExpr : err->expressions)
check(scope, subExpr); check(scope, subExpr);
result = singletonTypes->errorRecoveryType(); result = Inference{singletonTypes->errorRecoveryType()};
} }
else else
{ {
LUAU_ASSERT(0); LUAU_ASSERT(0);
result = freshType(scope); result = Inference{freshType(scope)};
} }
LUAU_ASSERT(result); LUAU_ASSERT(result.ty);
astTypes[expr] = result; astTypes[expr] = result.ty;
return result; return result;
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType)
{
if (expectedType)
{
const TypeId expectedTy = follow(*expectedType);
if (get<BlockedTypeVar>(expectedTy) || get<PendingExpansionTypeVar>(expectedTy))
{
TypeId ty = arena->addType(BlockedTypeVar{});
TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(string->value.data, string->value.size)}));
addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->stringType});
return Inference{ty};
}
else if (maybeSingleton(expectedTy))
return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})};
return Inference{singletonTypes->stringType};
}
return Inference{singletonTypes->stringType};
}
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional<TypeId> expectedType)
{
if (expectedType)
{
const TypeId expectedTy = follow(*expectedType);
const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType;
if (get<BlockedTypeVar>(expectedTy) || get<PendingExpansionTypeVar>(expectedTy))
{
TypeId ty = arena->addType(BlockedTypeVar{});
addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->booleanType});
return Inference{ty};
}
else if (maybeSingleton(expectedTy))
return Inference{singletonType};
return Inference{singletonTypes->booleanType};
}
return Inference{singletonTypes->booleanType};
}
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local)
{ {
std::optional<TypeId> resultTy; std::optional<TypeId> resultTy;
@ -1035,26 +1056,26 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local)
} }
if (!resultTy) if (!resultTy)
return singletonTypes->errorRecoveryType(); // TODO: replace with ice, locals should never exist before its definition. return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition.
return *resultTy; return Inference{*resultTy};
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global)
{ {
if (std::optional<TypeId> ty = scope->lookup(global->name)) if (std::optional<TypeId> ty = scope->lookup(global->name))
return *ty; return Inference{*ty};
/* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any
* global that is not already in-scope is definitely an unknown symbol. * global that is not already in-scope is definitely an unknown symbol.
*/ */
reportError(global->location, UnknownSymbol{global->name.value}); reportError(global->location, UnknownSymbol{global->name.value});
return singletonTypes->errorRecoveryType(); return Inference{singletonTypes->errorRecoveryType()};
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName)
{ {
TypeId obj = check(scope, indexName->expr); TypeId obj = check(scope, indexName->expr).ty;
TypeId result = freshType(scope); TypeId result = freshType(scope);
TableTypeVar::Props props{{indexName->index.value, Property{result}}}; TableTypeVar::Props props{{indexName->index.value, Property{result}}};
@ -1065,13 +1086,13 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* in
addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType});
return result; return Inference{result};
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr)
{ {
TypeId obj = check(scope, indexExpr->expr); TypeId obj = check(scope, indexExpr->expr).ty;
TypeId indexType = check(scope, indexExpr->index); TypeId indexType = check(scope, indexExpr->index).ty;
TypeId result = freshType(scope); TypeId result = freshType(scope);
@ -1081,61 +1102,49 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* in
addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType});
return result; return Inference{result};
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary)
{ {
TypeId operandType = check_(scope, unary); TypeId operandType = check(scope, unary->expr).ty;
TypeId resultType = arena->addType(BlockedTypeVar{}); TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType});
return resultType; return Inference{resultType};
} }
TypeId ConstraintGraphBuilder::check_(const ScopePtr& scope, AstExprUnary* unary) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType)
{ {
if (unary->op == AstExprUnary::Not) TypeId leftType = check(scope, binary->left, expectedType).ty;
{ TypeId rightType = check(scope, binary->right, expectedType).ty;
TypeId ty = check(scope, unary->expr, std::nullopt);
return ty;
}
return check(scope, unary->expr);
}
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType)
{
TypeId leftType = check(scope, binary->left, expectedType);
TypeId rightType = check(scope, binary->right, expectedType);
TypeId resultType = arena->addType(BlockedTypeVar{}); TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType});
return resultType; return Inference{resultType};
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType)
{ {
check(scope, ifElse->condition); check(scope, ifElse->condition);
TypeId thenType = check(scope, ifElse->trueExpr, expectedType); TypeId thenType = check(scope, ifElse->trueExpr, expectedType).ty;
TypeId elseType = check(scope, ifElse->falseExpr, expectedType); TypeId elseType = check(scope, ifElse->falseExpr, expectedType).ty;
if (ifElse->hasElse) if (ifElse->hasElse)
{ {
TypeId resultType = expectedType ? *expectedType : freshType(scope); TypeId resultType = expectedType ? *expectedType : freshType(scope);
addConstraint(scope, ifElse->trueExpr->location, SubtypeConstraint{thenType, resultType}); addConstraint(scope, ifElse->trueExpr->location, SubtypeConstraint{thenType, resultType});
addConstraint(scope, ifElse->falseExpr->location, SubtypeConstraint{elseType, resultType}); addConstraint(scope, ifElse->falseExpr->location, SubtypeConstraint{elseType, resultType});
return resultType; return Inference{resultType};
} }
return thenType; return Inference{thenType};
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert)
{ {
check(scope, typeAssert->expr, std::nullopt); check(scope, typeAssert->expr, std::nullopt);
return resolveType(scope, typeAssert->annotation); return Inference{resolveType(scope, typeAssert->annotation)};
} }
TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs) TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs)
@ -1286,22 +1295,22 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
auto dottedPath = extractDottedName(expr); auto dottedPath = extractDottedName(expr);
if (!dottedPath) if (!dottedPath)
return check(scope, expr); return check(scope, expr).ty;
const auto [sym, segments] = std::move(*dottedPath); const auto [sym, segments] = std::move(*dottedPath);
if (!sym.local) if (!sym.local)
return check(scope, expr); return check(scope, expr).ty;
auto lookupResult = scope->lookupEx(sym); auto lookupResult = scope->lookupEx(sym);
if (!lookupResult) if (!lookupResult)
return check(scope, expr); return check(scope, expr).ty;
const auto [ty, symbolScope] = std::move(*lookupResult); const auto [ty, symbolScope] = std::move(*lookupResult);
TypeId replaceTy = arena->freshType(scope.get()); TypeId replaceTy = arena->freshType(scope.get());
std::optional<TypeId> updatedType = updateTheTableType(arena, ty, segments, replaceTy); std::optional<TypeId> updatedType = updateTheTableType(arena, ty, segments, replaceTy);
if (!updatedType) if (!updatedType)
return check(scope, expr); return check(scope, expr).ty;
std::optional<DefId> def = dfg->getDef(sym); std::optional<DefId> def = dfg->getDef(sym);
LUAU_ASSERT(def); LUAU_ASSERT(def);
@ -1310,7 +1319,7 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
return replaceTy; return replaceTy;
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType)
{ {
TypeId ty = arena->addType(TableTypeVar{}); TypeId ty = arena->addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(ty); TableTypeVar* ttv = getMutable<TableTypeVar>(ty);
@ -1344,16 +1353,14 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr,
} }
} }
TypeId itemTy = check(scope, item.value, expectedValueType); TypeId itemTy = check(scope, item.value, expectedValueType).ty;
if (get<ErrorTypeVar>(follow(itemTy)))
return ty;
if (item.key) if (item.key)
{ {
// Even though we don't need to use the type of the item's key if // Even though we don't need to use the type of the item's key if
// it's a string constant, we still want to check it to populate // it's a string constant, we still want to check it to populate
// astTypes. // astTypes.
TypeId keyTy = check(scope, item.key); TypeId keyTy = check(scope, item.key).ty;
if (AstExprConstantString* key = item.key->as<AstExprConstantString>()) if (AstExprConstantString* key = item.key->as<AstExprConstantString>())
{ {
@ -1373,7 +1380,7 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr,
} }
} }
return ty; return Inference{ty};
} }
ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn) ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn)
@ -1541,9 +1548,18 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b
} }
} }
std::optional<TypeFun> alias = scope->lookupType(ref->name.value); std::optional<TypeFun> alias;
if (alias.has_value() || ref->prefix.has_value()) if (ref->prefix.has_value())
{
alias = scope->lookupImportedType(ref->prefix->value, ref->name.value);
}
else
{
alias = scope->lookupType(ref->name.value);
}
if (alias.has_value())
{ {
// If the alias is not generic, we don't need to set up a blocked // If the alias is not generic, we don't need to set up a blocked
// type and an instantiation constraint. // type and an instantiation constraint.
@ -1586,7 +1602,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b
} }
else else
{ {
reportError(ty->location, UnknownSymbol{ref->name.value, UnknownSymbol::Context::Type}); std::string typeName;
if (ref->prefix)
typeName = std::string(ref->prefix->value) + ".";
typeName += ref->name.value;
result = singletonTypes->errorRecoveryType(); result = singletonTypes->errorRecoveryType();
} }
} }
@ -1685,7 +1705,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b
else if (auto tof = ty->as<AstTypeTypeof>()) else if (auto tof = ty->as<AstTypeTypeof>())
{ {
// TODO: Recursion limit. // TODO: Recursion limit.
TypeId exprType = check(scope, tof->expr); TypeId exprType = check(scope, tof->expr).ty;
result = exprType; result = exprType;
} }
else if (auto unionAnnotation = ty->as<AstTypeUnion>()) else if (auto unionAnnotation = ty->as<AstTypeUnion>())
@ -1694,7 +1714,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b
for (AstType* part : unionAnnotation->types) for (AstType* part : unionAnnotation->types)
{ {
// TODO: Recursion limit. // TODO: Recursion limit.
parts.push_back(resolveType(scope, part)); parts.push_back(resolveType(scope, part, topLevel));
} }
result = arena->addType(UnionTypeVar{parts}); result = arena->addType(UnionTypeVar{parts});
@ -1705,7 +1725,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b
for (AstType* part : intersectionAnnotation->types) for (AstType* part : intersectionAnnotation->types)
{ {
// TODO: Recursion limit. // TODO: Recursion limit.
parts.push_back(resolveType(scope, part)); parts.push_back(resolveType(scope, part, topLevel));
} }
result = arena->addType(IntersectionTypeVar{parts}); result = arena->addType(IntersectionTypeVar{parts});
@ -1795,10 +1815,7 @@ std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::crea
if (generic.defaultValue) if (generic.defaultValue)
defaultTy = resolveType(scope, generic.defaultValue); defaultTy = resolveType(scope, generic.defaultValue);
result.push_back({generic.name.value, GenericTypeDefinition{ result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}});
genericTy,
defaultTy,
}});
} }
return result; return result;
@ -1816,19 +1833,17 @@ std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGraphBuilder::
if (generic.defaultValue) if (generic.defaultValue)
defaultTy = resolveTypePack(scope, generic.defaultValue); defaultTy = resolveTypePack(scope, generic.defaultValue);
result.push_back({generic.name.value, GenericTypePackDefinition{ result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}});
genericTy,
defaultTy,
}});
} }
return result; return result;
} }
TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, TypePackId tp) Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack)
{ {
auto [tp] = pack;
if (auto f = first(tp)) if (auto f = first(tp))
return *f; return Inference{*f};
TypeId typeResult = freshType(scope); TypeId typeResult = freshType(scope);
TypePack onePack{{typeResult}, freshTypePack(scope)}; TypePack onePack{{typeResult}, freshTypePack(scope)};
@ -1836,7 +1851,7 @@ TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location locat
addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack});
return typeResult; return Inference{typeResult};
} }
void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err)

View File

@ -544,6 +544,7 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull<const Const
} }
case AstExprUnary::Len: case AstExprUnary::Len:
{ {
// __len must return a number.
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->numberType); asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->numberType);
return true; return true;
} }
@ -552,13 +553,46 @@ bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull<const Const
if (isNumber(operandType) || get<AnyTypeVar>(operandType) || get<ErrorTypeVar>(operandType)) if (isNumber(operandType) || get<AnyTypeVar>(operandType) || get<ErrorTypeVar>(operandType))
{ {
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(c.operandType); asMutable(c.resultType)->ty.emplace<BoundTypeVar>(c.operandType);
return true;
} }
break; else if (std::optional<TypeId> mm = findMetatableEntry(singletonTypes, errors, operandType, "__unm", constraint->location))
{
const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*mm));
if (!ftv)
{
if (std::optional<TypeId> callMm = findMetatableEntry(singletonTypes, errors, follow(*mm), "__call", constraint->location))
{
ftv = get<FunctionTypeVar>(follow(*callMm));
}
}
if (!ftv)
{
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->errorRecoveryType());
return true;
}
TypePackId argsPack = arena->addTypePack({operandType});
unify(ftv->argTypes, argsPack, constraint->scope);
TypeId result = singletonTypes->errorRecoveryType();
if (ftv)
{
result = first(ftv->retTypes).value_or(singletonTypes->errorRecoveryType());
}
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(result);
}
else
{
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(singletonTypes->errorRecoveryType());
}
return true;
} }
} }
LUAU_ASSERT(false); // TODO metatable handling LUAU_ASSERT(false);
return false; return false;
} }
@ -862,6 +896,10 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull<const Constr
ttv->name = c.name; ttv->name = c.name;
else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(target)) else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(target))
mtv->syntheticName = c.name; mtv->syntheticName = c.name;
else if (get<IntersectionTypeVar>(target) || get<UnionTypeVar>(target))
{
// nothing (yet)
}
else else
return block(c.namedType, constraint); return block(c.namedType, constraint);

View File

@ -7,6 +7,8 @@
#include <stdexcept> #include <stdexcept>
LUAU_FASTFLAGVARIABLE(LuauIceExceptionInheritanceChange, false)
static std::string wrongNumberOfArgsString( static std::string wrongNumberOfArgsString(
size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
{ {
@ -460,6 +462,11 @@ struct ErrorConverter
{ {
return "Code is too complex to typecheck! Consider simplifying the code around this area"; return "Code is too complex to typecheck! Consider simplifying the code around this area";
} }
std::string operator()(const TypePackMismatch& e) const
{
return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'";
}
}; };
struct InvalidNameChecker struct InvalidNameChecker
@ -718,6 +725,11 @@ bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const
return left == rhs.left && right == rhs.right; return left == rhs.left && right == rhs.right;
} }
bool TypePackMismatch::operator==(const TypePackMismatch& rhs) const
{
return *wantedTp == *rhs.wantedTp && *givenTp == *rhs.givenTp;
}
std::string toString(const TypeError& error) std::string toString(const TypeError& error)
{ {
return toString(error, TypeErrorToStringOptions{}); return toString(error, TypeErrorToStringOptions{});
@ -869,6 +881,11 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState)
else if constexpr (std::is_same_v<T, NormalizationTooComplex>) else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
{ {
} }
else if constexpr (std::is_same_v<T, TypePackMismatch>)
{
e.wantedTp = clone(e.wantedTp);
e.givenTp = clone(e.givenTp);
}
else else
static_assert(always_false_v<T>, "Non-exhaustive type switch"); static_assert(always_false_v<T>, "Non-exhaustive type switch");
} }
@ -913,4 +930,30 @@ const char* InternalCompilerError::what() const throw()
return this->message.data(); return this->message.data();
} }
// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted.
void throwRuntimeError(const std::string& message)
{
if (FFlag::LuauIceExceptionInheritanceChange)
{
throw InternalCompilerError(message);
}
else
{
throw std::runtime_error(message);
}
}
// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted.
void throwRuntimeError(const std::string& message, const std::string& moduleName)
{
if (FFlag::LuauIceExceptionInheritanceChange)
{
throw InternalCompilerError(message, moduleName);
}
else
{
throw std::runtime_error(message);
}
}
} // namespace Luau } // namespace Luau

View File

@ -31,6 +31,7 @@ LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100)
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauLogSolverToJson);
LUAU_FASTFLAGVARIABLE(LuauFixMarkDirtyReverseDeps, false) LUAU_FASTFLAGVARIABLE(LuauFixMarkDirtyReverseDeps, false)
LUAU_FASTFLAGVARIABLE(LuauPersistTypesAfterGeneratingDocSyms, false)
namespace Luau namespace Luau
{ {
@ -111,24 +112,57 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c
CloneState cloneState; CloneState cloneState;
for (const auto& [name, ty] : checkedModule->declaredGlobals) if (FFlag::LuauPersistTypesAfterGeneratingDocSyms)
{ {
TypeId globalTy = clone(ty, globalTypes, cloneState); std::vector<TypeId> typesToPersist;
std::string documentationSymbol = packageName + "/global/" + name; typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size());
generateDocumentationSymbols(globalTy, documentationSymbol);
globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
persist(globalTy); for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
typesToPersist.push_back(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
globalScope->exportedTypeBindings[name] = globalTy;
typesToPersist.push_back(globalTy.type);
}
for (TypeId ty : typesToPersist)
{
persist(ty);
}
} }
else
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{ {
TypeFun globalTy = clone(ty, globalTypes, cloneState); for (const auto& [name, ty] : checkedModule->declaredGlobals)
std::string documentationSymbol = packageName + "/globaltype/" + name; {
generateDocumentationSymbols(globalTy.type, documentationSymbol); TypeId globalTy = clone(ty, globalTypes, cloneState);
globalScope->exportedTypeBindings[name] = globalTy; std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
persist(globalTy.type); persist(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
globalScope->exportedTypeBindings[name] = globalTy;
persist(globalTy.type);
}
} }
return LoadDefinitionFileResult{true, parseResult, checkedModule}; return LoadDefinitionFileResult{true, parseResult, checkedModule};
@ -160,24 +194,57 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
CloneState cloneState; CloneState cloneState;
for (const auto& [name, ty] : checkedModule->declaredGlobals) if (FFlag::LuauPersistTypesAfterGeneratingDocSyms)
{ {
TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::vector<TypeId> typesToPersist;
std::string documentationSymbol = packageName + "/global/" + name; typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size());
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
persist(globalTy); for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
typesToPersist.push_back(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;
typesToPersist.push_back(globalTy.type);
}
for (TypeId ty : typesToPersist)
{
persist(ty);
}
} }
else
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{ {
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); for (const auto& [name, ty] : checkedModule->declaredGlobals)
std::string documentationSymbol = packageName + "/globaltype/" + name; {
generateDocumentationSymbols(globalTy.type, documentationSymbol); TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState);
targetScope->exportedTypeBindings[name] = globalTy; std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
persist(globalTy.type); persist(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;
persist(globalTy.type);
}
} }
return LoadDefinitionFileResult{true, parseResult, checkedModule}; return LoadDefinitionFileResult{true, parseResult, checkedModule};
@ -426,13 +493,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
{ {
auto it2 = moduleResolverForAutocomplete.modules.find(name); auto it2 = moduleResolverForAutocomplete.modules.find(name);
if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr)
throw std::runtime_error("Frontend::modules does not have data for " + name); throwRuntimeError("Frontend::modules does not have data for " + name, name);
} }
else else
{ {
auto it2 = moduleResolver.modules.find(name); auto it2 = moduleResolver.modules.find(name);
if (it2 == moduleResolver.modules.end() || it2->second == nullptr) if (it2 == moduleResolver.modules.end() || it2->second == nullptr)
throw std::runtime_error("Frontend::modules does not have data for " + name); throwRuntimeError("Frontend::modules does not have data for " + name, name);
} }
return CheckResult{ return CheckResult{
@ -539,7 +606,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
stats.filesNonstrict += mode == Mode::Nonstrict; stats.filesNonstrict += mode == Mode::Nonstrict;
if (module == nullptr) if (module == nullptr)
throw std::runtime_error("Frontend::check produced a nullptr module for " + moduleName); throwRuntimeError("Frontend::check produced a nullptr module for " + moduleName, moduleName);
if (!frontendOptions.retainFullTypeGraphs) if (!frontendOptions.retainFullTypeGraphs)
{ {
@ -1007,11 +1074,11 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const
double timestamp = getTimestamp(); double timestamp = getTimestamp();
auto parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); Luau::ParseResult parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions);
stats.timeParse += getTimestamp() - timestamp; stats.timeParse += getTimestamp() - timestamp;
stats.files++; stats.files++;
stats.lines += std::count(src.begin(), src.end(), '\n') + (src.size() && src.back() != '\n'); stats.lines += parseResult.lines;
if (!parseResult.errors.empty()) if (!parseResult.errors.empty())
sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end()); sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end());

View File

@ -188,6 +188,8 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }";
else if constexpr (std::is_same_v<T, NormalizationTooComplex>) else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
stream << "NormalizationTooComplex { }"; stream << "NormalizationTooComplex { }";
else if constexpr (std::is_same_v<T, TypePackMismatch>)
stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }";
else else
static_assert(always_false_v<T>, "Non-exhaustive type switch"); static_assert(always_false_v<T>, "Non-exhaustive type switch");
} }

View File

@ -7,6 +7,7 @@
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/TypeVar.h"
#include "Luau/Unifier.h" #include "Luau/Unifier.h"
#include "Luau/VisitTypeVar.h" #include "Luau/VisitTypeVar.h"
@ -18,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false);
LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false);
LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false);
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf);
@ -107,12 +109,110 @@ bool TypeIds::operator==(const TypeIds& there) const
return hash == there.hash && types == there.types; return hash == there.hash && types == there.types;
} }
NormalizedStringType::NormalizedStringType(bool isCofinite, std::optional<std::map<std::string, TypeId>> singletons)
: isCofinite(isCofinite)
, singletons(std::move(singletons))
{
if (!FFlag::LuauNegatedStringSingletons)
LUAU_ASSERT(!isCofinite);
}
void NormalizedStringType::resetToString()
{
if (FFlag::LuauNegatedStringSingletons)
{
isCofinite = true;
singletons->clear();
}
else
singletons.reset();
}
void NormalizedStringType::resetToNever()
{
if (FFlag::LuauNegatedStringSingletons)
{
isCofinite = false;
singletons.emplace();
}
else
{
if (singletons)
singletons->clear();
else
singletons.emplace();
}
}
bool NormalizedStringType::isNever() const
{
if (FFlag::LuauNegatedStringSingletons)
return !isCofinite && singletons->empty();
else
return singletons && singletons->empty();
}
bool NormalizedStringType::isString() const
{
if (FFlag::LuauNegatedStringSingletons)
return isCofinite && singletons->empty();
else
return !singletons;
}
bool NormalizedStringType::isUnion() const
{
if (FFlag::LuauNegatedStringSingletons)
return !isCofinite;
else
return singletons.has_value();
}
bool NormalizedStringType::isIntersection() const
{
if (FFlag::LuauNegatedStringSingletons)
return isCofinite;
else
return false;
}
bool NormalizedStringType::includes(const std::string& str) const
{
if (isString())
return true;
else if (isUnion() && singletons->count(str))
return true;
else if (isIntersection() && !singletons->count(str))
return true;
else
return false;
}
const NormalizedStringType NormalizedStringType::never{false, {{}}};
bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr)
{
if (subStr.isUnion() && superStr.isUnion())
{
for (auto [name, ty] : *subStr.singletons)
{
if (!superStr.singletons->count(name))
return false;
}
}
else if (subStr.isString() && superStr.isUnion())
return false;
return true;
}
NormalizedType::NormalizedType(NotNull<SingletonTypes> singletonTypes) NormalizedType::NormalizedType(NotNull<SingletonTypes> singletonTypes)
: tops(singletonTypes->neverType) : tops(singletonTypes->neverType)
, booleans(singletonTypes->neverType) , booleans(singletonTypes->neverType)
, errors(singletonTypes->neverType) , errors(singletonTypes->neverType)
, nils(singletonTypes->neverType) , nils(singletonTypes->neverType)
, numbers(singletonTypes->neverType) , numbers(singletonTypes->neverType)
, strings{NormalizedStringType::never}
, threads(singletonTypes->neverType) , threads(singletonTypes->neverType)
{ {
} }
@ -120,7 +220,7 @@ NormalizedType::NormalizedType(NotNull<SingletonTypes> singletonTypes)
static bool isInhabited(const NormalizedType& norm) static bool isInhabited(const NormalizedType& norm)
{ {
return !get<NeverTypeVar>(norm.tops) || !get<NeverTypeVar>(norm.booleans) || !norm.classes.empty() || !get<NeverTypeVar>(norm.errors) || return !get<NeverTypeVar>(norm.tops) || !get<NeverTypeVar>(norm.booleans) || !norm.classes.empty() || !get<NeverTypeVar>(norm.errors) ||
!get<NeverTypeVar>(norm.nils) || !get<NeverTypeVar>(norm.numbers) || !norm.strings || !norm.strings->empty() || !get<NeverTypeVar>(norm.nils) || !get<NeverTypeVar>(norm.numbers) || !norm.strings.isNever() ||
!get<NeverTypeVar>(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty(); !get<NeverTypeVar>(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty();
} }
@ -183,10 +283,10 @@ static bool isNormalizedNumber(TypeId ty)
static bool isNormalizedString(const NormalizedStringType& ty) static bool isNormalizedString(const NormalizedStringType& ty)
{ {
if (!ty) if (ty.isString())
return true; return true;
for (auto& [str, ty] : *ty) for (auto& [str, ty] : *ty.singletons)
{ {
if (const SingletonTypeVar* stv = get<SingletonTypeVar>(ty)) if (const SingletonTypeVar* stv = get<SingletonTypeVar>(ty))
{ {
@ -317,10 +417,7 @@ void Normalizer::clearNormal(NormalizedType& norm)
norm.errors = singletonTypes->neverType; norm.errors = singletonTypes->neverType;
norm.nils = singletonTypes->neverType; norm.nils = singletonTypes->neverType;
norm.numbers = singletonTypes->neverType; norm.numbers = singletonTypes->neverType;
if (norm.strings) norm.strings.resetToNever();
norm.strings->clear();
else
norm.strings.emplace();
norm.threads = singletonTypes->neverType; norm.threads = singletonTypes->neverType;
norm.tables.clear(); norm.tables.clear();
norm.functions = std::nullopt; norm.functions = std::nullopt;
@ -495,10 +592,56 @@ void Normalizer::unionClasses(TypeIds& heres, const TypeIds& theres)
void Normalizer::unionStrings(NormalizedStringType& here, const NormalizedStringType& there) void Normalizer::unionStrings(NormalizedStringType& here, const NormalizedStringType& there)
{ {
if (!there) if (FFlag::LuauNegatedStringSingletons)
here.reset(); {
else if (here) if (there.isString())
here->insert(there->begin(), there->end()); here.resetToString();
else if (here.isUnion() && there.isUnion())
here.singletons->insert(there.singletons->begin(), there.singletons->end());
else if (here.isUnion() && there.isIntersection())
{
here.isCofinite = true;
for (const auto& pair : *there.singletons)
{
auto it = here.singletons->find(pair.first);
if (it != end(*here.singletons))
here.singletons->erase(it);
else
here.singletons->insert(pair);
}
}
else if (here.isIntersection() && there.isUnion())
{
for (const auto& [name, ty] : *there.singletons)
here.singletons->erase(name);
}
else if (here.isIntersection() && there.isIntersection())
{
auto iter = begin(*here.singletons);
auto endIter = end(*here.singletons);
while (iter != endIter)
{
if (!there.singletons->count(iter->first))
{
auto eraseIt = iter;
++iter;
here.singletons->erase(eraseIt);
}
else
++iter;
}
}
else
LUAU_ASSERT(!"Unreachable");
}
else
{
if (there.isString())
here.resetToString();
else if (here.isUnion())
here.singletons->insert(there.singletons->begin(), there.singletons->end());
}
} }
std::optional<TypePackId> Normalizer::unionOfTypePacks(TypePackId here, TypePackId there) std::optional<TypePackId> Normalizer::unionOfTypePacks(TypePackId here, TypePackId there)
@ -858,7 +1001,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
else if (ptv->type == PrimitiveTypeVar::Number) else if (ptv->type == PrimitiveTypeVar::Number)
here.numbers = there; here.numbers = there;
else if (ptv->type == PrimitiveTypeVar::String) else if (ptv->type == PrimitiveTypeVar::String)
here.strings = std::nullopt; here.strings.resetToString();
else if (ptv->type == PrimitiveTypeVar::Thread) else if (ptv->type == PrimitiveTypeVar::Thread)
here.threads = there; here.threads = there;
else else
@ -870,12 +1013,33 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
here.booleans = unionOfBools(here.booleans, there); here.booleans = unionOfBools(here.booleans, there);
else if (const StringSingleton* sstv = get<StringSingleton>(stv)) else if (const StringSingleton* sstv = get<StringSingleton>(stv))
{ {
if (here.strings) if (FFlag::LuauNegatedStringSingletons)
here.strings->insert({sstv->value, there}); {
if (here.strings.isCofinite)
{
auto it = here.strings.singletons->find(sstv->value);
if (it != here.strings.singletons->end())
here.strings.singletons->erase(it);
}
else
here.strings.singletons->insert({sstv->value, there});
}
else
{
if (here.strings.isUnion())
here.strings.singletons->insert({sstv->value, there});
}
} }
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
} }
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(there))
{
const NormalizedType* thereNormal = normalize(ntv->ty);
NormalizedType tn = negateNormal(*thereNormal);
if (!unionNormals(here, tn))
return false;
}
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
@ -887,6 +1051,159 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
return true; return true;
} }
// ------- Negations
NormalizedType Normalizer::negateNormal(const NormalizedType& here)
{
NormalizedType result{singletonTypes};
if (!get<NeverTypeVar>(here.tops))
{
// The negation of unknown or any is never. Easy.
return result;
}
if (!get<NeverTypeVar>(here.errors))
{
// Negating an error yields the same error.
result.errors = here.errors;
return result;
}
if (get<NeverTypeVar>(here.booleans))
result.booleans = singletonTypes->booleanType;
else if (get<PrimitiveTypeVar>(here.booleans))
result.booleans = singletonTypes->neverType;
else if (auto stv = get<SingletonTypeVar>(here.booleans))
{
auto boolean = get<BooleanSingleton>(stv);
LUAU_ASSERT(boolean != nullptr);
if (boolean->value)
result.booleans = singletonTypes->falseType;
else
result.booleans = singletonTypes->trueType;
}
result.classes = negateAll(here.classes);
result.nils = get<NeverTypeVar>(here.nils) ? singletonTypes->nilType : singletonTypes->neverType;
result.numbers = get<NeverTypeVar>(here.numbers) ? singletonTypes->numberType : singletonTypes->neverType;
result.strings = here.strings;
result.strings.isCofinite = !result.strings.isCofinite;
result.threads = get<NeverTypeVar>(here.threads) ? singletonTypes->threadType : singletonTypes->neverType;
// TODO: negating tables
// TODO: negating functions
// TODO: negating tyvars?
return result;
}
TypeIds Normalizer::negateAll(const TypeIds& theres)
{
TypeIds tys;
for (TypeId there : theres)
tys.insert(negate(there));
return tys;
}
TypeId Normalizer::negate(TypeId there)
{
there = follow(there);
if (get<AnyTypeVar>(there))
return there;
else if (get<UnknownTypeVar>(there))
return singletonTypes->neverType;
else if (get<NeverTypeVar>(there))
return singletonTypes->unknownType;
else if (auto ntv = get<NegationTypeVar>(there))
return ntv->ty; // TODO: do we want to normalize this?
else if (auto utv = get<UnionTypeVar>(there))
{
std::vector<TypeId> parts;
for (TypeId option : utv)
parts.push_back(negate(option));
return arena->addType(IntersectionTypeVar{std::move(parts)});
}
else if (auto itv = get<IntersectionTypeVar>(there))
{
std::vector<TypeId> options;
for (TypeId part : itv)
options.push_back(negate(part));
return arena->addType(UnionTypeVar{std::move(options)});
}
else
return there;
}
void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty)
{
const PrimitiveTypeVar* ptv = get<PrimitiveTypeVar>(follow(ty));
LUAU_ASSERT(ptv);
switch (ptv->type)
{
case PrimitiveTypeVar::NilType:
here.nils = singletonTypes->neverType;
break;
case PrimitiveTypeVar::Boolean:
here.booleans = singletonTypes->neverType;
break;
case PrimitiveTypeVar::Number:
here.numbers = singletonTypes->neverType;
break;
case PrimitiveTypeVar::String:
here.strings.resetToNever();
break;
case PrimitiveTypeVar::Thread:
here.threads = singletonTypes->neverType;
break;
}
}
void Normalizer::subtractSingleton(NormalizedType& here, TypeId ty)
{
LUAU_ASSERT(FFlag::LuauNegatedStringSingletons);
const SingletonTypeVar* stv = get<SingletonTypeVar>(ty);
LUAU_ASSERT(stv);
if (const StringSingleton* ss = get<StringSingleton>(stv))
{
if (here.strings.isCofinite)
here.strings.singletons->insert({ss->value, ty});
else
{
auto it = here.strings.singletons->find(ss->value);
if (it != here.strings.singletons->end())
here.strings.singletons->erase(it);
}
}
else if (const BooleanSingleton* bs = get<BooleanSingleton>(stv))
{
if (get<NeverTypeVar>(here.booleans))
{
// Nothing
}
else if (get<PrimitiveTypeVar>(here.booleans))
here.booleans = bs->value ? singletonTypes->falseType : singletonTypes->trueType;
else if (auto hereSingleton = get<SingletonTypeVar>(here.booleans))
{
const BooleanSingleton* hereBooleanSingleton = get<BooleanSingleton>(hereSingleton);
LUAU_ASSERT(hereBooleanSingleton);
// Crucial subtlety: ty (and thus bs) are the value that is being
// negated out. We therefore reduce to never when the values match,
// rather than when they differ.
if (bs->value == hereBooleanSingleton->value)
here.booleans = singletonTypes->neverType;
}
else
LUAU_ASSERT(!"Unreachable");
}
else
LUAU_ASSERT(!"Unreachable");
}
// ------- Normalizing intersections // ------- Normalizing intersections
TypeId Normalizer::intersectionOfTops(TypeId here, TypeId there) TypeId Normalizer::intersectionOfTops(TypeId here, TypeId there)
{ {
@ -971,17 +1288,17 @@ void Normalizer::intersectClassesWithClass(TypeIds& heres, TypeId there)
void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there)
{ {
if (!there) if (there.isString())
return; return;
if (!here) if (here.isString())
here.emplace(); here.resetToNever();
for (auto it = here->begin(); it != here->end();) for (auto it = here.singletons->begin(); it != here.singletons->end();)
{ {
if (there->count(it->first)) if (there.singletons->count(it->first))
it++; it++;
else else
it = here->erase(it); it = here.singletons->erase(it);
} }
} }
@ -1646,12 +1963,35 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
here.booleans = intersectionOfBools(booleans, there); here.booleans = intersectionOfBools(booleans, there);
else if (const StringSingleton* sstv = get<StringSingleton>(stv)) else if (const StringSingleton* sstv = get<StringSingleton>(stv))
{ {
if (!strings || strings->count(sstv->value)) if (strings.includes(sstv->value))
here.strings->insert({sstv->value, there}); here.strings.singletons->insert({sstv->value, there});
} }
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
} }
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(there); FFlag::LuauNegatedStringSingletons && ntv)
{
TypeId t = follow(ntv->ty);
if (const PrimitiveTypeVar* ptv = get<PrimitiveTypeVar>(t))
subtractPrimitive(here, ntv->ty);
else if (const SingletonTypeVar* stv = get<SingletonTypeVar>(t))
subtractSingleton(here, follow(ntv->ty));
else if (const UnionTypeVar* itv = get<UnionTypeVar>(t))
{
for (TypeId part : itv->options)
{
const NormalizedType* normalPart = normalize(part);
NormalizedType negated = negateNormal(*normalPart);
intersectNormals(here, negated);
}
}
else
{
// TODO negated unions, intersections, table, and function.
// Report a TypeError for other types.
LUAU_ASSERT(!"Unimplemented");
}
}
else else
LUAU_ASSERT(!"Unreachable"); LUAU_ASSERT(!"Unreachable");
@ -1691,11 +2031,25 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
result.push_back(norm.nils); result.push_back(norm.nils);
if (!get<NeverTypeVar>(norm.numbers)) if (!get<NeverTypeVar>(norm.numbers))
result.push_back(norm.numbers); result.push_back(norm.numbers);
if (norm.strings) if (norm.strings.isString())
for (auto& [_, ty] : *norm.strings)
result.push_back(ty);
else
result.push_back(singletonTypes->stringType); result.push_back(singletonTypes->stringType);
else if (norm.strings.isUnion())
{
for (auto& [_, ty] : *norm.strings.singletons)
result.push_back(ty);
}
else if (FFlag::LuauNegatedStringSingletons && norm.strings.isIntersection())
{
std::vector<TypeId> parts;
parts.push_back(singletonTypes->stringType);
for (const auto& [name, ty] : *norm.strings.singletons)
parts.push_back(arena->addType(NegationTypeVar{ty}));
result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)}));
}
if (!get<NeverTypeVar>(norm.threads))
result.push_back(singletonTypes->threadType);
result.insert(result.end(), norm.tables.begin(), norm.tables.end()); result.insert(result.end(), norm.tables.begin(), norm.tables.end());
for (auto& [tyvar, intersect] : norm.tyvars) for (auto& [tyvar, intersect] : norm.tyvars)
{ {

View File

@ -11,6 +11,7 @@
#include <stdexcept> #include <stdexcept>
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(LuauLvaluelessPath)
LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false)
LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false)
LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false)
@ -435,7 +436,7 @@ struct TypeVarStringifier
return; return;
default: default:
LUAU_ASSERT(!"Unknown primitive type"); LUAU_ASSERT(!"Unknown primitive type");
throw std::runtime_error("Unknown primitive type " + std::to_string(ptv.type)); throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type));
} }
} }
@ -452,7 +453,7 @@ struct TypeVarStringifier
else else
{ {
LUAU_ASSERT(!"Unknown singleton type"); LUAU_ASSERT(!"Unknown singleton type");
throw std::runtime_error("Unknown singleton type"); throwRuntimeError("Unknown singleton type");
} }
} }
@ -1548,6 +1549,8 @@ std::string dump(const Constraint& c)
std::string toString(const LValue& lvalue) std::string toString(const LValue& lvalue)
{ {
LUAU_ASSERT(!FFlag::LuauLvaluelessPath);
std::string s; std::string s;
for (const LValue* current = &lvalue; current; current = baseof(*current)) for (const LValue* current = &lvalue; current; current = baseof(*current))
{ {
@ -1562,4 +1565,37 @@ std::string toString(const LValue& lvalue)
return s; return s;
} }
std::optional<std::string> getFunctionNameAsString(const AstExpr& expr)
{
LUAU_ASSERT(FFlag::LuauLvaluelessPath);
const AstExpr* curr = &expr;
std::string s;
for (;;)
{
if (auto local = curr->as<AstExprLocal>())
return local->local->name.value + s;
if (auto global = curr->as<AstExprGlobal>())
return global->name.value + s;
if (auto indexname = curr->as<AstExprIndexName>())
{
curr = indexname->expr;
s = "." + std::string(indexname->index.value) + s;
}
else if (auto group = curr->as<AstExprGroup>())
{
curr = group->expr;
}
else
{
return std::nullopt;
}
}
return s;
}
} // namespace Luau } // namespace Luau

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TopoSortStatements.h" #include "Luau/TopoSortStatements.h"
#include "Luau/Error.h"
/* Decide the order in which we typecheck Lua statements in a block. /* Decide the order in which we typecheck Lua statements in a block.
* *
* Algorithm: * Algorithm:
@ -149,7 +150,7 @@ Identifier mkName(const AstStatFunction& function)
auto name = mkName(*function.name); auto name = mkName(*function.name);
LUAU_ASSERT(bool(name)); LUAU_ASSERT(bool(name));
if (!name) if (!name)
throw std::runtime_error("Internal error: Function declaration has a bad name"); throwRuntimeError("Internal error: Function declaration has a bad name");
return *name; return *name;
} }
@ -255,7 +256,7 @@ struct ArcCollector : public AstVisitor
{ {
auto name = mkName(*node->name); auto name = mkName(*node->name);
if (!name) if (!name)
throw std::runtime_error("Internal error: AstStatFunction has a bad name"); throwRuntimeError("Internal error: AstStatFunction has a bad name");
add(*name); add(*name);
return true; return true;

View File

@ -347,7 +347,7 @@ public:
AstType* operator()(const NegationTypeVar& ntv) AstType* operator()(const NegationTypeVar& ntv)
{ {
// FIXME: do the same thing we do with ErrorTypeVar // FIXME: do the same thing we do with ErrorTypeVar
throw std::runtime_error("Cannot convert NegationTypeVar into AstNode"); throwRuntimeError("Cannot convert NegationTypeVar into AstNode");
} }
private: private:

View File

@ -934,8 +934,62 @@ struct TypeChecker2
void visit(AstExprUnary* expr) void visit(AstExprUnary* expr)
{ {
// TODO!
visit(expr->expr); visit(expr->expr);
NotNull<Scope> scope = stack.back();
TypeId operandType = lookupType(expr->expr);
if (get<AnyTypeVar>(operandType) || get<ErrorTypeVar>(operandType) || get<NeverTypeVar>(operandType))
return;
if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end())
{
std::optional<TypeId> mm = findMetatableEntry(singletonTypes, module->errors, operandType, it->second, expr->location);
if (mm)
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*mm)))
{
TypePackId expectedArgs = module->internalTypes.addTypePack({operandType});
reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs));
if (std::optional<TypeId> ret = first(ftv->retTypes))
{
if (expr->op == AstExprUnary::Op::Len)
{
reportErrors(tryUnify(scope, expr->location, follow(*ret), singletonTypes->numberType));
}
}
else
{
reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location);
}
}
return;
}
}
if (expr->op == AstExprUnary::Op::Len)
{
DenseHashSet<TypeId> seen{nullptr};
int recursionCount = 0;
if (!hasLength(operandType, seen, &recursionCount))
{
reportError(NotATable{operandType}, expr->location);
}
}
else if (expr->op == AstExprUnary::Op::Minus)
{
reportErrors(tryUnify(scope, expr->location, operandType, singletonTypes->numberType));
}
else if (expr->op == AstExprUnary::Op::Not)
{
}
else
{
LUAU_ASSERT(!"Unhandled unary operator");
}
} }
void visit(AstExprBinary* expr) void visit(AstExprBinary* expr)
@ -1240,9 +1294,8 @@ struct TypeChecker2
Scope* scope = findInnermostScope(ty->location); Scope* scope = findInnermostScope(ty->location);
LUAU_ASSERT(scope); LUAU_ASSERT(scope);
// TODO: Imported types std::optional<TypeFun> alias =
(ty->prefix) ? scope->lookupImportedType(ty->prefix->value, ty->name.value) : scope->lookupType(ty->name.value);
std::optional<TypeFun> alias = scope->lookupType(ty->name.value);
if (alias.has_value()) if (alias.has_value())
{ {

View File

@ -36,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false.
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false)
LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false)
LUAU_FASTFLAGVARIABLE(LuauLvaluelessPath, false)
LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false)
LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false)
LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false)
@ -43,15 +44,15 @@ LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false)
LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false)
LUAU_FASTFLAGVARIABLE(LuauUnionOfTypesFollow, false)
LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false)
LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false)
LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false)
namespace Luau namespace Luau
{ {
const char* TimeLimitError_DEPRECATED::what() const throw()
const char* TimeLimitError::what() const throw()
{ {
LUAU_ASSERT(!FFlag::LuauIceExceptionInheritanceChange);
return "Typeinfer failed to complete in allotted time"; return "Typeinfer failed to complete in allotted time";
} }
@ -264,6 +265,11 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona
reportErrorCodeTooComplex(module.root->location); reportErrorCodeTooComplex(module.root->location);
return std::move(currentModule); return std::move(currentModule);
} }
catch (const RecursionLimitException_DEPRECATED&)
{
reportErrorCodeTooComplex(module.root->location);
return std::move(currentModule);
}
} }
ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope) ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope)
@ -308,6 +314,10 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
{ {
currentModule->timeout = true; currentModule->timeout = true;
} }
catch (const TimeLimitError_DEPRECATED&)
{
currentModule->timeout = true;
}
if (FFlag::DebugLuauSharedSelf) if (FFlag::DebugLuauSharedSelf)
{ {
@ -415,7 +425,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program)
ice("Unknown AstStat"); ice("Unknown AstStat");
if (finishTime && TimeTrace::getClock() > *finishTime) if (finishTime && TimeTrace::getClock() > *finishTime)
throw TimeLimitError(); throwTimeLimitError();
} }
// This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly.
@ -442,6 +452,11 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
reportErrorCodeTooComplex(block.location); reportErrorCodeTooComplex(block.location);
return; return;
} }
catch (const RecursionLimitException_DEPRECATED&)
{
reportErrorCodeTooComplex(block.location);
return;
}
} }
struct InplaceDemoter : TypeVarOnceVisitor struct InplaceDemoter : TypeVarOnceVisitor
@ -2456,11 +2471,8 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op)
TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes) TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes)
{ {
if (FFlag::LuauUnionOfTypesFollow) a = follow(a);
{ b = follow(b);
a = follow(a);
b = follow(b);
}
if (unifyFreeTypes && (get<FreeTypeVar>(a) || get<FreeTypeVar>(b))) if (unifyFreeTypes && (get<FreeTypeVar>(a) || get<FreeTypeVar>(b)))
{ {
@ -3596,8 +3608,17 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
location = {state.location.begin, argLocations.back().end}; location = {state.location.begin, argLocations.back().end};
std::string namePath; std::string namePath;
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue); if (FFlag::LuauLvaluelessPath)
{
if (std::optional<std::string> path = getFunctionNameAsString(funName))
namePath = *path;
}
else
{
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue);
}
auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack);
state.reportError(TypeError{location, state.reportError(TypeError{location,
@ -3706,11 +3727,28 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
bool isVariadic = tail && Luau::isVariadic(*tail); bool isVariadic = tail && Luau::isVariadic(*tail);
std::string namePath; std::string namePath;
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue);
state.reportError(TypeError{ if (FFlag::LuauLvaluelessPath)
state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); {
if (std::optional<std::string> path = getFunctionNameAsString(funName))
namePath = *path;
}
else
{
if (std::optional<LValue> lValue = tryGetLValue(funName))
namePath = toString(*lValue);
}
if (FFlag::LuauArgMismatchReportFunctionLocation)
{
state.reportError(TypeError{
funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}});
}
else
{
state.reportError(TypeError{
state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}});
}
return; return;
} }
++paramIter; ++paramIter;
@ -4647,6 +4685,19 @@ void TypeChecker::ice(const std::string& message)
iceHandler->ice(message); iceHandler->ice(message);
} }
// TODO: Inline me when LuauIceExceptionInheritanceChange is deleted.
void TypeChecker::throwTimeLimitError()
{
if (FFlag::LuauIceExceptionInheritanceChange)
{
throw TimeLimitError(iceHandler->moduleName);
}
else
{
throw TimeLimitError_DEPRECATED();
}
}
void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec)
{ {
// Remove errors with names that were generated by recovery from a parse error // Remove errors with names that were generated by recovery from a parse error

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/Error.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include <stdexcept> #include <stdexcept>
@ -234,7 +235,7 @@ TypePackId follow(TypePackId tp, std::function<TypePackId(TypePackId)> mapper)
cycleTester = nullptr; cycleTester = nullptr;
if (tp == cycleTester) if (tp == cycleTester)
throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); throwRuntimeError("Luau::follow detected a TypeVar cycle!!");
} }
} }
} }

View File

@ -61,7 +61,7 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
{ {
std::optional<TypeId> ty = utv->scope->lookup(utv->def); std::optional<TypeId> ty = utv->scope->lookup(utv->def);
if (!ty) if (!ty)
throw std::runtime_error("UseTypeVar must map to another TypeId"); throwRuntimeError("UseTypeVar must map to another TypeId");
return *ty; return *ty;
} }
else else
@ -73,7 +73,7 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
{ {
TypeId res = ltv->thunk(); TypeId res = ltv->thunk();
if (get<LazyTypeVar>(res)) if (get<LazyTypeVar>(res))
throw std::runtime_error("Lazy TypeVar cannot resolve to another Lazy TypeVar"); throwRuntimeError("Lazy TypeVar cannot resolve to another Lazy TypeVar");
*asMutable(ty) = BoundTypeVar(res); *asMutable(ty) = BoundTypeVar(res);
} }
@ -111,7 +111,7 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
cycleTester = nullptr; cycleTester = nullptr;
if (t == cycleTester) if (t == cycleTester)
throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); throwRuntimeError("Luau::follow detected a TypeVar cycle!!");
} }
} }
} }
@ -946,7 +946,7 @@ void persist(TypeId ty)
queue.push_back(mtv->table); queue.push_back(mtv->table);
queue.push_back(mtv->metatable); queue.push_back(mtv->metatable);
} }
else if (get<GenericTypeVar>(t) || get<AnyTypeVar>(t) || get<FreeTypeVar>(t) || get<SingletonTypeVar>(t) || get<PrimitiveTypeVar>(t)) else if (get<GenericTypeVar>(t) || get<AnyTypeVar>(t) || get<FreeTypeVar>(t) || get<SingletonTypeVar>(t) || get<PrimitiveTypeVar>(t) || get<NegationTypeVar>(t))
{ {
} }
else else

View File

@ -16,6 +16,7 @@
LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit);
LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauReportTypeMismatchForTypePackUnificationFailure, false)
LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false);
LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false)
@ -273,7 +274,7 @@ TypeId Widen::clean(TypeId ty)
TypePackId Widen::clean(TypePackId) TypePackId Widen::clean(TypePackId)
{ {
throw std::runtime_error("Widen attempted to clean a dirty type pack?"); throwRuntimeError("Widen attempted to clean a dirty type pack?");
} }
bool Widen::ignoreChildren(TypeId ty) bool Widen::ignoreChildren(TypeId ty)
@ -551,6 +552,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
else if (log.getMutable<ClassTypeVar>(subTy)) else if (log.getMutable<ClassTypeVar>(subTy))
tryUnifyWithClass(subTy, superTy, /*reversed*/ true); tryUnifyWithClass(subTy, superTy, /*reversed*/ true);
else if (log.get<NegationTypeVar>(superTy))
tryUnifyTypeWithNegation(subTy, superTy);
else if (log.get<NegationTypeVar>(subTy))
tryUnifyNegationWithType(subTy, superTy);
else else
reportError(TypeError{location, TypeMismatch{superTy, subTy}}); reportError(TypeError{location, TypeMismatch{superTy, subTy}});
@ -866,13 +873,7 @@ void Unifier::tryUnifyNormalizedTypes(
if (!get<PrimitiveTypeVar>(superNorm.numbers)) if (!get<PrimitiveTypeVar>(superNorm.numbers))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
if (subNorm.strings && superNorm.strings) if (!isSubtype(subNorm.strings, superNorm.strings))
{
for (auto [name, ty] : *subNorm.strings)
if (!superNorm.strings->count(name))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
}
else if (!subNorm.strings && superNorm.strings)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
if (get<PrimitiveTypeVar>(subNorm.threads)) if (get<PrimitiveTypeVar>(subNorm.threads))
@ -1392,7 +1393,10 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
} }
else else
{ {
reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure)
reportError(TypeError{location, TypePackMismatch{subTp, superTp}});
else
reportError(TypeError{location, GenericError{"Failed to unify type packs"}});
} }
} }
@ -1441,7 +1445,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0); bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0);
if (FFlag::LuauInstantiateInSubtyping && variance == Covariant && shouldInstantiate) // TODO: This is unsound when the context is invariant, but the annotation burden without allowing it and without
// read-only properties is too high for lua-apps. Read-only properties _should_ resolve their issue by allowing
// generic methods in tables to be marked read-only.
if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate)
{ {
Instantiation instantiation{&log, types, scope->level, scope}; Instantiation instantiation{&log, types, scope->level, scope};
@ -1576,6 +1583,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
{ {
TableTypeVar* superTable = log.getMutable<TableTypeVar>(superTy); TableTypeVar* superTable = log.getMutable<TableTypeVar>(superTy);
TableTypeVar* subTable = log.getMutable<TableTypeVar>(subTy); TableTypeVar* subTable = log.getMutable<TableTypeVar>(subTy);
TableTypeVar* instantiatedSubTable = subTable;
if (!superTable || !subTable) if (!superTable || !subTable)
ice("passed non-table types to unifyTables"); ice("passed non-table types to unifyTables");
@ -1593,6 +1601,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (instantiated.has_value()) if (instantiated.has_value())
{ {
subTable = log.getMutable<TableTypeVar>(*instantiated); subTable = log.getMutable<TableTypeVar>(*instantiated);
instantiatedSubTable = subTable;
if (!subTable) if (!subTable)
ice("instantiation made a table type into a non-table type in tryUnifyTables"); ice("instantiation made a table type into a non-table type in tryUnifyTables");
@ -1696,7 +1705,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
// txn log. // txn log.
TableTypeVar* newSuperTable = log.getMutable<TableTypeVar>(superTy); TableTypeVar* newSuperTable = log.getMutable<TableTypeVar>(superTy);
TableTypeVar* newSubTable = log.getMutable<TableTypeVar>(subTy); TableTypeVar* newSubTable = log.getMutable<TableTypeVar>(subTy);
if (superTable != newSuperTable || subTable != newSubTable) if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable))
{ {
if (errors.empty()) if (errors.empty())
return tryUnifyTables(subTy, superTy, isIntersection); return tryUnifyTables(subTy, superTy, isIntersection);
@ -1758,7 +1767,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
// txn log. // txn log.
TableTypeVar* newSuperTable = log.getMutable<TableTypeVar>(superTy); TableTypeVar* newSuperTable = log.getMutable<TableTypeVar>(superTy);
TableTypeVar* newSubTable = log.getMutable<TableTypeVar>(subTy); TableTypeVar* newSubTable = log.getMutable<TableTypeVar>(subTy);
if (superTable != newSuperTable || subTable != newSubTable) if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable))
{ {
if (errors.empty()) if (errors.empty())
return tryUnifyTables(subTy, superTy, isIntersection); return tryUnifyTables(subTy, superTy, isIntersection);
@ -2098,6 +2107,34 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
return fail(); return fail();
} }
void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy)
{
const NegationTypeVar* ntv = get<NegationTypeVar>(superTy);
if (!ntv)
ice("tryUnifyTypeWithNegation superTy must be a negation type");
const NormalizedType* subNorm = normalizer->normalize(subTy);
const NormalizedType* superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
return reportError(TypeError{location, UnificationTooComplex{}});
// T </: ~U iff T <: U
Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
}
void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy)
{
const NegationTypeVar* ntv = get<NegationTypeVar>(subTy);
if (!ntv)
ice("tryUnifyNegationWithType subTy must be a negation type");
// TODO: ~T </: U iff T <: U
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
}
static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack)
{ {
while (true) while (true)

View File

@ -58,6 +58,8 @@ struct Comment
struct ParseResult struct ParseResult
{ {
AstStatBlock* root; AstStatBlock* root;
size_t lines = 0;
std::vector<HotComment> hotcomments; std::vector<HotComment> hotcomments;
std::vector<ParseError> errors; std::vector<ParseError> errors;

View File

@ -302,8 +302,8 @@ private:
AstStatError* reportStatError(const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements, AstStatError* reportStatError(const Location& location, const AstArray<AstExpr*>& expressions, const AstArray<AstStat*>& statements,
const char* format, ...) LUAU_PRINTF_ATTR(5, 6); const char* format, ...) LUAU_PRINTF_ATTR(5, 6);
AstExprError* reportExprError(const Location& location, const AstArray<AstExpr*>& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); AstExprError* reportExprError(const Location& location, const AstArray<AstExpr*>& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5);
AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, bool isMissing, const char* format, ...) AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, const char* format, ...)
LUAU_PRINTF_ATTR(5, 6); LUAU_PRINTF_ATTR(4, 5);
// `parseErrorLocation` is associated with the parser error // `parseErrorLocation` is associated with the parser error
// `astErrorLocation` is associated with the AstTypeError created // `astErrorLocation` is associated with the AstTypeError created
// It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely

View File

@ -23,7 +23,6 @@ LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false)
LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false)
LUAU_FASTFLAGVARIABLE(LuauTypeAnnotationLocationChange, false)
LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false)
@ -164,15 +163,16 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n
try try
{ {
AstStatBlock* root = p.parseChunk(); AstStatBlock* root = p.parseChunk();
size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n');
return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)};
} }
catch (ParseError& err) catch (ParseError& err)
{ {
// when catching a fatal error, append it to the list of non-fatal errors and return // when catching a fatal error, append it to the list of non-fatal errors and return
p.parseErrors.push_back(err); p.parseErrors.push_back(err);
return ParseResult{nullptr, {}, p.parseErrors}; return ParseResult{nullptr, 0, {}, p.parseErrors};
} }
} }
@ -811,9 +811,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr)
{ {
return AstDeclaredClassProp{fnName.name, return AstDeclaredClassProp{
reportTypeAnnotationError(Location(start, end), {}, /*isMissing*/ false, "'self' must be present as the unannotated first parameter"), fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true};
true};
} }
// Skip the first index. // Skip the first index.
@ -824,8 +823,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
if (args[i].annotation) if (args[i].annotation)
vars.push_back(args[i].annotation); vars.push_back(args[i].annotation);
else else
vars.push_back(reportTypeAnnotationError( vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated"));
Location(start, end), {}, /*isMissing*/ false, "All declaration parameters aside from 'self' must be annotated"));
} }
if (vararg && !varargAnnotation) if (vararg && !varargAnnotation)
@ -1537,7 +1535,7 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
if (isUnion && isIntersection) if (isUnion && isIntersection)
{ {
return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), /*isMissing*/ false, return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts),
"Mixing union and intersection types is not allowed; consider wrapping in parentheses."); "Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
} }
@ -1623,18 +1621,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
return {allocator.alloc<AstTypeSingletonString>(start, svalue)}; return {allocator.alloc<AstTypeSingletonString>(start, svalue)};
} }
else else
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")};
} }
else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple)
{ {
parseInterpString(); parseInterpString();
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Interpolated string literals cannot be used as types")}; return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")};
} }
else if (lexer.current().type == Lexeme::BrokenString) else if (lexer.current().type == Lexeme::BrokenString)
{ {
nextLexeme(); nextLexeme();
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, "Malformed string")}; return {reportTypeAnnotationError(start, {}, "Malformed string")};
} }
else if (lexer.current().type == Lexeme::Name) else if (lexer.current().type == Lexeme::Name)
{ {
@ -1693,33 +1691,20 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
{ {
nextLexeme(); nextLexeme();
return {reportTypeAnnotationError(start, {}, /*isMissing*/ false, return {reportTypeAnnotationError(start, {},
"Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> "
"...any'"), "...any'"),
{}}; {}};
} }
else else
{ {
if (FFlag::LuauTypeAnnotationLocationChange) // For a missing type annotation, capture 'space' between last token and the next one
{ Location astErrorlocation(lexer.previousLocation().end, start.begin);
// For a missing type annotation, capture 'space' between last token and the next one // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message).
Location astErrorlocation(lexer.previousLocation().end, start.begin); // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau.
// The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). Location parseErrorLocation(lexer.previousLocation().end, start.end);
// Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. return {
Location parseErrorLocation(lexer.previousLocation().end, start.end); reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}};
return {
reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()),
{}};
}
else
{
Location location = lexer.current().location;
// For a missing type annotation, capture 'space' between last token and the next one
location = Location(lexer.previousLocation().end, lexer.current().location.begin);
return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}};
}
} }
} }
@ -3033,27 +3018,18 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray<A
return allocator.alloc<AstExprError>(location, expressions, unsigned(parseErrors.size() - 1)); return allocator.alloc<AstExprError>(location, expressions, unsigned(parseErrors.size() - 1));
} }
AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, bool isMissing, const char* format, ...) AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, const char* format, ...)
{ {
if (FFlag::LuauTypeAnnotationLocationChange)
{
// Missing type annotations should be using `reportMissingTypeAnnotationError` when LuauTypeAnnotationLocationChange is enabled
// Note: `isMissing` can be removed once FFlag::LuauTypeAnnotationLocationChange is removed since it will always be true.
LUAU_ASSERT(!isMissing);
}
va_list args; va_list args;
va_start(args, format); va_start(args, format);
report(location, format, args); report(location, format, args);
va_end(args); va_end(args);
return allocator.alloc<AstTypeError>(location, types, isMissing, unsigned(parseErrors.size() - 1)); return allocator.alloc<AstTypeError>(location, types, false, unsigned(parseErrors.size() - 1));
} }
AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...)
{ {
LUAU_ASSERT(FFlag::LuauTypeAnnotationLocationChange);
va_list args; va_list args;
va_start(args, format); va_start(args, format);
report(parseErrorLocation, format, args); report(parseErrorLocation, format, args);

View File

@ -16,7 +16,6 @@
#include "isocline.h" #include "isocline.h"
#include <algorithm>
#include <memory> #include <memory>
#ifdef _WIN32 #ifdef _WIN32
@ -688,11 +687,11 @@ static std::string getCodegenAssembly(const char* name, const std::string& bytec
return ""; return "";
} }
static void annotateInstruction(void* context, std::string& text, int fid, int instid) static void annotateInstruction(void* context, std::string& text, int fid, int instpos)
{ {
Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context;
bcb.annotateInstruction(text, fid, instid); bcb.annotateInstruction(text, fid, instpos);
} }
struct CompileStats struct CompileStats
@ -711,7 +710,8 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st
return false; return false;
} }
stats.lines += std::count(source->begin(), source->end(), '\n'); // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example)
// This function is much more complicated because it supports many output human-readable formats through internal interfaces
try try
{ {
@ -736,7 +736,16 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st
bcb.setDumpSource(*source); bcb.setDumpSource(*source);
} }
Luau::compileOrThrow(bcb, *source, copts()); Luau::Allocator allocator;
Luau::AstNameTable names(allocator);
Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator);
if (!result.errors.empty())
throw Luau::ParseErrors(result.errors);
stats.lines += result.lines;
Luau::compileOrThrow(bcb, result, names, copts());
stats.bytecode += bcb.getBytecode().size(); stats.bytecode += bcb.getBytecode().size();
switch (format) switch (format)

View File

@ -143,6 +143,11 @@ if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924)
set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-)
endif() endif()
if (NOT MSVC)
# disable support for math_errno which allows compilers to lower sqrt() into a single CPU instruction
target_compile_options(Luau.VM PRIVATE -fno-math-errno)
endif()
if(MSVC AND LUAU_BUILD_CLI) if(MSVC AND LUAU_BUILD_CLI)
# the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger # the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger
set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152)

View File

@ -17,7 +17,7 @@ void create(lua_State* L);
// Builds target function and all inner functions // Builds target function and all inner functions
void compile(lua_State* L, int idx); void compile(lua_State* L, int idx);
using annotatorFn = void (*)(void* context, std::string& result, int fid, int instid); using annotatorFn = void (*)(void* context, std::string& result, int fid, int instpos);
struct AssemblyOptions struct AssemblyOptions
{ {

View File

@ -34,7 +34,350 @@ namespace CodeGen
constexpr uint32_t kFunctionAlignment = 32; constexpr uint32_t kFunctionAlignment = 32;
static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, Proto* proto, AssemblyOptions options) struct InstructionOutline
{
int pcpos;
int length;
};
static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers)
{
if (build.logText)
build.logAppend("; exitContinueVm\n");
helpers.exitContinueVm = build.setLabel();
emitExit(build, /* continueInVm */ true);
if (build.logText)
build.logAppend("; exitNoContinueVm\n");
helpers.exitNoContinueVm = build.setLabel();
emitExit(build, /* continueInVm */ false);
}
static int emitInst(
AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i, Label* labelarr, Label& fallback)
{
int skip = 0;
switch (op)
{
case LOP_NOP:
break;
case LOP_LOADNIL:
emitInstLoadNil(build, pc);
break;
case LOP_LOADB:
emitInstLoadB(build, pc, i, labelarr);
break;
case LOP_LOADN:
emitInstLoadN(build, pc);
break;
case LOP_LOADK:
emitInstLoadK(build, pc);
break;
case LOP_LOADKX:
emitInstLoadKX(build, pc);
break;
case LOP_MOVE:
emitInstMove(build, pc);
break;
case LOP_GETGLOBAL:
emitInstGetGlobal(build, pc, i, fallback);
break;
case LOP_SETGLOBAL:
emitInstSetGlobal(build, pc, i, labelarr, fallback);
break;
case LOP_RETURN:
emitInstReturn(build, helpers, pc, i, labelarr);
break;
case LOP_GETTABLE:
emitInstGetTable(build, pc, i, fallback);
break;
case LOP_SETTABLE:
emitInstSetTable(build, pc, i, labelarr, fallback);
break;
case LOP_GETTABLEKS:
emitInstGetTableKS(build, pc, i, fallback);
break;
case LOP_SETTABLEKS:
emitInstSetTableKS(build, pc, i, labelarr, fallback);
break;
case LOP_GETTABLEN:
emitInstGetTableN(build, pc, i, fallback);
break;
case LOP_SETTABLEN:
emitInstSetTableN(build, pc, i, labelarr, fallback);
break;
case LOP_JUMP:
emitInstJump(build, pc, i, labelarr);
break;
case LOP_JUMPBACK:
emitInstJumpBack(build, pc, i, labelarr);
break;
case LOP_JUMPIF:
emitInstJumpIf(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFNOT:
emitInstJumpIf(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::LessEqual, fallback);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::Less, fallback);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::NotLessEqual, fallback);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::NotLess, fallback);
break;
case LOP_JUMPX:
emitInstJumpX(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKNIL:
emitInstJumpxEqNil(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKB:
emitInstJumpxEqB(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKN:
emitInstJumpxEqN(build, pc, proto->k, i, labelarr);
break;
case LOP_JUMPXEQKS:
emitInstJumpxEqS(build, pc, i, labelarr);
break;
case LOP_ADD:
emitInstBinary(build, pc, i, TM_ADD, fallback);
break;
case LOP_SUB:
emitInstBinary(build, pc, i, TM_SUB, fallback);
break;
case LOP_MUL:
emitInstBinary(build, pc, i, TM_MUL, fallback);
break;
case LOP_DIV:
emitInstBinary(build, pc, i, TM_DIV, fallback);
break;
case LOP_MOD:
emitInstBinary(build, pc, i, TM_MOD, fallback);
break;
case LOP_POW:
emitInstBinary(build, pc, i, TM_POW, fallback);
break;
case LOP_ADDK:
emitInstBinaryK(build, pc, i, TM_ADD, fallback);
break;
case LOP_SUBK:
emitInstBinaryK(build, pc, i, TM_SUB, fallback);
break;
case LOP_MULK:
emitInstBinaryK(build, pc, i, TM_MUL, fallback);
break;
case LOP_DIVK:
emitInstBinaryK(build, pc, i, TM_DIV, fallback);
break;
case LOP_MODK:
emitInstBinaryK(build, pc, i, TM_MOD, fallback);
break;
case LOP_POWK:
emitInstPowK(build, pc, proto->k, i, fallback);
break;
case LOP_NOT:
emitInstNot(build, pc);
break;
case LOP_MINUS:
emitInstMinus(build, pc, i, fallback);
break;
case LOP_LENGTH:
emitInstLength(build, pc, i, fallback);
break;
case LOP_NEWTABLE:
emitInstNewTable(build, pc, i, labelarr);
break;
case LOP_DUPTABLE:
emitInstDupTable(build, pc, i, labelarr);
break;
case LOP_SETLIST:
emitInstSetList(build, pc, i, labelarr);
break;
case LOP_GETUPVAL:
emitInstGetUpval(build, pc, i);
break;
case LOP_SETUPVAL:
emitInstSetUpval(build, pc, i, labelarr);
break;
case LOP_CLOSEUPVALS:
emitInstCloseUpvals(build, pc, i, labelarr);
break;
case LOP_FASTCALL:
skip = emitInstFastCall(build, pc, i, labelarr);
break;
case LOP_FASTCALL1:
skip = emitInstFastCall1(build, pc, i, labelarr);
break;
case LOP_FASTCALL2:
skip = emitInstFastCall2(build, pc, i, labelarr);
break;
case LOP_FASTCALL2K:
skip = emitInstFastCall2K(build, pc, i, labelarr);
break;
case LOP_FORNPREP:
emitInstForNPrep(build, pc, i, labelarr);
break;
case LOP_FORNLOOP:
emitInstForNLoop(build, pc, i, labelarr);
break;
case LOP_FORGLOOP:
emitinstForGLoop(build, pc, i, labelarr, fallback);
break;
case LOP_FORGPREP_NEXT:
emitInstForGPrepNext(build, pc, i, labelarr, fallback);
break;
case LOP_FORGPREP_INEXT:
emitInstForGPrepInext(build, pc, i, labelarr, fallback);
break;
case LOP_AND:
emitInstAnd(build, pc);
break;
case LOP_ANDK:
emitInstAndK(build, pc);
break;
case LOP_OR:
emitInstOr(build, pc);
break;
case LOP_ORK:
emitInstOrK(build, pc);
break;
case LOP_GETIMPORT:
emitInstGetImport(build, pc, fallback);
break;
case LOP_CONCAT:
emitInstConcat(build, pc, i, labelarr);
break;
default:
emitFallback(build, data, op, i);
break;
}
return skip;
}
static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauOpcode op, const Instruction* pc, int i, Label* labelarr)
{
switch (op)
{
case LOP_GETIMPORT:
emitInstGetImportFallback(build, pc, i);
break;
case LOP_GETTABLE:
emitInstGetTableFallback(build, pc, i);
break;
case LOP_SETTABLE:
emitInstSetTableFallback(build, pc, i);
break;
case LOP_GETTABLEN:
emitInstGetTableNFallback(build, pc, i);
break;
case LOP_SETTABLEN:
emitInstSetTableNFallback(build, pc, i);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::LessEqual);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::Less);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::NotLessEqual);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::NotLess);
break;
case LOP_ADD:
emitInstBinaryFallback(build, pc, i, TM_ADD);
break;
case LOP_SUB:
emitInstBinaryFallback(build, pc, i, TM_SUB);
break;
case LOP_MUL:
emitInstBinaryFallback(build, pc, i, TM_MUL);
break;
case LOP_DIV:
emitInstBinaryFallback(build, pc, i, TM_DIV);
break;
case LOP_MOD:
emitInstBinaryFallback(build, pc, i, TM_MOD);
break;
case LOP_POW:
emitInstBinaryFallback(build, pc, i, TM_POW);
break;
case LOP_ADDK:
emitInstBinaryKFallback(build, pc, i, TM_ADD);
break;
case LOP_SUBK:
emitInstBinaryKFallback(build, pc, i, TM_SUB);
break;
case LOP_MULK:
emitInstBinaryKFallback(build, pc, i, TM_MUL);
break;
case LOP_DIVK:
emitInstBinaryKFallback(build, pc, i, TM_DIV);
break;
case LOP_MODK:
emitInstBinaryKFallback(build, pc, i, TM_MOD);
break;
case LOP_POWK:
emitInstBinaryKFallback(build, pc, i, TM_POW);
break;
case LOP_MINUS:
emitInstMinusFallback(build, pc, i);
break;
case LOP_LENGTH:
emitInstLengthFallback(build, pc, i);
break;
case LOP_FORGLOOP:
emitinstForGLoopFallback(build, pc, i, labelarr);
break;
case LOP_FORGPREP_NEXT:
case LOP_FORGPREP_INEXT:
emitInstForGPrepXnextFallback(build, pc, i, labelarr);
break;
case LOP_GETGLOBAL:
// TODO: luaV_gettable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETGLOBAL:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_GETTABLEKS:
// Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access
// It is also required to perform cached slot update
// TODO: extra fast-paths could be lowered before the full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETTABLEKS:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
default:
LUAU_ASSERT(!"Expected fallback for instruction");
}
}
static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options)
{ {
NativeProto* result = new NativeProto(); NativeProto* result = new NativeProto();
@ -59,222 +402,65 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
std::vector<Label> instFallbacks; std::vector<Label> instFallbacks;
instFallbacks.resize(proto->sizecode); instFallbacks.resize(proto->sizecode);
std::vector<InstructionOutline> instOutlines;
instOutlines.reserve(64);
build.align(kFunctionAlignment, AlignmentDataX64::Ud2); build.align(kFunctionAlignment, AlignmentDataX64::Ud2);
Label start = build.setLabel(); Label start = build.setLabel();
for (int i = 0, instid = 0; i < proto->sizecode; ++instid) for (int i = 0; i < proto->sizecode;)
{ {
const Instruction* pc = &proto->code[i]; const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc)); LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
build.setLabel(instLabels[i]); build.setLabel(instLabels[i]);
if (options.annotator) if (options.annotator)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, instid); options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i);
switch (op) int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), instFallbacks[i]);
{
case LOP_NOP:
break;
case LOP_LOADNIL:
emitInstLoadNil(build, pc);
break;
case LOP_LOADB:
emitInstLoadB(build, pc, i, instLabels.data());
break;
case LOP_LOADN:
emitInstLoadN(build, pc);
break;
case LOP_LOADK:
emitInstLoadK(build, pc);
break;
case LOP_LOADKX:
emitInstLoadKX(build, pc);
break;
case LOP_MOVE:
emitInstMove(build, pc);
break;
case LOP_GETGLOBAL:
emitInstGetGlobal(build, pc, i, instFallbacks[i]);
break;
case LOP_SETGLOBAL:
emitInstSetGlobal(build, pc, i, instLabels.data(), instFallbacks[i]);
break;
case LOP_GETTABLE:
emitInstGetTable(build, pc, i, instFallbacks[i]);
break;
case LOP_SETTABLE:
emitInstSetTable(build, pc, i, instLabels.data(), instFallbacks[i]);
break;
case LOP_GETTABLEKS:
emitInstGetTableKS(build, pc, i, instFallbacks[i]);
break;
case LOP_SETTABLEKS:
emitInstSetTableKS(build, pc, i, instLabels.data(), instFallbacks[i]);
break;
case LOP_GETTABLEN:
emitInstGetTableN(build, pc, i, instFallbacks[i]);
break;
case LOP_SETTABLEN:
emitInstSetTableN(build, pc, i, instLabels.data(), instFallbacks[i]);
break;
case LOP_JUMP:
emitInstJump(build, pc, i, instLabels.data());
break;
case LOP_JUMPBACK:
emitInstJumpBack(build, pc, i, instLabels.data());
break;
case LOP_JUMPIF:
emitInstJumpIf(build, pc, i, instLabels.data(), /* not_ */ false);
break;
case LOP_JUMPIFNOT:
emitInstJumpIf(build, pc, i, instLabels.data(), /* not_ */ true);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEq(build, pc, i, instLabels.data(), /* not_ */ false, instFallbacks[i]);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCond(build, pc, i, instLabels.data(), Condition::LessEqual, instFallbacks[i]);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCond(build, pc, i, instLabels.data(), Condition::Less, instFallbacks[i]);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, pc, i, instLabels.data(), /* not_ */ true, instFallbacks[i]);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, pc, i, instLabels.data(), Condition::NotLessEqual, instFallbacks[i]);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, pc, i, instLabels.data(), Condition::NotLess, instFallbacks[i]);
break;
case LOP_JUMPX:
emitInstJumpX(build, pc, i, instLabels.data());
break;
case LOP_JUMPXEQKNIL:
emitInstJumpxEqNil(build, pc, i, instLabels.data());
break;
case LOP_JUMPXEQKB:
emitInstJumpxEqB(build, pc, i, instLabels.data());
break;
case LOP_JUMPXEQKN:
emitInstJumpxEqN(build, pc, proto->k, i, instLabels.data());
break;
case LOP_JUMPXEQKS:
emitInstJumpxEqS(build, pc, i, instLabels.data());
break;
case LOP_ADD:
emitInstBinary(build, pc, i, TM_ADD, instFallbacks[i]);
break;
case LOP_SUB:
emitInstBinary(build, pc, i, TM_SUB, instFallbacks[i]);
break;
case LOP_MUL:
emitInstBinary(build, pc, i, TM_MUL, instFallbacks[i]);
break;
case LOP_DIV:
emitInstBinary(build, pc, i, TM_DIV, instFallbacks[i]);
break;
case LOP_MOD:
emitInstBinary(build, pc, i, TM_MOD, instFallbacks[i]);
break;
case LOP_POW:
emitInstBinary(build, pc, i, TM_POW, instFallbacks[i]);
break;
case LOP_ADDK:
emitInstBinaryK(build, pc, i, TM_ADD, instFallbacks[i]);
break;
case LOP_SUBK:
emitInstBinaryK(build, pc, i, TM_SUB, instFallbacks[i]);
break;
case LOP_MULK:
emitInstBinaryK(build, pc, i, TM_MUL, instFallbacks[i]);
break;
case LOP_DIVK:
emitInstBinaryK(build, pc, i, TM_DIV, instFallbacks[i]);
break;
case LOP_MODK:
emitInstBinaryK(build, pc, i, TM_MOD, instFallbacks[i]);
break;
case LOP_POWK:
emitInstPowK(build, pc, proto->k, i, instFallbacks[i]);
break;
case LOP_NOT:
emitInstNot(build, pc);
break;
case LOP_MINUS:
emitInstMinus(build, pc, i, instFallbacks[i]);
break;
case LOP_LENGTH:
emitInstLength(build, pc, i, instFallbacks[i]);
break;
case LOP_NEWTABLE:
emitInstNewTable(build, pc, i, instLabels.data());
break;
case LOP_DUPTABLE:
emitInstDupTable(build, pc, i, instLabels.data());
break;
case LOP_SETLIST:
emitInstSetList(build, pc, i, instLabels.data());
break;
case LOP_GETUPVAL:
emitInstGetUpval(build, pc, i);
break;
case LOP_SETUPVAL:
emitInstSetUpval(build, pc, i, instLabels.data());
break;
case LOP_CLOSEUPVALS:
emitInstCloseUpvals(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL:
emitInstFastCall(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL1:
emitInstFastCall1(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL2:
emitInstFastCall2(build, pc, i, instLabels.data());
break;
case LOP_FASTCALL2K:
emitInstFastCall2K(build, pc, i, instLabels.data());
break;
case LOP_FORNPREP:
emitInstForNPrep(build, pc, i, instLabels.data());
break;
case LOP_FORNLOOP:
emitInstForNLoop(build, pc, i, instLabels.data());
break;
case LOP_AND:
emitInstAnd(build, pc);
break;
case LOP_ANDK:
emitInstAndK(build, pc);
break;
case LOP_OR:
emitInstOr(build, pc);
break;
case LOP_ORK:
emitInstOrK(build, pc);
break;
case LOP_GETIMPORT:
emitInstGetImport(build, pc, instFallbacks[i]);
break;
case LOP_CONCAT:
emitInstConcat(build, pc, i, instLabels.data());
break;
default:
emitFallback(build, data, op, i);
break;
}
i += getOpLength(op); if (skip != 0)
instOutlines.push_back({nexti, skip});
i = nexti + skip;
LUAU_ASSERT(i <= proto->sizecode); LUAU_ASSERT(i <= proto->sizecode);
} }
size_t textSize = build.text.size(); size_t textSize = build.text.size();
uint32_t codeSize = build.getCodeSize(); uint32_t codeSize = build.getCodeSize();
if (options.annotator && !options.skipOutlinedCode)
build.logAppend("; outlined instructions\n");
for (auto [pcpos, length] : instOutlines)
{
int i = pcpos;
while (i < pcpos + length)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
build.setLabel(instLabels[i]);
if (options.annotator && !options.skipOutlinedCode)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i);
int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), instFallbacks[i]);
LUAU_ASSERT(skip == 0);
i += getOpLength(op);
}
if (i < proto->sizecode)
build.jmp(instLabels[i]);
}
if (options.annotator && !options.skipOutlinedCode) if (options.annotator && !options.skipOutlinedCode)
build.logAppend("; outlined code\n"); build.logAppend("; outlined code\n");
@ -297,104 +483,7 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
build.setLabel(instFallbacks[i]); build.setLabel(instFallbacks[i]);
switch (op) emitInstFallback(build, data, op, pc, i, instLabels.data());
{
case LOP_GETIMPORT:
emitInstGetImportFallback(build, pc, i);
break;
case LOP_GETTABLE:
emitInstGetTableFallback(build, pc, i);
break;
case LOP_SETTABLE:
emitInstSetTableFallback(build, pc, i);
break;
case LOP_GETTABLEN:
emitInstGetTableNFallback(build, pc, i);
break;
case LOP_SETTABLEN:
emitInstSetTableNFallback(build, pc, i);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEqFallback(build, pc, i, instLabels.data(), /* not_ */ false);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCondFallback(build, pc, i, instLabels.data(), Condition::LessEqual);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCondFallback(build, pc, i, instLabels.data(), Condition::Less);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEqFallback(build, pc, i, instLabels.data(), /* not_ */ true);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCondFallback(build, pc, i, instLabels.data(), Condition::NotLessEqual);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCondFallback(build, pc, i, instLabels.data(), Condition::NotLess);
break;
case LOP_ADD:
emitInstBinaryFallback(build, pc, i, TM_ADD);
break;
case LOP_SUB:
emitInstBinaryFallback(build, pc, i, TM_SUB);
break;
case LOP_MUL:
emitInstBinaryFallback(build, pc, i, TM_MUL);
break;
case LOP_DIV:
emitInstBinaryFallback(build, pc, i, TM_DIV);
break;
case LOP_MOD:
emitInstBinaryFallback(build, pc, i, TM_MOD);
break;
case LOP_POW:
emitInstBinaryFallback(build, pc, i, TM_POW);
break;
case LOP_ADDK:
emitInstBinaryKFallback(build, pc, i, TM_ADD);
break;
case LOP_SUBK:
emitInstBinaryKFallback(build, pc, i, TM_SUB);
break;
case LOP_MULK:
emitInstBinaryKFallback(build, pc, i, TM_MUL);
break;
case LOP_DIVK:
emitInstBinaryKFallback(build, pc, i, TM_DIV);
break;
case LOP_MODK:
emitInstBinaryKFallback(build, pc, i, TM_MOD);
break;
case LOP_POWK:
emitInstBinaryKFallback(build, pc, i, TM_POW);
break;
case LOP_MINUS:
emitInstMinusFallback(build, pc, i);
break;
case LOP_LENGTH:
emitInstLengthFallback(build, pc, i);
break;
case LOP_GETGLOBAL:
// TODO: luaV_gettable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETGLOBAL:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_GETTABLEKS:
// Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access
// It is also required to perform cached slot update
// TODO: extra fast-paths could be lowered before the full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETTABLEKS:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
default:
LUAU_ASSERT(!"Expected fallback for instruction");
}
// Jump back to the next instruction handler // Jump back to the next instruction handler
if (nexti < proto->sizecode) if (nexti < proto->sizecode)
@ -568,13 +657,16 @@ void compile(lua_State* L, int idx)
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, clvalue(func)->l.p); gatherFunctions(protos, clvalue(func)->l.p);
ModuleHelpers helpers;
assembleHelpers(build, helpers);
std::vector<NativeProto*> results; std::vector<NativeProto*> results;
results.reserve(protos.size()); results.reserve(protos.size());
// Skip protos that have been compiled during previous invocations of CodeGen::compile // Skip protos that have been compiled during previous invocations of CodeGen::compile
for (Proto* p : protos) for (Proto* p : protos)
if (p && getProtoExecData(p) == nullptr) if (p && getProtoExecData(p) == nullptr)
results.push_back(assembleFunction(build, *data, p, {})); results.push_back(assembleFunction(build, *data, helpers, p, {}));
build.finalize(); build.finalize();
@ -615,10 +707,13 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options)
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, clvalue(func)->l.p); gatherFunctions(protos, clvalue(func)->l.p);
ModuleHelpers helpers;
assembleHelpers(build, helpers);
for (Proto* p : protos) for (Proto* p : protos)
if (p) if (p)
{ {
NativeProto* nativeProto = assembleFunction(build, data, p, options); NativeProto* nativeProto = assembleFunction(build, data, helpers, p, options);
destroyNativeProto(nativeProto); destroyNativeProto(nativeProto);
} }

View File

@ -0,0 +1,76 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "CodeGenUtils.h"
#include "ldo.h"
#include "ltable.h"
#include "FallbacksProlog.h"
#include <string.h>
namespace Luau
{
namespace CodeGen
{
bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra)
{
// then we advance index through the hash portion
while (unsigned(index - h->sizearray) < unsigned(1 << h->lsizenode))
{
LuaNode* n = &h->node[index - h->sizearray];
if (!ttisnil(gval(n)))
{
setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1)));
getnodekey(L, ra + 3, n);
setobj(L, ra + 4, gval(n));
return true;
}
index++;
}
return false;
}
bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux)
{
TValue* base = L->base;
TValue* ra = VM_REG(insnA);
// note: it's safe to push arguments past top for complicated reasons (see lvmexecute.cpp)
setobj2s(L, ra + 3 + 2, ra + 2);
setobj2s(L, ra + 3 + 1, ra + 1);
setobj2s(L, ra + 3, ra);
L->top = ra + 3 + 3; // func + 2 args (state and index)
LUAU_ASSERT(L->top <= L->stack_last);
luaD_call(L, ra + 3, uint8_t(aux));
L->top = L->ci->top;
// recompute ra since stack might have been reallocated
base = L->base;
ra = VM_REG(insnA);
// copy first variable back into the iteration index
setobj2s(L, ra + 2, ra + 3);
return !ttisnil(ra + 3);
}
void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc)
{
if (!ttisfunction(ra))
{
Closure* cl = clvalue(L->ci->func);
L->ci->savedpc = cl->l.p->code + pc;
luaG_typeerror(L, ra, "iterate over");
}
}
} // namespace CodeGen
} // namespace Luau

View File

@ -0,0 +1,17 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "lobject.h"
namespace Luau
{
namespace CodeGen
{
bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra);
bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux);
void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc);
} // namespace CodeGen
} // namespace Luau

View File

@ -67,6 +67,13 @@ constexpr unsigned kLuaNodeTagMask = 0xf;
constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field constexpr unsigned kOffsetOfLuaNodeTag = 12; // offsetof cannot be used on a bit field
constexpr unsigned kOffsetOfInstructionC = 3; constexpr unsigned kOffsetOfInstructionC = 3;
// Leaf functions that are placed in every module to perform common instruction sequences
struct ModuleHelpers
{
Label exitContinueVm;
Label exitNoContinueVm;
};
inline OperandX64 luauReg(int ri) inline OperandX64 luauReg(int ri)
{ {
return xmmword[rBase + ri * sizeof(TValue)]; return xmmword[rBase + ri * sizeof(TValue)];
@ -107,6 +114,12 @@ inline OperandX64 luauNodeKeyValue(RegisterX64 node)
return qword[node + offsetof(LuaNode, key) + offsetof(TKey, value)]; return qword[node + offsetof(LuaNode, key) + offsetof(TKey, value)];
} }
// Note: tag has dirty upper bits
inline OperandX64 luauNodeKeyTag(RegisterX64 node)
{
return dword[node + offsetof(LuaNode, key) + kOffsetOfLuaNodeTag];
}
inline OperandX64 luauNodeValue(RegisterX64 node) inline OperandX64 luauNodeValue(RegisterX64 node)
{ {
return xmmword[node + offsetof(LuaNode, val)]; return xmmword[node + offsetof(LuaNode, val)];
@ -184,7 +197,7 @@ inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, Re
{ {
tmp.size = SizeX64::dword; tmp.size = SizeX64::dword;
build.mov(tmp, dword[node + offsetof(LuaNode, key) + kOffsetOfLuaNodeTag]); build.mov(tmp, luauNodeKeyTag(node));
build.and_(tmp, kLuaNodeTagMask); build.and_(tmp, kLuaNodeTagMask);
build.cmp(tmp, tag); build.cmp(tmp, tag);
build.jcc(Condition::NotEqual, label); build.jcc(Condition::NotEqual, label);

View File

@ -34,7 +34,7 @@ void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos,
build.mov(luauRegTag(ra), LUA_TBOOLEAN); build.mov(luauRegTag(ra), LUA_TBOOLEAN);
if (int target = LUAU_INSN_C(*pc)) if (int target = LUAU_INSN_C(*pc))
build.jmp(labelarr[pcpos + target + 1]); build.jmp(labelarr[pcpos + 1 + target]);
} }
void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc) void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc)
@ -72,23 +72,185 @@ void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc)
build.vmovups(luauReg(ra), xmm0); build.vmovups(luauReg(ra), xmm0);
} }
void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr)
{
emitInterrupt(build, pcpos);
int ra = LUAU_INSN_A(*pc);
int b = LUAU_INSN_B(*pc) - 1;
RegisterX64 ci = r8;
RegisterX64 cip = r9;
RegisterX64 res = rdi;
RegisterX64 nresults = esi;
build.mov(ci, qword[rState + offsetof(lua_State, ci)]);
build.lea(cip, qword[ci - sizeof(CallInfo)]);
// res = ci->func; note: we assume CALL always puts func+args and expects results to start at func
build.mov(res, qword[ci + offsetof(CallInfo, func)]);
// nresults = ci->nresults
build.mov(nresults, dword[ci + offsetof(CallInfo, nresults)]);
{
Label skipResultCopy;
RegisterX64 counter = ecx;
if (b == 0)
{
// Our instruction doesn't have any results, so just fill results expected in parent with 'nil'
build.test(nresults, nresults); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0
build.jcc(Condition::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
build.mov(counter, nresults);
Label repeatNilLoop = build.setLabel();
build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL);
build.add(res, sizeof(TValue));
build.dec(counter);
build.jcc(Condition::NotZero, repeatNilLoop);
}
else if (b == 1)
{
// Try setting our 1 result
build.test(nresults, nresults);
build.jcc(Condition::Zero, skipResultCopy);
build.lea(counter, dword[nresults - 1]);
build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[res], xmm0);
build.add(res, sizeof(TValue));
// Fill the rest of the expected results with 'nil'
build.test(counter, counter); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0
build.jcc(Condition::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
Label repeatNilLoop = build.setLabel();
build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL);
build.add(res, sizeof(TValue));
build.dec(counter);
build.jcc(Condition::NotZero, repeatNilLoop);
}
else
{
RegisterX64 vali = rax;
RegisterX64 valend = rdx;
// Copy return values into parent stack (but only up to nresults!)
build.test(nresults, nresults);
build.jcc(Condition::Zero, skipResultCopy);
// vali = ra
build.lea(vali, luauRegValue(ra));
// Copy as much as possible for MULTRET calls, and only as much as needed otherwise
if (b == LUA_MULTRET)
build.mov(valend, qword[rState + offsetof(lua_State, top)]); // valend = L->top
else
build.lea(valend, luauRegValue(ra + b)); // valend = ra + b
build.mov(counter, nresults);
Label repeatValueLoop, exitValueLoop;
build.setLabel(repeatValueLoop);
build.cmp(vali, valend);
build.jcc(Condition::NotBelow, exitValueLoop);
build.vmovups(xmm0, xmmword[vali]);
build.vmovups(xmmword[res], xmm0);
build.add(vali, sizeof(TValue));
build.add(res, sizeof(TValue));
build.dec(counter);
build.jcc(Condition::NotZero, repeatValueLoop);
build.setLabel(exitValueLoop);
// Fill the rest of the expected results with 'nil'
build.test(counter, counter); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0
build.jcc(Condition::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
Label repeatNilLoop = build.setLabel();
build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL);
build.add(res, sizeof(TValue));
build.dec(counter);
build.jcc(Condition::NotZero, repeatNilLoop);
}
build.setLabel(skipResultCopy);
}
build.mov(qword[rState + offsetof(lua_State, ci)], cip); // L->ci = cip
build.mov(rBase, qword[cip + offsetof(CallInfo, base)]); // sync base = L->base while we have a chance
build.mov(qword[rState + offsetof(lua_State, base)], rBase); // L->base = cip->base
// Start with result for LUA_MULTRET/exit value
build.mov(qword[rState + offsetof(lua_State, top)], res); // L->top = res
// Unlikely, but this might be the last return from VM
build.test(byte[ci + offsetof(CallInfo, flags)], LUA_CALLINFO_RETURN);
build.jcc(Condition::NotZero, helpers.exitNoContinueVm);
Label skipFixedRetTop;
build.test(nresults, nresults); // test here will set SF=1 for a negative number and it always sets OF to 0
build.jcc(Condition::Less, skipFixedRetTop); // jl jumps if SF != OF
build.mov(rax, qword[cip + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax); // L->top = cip->top
build.setLabel(skipFixedRetTop);
// Returning back to the previous function is a bit tricky
// Registers alive: r9 (cip)
RegisterX64 proto = rcx;
RegisterX64 execdata = rbx;
// Change closure
build.mov(rax, qword[cip + offsetof(CallInfo, func)]);
build.mov(rax, qword[rax + offsetof(TValue, value.gc)]);
build.mov(sClosure, rax);
build.mov(proto, qword[rax + offsetof(Closure, l.p)]);
build.mov(execdata, qword[proto + offsetofProtoExecData]);
build.test(execdata, execdata);
build.jcc(Condition::Zero, helpers.exitContinueVm); // Continue in interpreter if function has no native data
// Change constants
build.mov(rConstants, qword[proto + offsetof(Proto, k)]);
// Change code
build.mov(rdx, qword[proto + offsetof(Proto, code)]);
build.mov(sCode, rdx);
build.mov(rax, qword[cip + offsetof(CallInfo, savedpc)]);
// To get instruction index from instruction pointer, we need to divide byte offset by 4
// But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out
build.sub(rax, rdx);
// Get new instruction location and jump to it
build.mov(rdx, qword[execdata + offsetof(NativeProto, instTargets)]);
build.jmp(qword[rdx + rax * 2]);
}
void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
build.jmp(labelarr[pcpos + LUAU_INSN_D(*pc) + 1]); build.jmp(labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]);
} }
void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
emitInterrupt(build, pcpos); emitInterrupt(build, pcpos);
build.jmp(labelarr[pcpos + LUAU_INSN_D(*pc) + 1]); build.jmp(labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]);
} }
void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_) void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 1]; Label& exit = labelarr[pcpos + 1];
if (not_) if (not_)
@ -102,7 +264,7 @@ void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = pc[1]; int rb = pc[1];
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2]; Label& exit = labelarr[pcpos + 2];
build.mov(eax, luauRegTag(ra)); build.mov(eax, luauRegTag(ra));
@ -121,7 +283,7 @@ void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_) void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_)
{ {
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? Condition::NotEqual : Condition::Equal, target, pcpos); jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? Condition::NotEqual : Condition::Equal, target, pcpos);
} }
@ -131,7 +293,7 @@ void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pc
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
int rb = pc[1]; int rb = pc[1];
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
// fast-path: number // fast-path: number
jumpIfTagIsNot(build, ra, LUA_TNUMBER, fallback); jumpIfTagIsNot(build, ra, LUA_TNUMBER, fallback);
@ -142,7 +304,7 @@ void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pc
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond) void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond)
{ {
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], cond, target, pcpos); jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], cond, target, pcpos);
} }
@ -151,7 +313,7 @@ void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos,
{ {
emitInterrupt(build, pcpos); emitInterrupt(build, pcpos);
build.jmp(labelarr[pcpos + LUAU_INSN_E(*pc) + 1]); build.jmp(labelarr[pcpos + 1 + LUAU_INSN_E(*pc)]);
} }
void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
@ -159,7 +321,7 @@ void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pc
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
bool not_ = (pc[1] & 0x80000000) != 0; bool not_ = (pc[1] & 0x80000000) != 0;
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
build.cmp(luauRegTag(ra), LUA_TNIL); build.cmp(luauRegTag(ra), LUA_TNIL);
build.jcc(not_ ? Condition::NotEqual : Condition::Equal, target); build.jcc(not_ ? Condition::NotEqual : Condition::Equal, target);
@ -171,7 +333,7 @@ void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
uint32_t aux = pc[1]; uint32_t aux = pc[1];
bool not_ = (aux & 0x80000000) != 0; bool not_ = (aux & 0x80000000) != 0;
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2]; Label& exit = labelarr[pcpos + 2];
jumpIfTagIsNot(build, ra, LUA_TBOOLEAN, not_ ? target : exit); jumpIfTagIsNot(build, ra, LUA_TBOOLEAN, not_ ? target : exit);
@ -187,7 +349,7 @@ void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TV
bool not_ = (aux & 0x80000000) != 0; bool not_ = (aux & 0x80000000) != 0;
TValue kv = k[aux & 0xffffff]; TValue kv = k[aux & 0xffffff];
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2]; Label& exit = labelarr[pcpos + 2];
jumpIfTagIsNot(build, ra, LUA_TNUMBER, not_ ? target : exit); jumpIfTagIsNot(build, ra, LUA_TNUMBER, not_ ? target : exit);
@ -212,7 +374,7 @@ void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
uint32_t aux = pc[1]; uint32_t aux = pc[1];
bool not_ = (aux & 0x80000000) != 0; bool not_ = (aux & 0x80000000) != 0;
Label& target = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2]; Label& exit = labelarr[pcpos + 2];
jumpIfTagIsNot(build, ra, LUA_TSTRING, not_ ? target : exit); jumpIfTagIsNot(build, ra, LUA_TSTRING, not_ ? target : exit);
@ -596,8 +758,8 @@ void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int p
build.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]);
} }
static void emitInstFastCallN( static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, bool customParams, int customParamCount, OperandX64 customArgs,
AssemblyBuilderX64& build, const Instruction* pc, bool customParams, int customParamCount, OperandX64 customArgs, int pcpos, Label* labelarr) int pcpos, int instLen, Label* labelarr)
{ {
int bfid = LUAU_INSN_A(*pc); int bfid = LUAU_INSN_A(*pc);
int skip = LUAU_INSN_C(*pc); int skip = LUAU_INSN_C(*pc);
@ -611,7 +773,7 @@ static void emitInstFastCallN(
int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1; int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1;
OperandX64 args = customParams ? customArgs : luauRegValue(ra + 2); OperandX64 args = customParams ? customArgs : luauRegValue(ra + 2);
Label exit; Label& exit = labelarr[pcpos + instLen];
jumpIfUnsafeEnv(build, rax, exit); jumpIfUnsafeEnv(build, rax, exit);
@ -633,10 +795,7 @@ static void emitInstFastCallN(
build.mov(qword[rState + offsetof(lua_State, top)], rax); build.mov(qword[rState + offsetof(lua_State, top)], rax);
} }
// TODO: once we start outlining the fallback, we will be able to fallthrough to the next instruction return skip + 2 - instLen; // Return fallback instruction sequence length
build.jmp(labelarr[pcpos + skip + 2]);
build.setLabel(exit);
return;
} }
// TODO: we can skip saving pc for some well-behaved builtins which we didn't inline // TODO: we can skip saving pc for some well-behaved builtins which we didn't inline
@ -705,37 +864,35 @@ static void emitInstFastCallN(
build.mov(qword[rState + offsetof(lua_State, top)], rax); build.mov(qword[rState + offsetof(lua_State, top)], rax);
} }
build.jmp(labelarr[pcpos + skip + 2]); return skip + 2 - instLen; // Return fallback instruction sequence length
build.setLabel(exit);
// TODO: fallback to LOP_CALL after a fast call should be outlined
} }
void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 1, /* customArgs */ 0, pcpos, labelarr); return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 1, /* customArgs */ 0, pcpos, /* instLen */ 1, labelarr);
} }
void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegValue(pc[1]), pcpos, labelarr); return emitInstFastCallN(
build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegValue(pc[1]), pcpos, /* instLen */ 2, labelarr);
} }
void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantValue(pc[1]), pcpos, labelarr); return emitInstFastCallN(
build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantValue(pc[1]), pcpos, /* instLen */ 2, labelarr);
} }
void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
emitInstFastCallN(build, pc, /* customParams */ false, /* customParamCount */ 0, /* customArgs */ 0, pcpos, labelarr); return emitInstFastCallN(build, pc, /* customParams */ false, /* customParamCount */ 0, /* customArgs */ 0, pcpos, /* instLen */ 1, labelarr);
} }
void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr) void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{ {
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
Label& loopExit = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& loopExit = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label tryConvert, exit; Label tryConvert, exit;
@ -784,7 +941,7 @@ void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
emitInterrupt(build, pcpos); emitInterrupt(build, pcpos);
int ra = LUAU_INSN_A(*pc); int ra = LUAU_INSN_A(*pc);
Label& loopRepeat = labelarr[pcpos + LUAU_INSN_D(*pc) + 1]; Label& loopRepeat = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
RegisterX64 limit = xmm0; RegisterX64 limit = xmm0;
RegisterX64 step = xmm1; RegisterX64 step = xmm1;
@ -814,6 +971,162 @@ void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
build.setLabel(exit); build.setLabel(exit);
} }
void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int aux = pc[1];
Label& loopRepeat = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2];
emitInterrupt(build, pcpos);
// fast-path: builtin table iteration
jumpIfTagIsNot(build, ra, LUA_TNIL, fallback);
// Registers are chosen in this way to simplify fallback code for the node part
RegisterX64 table = rArg2;
RegisterX64 index = rArg3;
RegisterX64 elemPtr = rax;
build.mov(table, luauRegValue(ra + 1));
build.mov(index, luauRegValue(ra + 2));
// &array[index]
build.mov(dwordReg(elemPtr), dwordReg(index));
build.shl(dwordReg(elemPtr), kTValueSizeLog2);
build.add(elemPtr, qword[table + offsetof(Table, array)]);
// Clear extra variables since we might have more than two
for (int i = 2; i < aux; ++i)
build.mov(luauRegTag(ra + 3 + i), LUA_TNIL);
// ipairs-style traversal is terminated early when array part ends of nil array element is encountered
bool isIpairsIter = aux < 0;
Label skipArray, skipArrayNil;
// First we advance index through the array portion
// while (unsigned(index) < unsigned(sizearray))
Label arrayLoop = build.setLabel();
build.cmp(dwordReg(index), dword[table + offsetof(Table, sizearray)]);
build.jcc(Condition::NotBelow, isIpairsIter ? exit : skipArray);
// If element is nil, we increment the index; if it's not, we still need 'index + 1' inside
build.inc(index);
build.cmp(dword[elemPtr + offsetof(TValue, tt)], LUA_TNIL);
build.jcc(Condition::Equal, isIpairsIter ? exit : skipArrayNil);
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1)));
build.mov(luauRegValue(ra + 2), index);
// Tag should already be set to lightuserdata
// setnvalue(ra + 3, double(index + 1));
build.vcvtsi2sd(xmm0, xmm0, dwordReg(index));
build.vmovsd(luauRegValue(ra + 3), xmm0);
build.mov(luauRegTag(ra + 3), LUA_TNUMBER);
// setobj2s(L, ra + 4, e);
setLuauReg(build, xmm2, ra + 4, xmmword[elemPtr]);
build.jmp(loopRepeat);
if (!isIpairsIter)
{
build.setLabel(skipArrayNil);
// Index already incremented, advance to next array element
build.add(elemPtr, sizeof(TValue));
build.jmp(arrayLoop);
build.setLabel(skipArray);
// Call helper to assign next node value or to signal loop exit
build.mov(rArg1, rState);
// rArg2 and rArg3 are already set
build.lea(rArg4, luauRegValue(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNodeIter)]);
build.test(al, al);
build.jcc(Condition::NotZero, loopRepeat);
}
}
void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
int aux = pc[1];
Label& loopRepeat = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(rArg2, ra);
build.mov(rArg3, aux);
build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNonTableFallback)]);
emitUpdateBase(build);
build.test(al, al);
build.jcc(Condition::NotZero, loopRepeat);
}
void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
// fast-path: pairs/next
jumpIfUnsafeEnv(build, rax, fallback);
jumpIfTagIsNot(build, ra + 1, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, ra + 2, LUA_TNIL, fallback);
build.mov(luauRegTag(ra), LUA_TNIL);
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
build.mov(luauRegValue(ra + 2), 0);
build.mov(luauRegTag(ra + 2), LUA_TLIGHTUSERDATA);
build.jmp(target);
}
void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
// fast-path: ipairs/inext
jumpIfUnsafeEnv(build, rax, fallback);
jumpIfTagIsNot(build, ra + 1, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, ra + 2, LUA_TNUMBER, fallback);
build.vxorpd(xmm0, xmm0, xmm0);
build.vmovsd(xmm1, luauRegValue(ra + 2));
jumpOnNumberCmp(build, noreg, xmm0, xmm1, Condition::NotEqual, fallback);
build.mov(luauRegTag(ra), LUA_TNIL);
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
build.mov(luauRegValue(ra + 2), 0);
build.mov(luauRegTag(ra + 2), LUA_TLIGHTUSERDATA);
build.jmp(target);
}
void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra));
build.mov(rArg3, pcpos + 1);
build.call(qword[rNativeContext + offsetof(NativeContext, forgPrepXnextFallback)]);
build.jmp(target);
}
static void emitInstAndX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c) static void emitInstAndX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c)
{ {
Label target, fallthrough; Label target, fallthrough;

View File

@ -16,7 +16,7 @@ namespace CodeGen
class AssemblyBuilderX64; class AssemblyBuilderX64;
enum class Condition; enum class Condition;
struct Label; struct Label;
struct NativeState; struct ModuleHelpers;
void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc); void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
@ -24,6 +24,7 @@ void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc); void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc); void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_); void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
@ -52,12 +53,17 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos); void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr); void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback);
void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc); void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc); void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc); void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc);

View File

@ -3,6 +3,7 @@
#include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilder.h"
#include "CodeGenUtils.h"
#include "CustomExecUtils.h" #include "CustomExecUtils.h"
#include "Fallbacks.h" #include "Fallbacks.h"
@ -13,14 +14,10 @@
#include "lvm.h" #include "lvm.h"
#include <math.h> #include <math.h>
#include <string.h>
#define CODEGEN_SET_FALLBACK(op, flags) data.context.fallback[op] = {execute_##op, flags} #define CODEGEN_SET_FALLBACK(op, flags) data.context.fallback[op] = {execute_##op, flags}
static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
return -1;
}
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -42,11 +39,7 @@ void initFallbackTable(NativeState& data)
CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, 0); CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, 0);
CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0); CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0);
CODEGEN_SET_FALLBACK(LOP_CALL, kFallbackUpdateCi | kFallbackCheckInterrupt); CODEGEN_SET_FALLBACK(LOP_CALL, kFallbackUpdateCi | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_RETURN, kFallbackUpdateCi | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc); CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_FORGLOOP, kFallbackUpdatePc | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_FORGPREP_INEXT, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_FORGPREP_NEXT, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0); CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0);
CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0); CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0);
CODEGEN_SET_FALLBACK(LOP_PREPVARARGS, 0); CODEGEN_SET_FALLBACK(LOP_PREPVARARGS, 0);
@ -62,12 +55,8 @@ void initFallbackTable(NativeState& data)
void initHelperFunctions(NativeState& data) void initHelperFunctions(NativeState& data)
{ {
static_assert(sizeof(data.context.luauF_table) / sizeof(data.context.luauF_table[0]) == sizeof(luauF_table) / sizeof(luauF_table[0]), static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length");
"fast call tables are not of the same length"); memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table));
// Replace missing fast call functions with an empty placeholder that forces LOP_CALL fallback
for (size_t i = 0; i < sizeof(data.context.luauF_table) / sizeof(data.context.luauF_table[0]); i++)
data.context.luauF_table[i] = luauF_table[i] ? luauF_table[i] : luauF_missing;
data.context.luaV_lessthan = luaV_lessthan; data.context.luaV_lessthan = luaV_lessthan;
data.context.luaV_lessequal = luaV_lessequal; data.context.luaV_lessequal = luaV_lessequal;
@ -93,6 +82,10 @@ void initHelperFunctions(NativeState& data)
data.context.luaF_close = luaF_close; data.context.luaF_close = luaF_close;
data.context.libm_pow = pow; data.context.libm_pow = pow;
data.context.forgLoopNodeIter = forgLoopNodeIter;
data.context.forgLoopNonTableFallback = forgLoopNonTableFallback;
data.context.forgPrepXnextFallback = forgPrepXnextFallback;
} }
} // namespace CodeGen } // namespace CodeGen

View File

@ -3,13 +3,16 @@
#include "Luau/Bytecode.h" #include "Luau/Bytecode.h"
#include "Luau/CodeAllocator.h" #include "Luau/CodeAllocator.h"
#include "Luau/Label.h"
#include <memory> #include <memory>
#include <stdint.h> #include <stdint.h>
#include "ldebug.h"
#include "lobject.h" #include "lobject.h"
#include "ltm.h" #include "ltm.h"
#include "lstate.h"
typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams); typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams);
@ -77,6 +80,11 @@ struct NativeContext
void (*luaF_close)(lua_State* L, StkId level) = nullptr; void (*luaF_close)(lua_State* L, StkId level) = nullptr;
double (*libm_pow)(double, double) = nullptr; double (*libm_pow)(double, double) = nullptr;
// Helper functions
bool (*forgLoopNodeIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr;
bool (*forgLoopNonTableFallback)(lua_State* L, int insnA, int aux) = nullptr;
void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr;
}; };
struct NativeState struct NativeState

View File

@ -117,7 +117,7 @@ public:
std::string dumpEverything() const; std::string dumpEverything() const;
std::string dumpSourceRemarks() const; std::string dumpSourceRemarks() const;
void annotateInstruction(std::string& result, uint32_t fid, uint32_t instid) const; void annotateInstruction(std::string& result, uint32_t fid, uint32_t instpos) const;
static uint32_t getImportId(int32_t id0); static uint32_t getImportId(int32_t id0);
static uint32_t getImportId(int32_t id0, int32_t id1); static uint32_t getImportId(int32_t id0, int32_t id1);

View File

@ -1977,14 +1977,14 @@ std::string BytecodeBuilder::dumpCurrentFunction(std::vector<int>& dumpinstoffs)
if (labels[i] == 0) if (labels[i] == 0)
labels[i] = nextLabel++; labels[i] = nextLabel++;
dumpinstoffs.reserve(insns.size()); dumpinstoffs.resize(insns.size() + 1, -1);
for (size_t i = 0; i < insns.size();) for (size_t i = 0; i < insns.size();)
{ {
const uint32_t* code = &insns[i]; const uint32_t* code = &insns[i];
uint8_t op = LUAU_INSN_OP(*code); uint8_t op = LUAU_INSN_OP(*code);
dumpinstoffs.push_back(int(result.size())); dumpinstoffs[i] = int(result.size());
if (op == LOP_PREPVARARGS) if (op == LOP_PREPVARARGS)
{ {
@ -2028,7 +2028,7 @@ std::string BytecodeBuilder::dumpCurrentFunction(std::vector<int>& dumpinstoffs)
LUAU_ASSERT(i <= insns.size()); LUAU_ASSERT(i <= insns.size());
} }
dumpinstoffs.push_back(int(result.size())); dumpinstoffs[insns.size()] = int(result.size());
return result; return result;
} }
@ -2119,7 +2119,7 @@ std::string BytecodeBuilder::dumpSourceRemarks() const
return result; return result;
} }
void BytecodeBuilder::annotateInstruction(std::string& result, uint32_t fid, uint32_t instid) const void BytecodeBuilder::annotateInstruction(std::string& result, uint32_t fid, uint32_t instpos) const
{ {
if ((dumpFlags & Dump_Code) == 0) if ((dumpFlags & Dump_Code) == 0)
return; return;
@ -2130,9 +2130,15 @@ void BytecodeBuilder::annotateInstruction(std::string& result, uint32_t fid, uin
const std::string& dump = function.dump; const std::string& dump = function.dump;
const std::vector<int>& dumpinstoffs = function.dumpinstoffs; const std::vector<int>& dumpinstoffs = function.dumpinstoffs;
LUAU_ASSERT(instid + 1 < dumpinstoffs.size()); uint32_t next = instpos + 1;
formatAppend(result, "%.*s", dumpinstoffs[instid + 1] - dumpinstoffs[instid], dump.data() + dumpinstoffs[instid]); LUAU_ASSERT(next < dumpinstoffs.size());
// Skip locations of multi-dword instructions
while (next < dumpinstoffs.size() && dumpinstoffs[next] == -1)
next++;
formatAppend(result, "%.*s", dumpinstoffs[next] - dumpinstoffs[instpos], dump.data() + dumpinstoffs[instpos]);
} }
} // namespace Luau } // namespace Luau

View File

@ -75,7 +75,7 @@ endif
# configuration-specific flags # configuration-specific flags
ifeq ($(config),release) ifeq ($(config),release)
CXXFLAGS+=-O2 -DNDEBUG CXXFLAGS+=-O2 -DNDEBUG -fno-math-errno
endif endif
ifeq ($(config),coverage) ifeq ($(config),coverage)
@ -102,7 +102,7 @@ ifeq ($(config),fuzz)
endif endif
ifeq ($(config),profile) ifeq ($(config),profile)
CXXFLAGS+=-O2 -DNDEBUG -gdwarf-4 -DCALLGRIND=1 CXXFLAGS+=-O2 -DNDEBUG -fno-math-errno -gdwarf-4 -DCALLGRIND=1
endif endif
ifeq ($(protobuf),download) ifeq ($(protobuf),download)

View File

@ -71,6 +71,7 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/src/CodeAllocator.cpp CodeGen/src/CodeAllocator.cpp
CodeGen/src/CodeBlockUnwind.cpp CodeGen/src/CodeBlockUnwind.cpp
CodeGen/src/CodeGen.cpp CodeGen/src/CodeGen.cpp
CodeGen/src/CodeGenUtils.cpp
CodeGen/src/CodeGenX64.cpp CodeGen/src/CodeGenX64.cpp
CodeGen/src/EmitBuiltinsX64.cpp CodeGen/src/EmitBuiltinsX64.cpp
CodeGen/src/EmitCommonX64.cpp CodeGen/src/EmitCommonX64.cpp
@ -82,6 +83,7 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/src/ByteUtils.h CodeGen/src/ByteUtils.h
CodeGen/src/CustomExecUtils.h CodeGen/src/CustomExecUtils.h
CodeGen/src/CodeGenUtils.h
CodeGen/src/CodeGenX64.h CodeGen/src/CodeGenX64.h
CodeGen/src/EmitBuiltinsX64.h CodeGen/src/EmitBuiltinsX64.h
CodeGen/src/EmitCommonX64.h CodeGen/src/EmitCommonX64.h
@ -339,6 +341,7 @@ if(TARGET Luau.UnitTest)
tests/TypeInfer.intersectionTypes.test.cpp tests/TypeInfer.intersectionTypes.test.cpp
tests/TypeInfer.loops.test.cpp tests/TypeInfer.loops.test.cpp
tests/TypeInfer.modules.test.cpp tests/TypeInfer.modules.test.cpp
tests/TypeInfer.negations.test.cpp
tests/TypeInfer.oop.test.cpp tests/TypeInfer.oop.test.cpp
tests/TypeInfer.operators.test.cpp tests/TypeInfer.operators.test.cpp
tests/TypeInfer.primitives.test.cpp tests/TypeInfer.primitives.test.cpp

View File

@ -20,6 +20,14 @@
#define LUAU_FASTMATH_END #define LUAU_FASTMATH_END
#endif #endif
// Some functions like floor/ceil have SSE4.1 equivalents but we currently support systems without SSE4.1
// On newer GCC and Clang we can use function multi-versioning to generate SSE4.1 code plus CPUID based dispatch.
#if !defined(__APPLE__) && (defined(__x86_64__) || defined(_M_X64)) && ((defined(__clang__) && __clang_major__ >= 14) || (defined(__GNUC__) && __GNUC__ >= 6)) && !defined(__SSE4_1__)
#define LUAU_DISPATCH_SSE41 __attribute__((target_clones("default", "sse4.1")))
#else
#define LUAU_DISPATCH_SSE41
#endif
// Used on functions that have a printf-like interface to validate them statically // Used on functions that have a printf-like interface to validate them statically
#if defined(__GNUC__) #if defined(__GNUC__)
#define LUA_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg))) #define LUA_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg)))

View File

@ -96,6 +96,7 @@ static int luauF_atan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId
} }
LUAU_FASTMATH_BEGIN LUAU_FASTMATH_BEGIN
LUAU_DISPATCH_SSE41
static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{ {
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))
@ -159,6 +160,7 @@ static int luauF_exp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId
} }
LUAU_FASTMATH_BEGIN LUAU_FASTMATH_BEGIN
LUAU_DISPATCH_SSE41
static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{ {
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))
@ -936,6 +938,7 @@ static int luauF_sign(lua_State* L, StkId res, TValue* arg0, int nresults, StkId
} }
LUAU_FASTMATH_BEGIN LUAU_FASTMATH_BEGIN
LUAU_DISPATCH_SSE41
static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{ {
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))
@ -1239,7 +1242,12 @@ static int luauF_setmetatable(lua_State* L, StkId res, TValue* arg0, int nresult
return -1; return -1;
} }
luau_FastFunction luauF_table[256] = { static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
return -1;
}
const luau_FastFunction luauF_table[256] = {
NULL, NULL,
luauF_assert, luauF_assert,
@ -1317,4 +1325,20 @@ luau_FastFunction luauF_table[256] = {
luauF_getmetatable, luauF_getmetatable,
luauF_setmetatable, luauF_setmetatable,
// When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback.
// This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing.
// Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest.
#define MISSING8 luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing, luauF_missing
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
MISSING8,
#undef MISSING8
}; };

View File

@ -6,4 +6,4 @@
typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams); typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams);
extern luau_FastFunction luauF_table[256]; extern const luau_FastFunction luauF_table[256];

View File

@ -34,6 +34,7 @@ inline bool luai_vecisnan(const float* a)
} }
LUAU_FASTMATH_BEGIN LUAU_FASTMATH_BEGIN
// TODO: LUAU_DISPATCH_SSE41 would be nice here, but clang-14 doesn't support it correctly on inline functions...
inline double luai_nummod(double a, double b) inline double luai_nummod(double a, double b)
{ {
return a - floor(a / b) * b; return a - floor(a / b) * b;

View File

@ -16,8 +16,6 @@
#include <string.h> #include <string.h>
LUAU_FASTFLAGVARIABLE(LuauNoTopRestoreInFastCall, false)
// Disable c99-designator to avoid the warning in CGOTO dispatch table // Disable c99-designator to avoid the warning in CGOTO dispatch table
#ifdef __clang__ #ifdef __clang__
#if __has_warning("-Wc99-designator") #if __has_warning("-Wc99-designator")
@ -2094,12 +2092,17 @@ reentry:
Table* h = hvalue(ra); Table* h = hvalue(ra);
// TODO: we really don't need this anymore
if (!ttistable(ra)) if (!ttistable(ra))
return; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode return; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode
int last = index + c - 1; int last = index + c - 1;
if (last > h->sizearray) if (last > h->sizearray)
{
VM_PROTECT_PC(); // luaH_resizearray may fail due to OOM
luaH_resizearray(L, h, last); luaH_resizearray(L, h, last);
}
TValue* array = h->array; TValue* array = h->array;
@ -2542,8 +2545,9 @@ reentry:
nparams = (nparams == LUA_MULTRET) ? int(L->top - ra - 1) : nparams; nparams = (nparams == LUA_MULTRET) ? int(L->top - ra - 1) : nparams;
luau_FastFunction f = luauF_table[bfid]; luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f) if (cl->env->safeenv)
{ {
VM_PROTECT_PC(); // f may fail due to OOM VM_PROTECT_PC(); // f may fail due to OOM
@ -2620,8 +2624,9 @@ reentry:
int nresults = LUAU_INSN_C(call) - 1; int nresults = LUAU_INSN_C(call) - 1;
luau_FastFunction f = luauF_table[bfid]; luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f) if (cl->env->safeenv)
{ {
VM_PROTECT_PC(); // f may fail due to OOM VM_PROTECT_PC(); // f may fail due to OOM
@ -2629,15 +2634,8 @@ reentry:
if (n >= 0) if (n >= 0)
{ {
if (FFlag::LuauNoTopRestoreInFastCall) if (nresults == LUA_MULTRET)
{ L->top = ra + n;
if (nresults == LUA_MULTRET)
L->top = ra + n;
}
else
{
L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top;
}
pc += skip + 1; // skip instructions that compute function as well as CALL pc += skip + 1; // skip instructions that compute function as well as CALL
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
@ -2676,8 +2674,9 @@ reentry:
int nresults = LUAU_INSN_C(call) - 1; int nresults = LUAU_INSN_C(call) - 1;
luau_FastFunction f = luauF_table[bfid]; luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f) if (cl->env->safeenv)
{ {
VM_PROTECT_PC(); // f may fail due to OOM VM_PROTECT_PC(); // f may fail due to OOM
@ -2685,15 +2684,8 @@ reentry:
if (n >= 0) if (n >= 0)
{ {
if (FFlag::LuauNoTopRestoreInFastCall) if (nresults == LUA_MULTRET)
{ L->top = ra + n;
if (nresults == LUA_MULTRET)
L->top = ra + n;
}
else
{
L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top;
}
pc += skip + 1; // skip instructions that compute function as well as CALL pc += skip + 1; // skip instructions that compute function as well as CALL
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
@ -2732,8 +2724,9 @@ reentry:
int nresults = LUAU_INSN_C(call) - 1; int nresults = LUAU_INSN_C(call) - 1;
luau_FastFunction f = luauF_table[bfid]; luau_FastFunction f = luauF_table[bfid];
LUAU_ASSERT(f);
if (cl->env->safeenv && f) if (cl->env->safeenv)
{ {
VM_PROTECT_PC(); // f may fail due to OOM VM_PROTECT_PC(); // f may fail due to OOM
@ -2741,15 +2734,8 @@ reentry:
if (n >= 0) if (n >= 0)
{ {
if (FFlag::LuauNoTopRestoreInFastCall) if (nresults == LUA_MULTRET)
{ L->top = ra + n;
if (nresults == LUA_MULTRET)
L->top = ra + n;
}
else
{
L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top;
}
pc += skip + 1; // skip instructions that compute function as well as CALL pc += skip + 1; // skip instructions that compute function as well as CALL
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));

456
bench/tests/voxelgen.lua Normal file
View File

@ -0,0 +1,456 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
-- Based on voxel terrain generator by Stickmasterluke
local kSelectedBiomes = {
['Mountains'] = true,
['Canyons'] = true,
['Dunes'] = true,
['Arctic'] = true,
['Lavaflow'] = true,
['Hills'] = true,
['Plains'] = true,
['Marsh'] = true,
['Water'] = true,
}
---------Directly used in Generation---------
local masterSeed = 618033988
local mapWidth = 32
local mapHeight = 32
local biomeSize = 16
local generateCaves = true
local waterLevel = .48
local surfaceThickness = .018
local biomes = {}
---------------------------------------------
local rock = "Rock"
local snow = "Snow"
local ice = "Glacier"
local grass = "Grass"
local ground = "Ground"
local mud = "Mud"
local slate = "Slate"
local concrete = "Concrete"
local lava = "CrackedLava"
local basalt = "Basalt"
local air = "Air"
local sand = "Sand"
local sandstone = "Sandstone"
local water = "Water"
math.randomseed(6180339)
local theseed={}
for i=1,999 do
table.insert(theseed,math.random())
end
local function getPerlin(x,y,z,seed,scale,raw)
local seed = seed or 0
local scale = scale or 1
if not raw then
return math.noise(x/scale+(seed*17)+masterSeed,y/scale-masterSeed,z/scale-seed*seed)*.5 + .5 -- accounts for bleeding from interpolated line
else
return math.noise(x/scale+(seed*17)+masterSeed,y/scale-masterSeed,z/scale-seed*seed)
end
end
local function getNoise(x,y,z,seed1)
local x = x or 0
local y = y or 0
local z = z or 0
local seed1 = seed1 or 7
local wtf=x+y+z+seed1+masterSeed + (masterSeed-x)*(seed1+z) + (seed1-y)*(masterSeed+z) -- + x*(y+z) + z*(masterSeed+seed1) + seed1*(x+y) --x+y+z+seed1+masterSeed + x*y*masterSeed-y*z+(z+masterSeed)*x --((x+y)*(y-seed1)*seed1)-(x+z)*seed2+x*11+z*23-y*17
return theseed[(math.floor(wtf%(#theseed)))+1]
end
local function thresholdFilter(value, bottom, size)
if value <= bottom then
return 0
elseif value >= bottom+size then
return 1
else
return (value-bottom)/size
end
end
local function ridgedFilter(value) --absolute and flip for ridges. and normalize
return value<.5 and value*2 or 2-value*2
end
local function ridgedFlippedFilter(value) --unflipped
return value < .5 and 1-value*2 or value*2-1
end
local function advancedRidgedFilter(value, cutoff)
local cutoff = cutoff or .5
value = value - cutoff
return 1 - (value < 0 and -value or value) * 1/(1-cutoff)
end
local function fractalize(operation,x,y,z, operationCount, scale, offset, gain)
local operationCount = operationCount or 3
local scale = scale or .5
local offset = 0
local gain = gain or 1
local totalValue = 0
local totalScale = 0
for i=1, operationCount do
local thisScale = scale^(i-1)
totalScale = totalScale + thisScale
totalValue = totalValue + (offset + gain * operation(x,y,z,i))*thisScale
end
return totalValue/totalScale
end
local function mountainsOperation(x,y,z,i)
return ridgedFilter(getPerlin(x,y,z,100+i,(1/i)*160))
end
local canyonBandingMaterial = {rock,mud,sand,sand,sandstone,sandstone,sandstone,sandstone,sandstone,sandstone,}
local function findBiomeInfo(choiceBiome,x,y,z,verticalGradientTurbulence)
local choiceBiomeValue = .5
local choiceBiomeSurface = grass
local choiceBiomeFill = rock
if choiceBiome == 'City' then
choiceBiomeValue = .55
choiceBiomeSurface = concrete
choiceBiomeFill = slate
elseif choiceBiome == 'Water' then
choiceBiomeValue = .36+getPerlin(x,y,z,2,50)*.08
choiceBiomeSurface =
(1-verticalGradientTurbulence < .44 and slate)
or sand
elseif choiceBiome == 'Marsh' then
local preLedge = getPerlin(x+getPerlin(x,0,z,5,7,true)*10+getPerlin(x,0,z,6,30,true)*50,0,z+getPerlin(x,0,z,9,7,true)*10+getPerlin(x,0,z,10,30,true)*50,2,70) --could use some turbulence
local grassyLedge = thresholdFilter(preLedge,.65,0)
local largeGradient = getPerlin(x,y,z,4,100)
local smallGradient = getPerlin(x,y,z,3,20)
local smallGradientThreshold = thresholdFilter(smallGradient,.5,0)
choiceBiomeValue = waterLevel-.04
+preLedge*grassyLedge*.025
+largeGradient*.035
+smallGradient*.025
choiceBiomeSurface =
(grassyLedge >= 1 and grass)
or (1-verticalGradientTurbulence < waterLevel-.01 and mud)
or (1-verticalGradientTurbulence < waterLevel+.01 and ground)
or grass
choiceBiomeFill = slate
elseif choiceBiome == 'Plains' then
local rivulet = ridgedFlippedFilter(getPerlin(x+getPerlin(x,y,z,17,40)*25,0,z+getPerlin(x,y,z,19,40)*25,2,200))
local rivuletThreshold = thresholdFilter(rivulet,.01,0)
local rockMap = thresholdFilter(ridgedFlippedFilter(getPerlin(x,0,z,101,7)),.3,.7) --rocks
* thresholdFilter(getPerlin(x,0,z,102,50),.6,.05) --zoning
choiceBiomeValue = .5 --.51
+getPerlin(x,y,z,2,100)*.02 --.05
+rivulet*.05 --.02
+rockMap*.05 --.03
+rivuletThreshold*.005
local verticalGradient = 1-((y-1)/(mapHeight-1))
local surfaceGradient = verticalGradient*.5 + choiceBiomeValue*.5
local thinSurface = surfaceGradient > .5-surfaceThickness*.4 and surfaceGradient < .5+surfaceThickness*.4
choiceBiomeSurface =
(rockMap>0 and rock)
or (not thinSurface and mud)
or (thinSurface and rivuletThreshold <=0 and water)
or (1-verticalGradientTurbulence < waterLevel-.01 and sand)
or grass
choiceBiomeFill =
(rockMap>0 and rock)
or sandstone
elseif choiceBiome == 'Canyons' then
local canyonNoise = ridgedFlippedFilter(getPerlin(x,0,z,2,200))
local canyonNoiseTurbed = ridgedFlippedFilter(getPerlin(x+getPerlin(x,0,z,5,20,true)*20,0,z+getPerlin(x,0,z,9,20,true)*20,2,200))
local sandbank = thresholdFilter(canyonNoiseTurbed,0,.05)
local canyonTop = thresholdFilter(canyonNoiseTurbed,.125,0)
local mesaSlope = thresholdFilter(canyonNoise,.33,.12)
local mesaTop = thresholdFilter(canyonNoiseTurbed,.49,0)
choiceBiomeValue = .42
+getPerlin(x,y,z,2,70)*.05
+canyonNoise*.05
+sandbank*.04 --canyon bottom slope
+thresholdFilter(canyonNoiseTurbed,.05,0)*.08 --canyon cliff
+thresholdFilter(canyonNoiseTurbed,.05,.075)*.04 --canyon cliff top slope
+canyonTop*.01 --canyon cliff top ledge
+thresholdFilter(canyonNoiseTurbed,.0575,.2725)*.01 --plane slope
+mesaSlope*.06 --mesa slope
+thresholdFilter(canyonNoiseTurbed,.45,0)*.14 --mesa cliff
+thresholdFilter(canyonNoiseTurbed,.45,.04)*.025 --mesa cap
+mesaTop*.02 --mesa top ledge
choiceBiomeSurface =
(1-verticalGradientTurbulence < waterLevel+.015 and sand) --this for biome blending in to lakes
or (sandbank>0 and sandbank<1 and sand) --this for canyonbase sandbanks
--or (canyonTop>0 and canyonTop<=1 and mesaSlope<=0 and grass) --this for grassy canyon tops
--or (mesaTop>0 and mesaTop<=1 and grass) --this for grassy mesa tops
or sandstone
choiceBiomeFill = canyonBandingMaterial[math.ceil((1-getNoise(1,y,2))*10)]
elseif choiceBiome == 'Hills' then
local rivulet = ridgedFlippedFilter(getPerlin(x+getPerlin(x,y,z,17,20)*20,0,z+getPerlin(x,y,z,19,20)*20,2,200))^(1/2)
local largeHills = getPerlin(x,y,z,3,60)
choiceBiomeValue = .48
+largeHills*.05
+(.05
+largeHills*.1
+getPerlin(x,y,z,4,25)*.125)
*rivulet
local surfaceMaterialGradient = (1-verticalGradientTurbulence)*.9 + rivulet*.1
choiceBiomeSurface =
(surfaceMaterialGradient < waterLevel-.015 and mud)
or (surfaceMaterialGradient < waterLevel and ground)
or grass
choiceBiomeFill = slate
elseif choiceBiome == 'Dunes' then
local duneTurbulence = getPerlin(x,0,z,227,20)*24
local layer1 = ridgedFilter(getPerlin(x,0,z,201,40))
local layer2 = ridgedFilter(getPerlin(x/10+duneTurbulence,0,z+duneTurbulence,200,48))
choiceBiomeValue = .4+.1*(layer1 + layer2)
choiceBiomeSurface = sand
choiceBiomeFill = sandstone
elseif choiceBiome == 'Mountains' then
local rivulet = ridgedFlippedFilter(getPerlin(x+getPerlin(x,y,z,17,20)*20,0,z+getPerlin(x,y,z,19,20)*20,2,200))
choiceBiomeValue = -.4 --.3
+fractalize(mountainsOperation,x,y/20,z, 8, .65)*1.2
+rivulet*.2
choiceBiomeSurface =
(verticalGradientTurbulence < .275 and snow)
or (verticalGradientTurbulence < .35 and rock)
or (verticalGradientTurbulence < .4 and ground)
or (1-verticalGradientTurbulence < waterLevel and rock)
or (1-verticalGradientTurbulence < waterLevel+.01 and mud)
or (1-verticalGradientTurbulence < waterLevel+.015 and ground)
or grass
elseif choiceBiome == 'Lavaflow' then
local crackX = x+getPerlin(x,y*.25,z,21,8,true)*5
local crackY = y+getPerlin(x,y*.25,z,22,8,true)*5
local crackZ = z+getPerlin(x,y*.25,z,23,8,true)*5
local crack1 = ridgedFilter(getPerlin(crackX+getPerlin(x,y,z,22,30,true)*30,crackY,crackZ+getPerlin(x,y,z,24,30,true)*30,2,120))
local crack2 = ridgedFilter(getPerlin(crackX,crackY,crackZ,3,40))*(crack1*.25+.75)
local crack3 = ridgedFilter(getPerlin(crackX,crackY,crackZ,4,20))*(crack2*.25+.75)
local generalHills = thresholdFilter(getPerlin(x,y,z,9,40),.25,.5)*getPerlin(x,y,z,10,60)
local cracks = math.max(0,1-thresholdFilter(crack1,.975,0)-thresholdFilter(crack2,.925,0)-thresholdFilter(crack3,.9,0))
local spires = thresholdFilter(getPerlin(crackX/40,crackY/300,crackZ/30,123,1),.6,.4)
choiceBiomeValue = waterLevel+.02
+cracks*(.5+generalHills*.5)*.02
+generalHills*.05
+spires*.3
+((1-verticalGradientTurbulence > waterLevel+.01 or spires>0) and .04 or 0) --This lets it lip over water
choiceBiomeFill = (spires>0 and rock) or (cracks<1 and lava) or basalt
choiceBiomeSurface = (choiceBiomeFill == lava and 1-verticalGradientTurbulence < waterLevel and basalt) or choiceBiomeFill
elseif choiceBiome == 'Arctic' then
local preBoundary = getPerlin(x+getPerlin(x,0,z,5,8,true)*5,y/8,z+getPerlin(x,0,z,9,8,true)*5,2,20)
--local cliffs = thresholdFilter(preBoundary,.5,0)
local boundary = ridgedFilter(preBoundary)
local roughChunks = getPerlin(x,y/4,z,436,2)
local boundaryMask = thresholdFilter(boundary,.8,.1) --,.7,.25)
local boundaryTypeMask = getPerlin(x,0,z,6,74)-.5
local boundaryComp = 0
if boundaryTypeMask < 0 then --divergent
boundaryComp = (boundary > (1+boundaryTypeMask*.5) and -.17 or 0)
--* boundaryTypeMask*-2
else --convergent
boundaryComp = boundaryMask*.1*roughChunks
* boundaryTypeMask
end
choiceBiomeValue = .55
+boundary*.05*boundaryTypeMask --.1 --soft slope up or down to boundary
+boundaryComp --convergent/divergent effects
+getPerlin(x,0,z,123,25)*.025 --*cliffs --gentle rolling slopes
choiceBiomeSurface = (1-verticalGradientTurbulence < waterLevel-.1 and ice) or (boundaryMask>.6 and boundaryTypeMask>.1 and roughChunks>.5 and ice) or snow
choiceBiomeFill = ice
end
return choiceBiomeValue, choiceBiomeSurface, choiceBiomeFill
end
function findBiomeTransitionValue(biome,weight,value,averageValue)
if biome == 'Arctic' then
return (weight>.2 and 1 or 0)*value
elseif biome == 'Canyons' then
return (weight>.7 and 1 or 0)*value
elseif biome == 'Mountains' then
local weight = weight^3 --This improves the ease of mountains transitioning to other biomes
return averageValue*(1-weight)+value*weight
else
return averageValue*(1-weight)+value*weight
end
end
function generate()
local mapWidth = mapWidth
local biomeSize = biomeSize
local biomeBlendPercent = .25 --(biomeSize==50 or biomeSize == 100) and .5 or .25
local biomeBlendPercentInverse = 1-biomeBlendPercent
local biomeBlendDistortion = biomeBlendPercent
local smoothScale = .5/mapHeight
biomes = {}
for i,v in pairs(kSelectedBiomes) do
if v then
table.insert(biomes,i)
end
end
if #biomes<=0 then
table.insert(biomes,'Hills')
end
table.sort(biomes)
--local oMap = {}
--local mMap = {}
for x = 1, mapWidth do
local oMapX = {}
--oMap[x] = oMapX
local mMapX = {}
--mMap[x] = mMapX
for z = 1, mapWidth do
local biomeNoCave = false
local cellToBiomeX = x/biomeSize + getPerlin(x,0,z,233,biomeSize*.3)*.25 + getPerlin(x,0,z,235,biomeSize*.05)*.075
local cellToBiomeZ = z/biomeSize + getPerlin(x,0,z,234,biomeSize*.3)*.25 + getPerlin(x,0,z,236,biomeSize*.05)*.075
local closestDistance = 1000000
local biomePoints = {}
for vx=-1,1 do
for vz=-1,1 do
local gridPointX = math.floor(cellToBiomeX+vx+.5)
local gridPointZ = math.floor(cellToBiomeZ+vz+.5)
--local pointX, pointZ = getBiomePoint(gridPointX,gridPointZ)
local pointX = gridPointX+(getNoise(gridPointX,gridPointZ,53)-.5)*.75 --de-uniforming grid for vornonoi
local pointZ = gridPointZ+(getNoise(gridPointX,gridPointZ,73)-.5)*.75
local dist = math.sqrt((pointX-cellToBiomeX)^2 + (pointZ-cellToBiomeZ)^2)
if dist < closestDistance then
closestDistance = dist
end
table.insert(biomePoints,{
x = pointX,
z = pointZ,
dist = dist,
biomeNoise = getNoise(gridPointX,gridPointZ),
weight = 0
})
end
end
local weightTotal = 0
local weightPoints = {}
for _,point in pairs(biomePoints) do
local weight = point.dist == closestDistance and 1 or ((closestDistance / point.dist)-biomeBlendPercentInverse)/biomeBlendPercent
if weight > 0 then
local weight = weight^2.1 --this smooths the biome transition from linear to cubic InOut
weightTotal = weightTotal + weight
local biome = biomes[math.ceil(#biomes*(1-point.biomeNoise))] --inverting the noise so that it is limited as (0,1]. One less addition operation when finding a random list index
weightPoints[biome] = {
weight = weightPoints[biome] and weightPoints[biome].weight + weight or weight
}
end
end
for biome,info in pairs(weightPoints) do
info.weight = info.weight / weightTotal
if biome == 'Arctic' then --biomes that don't have caves that breach the surface
biomeNoCave = true
end
end
for y = 1, mapHeight do
local oMapY = oMapX[y] or {}
oMapX[y] = oMapY
local mMapY = mMapX[y] or {}
mMapX[y] = mMapY
--[[local oMapY = {}
oMapX[y] = oMapY
local mMapY = {}
mMapX[z] = mMapY]]
local verticalGradient = 1-((y-1)/(mapHeight-1))
local caves = 0
local verticalGradientTurbulence = verticalGradient*.9 + .1*getPerlin(x,y,z,107,15)
local choiceValue = 0
local choiceSurface = lava
local choiceFill = rock
if verticalGradient > .65 or verticalGradient < .1 then
--under surface of every biome; don't get biome data; waste of time.
choiceValue = .5
elseif #biomes == 1 then
choiceValue, choiceSurface, choiceFill = findBiomeInfo(biomes[1],x,y,z,verticalGradientTurbulence)
else
local averageValue = 0
--local findChoiceMaterial = -getNoise(x,y,z,19)
for biome,info in pairs(weightPoints) do
local biomeValue, biomeSurface, biomeFill = findBiomeInfo(biome,x,y,z,verticalGradientTurbulence)
info.biomeValue = biomeValue
info.biomeSurface = biomeSurface
info.biomeFill = biomeFill
local value = biomeValue * info.weight
averageValue = averageValue + value
--[[if findChoiceMaterial < 0 and findChoiceMaterial + weight >= 0 then
choiceMaterial = biomeMaterial
end
findChoiceMaterial = findChoiceMaterial + weight]]
end
for biome,info in pairs(weightPoints) do
local value = findBiomeTransitionValue(biome,info.weight,info.biomeValue,averageValue)
if value > choiceValue then
choiceValue = value
choiceSurface = info.biomeSurface
choiceFill = info.biomeFill
end
end
end
local preCaveComp = verticalGradient*.5 + choiceValue*.5
local surface = preCaveComp > .5-surfaceThickness and preCaveComp < .5+surfaceThickness
if generateCaves --user wants caves
and (not biomeNoCave or verticalGradient > .65) --biome allows caves or deep enough
and not (surface and (1-verticalGradient) < waterLevel+.005) --caves only breach surface above waterlevel
and not (surface and (1-verticalGradient) > waterLevel+.58) then --caves don't go too high so that they don't cut up mountain tops
local ridged2 = ridgedFilter(getPerlin(x,y,z,4,30))
local caves2 = thresholdFilter(ridged2,.84,.01)
local ridged3 = ridgedFilter(getPerlin(x,y,z,5,30))
local caves3 = thresholdFilter(ridged3,.84,.01)
local ridged4 = ridgedFilter(getPerlin(x,y,z,6,30))
local caves4 = thresholdFilter(ridged4,.84,.01)
local caveOpenings = (surface and 1 or 0) * thresholdFilter(getPerlin(x,0,z,143,62),.35,0) --.45
caves = caves2 * caves3 * caves4 - caveOpenings
caves = caves < 0 and 0 or caves > 1 and 1 or caves
end
local comp = preCaveComp - caves
local smoothedResult = thresholdFilter(comp,.5,smoothScale)
---below water level -above surface -no terrain
if 1-verticalGradient < waterLevel and preCaveComp <= .5 and smoothedResult <= 0 then
smoothedResult = 1
choiceSurface = water
choiceFill = water
surface = true
end
oMapY[z] = (y == 1 and 1) or smoothedResult
mMapY[z] = (y == 1 and lava) or (smoothedResult <= 0 and air) or (surface and choiceSurface) or choiceFill
end
end
-- local regionStart = Vector3.new(mapWidth*-2+(x-1)*4,mapHeight*-2,mapWidth*-2)
-- local regionEnd = Vector3.new(mapWidth*-2+x*4,mapHeight*2,mapWidth*2)
-- local mapRegion = Region3.new(regionStart, regionEnd)
-- terrain:WriteVoxels(mapRegion, 4, {mMapX}, {oMapX})
end
end
bench.runCode(generate, "voxelgen")

View File

@ -91,6 +91,8 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "class_method")
TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_class_method") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_class_method")
{ {
ScopedFastFlag luauCheckOverloadedDocSymbol{"LuauCheckOverloadedDocSymbol", true};
loadDefinition(R"( loadDefinition(R"(
declare class Foo declare class Foo
function bar(self, x: string): number function bar(self, x: string): number
@ -125,6 +127,8 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_function_prop")
TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_overloaded_function_prop") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_overloaded_function_prop")
{ {
ScopedFastFlag luauCheckOverloadedDocSymbol{"LuauCheckOverloadedDocSymbol", true};
loadDefinition(R"( loadDefinition(R"(
declare Foo: { declare Foo: {
new: ((number) -> string) & ((string) -> number) new: ((number) -> string) & ((string) -> number)

View File

@ -506,6 +506,15 @@ std::optional<TypeId> linearSearchForBinding(Scope* scope, const char* name)
return std::nullopt; return std::nullopt;
} }
void registerNotType(Fixture& fixture, TypeArena& arena)
{
TypeId t = arena.addType(GenericTypeVar{"T"});
GenericTypeDefinition genericT{t};
ScopePtr moduleScope = fixture.frontend.getGlobalScope();
moduleScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, arena.addType(NegationTypeVar{t})};
}
void dump(const std::vector<Constraint>& constraints) void dump(const std::vector<Constraint>& constraints)
{ {
ToStringOptions opts; ToStringOptions opts;

View File

@ -186,6 +186,8 @@ std::optional<TypeId> lookupName(ScopePtr scope, const std::string& name); // Wa
std::optional<TypeId> linearSearchForBinding(Scope* scope, const char* name); std::optional<TypeId> linearSearchForBinding(Scope* scope, const char* name);
void registerNotType(Fixture& fixture, TypeArena& arena);
} // namespace Luau } // namespace Luau
#define LUAU_REQUIRE_ERRORS(result) \ #define LUAU_REQUIRE_ERRORS(result) \

View File

@ -11,6 +11,7 @@
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauIceExceptionInheritanceChange);
TEST_SUITE_BEGIN("ModuleTests"); TEST_SUITE_BEGIN("ModuleTests");
@ -278,7 +279,14 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit")
TypeArena dest; TypeArena dest;
CloneState cloneState; CloneState cloneState;
CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); if (FFlag::LuauIceExceptionInheritanceChange)
{
CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException);
}
else
{
CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException_DEPRECATED);
}
} }
TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak")

View File

@ -10,13 +10,16 @@
using namespace Luau; using namespace Luau;
struct NormalizeFixture : Fixture namespace
{
struct IsSubtypeFixture : Fixture
{ {
bool isSubtype(TypeId a, TypeId b) bool isSubtype(TypeId a, TypeId b)
{ {
return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice);
} }
}; };
}
void createSomeClasses(Frontend& frontend) void createSomeClasses(Frontend& frontend)
{ {
@ -55,7 +58,7 @@ void createSomeClasses(Frontend& frontend)
TEST_SUITE_BEGIN("isSubtype"); TEST_SUITE_BEGIN("isSubtype");
TEST_CASE_FIXTURE(NormalizeFixture, "primitives") TEST_CASE_FIXTURE(IsSubtypeFixture, "primitives")
{ {
check(R"( check(R"(
local a = 41 local a = 41
@ -75,7 +78,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "primitives")
CHECK(!isSubtype(d, a)); CHECK(!isSubtype(d, a));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "functions") TEST_CASE_FIXTURE(IsSubtypeFixture, "functions")
{ {
check(R"( check(R"(
function a(x: number): number return x end function a(x: number): number return x end
@ -96,7 +99,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions")
CHECK(isSubtype(a, d)); CHECK(isSubtype(a, d));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "functions_and_any") TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_and_any")
{ {
check(R"( check(R"(
function a(n: number) return "string" end function a(n: number) return "string" end
@ -114,7 +117,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions_and_any")
CHECK(!isSubtype(a, b)); CHECK(!isSubtype(a, b));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "variadic_functions_with_no_head") TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_functions_with_no_head")
{ {
check(R"( check(R"(
local a: (...number) -> () local a: (...number) -> ()
@ -129,7 +132,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "variadic_functions_with_no_head")
} }
#if 0 #if 0
TEST_CASE_FIXTURE(NormalizeFixture, "variadic_function_with_head") TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_function_with_head")
{ {
check(R"( check(R"(
local a: (...number) -> () local a: (...number) -> ()
@ -144,7 +147,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "variadic_function_with_head")
} }
#endif #endif
TEST_CASE_FIXTURE(NormalizeFixture, "union") TEST_CASE_FIXTURE(IsSubtypeFixture, "union")
{ {
check(R"( check(R"(
local a: number | string local a: number | string
@ -171,7 +174,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union")
CHECK(!isSubtype(d, b)); CHECK(!isSubtype(d, b));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "table_with_union_prop") TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_union_prop")
{ {
check(R"( check(R"(
local a: {x: number} local a: {x: number}
@ -185,7 +188,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "table_with_union_prop")
CHECK(!isSubtype(b, a)); CHECK(!isSubtype(b, a));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop") TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop")
{ {
check(R"( check(R"(
local a: {x: number} local a: {x: number}
@ -199,7 +202,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop")
CHECK(!isSubtype(b, a)); CHECK(!isSubtype(b, a));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "intersection") TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection")
{ {
ScopedFastFlag sffs[]{ ScopedFastFlag sffs[]{
{"LuauSubtypeNormalizer", true}, {"LuauSubtypeNormalizer", true},
@ -229,7 +232,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersection")
CHECK(isSubtype(a, d)); CHECK(isSubtype(a, d));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection") TEST_CASE_FIXTURE(IsSubtypeFixture, "union_and_intersection")
{ {
check(R"( check(R"(
local a: number & string local a: number & string
@ -243,7 +246,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection")
CHECK(isSubtype(a, b)); CHECK(isSubtype(a, b));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "tables") TEST_CASE_FIXTURE(IsSubtypeFixture, "tables")
{ {
check(R"( check(R"(
local a: {x: number} local a: {x: number}
@ -271,7 +274,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "tables")
} }
#if 0 #if 0
TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant") TEST_CASE_FIXTURE(IsSubtypeFixture, "table_indexers_are_invariant")
{ {
check(R"( check(R"(
local a: {[string]: number} local a: {[string]: number}
@ -290,7 +293,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant")
CHECK(isSubtype(a, c)); CHECK(isSubtype(a, c));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "mismatched_indexers") TEST_CASE_FIXTURE(IsSubtypeFixture, "mismatched_indexers")
{ {
check(R"( check(R"(
local a: {x: number} local a: {x: number}
@ -309,7 +312,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "mismatched_indexers")
CHECK(isSubtype(b, c)); CHECK(isSubtype(b, c));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_table") TEST_CASE_FIXTURE(IsSubtypeFixture, "cyclic_table")
{ {
check(R"( check(R"(
type A = {method: (A) -> ()} type A = {method: (A) -> ()}
@ -348,7 +351,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_table")
} }
#endif #endif
TEST_CASE_FIXTURE(NormalizeFixture, "classes") TEST_CASE_FIXTURE(IsSubtypeFixture, "classes")
{ {
createSomeClasses(frontend); createSomeClasses(frontend);
@ -365,7 +368,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "classes")
} }
#if 0 #if 0
TEST_CASE_FIXTURE(NormalizeFixture, "metatable" * doctest::expected_failures{1}) TEST_CASE_FIXTURE(IsSubtypeFixture, "metatable" * doctest::expected_failures{1})
{ {
check(R"( check(R"(
local T = {} local T = {}
@ -389,8 +392,112 @@ TEST_CASE_FIXTURE(NormalizeFixture, "metatable" * doctest::expected_failures{1})
TEST_SUITE_END(); TEST_SUITE_END();
struct NormalizeFixture : Fixture
{
ScopedFastFlag sff{"LuauNegatedStringSingletons", true};
TypeArena arena;
InternalErrorReporter iceHandler;
UnifierSharedState unifierState{&iceHandler};
Normalizer normalizer{&arena, singletonTypes, NotNull{&unifierState}};
NormalizeFixture()
{
registerNotType(*this, arena);
}
TypeId normal(const std::string& annotation)
{
CheckResult result = check("type _Res = " + annotation);
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = lookupType("_Res");
REQUIRE(ty);
const NormalizedType* norm = normalizer.normalize(*ty);
REQUIRE(norm);
return normalizer.typeFromNormal(*norm);
}
};
TEST_SUITE_BEGIN("Normalize"); TEST_SUITE_BEGIN("Normalize");
TEST_CASE_FIXTURE(NormalizeFixture, "negate_string")
{
CHECK("number" == toString(normal(R"(
(number | string) & Not<string>
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "negate_string_from_cofinite_string_intersection")
{
CHECK("number" == toString(normal(R"(
(number | (string & Not<"hello"> & Not<"world">)) & Not<string>
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "no_op_negation_is_dropped")
{
CHECK("number" == toString(normal(R"(
number & Not<string>
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "union_of_negation")
{
CHECK("string" == toString(normal(R"(
(string & Not<"hello">) | "hello"
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "intersect_truthy")
{
CHECK("number | string | true" == toString(normal(R"(
(string | number | boolean | nil) & Not<false | nil>
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "intersect_truthy_expressed_as_intersection")
{
CHECK("number | string | true" == toString(normal(R"(
(string | number | boolean | nil) & Not<false> & Not<nil>
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "union_of_union")
{
CHECK(R"("alpha" | "beta" | "gamma")" == toString(normal(R"(
("alpha" | "beta") | "gamma"
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "union_of_negations")
{
CHECK(R"(string & ~"world")" == toString(normal(R"(
(string & Not<"hello"> & Not<"world">) | (string & Not<"goodbye"> & Not<"world">)
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean")
{
CHECK("true" == toString(normal(R"(
boolean & Not<false>
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean_2")
{
CHECK("never" == toString(normal(R"(
true & Not<true>
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "bare_negation")
{
// TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function
CHECK("(number | string | thread)?" == toString(normal(R"(
Not<boolean>
)")));
}
TEST_CASE_FIXTURE(Fixture, "higher_order_function") TEST_CASE_FIXTURE(Fixture, "higher_order_function")
{ {
check(R"( check(R"(

View File

@ -1736,8 +1736,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_type_annotation")
TEST_CASE_FIXTURE(Fixture, "parse_error_missing_type_annotation") TEST_CASE_FIXTURE(Fixture, "parse_error_missing_type_annotation")
{ {
ScopedFastFlag LuauTypeAnnotationLocationChange{"LuauTypeAnnotationLocationChange", true};
{ {
ParseResult result = tryParse("local x:"); ParseResult result = tryParse("local x:");
CHECK(result.errors.size() == 1); CHECK(result.errors.size() == 1);

View File

@ -512,9 +512,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2")
TableTypeVar* tMeta5 = getMutable<TableTypeVar>(tMeta4->props["__index"].type); TableTypeVar* tMeta5 = getMutable<TableTypeVar>(tMeta4->props["__index"].type);
REQUIRE(tMeta5); REQUIRE(tMeta5);
REQUIRE(tMeta5->props.count("one") > 0);
TableTypeVar* tMeta6 = getMutable<TableTypeVar>(tMeta3->table); TableTypeVar* tMeta6 = getMutable<TableTypeVar>(tMeta3->table);
REQUIRE(tMeta6); REQUIRE(tMeta6);
REQUIRE(tMeta6->props.count("two") > 0);
ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts); ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts);
if (!FFlag::LuauFixNameMaps) if (!FFlag::LuauFixNameMaps)

View File

@ -913,4 +913,14 @@ TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_shadow_user_defined_alias")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(Fixture, "cannot_create_cyclic_type_with_unknown_module")
{
CheckResult result = check(R"(
type AAA = B.AAA
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK(toString(result.errors[0]) == "Unknown type 'B.AAA'");
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -7,6 +7,8 @@
#include "doctest.h" #include "doctest.h"
LUAU_FASTFLAG(LuauIceExceptionInheritanceChange)
using namespace Luau; using namespace Luau;
TEST_SUITE_BEGIN("AnnotationTests"); TEST_SUITE_BEGIN("AnnotationTests");
@ -664,8 +666,8 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag")
AssertionCatcher ac; AssertionCatcher ac;
CHECK_THROWS_AS(check(R"( CHECK_THROWS_AS(check(R"(
local a: _luau_ice = 55 local a: _luau_ice = 55
)"), )"),
InternalCompilerError); InternalCompilerError);
LUAU_ASSERT(1 == AssertionCatcher::tripped); LUAU_ASSERT(1 == AssertionCatcher::tripped);
@ -682,8 +684,8 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag_handler
}; };
CHECK_THROWS_AS(check(R"( CHECK_THROWS_AS(check(R"(
local a: _luau_ice = 55 local a: _luau_ice = 55
)"), )"),
InternalCompilerError); InternalCompilerError);
CHECK_EQ(true, caught); CHECK_EQ(true, caught);

View File

@ -348,8 +348,6 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options")
TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_types_regression_test") TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_types_regression_test")
{ {
ScopedFastFlag LuauUnionOfTypesFollow{"LuauUnionOfTypesFollow", true};
CheckResult result = check(R"( CheckResult result = check(R"(
--!strict --!strict
local stat local stat

View File

@ -309,6 +309,23 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols")
CHECK_EQ(yTtv->props["x"].documentationSymbol, "@test/global/y.x"); CHECK_EQ(yTtv->props["x"].documentationSymbol, "@test/global/y.x");
} }
TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_referenced_types")
{
ScopedFastFlag LuauPersistTypesAfterGeneratingDocSyms("LuauPersistTypesAfterGeneratingDocSyms", true);
loadDefinition(R"(
declare class MyClass
function myMethod(self)
end
declare function myFunc(): MyClass
)");
std::optional<TypeFun> myClassTy = typeChecker.globalScope->lookupType("MyClass");
REQUIRE(bool(myClassTy));
CHECK_EQ(myClassTy->type->documentationSymbol, "@test/globaltype/MyClass");
}
TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_types") TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_types")
{ {
loadDefinition(R"( loadDefinition(R"(

View File

@ -229,6 +229,48 @@ TEST_CASE_FIXTURE(Fixture, "too_many_arguments")
CHECK_EQ(0, acm->actual); CHECK_EQ(0, acm->actual);
} }
TEST_CASE_FIXTURE(Fixture, "too_many_arguments_error_location")
{
ScopedFastFlag sff{"LuauArgMismatchReportFunctionLocation", true};
CheckResult result = check(R"(
--!strict
function myfunction(a: number, b:number) end
myfunction(1)
function getmyfunction()
return myfunction
end
getmyfunction()()
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
{
TypeError err = result.errors[0];
// Ensure the location matches the location of the function identifier
CHECK_EQ(err.location, Location(Position(4, 8), Position(4, 18)));
auto acm = get<CountMismatch>(err);
REQUIRE(acm);
CHECK_EQ(2, acm->expected);
CHECK_EQ(1, acm->actual);
}
{
TypeError err = result.errors[1];
// Ensure the location matches the location of the expression returning the function
CHECK_EQ(err.location, Location(Position(9, 8), Position(9, 23)));
auto acm = get<CountMismatch>(err);
REQUIRE(acm);
CHECK_EQ(2, acm->expected);
CHECK_EQ(0, acm->actual);
}
}
TEST_CASE_FIXTURE(Fixture, "recursive_function") TEST_CASE_FIXTURE(Fixture, "recursive_function")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
@ -1659,17 +1701,20 @@ foo3()
string.find() string.find()
local t = {} local t = {}
function t.foo(x: number, y: string?, ...: any) end function t.foo(x: number, y: string?, ...: any) return 1 end
function t:bar(x: number, y: string?) end function t:bar(x: number, y: string?) end
t.foo() t.foo()
t:bar() t:bar()
local u = { a = t } local u = { a = t, b = function() return t end }
u.a.foo() u.a.foo()
local x = (u.a).foo()
u.b().foo()
)"); )");
LUAU_REQUIRE_ERROR_COUNT(7, result); LUAU_REQUIRE_ERROR_COUNT(9, result);
CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo1' expects 1 argument, but none are specified"); CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo1' expects 1 argument, but none are specified");
CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'foo2' expects 1 to 2 arguments, but none are specified"); CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'foo2' expects 1 to 2 arguments, but none are specified");
CHECK_EQ(toString(result.errors[2]), "Argument count mismatch. Function 'foo3' expects 1 to 3 arguments, but none are specified"); CHECK_EQ(toString(result.errors[2]), "Argument count mismatch. Function 'foo3' expects 1 to 3 arguments, but none are specified");
@ -1677,6 +1722,8 @@ u.a.foo()
CHECK_EQ(toString(result.errors[4]), "Argument count mismatch. Function 't.foo' expects at least 1 argument, but none are specified"); CHECK_EQ(toString(result.errors[4]), "Argument count mismatch. Function 't.foo' expects at least 1 argument, but none are specified");
CHECK_EQ(toString(result.errors[5]), "Argument count mismatch. Function 't.bar' expects 2 to 3 arguments, but only 1 is specified"); CHECK_EQ(toString(result.errors[5]), "Argument count mismatch. Function 't.bar' expects 2 to 3 arguments, but only 1 is specified");
CHECK_EQ(toString(result.errors[6]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); CHECK_EQ(toString(result.errors[6]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified");
CHECK_EQ(toString(result.errors[7]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified");
CHECK_EQ(toString(result.errors[8]), "Argument count mismatch. Function expects at least 1 argument, but none are specified");
} }
// This might be surprising, but since 'any' became optional, unannotated functions in non-strict 'expect' 0 arguments // This might be surprising, but since 'any' became optional, unannotated functions in non-strict 'expect' 0 arguments

View File

@ -251,23 +251,7 @@ end
return m return m
)"); )");
if (FFlag::LuauInstantiateInSubtyping) LUAU_REQUIRE_NO_ERRORS(result);
{
// though this didn't error before the flag, it seems as though it should error since fields of a table are invariant.
// the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be
// unsound.
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(R"(Type 'n' could not be converted into 't1 where t1 = {- Clone: (t1) -> (a...) -}'
caused by:
Property 'Clone' is not compatible. Type '<a>(a) -> ()' could not be converted into 't1 where t1 = ({- Clone: t1 -}) -> (a...)'; different number of generic type parameters)",
toString(result.errors[0]));
}
else
{
LUAU_REQUIRE_NO_ERRORS(result);
}
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global")

View File

@ -0,0 +1,52 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Fixture.h"
#include "doctest.h"
#include "Luau/Common.h"
#include "ScopedFlags.h"
using namespace Luau;
namespace
{
struct NegationFixture : Fixture
{
TypeArena arena;
ScopedFastFlag sff[2] {
{"LuauNegatedStringSingletons", true},
{"LuauSubtypeNormalizer", true},
};
NegationFixture()
{
registerNotType(*this, arena);
}
};
}
TEST_SUITE_BEGIN("Negations");
TEST_CASE_FIXTURE(NegationFixture, "negated_string_is_a_subtype_of_string")
{
CheckResult result = check(R"(
function foo(arg: string) end
local a: string & Not<"Hello">
foo(a)
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(NegationFixture, "string_is_not_a_subtype_of_negated_string")
{
CheckResult result = check(R"(
function foo(arg: string & Not<"hello">) end
local a: string
foo(a)
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
}
TEST_SUITE_END();

View File

@ -434,16 +434,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
--!strict --!strict
local foo = { local foo
value = 10
}
local mt = {} local mt = {}
setmetatable(foo, mt)
mt.__unm = function(val: typeof(foo)): string mt.__unm = function(val: typeof(foo)): string
return val.value .. "test" return tostring(val.value) .. "test"
end end
foo = setmetatable({
value = 10
}, mt)
local a = -foo local a = -foo
local b = 1+-1 local b = 1+-1
@ -459,25 +460,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus")
CHECK_EQ("string", toString(requireType("a"))); CHECK_EQ("string", toString(requireType("a")));
CHECK_EQ("number", toString(requireType("b"))); CHECK_EQ("number", toString(requireType("b")));
GenericError* gen = get<GenericError>(result.errors[0]); if (FFlag::DebugLuauDeferredConstraintResolution)
REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); {
CHECK(toString(result.errors[0]) == "Type '{ value: number }' could not be converted into 'number'");
}
else
{
GenericError* gen = get<GenericError>(result.errors[0]);
REQUIRE(gen);
REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'");
}
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
--!strict --!strict
local foo = {
value = 10
}
local mt = {} local mt = {}
setmetatable(foo, mt)
mt.__unm = function(val: boolean): string mt.__unm = function(val: boolean): string
return "test" return "test"
end end
local foo = setmetatable({
value = 10
}, mt)
local a = -foo local a = -foo
)"); )");
@ -494,16 +502,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error")
{ {
CheckResult result = check(R"( CheckResult result = check(R"(
--!strict --!strict
local foo = {
value = 10
}
local mt = {} local mt = {}
setmetatable(foo, mt)
mt.__len = function(val: any): string mt.__len = function(val): string
return "test" return "test"
end end
local foo = setmetatable({
value = 10,
}, mt)
local a = #foo local a = #foo
)"); )");

View File

@ -624,15 +624,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument")
CHECK_EQ("{string | string}", toString(requireType("t"))); CHECK_EQ("{string | string}", toString(requireType("t")));
} }
struct NormalizeFixture : Fixture namespace
{
struct IsSubtypeFixture : Fixture
{ {
bool isSubtype(TypeId a, TypeId b) bool isSubtype(TypeId a, TypeId b)
{ {
return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice);
} }
}; };
}
TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arities") TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection_of_functions_of_different_arities")
{ {
check(R"( check(R"(
type A = (any) -> () type A = (any) -> ()
@ -653,7 +656,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arit
CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t"))); CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t")));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity")
{ {
check(R"( check(R"(
local a: (number) -> () local a: (number) -> ()
@ -676,7 +679,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity")
CHECK(!isSubtype(b, c)); CHECK(!isSubtype(b, c));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_optional_parameters") TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity_but_optional_parameters")
{ {
/* /*
* (T0..TN) <: (T0..TN, A?) * (T0..TN) <: (T0..TN, A?)
@ -736,7 +739,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_option
// CHECK(!isSubtype(b, c)); // CHECK(!isSubtype(b, c));
} }
TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_any_is_an_optional_param") TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity_but_any_is_an_optional_param")
{ {
check(R"( check(R"(
local a: (number?) -> () local a: (number?) -> ()

View File

@ -1,5 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
#include "Luau/Common.h"
#include "Luau/Frontend.h"
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
@ -14,6 +17,7 @@ using namespace Luau;
LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
TEST_SUITE_BEGIN("TableTests"); TEST_SUITE_BEGIN("TableTests");
@ -1957,7 +1961,11 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table
local c : string = t.m("hi") local c : string = t.m("hi")
)"); )");
LUAU_REQUIRE_ERRORS(result); // TODO: test behavior is wrong with LuauInstantiateInSubtyping until we can re-enable the covariant requirement for instantiation in subtyping
if (FFlag::LuauInstantiateInSubtyping)
LUAU_REQUIRE_NO_ERRORS(result);
else
LUAU_REQUIRE_ERRORS(result);
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_nonstrict")
@ -3262,11 +3270,13 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table
local c : string = t.m("hi") local c : string = t.m("hi")
)"); )");
LUAU_REQUIRE_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' // TODO: test behavior is wrong until we can re-enable the covariant requirement for instantiation in subtyping
caused by: // LUAU_REQUIRE_ERRORS(result);
Property 'm' is not compatible. Type '<a>(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)"); // CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}'
// this error message is not great since the underlying issue is that the context is invariant, // caused by:
// Property 'm' is not compatible. Type '<a>(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)");
// // this error message is not great since the underlying issue is that the context is invariant,
// and `(number) -> number` cannot be a subtype of `<a>(a) -> a`. // and `(number) -> number` cannot be a subtype of `<a>(a) -> a`.
} }
@ -3292,4 +3302,43 @@ local g : ({ p : number, q : string }) -> ({ p : number, r : boolean }) = f
CHECK_EQ("r", error->properties[0]); CHECK_EQ("r", error->properties[0]);
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_has_a_side_effect")
{
if (!FFlag::DebugLuauDeferredConstraintResolution)
return;
CheckResult result = check(R"(
local mt = {
__add = function(x, y)
return 123
end,
}
local foo = {}
setmetatable(foo, mt)
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK(toString(requireType("foo")) == "{ @metatable { __add: (a, b) -> number }, { } }");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated")
{
CheckResult result = check(R"(
local t = {
x = 5 :: NonexistingTypeWhichEndsUpReturningAnErrorType,
y = 5
}
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
ToStringOptions opts;
opts.exhaustive = true;
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts));
else
CHECK_EQ("{ x: <error-type>, y: number }", toString(requireType("t"), opts));
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -22,7 +22,14 @@ TEST_CASE_FIXTURE(Fixture, "throw_when_limit_is_exceeded")
TypeId tType = requireType("t"); TypeId tType = requireType("t");
CHECK_THROWS_AS(toString(tType), RecursionLimitException); if (FFlag::LuauIceExceptionInheritanceChange)
{
CHECK_THROWS_AS(toString(tType), RecursionLimitException);
}
else
{
CHECK_THROWS_AS(toString(tType), RecursionLimitException_DEPRECATED);
}
} }
TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough")

View File

@ -1,14 +1,14 @@
AnnotationTests.builtin_types_are_not_exported
AnnotationTests.corecursive_types_error_on_tight_loop AnnotationTests.corecursive_types_error_on_tight_loop
AnnotationTests.duplicate_type_param_name AnnotationTests.duplicate_type_param_name
AnnotationTests.for_loop_counter_annotation_is_checked AnnotationTests.for_loop_counter_annotation_is_checked
AnnotationTests.generic_aliases_are_cloned_properly AnnotationTests.generic_aliases_are_cloned_properly
AnnotationTests.instantiation_clone_has_to_follow AnnotationTests.instantiation_clone_has_to_follow
AnnotationTests.luau_print_is_not_special_without_the_flag
AnnotationTests.occurs_check_on_cyclic_intersection_typevar AnnotationTests.occurs_check_on_cyclic_intersection_typevar
AnnotationTests.occurs_check_on_cyclic_union_typevar AnnotationTests.occurs_check_on_cyclic_union_typevar
AnnotationTests.too_many_type_params AnnotationTests.too_many_type_params
AnnotationTests.two_type_params AnnotationTests.two_type_params
AnnotationTests.use_type_required_from_another_file AnnotationTests.unknown_type_reference_generates_error
AstQuery.last_argument_function_call_type AstQuery.last_argument_function_call_type
AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AstQuery::getDocumentationSymbolAtPosition.overloaded_fn
AutocompleteTest.autocomplete_first_function_arg_expected_type AutocompleteTest.autocomplete_first_function_arg_expected_type
@ -86,7 +86,6 @@ BuiltinTests.table_pack
BuiltinTests.table_pack_reduce BuiltinTests.table_pack_reduce
BuiltinTests.table_pack_variadic BuiltinTests.table_pack_variadic
BuiltinTests.tonumber_returns_optional_number_type BuiltinTests.tonumber_returns_optional_number_type
BuiltinTests.tonumber_returns_optional_number_type2
DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_overload_metamethods
DefinitionTests.class_definition_string_props DefinitionTests.class_definition_string_props
DefinitionTests.declaring_generic_functions DefinitionTests.declaring_generic_functions
@ -96,7 +95,6 @@ FrontendTest.imported_table_modification_2
FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded
FrontendTest.nocheck_cycle_used_by_checked FrontendTest.nocheck_cycle_used_by_checked
FrontendTest.reexport_cyclic_type FrontendTest.reexport_cyclic_type
FrontendTest.reexport_type_alias
FrontendTest.trace_requires_in_nonstrict_mode FrontendTest.trace_requires_in_nonstrict_mode
GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics1
GenericsTests.apply_type_function_nested_generics2 GenericsTests.apply_type_function_nested_generics2
@ -105,7 +103,6 @@ GenericsTests.calling_self_generic_methods
GenericsTests.check_generic_typepack_function GenericsTests.check_generic_typepack_function
GenericsTests.check_mutual_generic_functions GenericsTests.check_mutual_generic_functions
GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.correctly_instantiate_polymorphic_member_functions
GenericsTests.do_not_always_instantiate_generic_intersection_types
GenericsTests.do_not_infer_generic_functions GenericsTests.do_not_infer_generic_functions
GenericsTests.duplicate_generic_type_packs GenericsTests.duplicate_generic_type_packs
GenericsTests.duplicate_generic_types GenericsTests.duplicate_generic_types
@ -143,7 +140,6 @@ IntersectionTypes.table_write_sealed_indirect
ModuleTests.any_persistance_does_not_leak ModuleTests.any_persistance_does_not_leak
ModuleTests.clone_self_property ModuleTests.clone_self_property
ModuleTests.deepClone_cyclic_table ModuleTests.deepClone_cyclic_table
ModuleTests.do_not_clone_reexports
NonstrictModeTests.for_in_iterator_variables_are_any NonstrictModeTests.for_in_iterator_variables_are_any
NonstrictModeTests.function_parameters_are_any NonstrictModeTests.function_parameters_are_any
NonstrictModeTests.inconsistent_module_return_types_are_ok NonstrictModeTests.inconsistent_module_return_types_are_ok
@ -158,7 +154,6 @@ NonstrictModeTests.parameters_having_type_any_are_optional
NonstrictModeTests.table_dot_insert_and_recursive_calls NonstrictModeTests.table_dot_insert_and_recursive_calls
NonstrictModeTests.table_props_are_any NonstrictModeTests.table_props_are_any
Normalize.cyclic_table_normalizes_sensibly Normalize.cyclic_table_normalizes_sensibly
Normalize.intersection_combine_on_bound_self
ParseErrorRecovery.generic_type_list_recovery ParseErrorRecovery.generic_type_list_recovery
ParseErrorRecovery.recovery_of_parenthesized_expressions ParseErrorRecovery.recovery_of_parenthesized_expressions
ParserTests.parse_nesting_based_end_detection_failsafe_earlier ParserTests.parse_nesting_based_end_detection_failsafe_earlier
@ -249,7 +244,6 @@ TableTests.defining_a_self_method_for_a_builtin_sealed_table_must_fail
TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail
TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar
TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index
TableTests.dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back
TableTests.dont_leak_free_table_props TableTests.dont_leak_free_table_props
TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_quantify_table_that_belongs_to_outer_scope
TableTests.dont_suggest_exact_match_keys TableTests.dont_suggest_exact_match_keys
@ -279,7 +273,6 @@ TableTests.inferring_crazy_table_should_also_be_quick
TableTests.instantiate_table_cloning_3 TableTests.instantiate_table_cloning_3
TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound
TableTests.leaking_bad_metatable_errors TableTests.leaking_bad_metatable_errors
TableTests.length_operator_union_errors
TableTests.less_exponential_blowup_please TableTests.less_exponential_blowup_please
TableTests.meta_add TableTests.meta_add
TableTests.meta_add_both_ways TableTests.meta_add_both_ways
@ -347,9 +340,9 @@ TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType
TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.result_of_failed_typepack_unification_is_constrained
TryUnifyTests.typepack_unification_should_trim_free_tails TryUnifyTests.typepack_unification_should_trim_free_tails
TryUnifyTests.variadics_should_use_reversed_properly TryUnifyTests.variadics_should_use_reversed_properly
TypeAliases.cannot_create_cyclic_type_with_unknown_module
TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any
TypeAliases.generic_param_remap TypeAliases.generic_param_remap
TypeAliases.mismatched_generic_pack_type_param
TypeAliases.mismatched_generic_type_param TypeAliases.mismatched_generic_type_param
TypeAliases.mutually_recursive_types_restriction_not_ok_1 TypeAliases.mutually_recursive_types_restriction_not_ok_1
TypeAliases.mutually_recursive_types_restriction_not_ok_2 TypeAliases.mutually_recursive_types_restriction_not_ok_2
@ -363,7 +356,7 @@ TypeAliases.type_alias_fwd_declaration_is_precise
TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_mutation
TypeAliases.type_alias_local_rename TypeAliases.type_alias_local_rename
TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeAliases.type_alias_of_an_imported_recursive_generic_type
TypeAliases.type_alias_of_an_imported_recursive_type TypeInfer.check_type_infer_recursion_count
TypeInfer.checking_should_not_ice TypeInfer.checking_should_not_ice
TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error
TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstExprError
@ -394,6 +387,7 @@ TypeInferClasses.warn_when_prop_almost_matches
TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class
TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types
TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument
TypeInferFunctions.cannot_hoist_interior_defns_into_signature
TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists
TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site
TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict
@ -439,12 +433,9 @@ TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free
TypeInferModules.bound_free_table_export_is_ok TypeInferModules.bound_free_table_export_is_ok
TypeInferModules.custom_require_global TypeInferModules.custom_require_global
TypeInferModules.do_not_modify_imported_types TypeInferModules.do_not_modify_imported_types
TypeInferModules.do_not_modify_imported_types_2
TypeInferModules.do_not_modify_imported_types_3
TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict
TypeInferModules.module_type_conflict_instantiated TypeInferModules.module_type_conflict_instantiated
TypeInferModules.require_a_variadic_function TypeInferModules.require_a_variadic_function
TypeInferModules.require_types
TypeInferModules.type_error_of_unknown_qualified_type TypeInferModules.type_error_of_unknown_qualified_type
TypeInferOOP.CheckMethodsOfSealed TypeInferOOP.CheckMethodsOfSealed
TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works
@ -468,9 +459,6 @@ TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with
TypeInferOperators.refine_and_or TypeInferOperators.refine_and_or
TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection
TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs
TypeInferOperators.typecheck_unary_len_error
TypeInferOperators.typecheck_unary_minus
TypeInferOperators.typecheck_unary_minus_error
TypeInferOperators.UnknownGlobalCompoundAssign TypeInferOperators.UnknownGlobalCompoundAssign
TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.CheckMethodsOfNumber
TypeInferPrimitives.singleton_types TypeInferPrimitives.singleton_types
@ -489,6 +477,7 @@ TypeInferUnknownNever.math_operators_and_never
TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable
TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2
TypeInferUnknownNever.unary_minus_of_never TypeInferUnknownNever.unary_minus_of_never
TypePackTests.detect_cyclic_typepacks2
TypePackTests.higher_order_function TypePackTests.higher_order_function
TypePackTests.pack_tail_unification_check TypePackTests.pack_tail_unification_check
TypePackTests.parenthesized_varargs_returns_any TypePackTests.parenthesized_varargs_returns_any