Sync to upstream/release/552 (#735)

* Reduce the stack utilization of type checking.
* Improve the error message that's reported when a delimiting comma is
missing from a table literal. eg
```lua
local t = {
    first = 1
    second = 2
}```
This commit is contained in:
Andy Friesen 2022-11-04 10:33:22 -07:00 committed by GitHub
parent e43a9e927c
commit c33700e473
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
78 changed files with 3061 additions and 2986 deletions

View File

@ -0,0 +1,68 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Def.h"
#include "Luau/TypedAllocator.h"
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <memory>
namespace Luau
{
struct Negation;
struct Conjunction;
struct Disjunction;
struct Equivalence;
struct Proposition;
using Connective = Variant<Negation, Conjunction, Disjunction, Equivalence, Proposition>;
using ConnectiveId = Connective*; // Can and most likely is nullptr.
struct Negation
{
ConnectiveId connective;
};
struct Conjunction
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Disjunction
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Equivalence
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Proposition
{
DefId def;
TypeId discriminantTy;
};
template<typename T>
const T* get(ConnectiveId connective)
{
return get_if<T>(connective);
}
struct ConnectiveArena
{
TypedAllocator<Connective> allocator;
ConnectiveId negation(ConnectiveId connective);
ConnectiveId conjunction(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId disjunction(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId equivalence(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId proposition(DefId def, TypeId discriminantTy);
};
} // namespace Luau

View File

@ -132,15 +132,16 @@ struct HasPropConstraint
std::string prop;
};
struct RefinementConstraint
// result ~ if isSingleton D then ~D else unknown where D = discriminantType
struct SingletonOrTopTypeConstraint
{
DefId def;
TypeId resultType;
TypeId discriminantType;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint,
HasPropConstraint, RefinementConstraint>;
HasPropConstraint, SingletonOrTopTypeConstraint>;
struct Constraint
{

View File

@ -2,6 +2,7 @@
#pragma once
#include "Luau/Ast.h"
#include "Luau/Connective.h"
#include "Luau/Constraint.h"
#include "Luau/DataFlowGraphBuilder.h"
#include "Luau/Module.h"
@ -26,11 +27,13 @@ struct DcrLogger;
struct Inference
{
TypeId ty = nullptr;
ConnectiveId connective = nullptr;
Inference() = default;
explicit Inference(TypeId ty)
explicit Inference(TypeId ty, ConnectiveId connective = nullptr)
: ty(ty)
, connective(connective)
{
}
};
@ -38,11 +41,13 @@ struct Inference
struct InferencePack
{
TypePackId tp = nullptr;
std::vector<ConnectiveId> connectives;
InferencePack() = default;
explicit InferencePack(TypePackId tp)
explicit InferencePack(TypePackId tp, const std::vector<ConnectiveId>& connectives = {})
: tp(tp)
, connectives(connectives)
{
}
};
@ -73,6 +78,7 @@ struct ConstraintGraphBuilder
// Defining scopes for AST nodes.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
NotNull<const DataFlowGraph> dfg;
ConnectiveArena connectiveArena;
int recursionCount = 0;
@ -126,6 +132,8 @@ struct ConstraintGraphBuilder
*/
NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
void applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective);
/**
* The entry point to the ConstraintGraphBuilder. This will construct a set
* of scopes, constraints, and free types that can be solved later.
@ -167,10 +175,10 @@ struct ConstraintGraphBuilder
* surrounding context. Used to implement bidirectional type checking.
* @return the type of the expression.
*/
Inference check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {});
Inference check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}, bool forceSingleton = false);
Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprLocal* local);
Inference check(const ScopePtr& scope, AstExprGlobal* global);
Inference check(const ScopePtr& scope, AstExprIndexName* indexName);
@ -180,6 +188,7 @@ struct ConstraintGraphBuilder
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, ConnectiveId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);

View File

@ -110,7 +110,7 @@ struct ConstraintSolver
bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const RefinementConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do
// also handles __iter metamethod

View File

@ -17,10 +17,8 @@ struct SingletonTypes;
using ModulePtr = std::shared_ptr<Module>;
bool isSubtype(
TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice,
bool anyIsTop = true);
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
class TypeIds
{
@ -169,12 +167,26 @@ struct NormalizedStringType
bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr);
// A normalized function type is either `never` (represented by `nullopt`)
// A normalized function type can be `never`, the top function type `function`,
// or an intersection of function types.
// NOTE: type normalization can fail on function types with generics
// (e.g. because we do not support unions and intersections of generic type packs),
// so this type may contain `error`.
using NormalizedFunctionType = std::optional<TypeIds>;
//
// NOTE: type normalization can fail on function types with generics (e.g.
// because we do not support unions and intersections of generic type packs), so
// this type may contain `error`.
struct NormalizedFunctionType
{
NormalizedFunctionType();
bool isTop = false;
// TODO: Remove this wrapping optional when clipping
// FFlagLuauNegatedFunctionTypes.
std::optional<TypeIds> parts;
void resetToNever();
void resetToTop();
bool isNever() const;
};
// A normalized generic/free type is a union, where each option is of the form (X & T) where
// * X is either a free type or a generic
@ -234,12 +246,14 @@ struct NormalizedType
NormalizedType(NotNull<SingletonTypes> singletonTypes);
NormalizedType(const NormalizedType&) = delete;
NormalizedType(NormalizedType&&) = default;
NormalizedType() = delete;
~NormalizedType() = default;
NormalizedType(const NormalizedType&) = delete;
NormalizedType& operator=(const NormalizedType&) = delete;
NormalizedType(NormalizedType&&) = default;
NormalizedType& operator=(NormalizedType&&) = default;
NormalizedType& operator=(NormalizedType&) = delete;
};
class Normalizer
@ -291,7 +305,7 @@ public:
bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1);
// ------- Negations
NormalizedType negateNormal(const NormalizedType& here);
std::optional<NormalizedType> negateNormal(const NormalizedType& here);
TypeIds negateAll(const TypeIds& theres);
TypeId negate(TypeId there);
void subtractPrimitive(NormalizedType& here, TypeId ty);

View File

@ -115,6 +115,7 @@ struct PrimitiveTypeVar
Number,
String,
Thread,
Function,
};
Type type;
@ -504,14 +505,6 @@ struct NeverTypeVar
{
};
// Invariant 1: there should never be a reason why such UseTypeVar exists without it mapping to another type.
// Invariant 2: UseTypeVar should always disappear across modules.
struct UseTypeVar
{
DefId def;
NotNull<Scope> scope;
};
// ~T
// TODO: Some simplification step that overwrites the type graph to make sure negation
// types disappear from the user's view, and (?) a debug flag to disable that
@ -522,9 +515,9 @@ struct NegationTypeVar
using ErrorTypeVar = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar,
TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar,
UseTypeVar, NegationTypeVar>;
using TypeVariant =
Unifiable::Variant<TypeId, PrimitiveTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar,
MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar, NegationTypeVar>;
struct TypeVar final
{
@ -644,6 +637,7 @@ public:
const TypeId stringType;
const TypeId booleanType;
const TypeId threadType;
const TypeId functionType;
const TypeId trueType;
const TypeId falseType;
const TypeId anyType;

View File

@ -61,7 +61,6 @@ struct Unifier
ErrorVec errors;
Location location;
Variance variance = Covariant;
bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once.
bool normalize; // Normalize unions and intersections if necessary
bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels
CountMismatch::Context ctx = CountMismatch::Arg;
@ -131,6 +130,7 @@ public:
Unifier makeChildUnifier();
void reportError(TypeError err);
LUAU_NOINLINE void reportError(Location location, TypeErrorData data);
private:
bool isNonstrictMode() const;

View File

@ -58,13 +58,15 @@ public:
constexpr int tid = getTypeId<T>();
typeId = tid;
new (&storage) TT(value);
new (&storage) TT(std::forward<T>(value));
}
Variant(const Variant& other)
{
static constexpr FnCopy table[sizeof...(Ts)] = {&fnCopy<Ts>...};
typeId = other.typeId;
tableCopy[typeId](&storage, &other.storage);
table[typeId](&storage, &other.storage);
}
Variant(Variant&& other)
@ -192,7 +194,6 @@ private:
return *static_cast<const T*>(lhs) == *static_cast<const T*>(rhs);
}
static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy<Ts>...};
static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...};
static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...};

View File

@ -155,10 +155,6 @@ struct GenericTypeVarVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const UseTypeVar& utv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const NegationTypeVar& ntv)
{
return visit(ty);
@ -321,8 +317,6 @@ struct GenericTypeVarVisitor
traverse(a);
}
}
else if (auto utv = get<UseTypeVar>(ty))
visit(ty, *utv);
else if (auto ntv = get<NegationTypeVar>(ty))
visit(ty, *ntv);
else if (!FFlag::LuauCompleteVisitor)

View File

@ -714,8 +714,7 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context)
result = arena->addType(UnionTypeVar{std::move(options)});
TypeId numberType = context.solver->singletonTypes->numberType;
TypeId packedTable = arena->addType(
TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed});
TypeId packedTable = arena->addType(TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed});
TypePackId tableTypePack = arena->addTypePack({packedTable});
asMutable(context.result)->ty.emplace<BoundTypePack>(tableTypePack);

View File

@ -62,7 +62,6 @@ struct TypeCloner
void operator()(const LazyTypeVar& t);
void operator()(const UnknownTypeVar& t);
void operator()(const NeverTypeVar& t);
void operator()(const UseTypeVar& t);
void operator()(const NegationTypeVar& t);
};
@ -338,12 +337,6 @@ void TypeCloner::operator()(const NeverTypeVar& t)
defaultClone(t);
}
void TypeCloner::operator()(const UseTypeVar& t)
{
TypeId result = dest.addType(BoundTypeVar{follow(typeId)});
seenTypes[typeId] = result;
}
void TypeCloner::operator()(const NegationTypeVar& t)
{
TypeId result = dest.addType(AnyTypeVar{});

View File

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Connective.h"
namespace Luau
{
ConnectiveId ConnectiveArena::negation(ConnectiveId connective)
{
return NotNull{allocator.allocate(Negation{connective})};
}
ConnectiveId ConnectiveArena::conjunction(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Conjunction{lhs, rhs})};
}
ConnectiveId ConnectiveArena::disjunction(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Disjunction{lhs, rhs})};
}
ConnectiveId ConnectiveArena::equivalence(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Equivalence{lhs, rhs})};
}
ConnectiveId ConnectiveArena::proposition(DefId def, TypeId discriminantTy)
{
return NotNull{allocator.allocate(Proposition{def, discriminantTy})};
}
} // namespace Luau

View File

@ -107,6 +107,101 @@ NotNull<Constraint> ConstraintGraphBuilder::addConstraint(const ScopePtr& scope,
return NotNull{scope->constraints.emplace_back(std::move(c)).get()};
}
static void unionRefinements(const std::unordered_map<DefId, TypeId>& lhs, const std::unordered_map<DefId, TypeId>& rhs,
std::unordered_map<DefId, TypeId>& dest, NotNull<TypeArena> arena)
{
for (auto [def, ty] : lhs)
{
auto rhsIt = rhs.find(def);
if (rhsIt == rhs.end())
continue;
std::vector<TypeId> discriminants{{ty, rhsIt->second}};
if (auto destIt = dest.find(def); destIt != dest.end())
discriminants.push_back(destIt->second);
dest[def] = arena->addType(UnionTypeVar{std::move(discriminants)});
}
}
static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, std::unordered_map<DefId, TypeId>* refis, bool sense,
NotNull<TypeArena> arena, bool eq, std::vector<SingletonOrTopTypeConstraint>* constraints)
{
using RefinementMap = std::unordered_map<DefId, TypeId>;
if (!connective)
return;
else if (auto negation = get<Negation>(connective))
return computeRefinement(scope, negation->connective, refis, !sense, arena, eq, constraints);
else if (auto conjunction = get<Conjunction>(connective))
{
RefinementMap lhsRefis;
RefinementMap rhsRefis;
computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints);
computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints);
if (!sense)
unionRefinements(lhsRefis, rhsRefis, *refis, arena);
}
else if (auto disjunction = get<Disjunction>(connective))
{
RefinementMap lhsRefis;
RefinementMap rhsRefis;
computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints);
computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints);
if (sense)
unionRefinements(lhsRefis, rhsRefis, *refis, arena);
}
else if (auto equivalence = get<Equivalence>(connective))
{
computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints);
computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints);
}
else if (auto proposition = get<Proposition>(connective))
{
TypeId discriminantTy = proposition->discriminantTy;
if (!sense && !eq)
discriminantTy = arena->addType(NegationTypeVar{proposition->discriminantTy});
else if (!sense && eq)
{
discriminantTy = arena->addType(BlockedTypeVar{});
constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy});
}
if (auto it = refis->find(proposition->def); it != refis->end())
(*refis)[proposition->def] = arena->addType(IntersectionTypeVar{{discriminantTy, it->second}});
else
(*refis)[proposition->def] = discriminantTy;
}
}
void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective)
{
if (!connective)
return;
std::unordered_map<DefId, TypeId> refinements;
std::vector<SingletonOrTopTypeConstraint> constraints;
computeRefinement(scope, connective, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints);
for (auto [def, discriminantTy] : refinements)
{
std::optional<TypeId> defTy = scope->lookup(def);
if (!defTy)
ice->ice("Every DefId must map to a type!");
TypeId resultTy = arena->addType(IntersectionTypeVar{{*defTy, discriminantTy}});
scope->dcrRefinements[def] = resultTy;
}
for (auto& c : constraints)
addConstraint(scope, location, c);
}
void ConstraintGraphBuilder::visit(AstStatBlock* block)
{
LUAU_ASSERT(scopes.empty());
@ -250,14 +345,33 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
if (value->is<AstExprConstantNil>())
{
// HACK: we leave nil-initialized things floating under the assumption that they will later be populated.
// See the test TypeInfer/infer_locals_with_nil_value.
// Better flow awareness should make this obsolete.
// HACK: we leave nil-initialized things floating under the
// assumption that they will later be populated.
//
// See the test TypeInfer/infer_locals_with_nil_value. Better flow
// awareness should make this obsolete.
if (!varTypes[i])
varTypes[i] = freshType(scope);
}
else if (i == local->values.size - 1)
// Only function calls and vararg expressions can produce packs. All
// other expressions produce exactly one value.
else if (i != local->values.size - 1 || (!value->is<AstExprCall>() && !value->is<AstExprVarargs>()))
{
std::optional<TypeId> expectedType;
if (hasAnnotation)
expectedType = varTypes.at(i);
TypeId exprType = check(scope, value, expectedType).ty;
if (i < varTypes.size())
{
if (varTypes[i])
addConstraint(scope, local->location, SubtypeConstraint{exprType, varTypes[i]});
else
varTypes[i] = exprType;
}
}
else
{
std::vector<TypeId> expectedTypes;
if (hasAnnotation)
@ -286,21 +400,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack});
}
}
else
{
std::optional<TypeId> expectedType;
if (hasAnnotation)
expectedType = varTypes.at(i);
TypeId exprType = check(scope, value, expectedType).ty;
if (i < varTypes.size())
{
if (varTypes[i])
addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType});
else
varTypes[i] = exprType;
}
}
}
for (size_t i = 0; i < local->vars.size; ++i)
@ -569,14 +668,16 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement
// TODO: Optimization opportunity, the interior scope of the condition could be
// reused for the then body, so we don't need to refine twice.
ScopePtr condScope = childScope(ifStatement->condition, scope);
check(condScope, ifStatement->condition, std::nullopt);
auto [_, connective] = check(condScope, ifStatement->condition, std::nullopt);
ScopePtr thenScope = childScope(ifStatement->thenbody, scope);
applyRefinements(thenScope, Location{}, connective);
visit(thenScope, ifStatement->thenbody);
if (ifStatement->elsebody)
{
ScopePtr elseScope = childScope(ifStatement->elsebody, scope);
applyRefinements(elseScope, Location{}, connectiveArena.negation(connective));
visit(elseScope, ifStatement->elsebody);
}
}
@ -925,7 +1026,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
}
}
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType)
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType, bool forceSingleton)
{
RecursionCounter counter{&recursionCount};
@ -938,13 +1039,13 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st
Inference result;
if (auto group = expr->as<AstExprGroup>())
result = check(scope, group->expr, expectedType);
result = check(scope, group->expr, expectedType, forceSingleton);
else if (auto stringExpr = expr->as<AstExprConstantString>())
result = check(scope, stringExpr, expectedType);
result = check(scope, stringExpr, expectedType, forceSingleton);
else if (expr->is<AstExprConstantNumber>())
result = Inference{singletonTypes->numberType};
else if (auto boolExpr = expr->as<AstExprConstantBool>())
result = check(scope, boolExpr, expectedType);
result = check(scope, boolExpr, expectedType, forceSingleton);
else if (expr->is<AstExprConstantNil>())
result = Inference{singletonTypes->nilType};
else if (auto local = expr->as<AstExprLocal>())
@ -999,8 +1100,11 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st
return result;
}
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType)
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton)
{
if (forceSingleton)
return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})};
if (expectedType)
{
const TypeId expectedTy = follow(*expectedType);
@ -1020,12 +1124,15 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantSt
return Inference{singletonTypes->stringType};
}
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional<TypeId> expectedType)
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional<TypeId> expectedType, bool forceSingleton)
{
const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType;
if (forceSingleton)
return Inference{singletonType};
if (expectedType)
{
const TypeId expectedTy = follow(*expectedType);
const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType;
if (get<BlockedTypeVar>(expectedTy) || get<PendingExpansionTypeVar>(expectedTy))
{
@ -1045,8 +1152,8 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local)
{
std::optional<TypeId> resultTy;
if (auto def = dfg->getDef(local))
auto def = dfg->getDef(local);
if (def)
resultTy = scope->lookup(*def);
if (!resultTy)
@ -1058,6 +1165,9 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* loc
if (!resultTy)
return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition.
if (def)
return Inference{*resultTy, connectiveArena.proposition(*def, singletonTypes->truthyType)};
else
return Inference{*resultTy};
}
@ -1107,20 +1217,23 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr*
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary)
{
TypeId operandType = check(scope, unary->expr).ty;
auto [operandType, connective] = check(scope, unary->expr);
TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType});
if (unary->op == AstExprUnary::Not)
return Inference{resultType, connectiveArena.negation(connective)};
else
return Inference{resultType};
}
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType)
{
TypeId leftType = check(scope, binary->left, expectedType).ty;
TypeId rightType = check(scope, binary->right, expectedType).ty;
auto [leftType, rightType, connective] = checkBinary(scope, binary, expectedType);
TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType});
return Inference{resultType};
return Inference{resultType, std::move(connective)};
}
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType)
@ -1147,6 +1260,58 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssert
return Inference{resolveType(scope, typeAssert->annotation)};
}
std::tuple<TypeId, TypeId, ConnectiveId> ConstraintGraphBuilder::checkBinary(
const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType)
{
if (binary->op == AstExprBinary::And)
{
auto [leftType, leftConnective] = check(scope, binary->left, expectedType);
ScopePtr rightScope = childScope(binary->right, scope);
applyRefinements(rightScope, binary->right->location, leftConnective);
auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType);
return {leftType, rightType, connectiveArena.conjunction(leftConnective, rightConnective)};
}
else if (binary->op == AstExprBinary::Or)
{
auto [leftType, leftConnective] = check(scope, binary->left, expectedType);
ScopePtr rightScope = childScope(binary->right, scope);
applyRefinements(rightScope, binary->right->location, connectiveArena.negation(leftConnective));
auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType);
return {leftType, rightType, connectiveArena.disjunction(leftConnective, rightConnective)};
}
else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe)
{
TypeId leftType = check(scope, binary->left, expectedType, true).ty;
TypeId rightType = check(scope, binary->right, expectedType, true).ty;
ConnectiveId leftConnective = nullptr;
if (auto def = dfg->getDef(binary->left))
leftConnective = connectiveArena.proposition(*def, rightType);
ConnectiveId rightConnective = nullptr;
if (auto def = dfg->getDef(binary->right))
rightConnective = connectiveArena.proposition(*def, leftType);
if (binary->op == AstExprBinary::CompareNe)
{
leftConnective = connectiveArena.negation(leftConnective);
rightConnective = connectiveArena.negation(rightConnective);
}
return {leftType, rightType, connectiveArena.equivalence(leftConnective, rightConnective)};
}
else
{
TypeId leftType = check(scope, binary->left, expectedType).ty;
TypeId rightType = check(scope, binary->right, expectedType).ty;
return {leftType, rightType, nullptr};
}
}
TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs)
{
std::vector<TypeId> types;
@ -1841,9 +2006,13 @@ std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGraphBuilder::
Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack)
{
auto [tp] = pack;
const auto& [tp, connectives] = pack;
ConnectiveId connective = nullptr;
if (!connectives.empty())
connective = connectives[0];
if (auto f = first(tp))
return Inference{*f};
return Inference{*f, connective};
TypeId typeResult = freshType(scope);
TypePack onePack{{typeResult}, freshTypePack(scope)};
@ -1851,7 +2020,7 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo
addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack});
return Inference{typeResult};
return Inference{typeResult, connective};
}
void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err)

View File

@ -440,8 +440,8 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*fcc, constraint);
else if (auto hpc = get<HasPropConstraint>(*constraint))
success = tryDispatch(*hpc, constraint);
else if (auto rc = get<RefinementConstraint>(*constraint))
success = tryDispatch(*rc, constraint);
else if (auto sottc = get<SingletonOrTopTypeConstraint>(*constraint))
success = tryDispatch(*sottc, constraint);
else
LUAU_ASSERT(false);
@ -1274,25 +1274,18 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
return true;
}
bool ConstraintSolver::tryDispatch(const RefinementConstraint& c, NotNull<const Constraint> constraint)
bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint)
{
// TODO: Figure out exact details on when refinements need to be blocked.
// It's possible that it never needs to be, since we can just use intersection types with the discriminant type?
if (isBlocked(c.discriminantType))
return false;
if (!constraint->scope->parent)
iceReporter.ice("No parent scope");
TypeId followed = follow(c.discriminantType);
std::optional<TypeId> previousTy = constraint->scope->parent->lookup(c.def);
if (!previousTy)
iceReporter.ice("No previous type");
std::optional<TypeId> useTy = constraint->scope->lookup(c.def);
if (!useTy)
iceReporter.ice("The def is not bound to a type");
TypeId resultTy = follow(*useTy);
std::vector<TypeId> parts{*previousTy, c.discriminantType};
asMutable(resultTy)->ty.emplace<IntersectionTypeVar>(std::move(parts));
// `nil` is a singleton type too! There's only one value of type `nil`.
if (get<SingletonTypeVar>(followed) || isNil(followed))
*asMutable(c.resultType) = NegationTypeVar{c.discriminantType};
else
*asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType};
return true;
}

View File

@ -13,16 +13,16 @@ declare bit32: {
bor: (...number) -> number,
bxor: (...number) -> number,
btest: (number, ...number) -> boolean,
rrotate: (number, number) -> number,
lrotate: (number, number) -> number,
lshift: (number, number) -> number,
arshift: (number, number) -> number,
rshift: (number, number) -> number,
bnot: (number) -> number,
extract: (number, number, number?) -> number,
replace: (number, number, number, number?) -> number,
countlz: (number) -> number,
countrz: (number) -> number,
rrotate: (x: number, disp: number) -> number,
lrotate: (x: number, disp: number) -> number,
lshift: (x: number, disp: number) -> number,
arshift: (x: number, disp: number) -> number,
rshift: (x: number, disp: number) -> number,
bnot: (x: number) -> number,
extract: (n: number, field: number, width: number?) -> number,
replace: (n: number, v: number, field: number, width: number?) -> number,
countlz: (n: number) -> number,
countrz: (n: number) -> number,
}
declare math: {
@ -93,9 +93,9 @@ type DateTypeResult = {
}
declare os: {
time: (DateTypeArg?) -> number,
date: (string?, number?) -> DateTypeResult | string,
difftime: (DateTypeResult | number, DateTypeResult | number) -> number,
time: (time: DateTypeArg?) -> number,
date: (formatString: string?, time: number?) -> DateTypeResult | string,
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
@ -145,51 +145,51 @@ declare function loadstring<A...>(src: string, chunkname: string?): (((A...) ->
declare function newproxy(mt: boolean?): any
declare coroutine: {
create: <A..., R...>((A...) -> R...) -> thread,
resume: <A..., R...>(thread, A...) -> (boolean, R...),
create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread,
status: (thread) -> "dead" | "running" | "normal" | "suspended",
status: (co: thread) -> "dead" | "running" | "normal" | "suspended",
-- FIXME: This technically returns a function, but we can't represent this yet.
wrap: <A..., R...>((A...) -> R...) -> any,
wrap: <A..., R...>(f: (A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
close: (thread) -> (boolean, any)
close: (co: thread) -> (boolean, any)
}
declare table: {
concat: <V>({V}, string?, number?, number?) -> string,
insert: (<V>({V}, V) -> ()) & (<V>({V}, number, V) -> ()),
maxn: <V>({V}) -> number,
remove: <V>({V}, number?) -> V?,
sort: <V>({V}, ((V, V) -> boolean)?) -> (),
create: <V>(number, V?) -> {V},
find: <V>({V}, V, number?) -> number?,
concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
maxn: <V>(t: {V}) -> number,
remove: <V>(t: {V}, number?) -> V?,
sort: <V>(t: {V}, comp: ((V, V) -> boolean)?) -> (),
create: <V>(count: number, value: V?) -> {V},
find: <V>(haystack: {V}, needle: V, init: number?) -> number?,
unpack: <V>({V}, number?, number?) -> ...V,
unpack: <V>(list: {V}, i: number?, j: number?) -> ...V,
pack: <V>(...V) -> { n: number, [number]: V },
getn: <V>({V}) -> number,
foreach: <K, V>({[K]: V}, (K, V) -> ()) -> (),
getn: <V>(t: {V}) -> number,
foreach: <K, V>(t: {[K]: V}, f: (K, V) -> ()) -> (),
foreachi: <V>({V}, (number, V) -> ()) -> (),
move: <V>({V}, number, number, number, {V}?) -> {V},
clear: <K, V>({[K]: V}) -> (),
move: <V>(src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V},
clear: <K, V>(table: {[K]: V}) -> (),
isfrozen: <K, V>({[K]: V}) -> boolean,
isfrozen: <K, V>(t: {[K]: V}) -> boolean,
}
declare debug: {
info: (<R...>(thread, number, string) -> R...) & (<R...>(number, string) -> R...) & (<A..., R1..., R2...>((A...) -> R1..., string) -> R2...),
traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string),
info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
}
declare utf8: {
char: (...number) -> string,
charpattern: string,
codes: (string) -> ((string, number) -> (number, number), string, number),
codepoint: (string, number?, number?) -> ...number,
len: (string, number?, number?) -> (number?, number?),
offset: (string, number?, number?) -> number,
codes: (str: string) -> ((string, number) -> (number, number), string, number),
codepoint: (str: string, i: number?, j: number?) -> ...number,
len: (s: string, i: number?, j: number?) -> (number?, number?),
offset: (s: string, n: number?, i: number?) -> number,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.

View File

@ -7,9 +7,9 @@
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/RecursionCounter.h"
#include "Luau/TypeVar.h"
#include "Luau/Unifier.h"
#include "Luau/VisitTypeVar.h"
LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false)
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false)
@ -20,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false);
LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false);
LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false);
LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false);
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf);
@ -206,6 +207,28 @@ bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& s
return true;
}
NormalizedFunctionType::NormalizedFunctionType()
: parts(FFlag::LuauNegatedFunctionTypes ? std::optional<TypeIds>{TypeIds{}} : std::nullopt)
{
}
void NormalizedFunctionType::resetToTop()
{
isTop = true;
parts.emplace();
}
void NormalizedFunctionType::resetToNever()
{
isTop = false;
parts.emplace();
}
bool NormalizedFunctionType::isNever() const
{
return !isTop && (!parts || parts->empty());
}
NormalizedType::NormalizedType(NotNull<SingletonTypes> singletonTypes)
: tops(singletonTypes->neverType)
, booleans(singletonTypes->neverType)
@ -220,8 +243,8 @@ NormalizedType::NormalizedType(NotNull<SingletonTypes> singletonTypes)
static bool isInhabited(const NormalizedType& norm)
{
return !get<NeverTypeVar>(norm.tops) || !get<NeverTypeVar>(norm.booleans) || !norm.classes.empty() || !get<NeverTypeVar>(norm.errors) ||
!get<NeverTypeVar>(norm.nils) || !get<NeverTypeVar>(norm.numbers) || !norm.strings.isNever() ||
!get<NeverTypeVar>(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty();
!get<NeverTypeVar>(norm.nils) || !get<NeverTypeVar>(norm.numbers) || !norm.strings.isNever() || !get<NeverTypeVar>(norm.threads) ||
!norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty();
}
static int tyvarIndex(TypeId ty)
@ -317,10 +340,14 @@ static bool isNormalizedThread(TypeId ty)
static bool areNormalizedFunctions(const NormalizedFunctionType& tys)
{
if (tys)
for (TypeId ty : *tys)
if (tys.parts)
{
for (TypeId ty : *tys.parts)
{
if (!get<FunctionTypeVar>(ty) && !get<ErrorTypeVar>(ty))
return false;
}
}
return true;
}
@ -420,7 +447,7 @@ void Normalizer::clearNormal(NormalizedType& norm)
norm.strings.resetToNever();
norm.threads = singletonTypes->neverType;
norm.tables.clear();
norm.functions = std::nullopt;
norm.functions.resetToNever();
norm.tyvars.clear();
}
@ -809,20 +836,28 @@ std::optional<TypeId> Normalizer::unionOfFunctions(TypeId here, TypeId there)
void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres)
{
if (!theres)
if (FFlag::LuauNegatedFunctionTypes)
{
if (heres.isTop)
return;
if (theres.isTop)
heres.resetToTop();
}
if (theres.isNever())
return;
TypeIds tmps;
if (!heres)
if (heres.isNever())
{
tmps.insert(theres->begin(), theres->end());
heres = std::move(tmps);
tmps.insert(theres.parts->begin(), theres.parts->end());
heres.parts = std::move(tmps);
return;
}
for (TypeId here : *heres)
for (TypeId there : *theres)
for (TypeId here : *heres.parts)
for (TypeId there : *theres.parts)
{
if (std::optional<TypeId> fun = unionOfFunctions(here, there))
tmps.insert(*fun);
@ -830,28 +865,28 @@ void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedF
tmps.insert(singletonTypes->errorRecoveryType(there));
}
heres = std::move(tmps);
heres.parts = std::move(tmps);
}
void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there)
{
if (!heres)
if (heres.isNever())
{
TypeIds tmps;
tmps.insert(there);
heres = std::move(tmps);
heres.parts = std::move(tmps);
return;
}
TypeIds tmps;
for (TypeId here : *heres)
for (TypeId here : *heres.parts)
{
if (std::optional<TypeId> fun = unionOfFunctions(here, there))
tmps.insert(*fun);
else
tmps.insert(singletonTypes->errorRecoveryType(there));
}
heres = std::move(tmps);
heres.parts = std::move(tmps);
}
void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there)
@ -1004,6 +1039,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
here.strings.resetToString();
else if (ptv->type == PrimitiveTypeVar::Thread)
here.threads = there;
else if (ptv->type == PrimitiveTypeVar::Function)
{
LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes);
here.functions.resetToTop();
}
else
LUAU_ASSERT(!"Unreachable");
}
@ -1036,8 +1076,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(there))
{
const NormalizedType* thereNormal = normalize(ntv->ty);
NormalizedType tn = negateNormal(*thereNormal);
if (!unionNormals(here, tn))
std::optional<NormalizedType> tn = negateNormal(*thereNormal);
if (!tn)
return false;
if (!unionNormals(here, *tn))
return false;
}
else
@ -1053,7 +1096,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
// ------- Negations
NormalizedType Normalizer::negateNormal(const NormalizedType& here)
std::optional<NormalizedType> Normalizer::negateNormal(const NormalizedType& here)
{
NormalizedType result{singletonTypes};
if (!get<NeverTypeVar>(here.tops))
@ -1092,8 +1135,22 @@ NormalizedType Normalizer::negateNormal(const NormalizedType& here)
result.threads = get<NeverTypeVar>(here.threads) ? singletonTypes->threadType : singletonTypes->neverType;
/*
* Things get weird and so, so complicated if we allow negations of
* arbitrary function types. Ordinary code can never form these kinds of
* types, so we decline to negate them.
*/
if (FFlag::LuauNegatedFunctionTypes)
{
if (here.functions.isNever())
result.functions.resetToTop();
else if (here.functions.isTop)
result.functions.resetToNever();
else
return std::nullopt;
}
// TODO: negating tables
// TODO: negating functions
// TODO: negating tyvars?
return result;
@ -1157,6 +1214,10 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty)
case PrimitiveTypeVar::Thread:
here.threads = singletonTypes->neverType;
break;
case PrimitiveTypeVar::Function:
LUAU_ASSERT(FFlag::LuauNegatedStringSingletons);
here.functions.resetToNever();
break;
}
}
@ -1738,18 +1799,20 @@ std::optional<TypeId> Normalizer::unionSaturatedFunctions(TypeId here, TypeId th
void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there)
{
if (!heres)
if (heres.isNever())
return;
for (auto it = heres->begin(); it != heres->end();)
heres.isTop = false;
for (auto it = heres.parts->begin(); it != heres.parts->end();)
{
TypeId here = *it;
if (get<ErrorTypeVar>(here))
it++;
else if (std::optional<TypeId> tmp = intersectionOfFunctions(here, there))
{
heres->erase(it);
heres->insert(*tmp);
heres.parts->erase(it);
heres.parts->insert(*tmp);
return;
}
else
@ -1757,27 +1820,27 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T
}
TypeIds tmps;
for (TypeId here : *heres)
for (TypeId here : *heres.parts)
{
if (std::optional<TypeId> tmp = unionSaturatedFunctions(here, there))
tmps.insert(*tmp);
}
heres->insert(there);
heres->insert(tmps.begin(), tmps.end());
heres.parts->insert(there);
heres.parts->insert(tmps.begin(), tmps.end());
}
void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres)
{
if (!heres)
if (heres.isNever())
return;
else if (!theres)
else if (theres.isNever())
{
heres = std::nullopt;
heres.resetToNever();
return;
}
else
{
for (TypeId there : *theres)
for (TypeId there : *theres.parts)
intersectFunctionsWithFunction(heres, there);
}
}
@ -1935,6 +1998,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
TypeId nils = here.nils;
TypeId numbers = here.numbers;
NormalizedStringType strings = std::move(here.strings);
NormalizedFunctionType functions = std::move(here.functions);
TypeId threads = here.threads;
clearNormal(here);
@ -1949,6 +2013,11 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
here.strings = std::move(strings);
else if (ptv->type == PrimitiveTypeVar::Thread)
here.threads = threads;
else if (ptv->type == PrimitiveTypeVar::Function)
{
LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes);
here.functions = std::move(functions);
}
else
LUAU_ASSERT(!"Unreachable");
}
@ -1981,8 +2050,10 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there)
for (TypeId part : itv->options)
{
const NormalizedType* normalPart = normalize(part);
NormalizedType negated = negateNormal(*normalPart);
intersectNormals(here, negated);
std::optional<NormalizedType> negated = negateNormal(*normalPart);
if (!negated)
return false;
intersectNormals(here, *negated);
}
}
else
@ -2016,14 +2087,16 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
result.insert(result.end(), norm.classes.begin(), norm.classes.end());
if (!get<NeverTypeVar>(norm.errors))
result.push_back(norm.errors);
if (norm.functions)
if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop)
result.push_back(singletonTypes->functionType);
else if (!norm.functions.isNever())
{
if (norm.functions->size() == 1)
result.push_back(*norm.functions->begin());
if (norm.functions.parts->size() == 1)
result.push_back(*norm.functions.parts->begin());
else
{
std::vector<TypeId> parts;
parts.insert(parts.end(), norm.functions->begin(), norm.functions->end());
parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end());
result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)}));
}
}
@ -2070,62 +2143,24 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm)
return arena->addType(UnionTypeVar{std::move(result)});
}
namespace
{
struct Replacer
{
TypeArena* arena;
TypeId sourceType;
TypeId replacedType;
DenseHashMap<TypeId, TypeId> newTypes;
Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType)
: arena(arena)
, sourceType(sourceType)
, replacedType(replacedType)
, newTypes(nullptr)
{
}
TypeId smartClone(TypeId t)
{
t = follow(t);
TypeId* res = newTypes.find(t);
if (res)
return *res;
TypeId result = shallowClone(t, *arena, TxnLog::empty());
newTypes[t] = result;
newTypes[result] = result;
return result;
}
};
} // anonymous namespace
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice, bool anyIsTop)
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
{
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant};
u.anyIsTop = anyIsTop;
u.tryUnify(subTy, superTy);
const bool ok = u.errors.empty() && u.log.empty();
return ok;
}
bool isSubtype(
TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice, bool anyIsTop)
bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
{
UnifierSharedState sharedState{&ice};
TypeArena arena;
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant};
u.anyIsTop = anyIsTop;
u.tryUnify(subPack, superPack);
const bool ok = u.errors.empty() && u.log.empty();

View File

@ -10,12 +10,12 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauLvaluelessPath)
LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false)
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false)
LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false)
LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false)
LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false)
/*
* Prefix generic typenames with gen-
@ -225,6 +225,20 @@ struct StringifierState
result.name += s;
}
void emitLevel(Scope* scope)
{
size_t count = 0;
for (Scope* s = scope; s; s = s->parent.get())
++count;
emit(count);
emit("-");
char buffer[16];
uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF);
snprintf(buffer, sizeof(buffer), "0x%x", s);
emit(buffer);
}
void emit(TypeLevel level)
{
emit(std::to_string(level.level));
@ -296,10 +310,7 @@ struct TypeVarStringifier
if (tv->ty.valueless_by_exception())
{
state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("* VALUELESS BY EXCEPTION *");
else
state.emit("< VALUELESS BY EXCEPTION >");
return;
}
@ -377,6 +388,9 @@ struct TypeVarStringifier
if (FFlag::DebugLuauVerboseTypeNames)
{
state.emit("-");
if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(ftv.scope);
else
state.emit(ftv.level);
}
}
@ -399,6 +413,15 @@ struct TypeVarStringifier
}
else
state.emit(state.getName(ty));
if (FFlag::DebugLuauVerboseTypeNames)
{
state.emit("-");
if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(gtv.scope);
else
state.emit(gtv.level);
}
}
void operator()(TypeId, const BlockedTypeVar& btv)
@ -434,6 +457,9 @@ struct TypeVarStringifier
case PrimitiveTypeVar::Thread:
state.emit("thread");
return;
case PrimitiveTypeVar::Function:
state.emit("function");
return;
default:
LUAU_ASSERT(!"Unknown primitive type");
throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type));
@ -462,10 +488,7 @@ struct TypeVarStringifier
if (state.hasSeen(&ftv))
{
state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return;
}
@ -573,10 +596,7 @@ struct TypeVarStringifier
if (state.hasSeen(&ttv))
{
state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return;
}
@ -710,10 +730,7 @@ struct TypeVarStringifier
if (state.hasSeen(&uv))
{
state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return;
}
@ -780,10 +797,7 @@ struct TypeVarStringifier
if (state.hasSeen(&uv))
{
state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return;
}
@ -828,10 +842,7 @@ struct TypeVarStringifier
void operator()(TypeId, const ErrorTypeVar& tv)
{
state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
else
state.emit(FFlag::LuauUnknownAndNeverType ? "<error-type>" : "*unknown*");
}
void operator()(TypeId, const LazyTypeVar& ltv)
@ -850,11 +861,6 @@ struct TypeVarStringifier
state.emit("never");
}
void operator()(TypeId ty, const UseTypeVar&)
{
stringify(follow(ty));
}
void operator()(TypeId, const NegationTypeVar& ntv)
{
state.emit("~");
@ -907,10 +913,7 @@ struct TypePackStringifier
if (tp->ty.valueless_by_exception())
{
state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("* VALUELESS TP BY EXCEPTION *");
else
state.emit("< VALUELESS TP BY EXCEPTION >");
return;
}
@ -934,10 +937,7 @@ struct TypePackStringifier
if (state.hasSeen(&tp))
{
state.result.cycle = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLETP*");
else
state.emit("<CYCLETP>");
return;
}
@ -982,10 +982,7 @@ struct TypePackStringifier
void operator()(TypePackId, const Unifiable::Error& error)
{
state.result.error = true;
if (FFlag::LuauSpecialTypesAsterisked)
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
else
state.emit(FFlag::LuauUnknownAndNeverType ? "<error-type>" : "*unknown*");
}
void operator()(TypePackId, const VariadicTypePack& pack)
@ -993,10 +990,7 @@ struct TypePackStringifier
state.emit("...");
if (FFlag::DebugLuauVerboseTypeNames && pack.hidden)
{
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*hidden*");
else
state.emit("<hidden>");
}
stringify(pack.ty);
}
@ -1031,6 +1025,9 @@ struct TypePackStringifier
if (FFlag::DebugLuauVerboseTypeNames)
{
state.emit("-");
if (FFlag::DebugLuauDeferredConstraintResolution)
state.emitLevel(pack.scope);
else
state.emit(pack.level);
}
@ -1204,10 +1201,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts)
{
result.truncated = true;
if (FFlag::LuauSpecialTypesAsterisked)
result.name += "... *TRUNCATED*";
else
result.name += "... <TRUNCATED>";
}
return result;
@ -1280,10 +1274,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts)
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
{
if (FFlag::LuauSpecialTypesAsterisked)
result.name += "... *TRUNCATED*";
else
result.name += "... <TRUNCATED>";
}
return result;
@ -1526,9 +1517,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
{
return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\"";
}
else if constexpr (std::is_same_v<T, RefinementConstraint>)
else if constexpr (std::is_same_v<T, SingletonOrTopTypeConstraint>)
{
return "TODO";
std::string result = tos(c.resultType, opts);
std::string discriminant = tos(c.discriminantType, opts);
return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant;
}
else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch");

View File

@ -338,12 +338,6 @@ public:
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"never"});
}
AstType* operator()(const UseTypeVar& utv)
{
std::optional<TypeId> ty = utv.scope->lookup(utv.def);
LUAU_ASSERT(ty);
return Luau::visit(*this, (*ty)->ty);
}
AstType* operator()(const NegationTypeVar& ntv)
{
// FIXME: do the same thing we do with ErrorTypeVar

View File

@ -301,7 +301,6 @@ struct TypeChecker2
UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant};
u.anyIsTop = true;
u.tryUnify(actualRetType, expectedRetType);
const bool ok = u.errors.empty() && u.log.empty();
@ -331,16 +330,21 @@ struct TypeChecker2
if (value)
visit(value);
if (i != local->values.size - 1)
TypeId* maybeValueType = value ? module->astTypes.find(value) : nullptr;
if (i != local->values.size - 1 || maybeValueType)
{
AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr;
if (var && var->annotation)
{
TypeId varType = lookupAnnotation(var->annotation);
TypeId annotationType = lookupAnnotation(var->annotation);
TypeId valueType = value ? lookupType(value) : nullptr;
if (valueType && !isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
reportError(TypeMismatch{varType, valueType}, value->location);
if (valueType)
{
ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType);
if (!errors.empty())
reportErrors(std::move(errors));
}
}
}
else
@ -606,7 +610,7 @@ struct TypeChecker2
visit(rhs);
TypeId rhsType = lookupType(rhs);
if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice))
{
reportError(TypeMismatch{lhsType, rhsType}, rhs->location);
}
@ -757,7 +761,7 @@ struct TypeChecker2
TypeId actualType = lookupType(number);
TypeId numberType = singletonTypes->numberType;
if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice))
{
reportError(TypeMismatch{actualType, numberType}, number->location);
}
@ -768,7 +772,7 @@ struct TypeChecker2
TypeId actualType = lookupType(string);
TypeId stringType = singletonTypes->stringType;
if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice))
{
reportError(TypeMismatch{actualType, stringType}, string->location);
}
@ -857,7 +861,7 @@ struct TypeChecker2
FunctionTypeVar ftv{argsTp, expectedRetType};
TypeId expectedType = arena.addType(ftv);
if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice))
{
CloneState cloneState;
expectedType = clone(expectedType, module->internalTypes, cloneState);
@ -876,7 +880,7 @@ struct TypeChecker2
getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true);
if (ty)
{
if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice))
{
reportError(TypeMismatch{resultType, *ty}, indexName->location);
}
@ -909,7 +913,7 @@ struct TypeChecker2
TypeId inferredArgTy = *argIt;
TypeId annotatedArgTy = lookupAnnotation(arg->annotation);
if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice))
{
reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location);
}
@ -1203,10 +1207,10 @@ struct TypeChecker2
TypeId computedType = lookupType(expr->expr);
// Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case.
if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice))
return;
if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice, /* anyIsTop */ false))
if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice))
return;
reportError(TypesAreUnrelated{computedType, annotationType}, expr->location);
@ -1507,7 +1511,6 @@ struct TypeChecker2
UnifierSharedState sharedState{&ice};
Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}};
Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant};
u.anyIsTop = true;
u.tryUnify(subTy, superTy);
return std::move(u.errors);

View File

@ -57,13 +57,6 @@ TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper)
return btv->boundTo;
else if (auto ttv = get<TableTypeVar>(mapper(ty)))
return ttv->boundTo;
else if (auto utv = get<UseTypeVar>(mapper(ty)))
{
std::optional<TypeId> ty = utv->scope->lookup(utv->def);
if (!ty)
throwRuntimeError("UseTypeVar must map to another TypeId");
return *ty;
}
else
return std::nullopt;
};
@ -761,6 +754,7 @@ SingletonTypes::SingletonTypes()
, stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}))
, booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}))
, threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}))
, functionType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Function}, /*persistent*/ true}))
, trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}))
, falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}))
, anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true}))
@ -946,7 +940,8 @@ void persist(TypeId ty)
queue.push_back(mtv->table);
queue.push_back(mtv->metatable);
}
else if (get<GenericTypeVar>(t) || get<AnyTypeVar>(t) || get<FreeTypeVar>(t) || get<SingletonTypeVar>(t) || get<PrimitiveTypeVar>(t) || get<NegationTypeVar>(t))
else if (get<GenericTypeVar>(t) || get<AnyTypeVar>(t) || get<FreeTypeVar>(t) || get<SingletonTypeVar>(t) || get<PrimitiveTypeVar>(t) ||
get<NegationTypeVar>(t))
{
}
else

View File

@ -8,6 +8,7 @@
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include "Luau/TimeTrace.h"
#include "Luau/TypeVar.h"
#include "Luau/VisitTypeVar.h"
#include "Luau/ToString.h"
@ -23,6 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false)
LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false);
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(LuauNegatedFunctionTypes)
namespace Luau
{
@ -363,7 +365,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
reportError(location, UnificationTooComplex{});
return;
}
@ -404,7 +406,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (subGeneric && !subsumes(useScopes, subGeneric, superFree))
{
// TODO: a more informative error message? CLI-39912
reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}});
reportError(location, GenericError{"Generic subtype escaping scope"});
return;
}
@ -433,7 +435,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (superGeneric && !subsumes(useScopes, superGeneric, subFree))
{
// TODO: a more informative error message? CLI-39912
reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}});
reportError(location, GenericError{"Generic supertype escaping scope"});
return;
}
@ -450,15 +452,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
return tryUnifyWithAny(subTy, superTy);
if (get<AnyTypeVar>(subTy))
{
if (anyIsTop)
{
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
return;
}
else
return tryUnifyWithAny(superTy, subTy);
}
if (log.get<ErrorTypeVar>(subTy))
return tryUnifyWithAny(superTy, subTy);
@ -478,7 +472,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (auto error = sharedState.cachedUnifyError.find({subTy, superTy}))
{
reportError(TypeError{location, *error});
reportError(location, *error);
return;
}
}
@ -520,6 +514,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
else if ((log.getMutable<PrimitiveTypeVar>(superTy) || log.getMutable<SingletonTypeVar>(superTy)) && log.getMutable<SingletonTypeVar>(subTy))
tryUnifySingletons(subTy, superTy);
else if (auto ptv = get<PrimitiveTypeVar>(superTy);
FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveTypeVar::Function && get<FunctionTypeVar>(subTy))
{
// Ok. Do nothing. forall functions F, F <: function
}
else if (log.getMutable<FunctionTypeVar>(superTy) && log.getMutable<FunctionTypeVar>(subTy))
tryUnifyFunctions(subTy, superTy, isFunctionCall);
@ -559,7 +559,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
tryUnifyNegationWithType(subTy, superTy);
else
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
if (cacheEnabled)
cacheResult(subTy, superTy, errorCount);
@ -633,9 +633,9 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion,
else if (failed)
{
if (firstFailedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}});
reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption});
else
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
}
}
@ -734,7 +734,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
const NormalizedType* subNorm = normalizer->normalize(subTy);
const NormalizedType* superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
reportError(TypeError{location, UnificationTooComplex{}});
reportError(location, UnificationTooComplex{});
else if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption);
else
@ -743,9 +743,9 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
else if (!found)
{
if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}});
reportError(location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption});
else
reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}});
reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible"});
}
}
@ -774,7 +774,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I
if (unificationTooComplex)
reportError(*unificationTooComplex);
else if (firstFailedOption)
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}});
reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption});
}
void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall)
@ -832,11 +832,11 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(TypeError{location, UnificationTooComplex{}});
reportError(location, UnificationTooComplex{});
}
else if (!found)
{
reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}});
reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"});
}
}
@ -848,37 +848,37 @@ void Unifier::tryUnifyNormalizedTypes(
if (get<UnknownTypeVar>(superNorm.tops) || get<AnyTypeVar>(superNorm.tops) || get<AnyTypeVar>(subNorm.tops))
return;
else if (get<UnknownTypeVar>(subNorm.tops))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<ErrorTypeVar>(subNorm.errors))
if (!get<ErrorTypeVar>(superNorm.errors))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.booleans))
{
if (!get<PrimitiveTypeVar>(superNorm.booleans))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
}
else if (const SingletonTypeVar* stv = get<SingletonTypeVar>(subNorm.booleans))
{
if (!get<PrimitiveTypeVar>(superNorm.booleans) && stv != get<SingletonTypeVar>(superNorm.booleans))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
}
if (get<PrimitiveTypeVar>(subNorm.nils))
if (!get<PrimitiveTypeVar>(superNorm.nils))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.numbers))
if (!get<PrimitiveTypeVar>(superNorm.numbers))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (!isSubtype(subNorm.strings, superNorm.strings))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
if (get<PrimitiveTypeVar>(subNorm.threads))
if (!get<PrimitiveTypeVar>(superNorm.errors))
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
for (TypeId subClass : subNorm.classes)
{
@ -894,7 +894,7 @@ void Unifier::tryUnifyNormalizedTypes(
}
}
if (!found)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
}
for (TypeId subTable : subNorm.tables)
@ -919,21 +919,19 @@ void Unifier::tryUnifyNormalizedTypes(
return reportError(*e);
}
if (!found)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
}
if (subNorm.functions)
if (!subNorm.functions.isNever())
{
if (!superNorm.functions)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
if (superNorm.functions->empty())
return;
for (TypeId superFun : *superNorm.functions)
if (superNorm.functions.isNever())
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
for (TypeId superFun : *superNorm.functions.parts)
{
Unifier innerState = makeChildUnifier();
const FunctionTypeVar* superFtv = get<FunctionTypeVar>(superFun);
if (!superFtv)
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes);
innerState.tryUnify_(tgt, superFtv->retTypes);
if (innerState.errors.empty())
@ -941,7 +939,7 @@ void Unifier::tryUnifyNormalizedTypes(
else if (auto e = hasUnificationTooComplex(innerState.errors))
return reportError(*e);
else
return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}});
return reportError(location, TypeMismatch{superTy, subTy, reason, error});
}
}
@ -959,15 +957,15 @@ void Unifier::tryUnifyNormalizedTypes(
TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args)
{
if (!overloads || overloads->empty())
if (overloads.isNever())
{
reportError(TypeError{location, CannotCallNonFunction{function}});
reportError(location, CannotCallNonFunction{function});
return singletonTypes->errorRecoveryTypePack();
}
std::optional<TypePackId> result;
const FunctionTypeVar* firstFun = nullptr;
for (TypeId overload : *overloads)
for (TypeId overload : *overloads.parts)
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(overload))
{
@ -1015,12 +1013,12 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized
// TODO: better error reporting?
// The logic for error reporting overload resolution
// is currently over in TypeInfer.cpp, should we move it?
reportError(TypeError{location, GenericError{"No matching overload."}});
reportError(location, GenericError{"No matching overload."});
return singletonTypes->errorRecoveryTypePack(firstFun->retTypes);
}
else
{
reportError(TypeError{location, CannotCallNonFunction{function}});
reportError(location, CannotCallNonFunction{function});
return singletonTypes->errorRecoveryTypePack();
}
}
@ -1199,7 +1197,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount)
{
reportError(TypeError{location, UnificationTooComplex{}});
reportError(location, UnificationTooComplex{});
return;
}
@ -1372,7 +1370,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
size_t actualSize = size(subTp);
if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult)
std::swap(expectedSize, actualSize);
reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}});
reportError(location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx});
while (superIter.good())
{
@ -1394,9 +1392,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
else
{
if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure)
reportError(TypeError{location, TypePackMismatch{subTp, superTp}});
reportError(location, TypePackMismatch{subTp, superTp});
else
reportError(TypeError{location, GenericError{"Failed to unify type packs"}});
reportError(location, GenericError{"Failed to unify type packs"});
}
}
@ -1408,7 +1406,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy)
ice("passed non primitive types to unifyPrimitives");
if (superPrim->type != subPrim->type)
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
}
void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy)
@ -1429,7 +1427,7 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy)
if (superPrim && superPrim->type == PrimitiveTypeVar::String && get<StringSingleton>(subSingleton) && variance == Covariant)
return;
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
}
void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall)
@ -1465,21 +1463,21 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
}
else
{
reportError(TypeError{location, UnificationTooComplex{}});
reportError(location, UnificationTooComplex{});
}
}
else if (numGenerics != subFunction->generics.size())
{
numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size());
reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}});
reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters"});
}
if (numGenericPacks != subFunction->genericPacks.size())
{
numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size());
reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}});
reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"});
}
for (size_t i = 0; i < numGenerics; i++)
@ -1506,11 +1504,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e);
else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError(
TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()}});
reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()});
else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}});
reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()});
innerState.ctx = CountMismatch::FunctionResult;
innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes);
@ -1520,13 +1517,12 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e);
else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes))
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}});
reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()});
else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError(
TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()}});
reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()});
else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}});
reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()});
}
log.concat(std::move(innerState.log));
@ -1608,7 +1604,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
}
else
{
reportError(TypeError{location, UnificationTooComplex{}});
reportError(location, UnificationTooComplex{});
}
}
}
@ -1626,7 +1622,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!missingProperties.empty())
{
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}});
reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)});
return;
}
}
@ -1644,7 +1640,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!extraProperties.empty())
{
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}});
reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra});
return;
}
}
@ -1825,13 +1821,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!missingProperties.empty())
{
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}});
reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)});
return;
}
if (!extraProperties.empty())
{
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}});
reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra});
return;
}
@ -1867,14 +1863,14 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
std::swap(subTy, superTy);
if (auto ttv = log.get<TableTypeVar>(superTy); !ttv || ttv->state != TableState::Free)
return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}});
return reportError(location, TypeMismatch{osuperTy, osubTy});
auto fail = [&](std::optional<TypeError> e) {
std::string reason = "The former's metatable does not satisfy the requirements.";
if (e)
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}});
reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e});
else
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}});
reportError(location, TypeMismatch{osuperTy, osubTy, reason});
};
// Given t1 where t1 = { lower: (t1) -> (a, b...) }
@ -1906,7 +1902,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
}
}
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}});
reportError(location, TypeMismatch{osuperTy, osubTy});
return;
}
@ -1947,7 +1943,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e);
else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}});
reportError(location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()});
log.concat(std::move(innerState.log));
}
@ -2024,9 +2020,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
auto fail = [&]() {
if (!reversed)
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
else
reportError(TypeError{location, TypeMismatch{subTy, superTy}});
reportError(location, TypeMismatch{subTy, superTy});
};
const ClassTypeVar* superClass = get<ClassTypeVar>(superTy);
@ -2071,7 +2067,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
if (!classProp)
{
ok = false;
reportError(TypeError{location, UnknownProperty{superTy, propName}});
reportError(location, UnknownProperty{superTy, propName});
}
else
{
@ -2095,7 +2091,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
{
ok = false;
std::string msg = "Class " + superClass->name + " does not have an indexer";
reportError(TypeError{location, GenericError{msg}});
reportError(location, GenericError{msg});
}
if (!ok)
@ -2116,13 +2112,13 @@ void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy)
const NormalizedType* subNorm = normalizer->normalize(subTy);
const NormalizedType* superNorm = normalizer->normalize(superTy);
if (!subNorm || !superNorm)
return reportError(TypeError{location, UnificationTooComplex{}});
return reportError(location, UnificationTooComplex{});
// T </: ~U iff T <: U
Unifier state = makeChildUnifier();
state.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "");
if (state.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
}
void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy)
@ -2132,7 +2128,7 @@ void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy)
ice("tryUnifyNegationWithType subTy must be a negation type");
// TODO: ~T </: U iff T <: U
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(location, TypeMismatch{superTy, subTy});
}
static void queueTypePack(std::vector<TypeId>& queue, DenseHashSet<TypePackId>& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack)
@ -2195,7 +2191,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
}
else if (get<Unifiable::Generic>(tail))
{
reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}});
reportError(location, GenericError{"Cannot unify variadic and generic packs"});
}
else if (get<Unifiable::Error>(tail))
{
@ -2209,7 +2205,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
}
else
{
reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}});
reportError(location, GenericError{"Failed to unify variadic packs"});
}
}
@ -2351,7 +2347,7 @@ bool Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
if (needle == haystack)
{
reportError(TypeError{location, OccursCheckFailed{}});
reportError(location, OccursCheckFailed{});
log.replace(needle, *singletonTypes->errorRecoveryType());
return true;
@ -2402,7 +2398,7 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
{
if (needle == haystack)
{
reportError(TypeError{location, OccursCheckFailed{}});
reportError(location, OccursCheckFailed{});
log.replace(needle, *singletonTypes->errorRecoveryTypePack());
return true;
@ -2423,18 +2419,31 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
Unifier Unifier::makeChildUnifier()
{
Unifier u = Unifier{normalizer, mode, scope, location, variance, &log};
u.anyIsTop = anyIsTop;
u.normalize = normalize;
u.useScopes = useScopes;
return u;
}
// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
//
// Note: report error accepts its arguments by value intentionally to reduce the stack usage of functions which call `reportError`.
void Unifier::reportError(Location location, TypeErrorData data)
{
errors.emplace_back(std::move(location), std::move(data));
}
// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
//
// Note: to conserve stack space in calling functions it is generally preferred to call `Unifier::reportError(Location location, TypeErrorData data)`
// instead of this method.
void Unifier::reportError(TypeError err)
{
errors.push_back(std::move(err));
}
bool Unifier::isNonstrictMode() const
{
return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck);
@ -2445,7 +2454,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId
if (auto e = hasUnificationTooComplex(innerErrors))
reportError(*e);
else if (!innerErrors.empty())
reportError(TypeError{location, TypeMismatch{wantedType, givenType}});
reportError(location, TypeMismatch{wantedType, givenType});
}
void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType)

View File

@ -25,6 +25,7 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false)
LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false)
LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false)
LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false)
bool lua_telemetry_parsed_out_of_range_bin_integer = false;
bool lua_telemetry_parsed_out_of_range_hex_integer = false;
@ -2310,9 +2311,13 @@ AstExpr* Parser::parseTableConstructor()
MatchLexeme matchBrace = lexer.current();
expectAndConsume('{', "table literal");
unsigned lastElementIndent = 0;
while (lexer.current().type != '}')
{
if (FFlag::LuauTableConstructorRecovery)
lastElementIndent = lexer.current().location.begin.column;
if (lexer.current().type == '[')
{
MatchLexeme matchLocationBracket = lexer.current();
@ -2357,9 +2362,13 @@ AstExpr* Parser::parseTableConstructor()
{
nextLexeme();
}
else
else if (FFlag::LuauTableConstructorRecovery && (lexer.current().type == '[' || lexer.current().type == Lexeme::Name) &&
lexer.current().location.begin.column == lastElementIndent)
{
report(lexer.current().location, "Expected ',' after table constructor element");
}
else if (lexer.current().type != '}')
{
if (lexer.current().type != '}')
break;
}
}

View File

@ -978,7 +978,8 @@ int replMain(int argc, char** argv)
if (compileFormat == CompileFormat::Null)
printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024));
else if (compileFormat == CompileFormat::CodegenNull)
printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024));
printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024),
int(stats.codegen / 1024));
return failed ? 1 : 0;
}

View File

@ -0,0 +1,52 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/RegisterA64.h"
namespace Luau
{
namespace CodeGen
{
enum class AddressKindA64 : uint8_t
{
imm, // reg + imm
reg, // reg + reg
// TODO:
// reg + reg << shift
// reg + sext(reg) << shift
// reg + uext(reg) << shift
// pc + offset
};
struct AddressA64
{
AddressA64(RegisterA64 base, int off = 0)
: kind(AddressKindA64::imm)
, base(base)
, offset(xzr)
, data(off)
{
LUAU_ASSERT(base.kind == KindA64::x);
LUAU_ASSERT(off >= 0 && off < 4096);
}
AddressA64(RegisterA64 base, RegisterA64 offset)
: kind(AddressKindA64::reg)
, base(base)
, offset(offset)
, data(0)
{
LUAU_ASSERT(base.kind == KindA64::x);
LUAU_ASSERT(offset.kind == KindA64::x);
}
AddressKindA64 kind;
RegisterA64 base;
RegisterA64 offset;
int data;
};
} // namespace CodeGen
} // namespace Luau

View File

@ -0,0 +1,144 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/RegisterA64.h"
#include "Luau/AddressA64.h"
#include "Luau/ConditionA64.h"
#include "Luau/Label.h"
#include <string>
#include <vector>
namespace Luau
{
namespace CodeGen
{
class AssemblyBuilderA64
{
public:
explicit AssemblyBuilderA64(bool logText);
~AssemblyBuilderA64();
// Moves
void mov(RegisterA64 dst, RegisterA64 src);
void mov(RegisterA64 dst, uint16_t src, int shift = 0);
void movk(RegisterA64 dst, uint16_t src, int shift = 0);
// Arithmetics
void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
void add(RegisterA64 dst, RegisterA64 src1, int src2);
void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
void sub(RegisterA64 dst, RegisterA64 src1, int src2);
void neg(RegisterA64 dst, RegisterA64 src);
// Comparisons
// Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm
// TODO: add cmp
// Binary
// Note: shifted-register support and bitfield operations are omitted for simplicity
// TODO: support immediate arguments (they have odd encoding and forbid many values)
// TODO: support not variants for and/or/eor (required to support not...)
void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void clz(RegisterA64 dst, RegisterA64 src);
void rbit(RegisterA64 dst, RegisterA64 src);
// Load
// Note: paired loads are currently omitted for simplicity
void ldr(RegisterA64 dst, AddressA64 src);
void ldrb(RegisterA64 dst, AddressA64 src);
void ldrh(RegisterA64 dst, AddressA64 src);
void ldrsb(RegisterA64 dst, AddressA64 src);
void ldrsh(RegisterA64 dst, AddressA64 src);
void ldrsw(RegisterA64 dst, AddressA64 src);
// Store
void str(RegisterA64 src, AddressA64 dst);
void strb(RegisterA64 src, AddressA64 dst);
void strh(RegisterA64 src, AddressA64 dst);
// Control flow
// Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks
void b(ConditionA64 cond, Label& label);
void cbz(RegisterA64 src, Label& label);
void cbnz(RegisterA64 src, Label& label);
void ret();
// Run final checks
bool finalize();
// Places a label at current location and returns it
Label setLabel();
// Assigns label position to the current location
void setLabel(Label& label);
void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3);
uint32_t getCodeSize() const;
// Resulting data and code that need to be copied over one after the other
// The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code'
std::vector<uint8_t> data;
std::vector<uint32_t> code;
std::string text;
const bool logText = false;
private:
// Instruction archetypes
void place0(const char* name, uint32_t word);
void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0);
void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op);
void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2);
void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op);
void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op);
void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0);
void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size);
void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond);
void placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond);
void place(uint32_t word);
void placeLabel(Label& label);
void commit();
LUAU_NOINLINE void extend();
// Data
size_t allocateData(size_t size, size_t align);
// Logging of assembly in text form
LUAU_NOINLINE void log(const char* opcode);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src);
LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label);
LUAU_NOINLINE void log(const char* opcode, Label label);
LUAU_NOINLINE void log(Label label);
LUAU_NOINLINE void log(RegisterA64 reg);
LUAU_NOINLINE void log(AddressA64 addr);
uint32_t nextLabel = 1;
std::vector<Label> pendingLabels;
std::vector<uint32_t> labelLocations;
bool finalized = false;
size_t dataPos = 0;
uint32_t* codePos = nullptr;
uint32_t* codeEnd = nullptr;
};
} // namespace CodeGen
} // namespace Luau

View File

@ -2,8 +2,8 @@
#pragma once
#include "Luau/Common.h"
#include "Luau/Condition.h"
#include "Luau/Label.h"
#include "Luau/ConditionX64.h"
#include "Luau/OperandX64.h"
#include "Luau/RegisterX64.h"
@ -84,7 +84,7 @@ public:
void ret();
// Control flow
void jcc(Condition cond, Label& label);
void jcc(ConditionX64 cond, Label& label);
void jmp(Label& label);
void jmp(OperandX64 op);

View File

@ -0,0 +1,37 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
namespace Luau
{
namespace CodeGen
{
enum class ConditionA64
{
Equal,
NotEqual,
CarrySet,
CarryClear,
Minus,
Plus,
Overflow,
NoOverflow,
UnsignedGreater,
UnsignedLessEqual,
GreaterEqual,
Less,
Greater,
LessEqual,
Always,
Count
};
} // namespace CodeGen
} // namespace Luau

View File

@ -6,7 +6,7 @@ namespace Luau
namespace CodeGen
{
enum class Condition
enum class ConditionX64 : uint8_t
{
Overflow,
NoOverflow,

View File

@ -61,7 +61,7 @@ struct OperandX64
constexpr OperandX64 operator[](OperandX64&& addr) const
{
LUAU_ASSERT(cat == CategoryX64::mem);
LUAU_ASSERT(memSize != SizeX64::none && index == noreg && scale == 1 && base == noreg && imm == 0);
LUAU_ASSERT(index == noreg && scale == 1 && base == noreg && imm == 0);
LUAU_ASSERT(addr.memSize == SizeX64::none);
addr.cat = CategoryX64::mem;
@ -70,13 +70,13 @@ struct OperandX64
}
};
constexpr OperandX64 addr{SizeX64::none, noreg, 1, noreg, 0};
constexpr OperandX64 byte{SizeX64::byte, noreg, 1, noreg, 0};
constexpr OperandX64 word{SizeX64::word, noreg, 1, noreg, 0};
constexpr OperandX64 dword{SizeX64::dword, noreg, 1, noreg, 0};
constexpr OperandX64 qword{SizeX64::qword, noreg, 1, noreg, 0};
constexpr OperandX64 xmmword{SizeX64::xmmword, noreg, 1, noreg, 0};
constexpr OperandX64 ymmword{SizeX64::ymmword, noreg, 1, noreg, 0};
constexpr OperandX64 ptr{sizeof(void*) == 4 ? SizeX64::dword : SizeX64::qword, noreg, 1, noreg, 0};
constexpr OperandX64 operator*(RegisterX64 reg, uint8_t scale)
{

View File

@ -0,0 +1,105 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <stdint.h>
namespace Luau
{
namespace CodeGen
{
enum class KindA64 : uint8_t
{
none,
w, // 32-bit GPR
x, // 64-bit GPR
};
struct RegisterA64
{
KindA64 kind : 3;
uint8_t index : 5;
constexpr bool operator==(RegisterA64 rhs) const
{
return kind == rhs.kind && index == rhs.index;
}
constexpr bool operator!=(RegisterA64 rhs) const
{
return !(*this == rhs);
}
};
constexpr RegisterA64 w0{KindA64::w, 0};
constexpr RegisterA64 w1{KindA64::w, 1};
constexpr RegisterA64 w2{KindA64::w, 2};
constexpr RegisterA64 w3{KindA64::w, 3};
constexpr RegisterA64 w4{KindA64::w, 4};
constexpr RegisterA64 w5{KindA64::w, 5};
constexpr RegisterA64 w6{KindA64::w, 6};
constexpr RegisterA64 w7{KindA64::w, 7};
constexpr RegisterA64 w8{KindA64::w, 8};
constexpr RegisterA64 w9{KindA64::w, 9};
constexpr RegisterA64 w10{KindA64::w, 10};
constexpr RegisterA64 w11{KindA64::w, 11};
constexpr RegisterA64 w12{KindA64::w, 12};
constexpr RegisterA64 w13{KindA64::w, 13};
constexpr RegisterA64 w14{KindA64::w, 14};
constexpr RegisterA64 w15{KindA64::w, 15};
constexpr RegisterA64 w16{KindA64::w, 16};
constexpr RegisterA64 w17{KindA64::w, 17};
constexpr RegisterA64 w18{KindA64::w, 18};
constexpr RegisterA64 w19{KindA64::w, 19};
constexpr RegisterA64 w20{KindA64::w, 20};
constexpr RegisterA64 w21{KindA64::w, 21};
constexpr RegisterA64 w22{KindA64::w, 22};
constexpr RegisterA64 w23{KindA64::w, 23};
constexpr RegisterA64 w24{KindA64::w, 24};
constexpr RegisterA64 w25{KindA64::w, 25};
constexpr RegisterA64 w26{KindA64::w, 26};
constexpr RegisterA64 w27{KindA64::w, 27};
constexpr RegisterA64 w28{KindA64::w, 28};
constexpr RegisterA64 w29{KindA64::w, 29};
constexpr RegisterA64 w30{KindA64::w, 30};
constexpr RegisterA64 wzr{KindA64::w, 31};
constexpr RegisterA64 x0{KindA64::x, 0};
constexpr RegisterA64 x1{KindA64::x, 1};
constexpr RegisterA64 x2{KindA64::x, 2};
constexpr RegisterA64 x3{KindA64::x, 3};
constexpr RegisterA64 x4{KindA64::x, 4};
constexpr RegisterA64 x5{KindA64::x, 5};
constexpr RegisterA64 x6{KindA64::x, 6};
constexpr RegisterA64 x7{KindA64::x, 7};
constexpr RegisterA64 x8{KindA64::x, 8};
constexpr RegisterA64 x9{KindA64::x, 9};
constexpr RegisterA64 x10{KindA64::x, 10};
constexpr RegisterA64 x11{KindA64::x, 11};
constexpr RegisterA64 x12{KindA64::x, 12};
constexpr RegisterA64 x13{KindA64::x, 13};
constexpr RegisterA64 x14{KindA64::x, 14};
constexpr RegisterA64 x15{KindA64::x, 15};
constexpr RegisterA64 x16{KindA64::x, 16};
constexpr RegisterA64 x17{KindA64::x, 17};
constexpr RegisterA64 x18{KindA64::x, 18};
constexpr RegisterA64 x19{KindA64::x, 19};
constexpr RegisterA64 x20{KindA64::x, 20};
constexpr RegisterA64 x21{KindA64::x, 21};
constexpr RegisterA64 x22{KindA64::x, 22};
constexpr RegisterA64 x23{KindA64::x, 23};
constexpr RegisterA64 x24{KindA64::x, 24};
constexpr RegisterA64 x25{KindA64::x, 25};
constexpr RegisterA64 x26{KindA64::x, 26};
constexpr RegisterA64 x27{KindA64::x, 27};
constexpr RegisterA64 x28{KindA64::x, 28};
constexpr RegisterA64 x29{KindA64::x, 29};
constexpr RegisterA64 x30{KindA64::x, 30};
constexpr RegisterA64 xzr{KindA64::x, 31};
constexpr RegisterA64 sp{KindA64::none, 31};
} // namespace CodeGen
} // namespace Luau

View File

@ -128,5 +128,10 @@ constexpr RegisterX64 dwordReg(RegisterX64 reg)
return RegisterX64{SizeX64::dword, reg.index};
}
constexpr RegisterX64 qwordReg(RegisterX64 reg)
{
return RegisterX64{SizeX64::qword, reg.index};
}
} // namespace CodeGen
} // namespace Luau

View File

@ -0,0 +1,607 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/AssemblyBuilderA64.h"
#include "ByteUtils.h"
#include <stdarg.h>
namespace Luau
{
namespace CodeGen
{
static const uint8_t codeForCondition[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
static const char* textForCondition[] = {
"b.eq", "b.ne", "b.cs", "b.cc", "b.mi", "b.pl", "b.vs", "b.vc", "b.hi", "b.ls", "b.ge", "b.lt", "b.gt", "b.le", "b.al"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionA64::Count), "all conditions have to be covered");
const unsigned kMaxAlign = 32;
AssemblyBuilderA64::AssemblyBuilderA64(bool logText)
: logText(logText)
{
data.resize(4096);
dataPos = data.size(); // data is filled backwards
code.resize(1024);
codePos = code.data();
codeEnd = code.data() + code.size();
}
AssemblyBuilderA64::~AssemblyBuilderA64()
{
LUAU_ASSERT(finalized);
}
void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src)
{
placeSR2("mov", dst, src, 0b01'01010);
}
void AssemblyBuilderA64::mov(RegisterA64 dst, uint16_t src, int shift)
{
placeI16("mov", dst, src, 0b10'100101, shift);
}
void AssemblyBuilderA64::movk(RegisterA64 dst, uint16_t src, int shift)
{
placeI16("movk", dst, src, 0b11'100101, shift);
}
void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
placeSR3("add", dst, src1, src2, 0b00'01011, shift);
}
void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, int src2)
{
placeI12("add", dst, src1, src2, 0b00'10001);
}
void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
placeSR3("sub", dst, src1, src2, 0b10'01011, shift);
}
void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, int src2)
{
placeI12("sub", dst, src1, src2, 0b10'10001);
}
void AssemblyBuilderA64::neg(RegisterA64 dst, RegisterA64 src)
{
placeSR2("neg", dst, src, 0b10'01011);
}
void AssemblyBuilderA64::and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("and", dst, src1, src2, 0b00'01010);
}
void AssemblyBuilderA64::orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("orr", dst, src1, src2, 0b01'01010);
}
void AssemblyBuilderA64::eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeSR3("eor", dst, src1, src2, 0b10'01010);
}
void AssemblyBuilderA64::lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("lsl", dst, src1, src2, 0b11010110, 0b0010'00);
}
void AssemblyBuilderA64::lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("lsr", dst, src1, src2, 0b11010110, 0b0010'01);
}
void AssemblyBuilderA64::asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("asr", dst, src1, src2, 0b11010110, 0b0010'10);
}
void AssemblyBuilderA64::ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
placeR3("ror", dst, src1, src2, 0b11010110, 0b0010'11);
}
void AssemblyBuilderA64::clz(RegisterA64 dst, RegisterA64 src)
{
placeR1("clz", dst, src, 0b10'11010110'00000'00010'0);
}
void AssemblyBuilderA64::rbit(RegisterA64 dst, RegisterA64 src)
{
placeR1("rbit", dst, src, 0b10'11010110'00000'0000'00);
}
void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldr", dst, src, 0b11100001, 0b10 | uint8_t(dst.kind == KindA64::x));
}
void AssemblyBuilderA64::ldrb(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::w);
placeA("ldrb", dst, src, 0b11100001, 0b00);
}
void AssemblyBuilderA64::ldrh(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::w);
placeA("ldrh", dst, src, 0b11100001, 0b01);
}
void AssemblyBuilderA64::ldrsb(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldrsb", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b00);
}
void AssemblyBuilderA64::ldrsh(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w);
placeA("ldrsh", dst, src, 0b11100010 | uint8_t(dst.kind == KindA64::w), 0b01);
}
void AssemblyBuilderA64::ldrsw(RegisterA64 dst, AddressA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::x);
placeA("ldrsw", dst, src, 0b11100010, 0b10);
}
void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w);
placeA("str", src, dst, 0b11100000, 0b10 | uint8_t(src.kind == KindA64::x));
}
void AssemblyBuilderA64::strb(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::w);
placeA("strb", src, dst, 0b11100000, 0b00);
}
void AssemblyBuilderA64::strh(RegisterA64 src, AddressA64 dst)
{
LUAU_ASSERT(src.kind == KindA64::w);
placeA("strh", src, dst, 0b11100000, 0b01);
}
void AssemblyBuilderA64::b(ConditionA64 cond, Label& label)
{
placeBC(textForCondition[int(cond)], label, 0b0101010'0, codeForCondition[int(cond)]);
}
void AssemblyBuilderA64::cbz(RegisterA64 src, Label& label)
{
placeBR("cbz", label, 0b011010'0, src);
}
void AssemblyBuilderA64::cbnz(RegisterA64 src, Label& label)
{
placeBR("cbnz", label, 0b011010'1, src);
}
void AssemblyBuilderA64::ret()
{
place0("ret", 0b1101011'0'0'10'11111'0000'0'0'11110'00000);
}
bool AssemblyBuilderA64::finalize()
{
bool success = true;
code.resize(codePos - code.data());
// Resolve jump targets
for (Label fixup : pendingLabels)
{
// If this assertion fires, a label was used in jmp without calling setLabel
LUAU_ASSERT(labelLocations[fixup.id - 1] != ~0u);
int value = int(labelLocations[fixup.id - 1]) - int(fixup.location);
// imm19 encoding word offset, at bit offset 5
// note that 18 bits of word offsets = 20 bits of byte offsets = +-1MB
if (value > -(1 << 18) && value < (1 << 18))
code[fixup.location] |= (value & ((1 << 19) - 1)) << 5;
else
success = false; // overflow
}
size_t dataSize = data.size() - dataPos;
// Shrink data
if (dataSize > 0)
memmove(&data[0], &data[dataPos], dataSize);
data.resize(dataSize);
finalized = true;
return success;
}
Label AssemblyBuilderA64::setLabel()
{
Label label{nextLabel++, getCodeSize()};
labelLocations.push_back(~0u);
if (logText)
log(label);
return label;
}
void AssemblyBuilderA64::setLabel(Label& label)
{
if (label.id == 0)
{
label.id = nextLabel++;
labelLocations.push_back(~0u);
}
label.location = getCodeSize();
labelLocations[label.id - 1] = label.location;
if (logText)
log(label);
}
void AssemblyBuilderA64::logAppend(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
text.append(buf);
}
uint32_t AssemblyBuilderA64::getCodeSize() const
{
return uint32_t(codePos - code.data());
}
void AssemblyBuilderA64::place0(const char* name, uint32_t op)
{
if (logText)
log(name);
place(op);
commit();
}
void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift)
{
if (logText)
log(name, dst, src1, src2, shift);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind);
LUAU_ASSERT(shift >= 0 && shift < 64); // right shift requires changing some encoding bits
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (shift << 10) | (src2.index << 16) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op)
{
if (logText)
log(name, dst, src);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (0x1f << 5) | (src.index << 16) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2)
{
if (logText)
log(name, dst, src1, src2);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (op2 << 10) | (src2.index << 16) | (op << 21) | sf);
commit();
}
void AssemblyBuilderA64::placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op)
{
if (logText)
log(name, dst, src);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src.kind);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src.index << 5) | (op << 10) | sf);
commit();
}
void AssemblyBuilderA64::placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op)
{
if (logText)
log(name, dst, src1, src2);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src1.kind);
LUAU_ASSERT(src2 >= 0 && src2 < (1 << 12));
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src1.index << 5) | (src2 << 10) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift)
{
if (logText)
log(name, dst, src, shift);
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(src >= 0 && src <= 0xffff);
LUAU_ASSERT(shift == 0 || shift == 16 || shift == 32 || shift == 48);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0;
place(dst.index | (src << 5) | ((shift >> 4) << 21) | (op << 23) | sf);
commit();
}
void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size)
{
if (logText)
log(name, dst, src);
switch (src.kind)
{
case AddressKindA64::imm:
LUAU_ASSERT(src.data % (1 << size) == 0);
place(dst.index | (src.base.index << 5) | ((src.data >> size) << 10) | (op << 22) | (1 << 24) | (size << 30));
break;
case AddressKindA64::reg:
place(dst.index | (src.base.index << 5) | (0b10 << 10) | (0b011 << 13) | (src.offset.index << 16) | (1 << 21) | (op << 22) | (size << 30));
break;
}
commit();
}
void AssemblyBuilderA64::placeBC(const char* name, Label& label, uint8_t op, uint8_t cond)
{
placeLabel(label);
if (logText)
log(name, label);
place(cond | (op << 24));
commit();
}
void AssemblyBuilderA64::placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond)
{
placeLabel(label);
if (logText)
log(name, cond, label);
LUAU_ASSERT(cond.kind == KindA64::w || cond.kind == KindA64::x);
uint32_t sf = (cond.kind == KindA64::x) ? 0x80000000 : 0;
place(cond.index | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::place(uint32_t word)
{
LUAU_ASSERT(codePos < codeEnd);
*codePos++ = word;
}
void AssemblyBuilderA64::placeLabel(Label& label)
{
if (label.location == ~0u)
{
if (label.id == 0)
{
label.id = nextLabel++;
labelLocations.push_back(~0u);
}
pendingLabels.push_back({label.id, getCodeSize()});
}
else
{
// note: if label has an assigned location we can in theory avoid patching it later, but
// we need to handle potential overflow of 19-bit offsets
LUAU_ASSERT(label.id != 0);
labelLocations[label.id - 1] = label.location;
pendingLabels.push_back({label.id, getCodeSize()});
}
}
void AssemblyBuilderA64::commit()
{
LUAU_ASSERT(codePos <= codeEnd);
if (codeEnd == codePos)
extend();
}
void AssemblyBuilderA64::extend()
{
uint32_t count = getCodeSize();
code.resize(code.size() * 2);
codePos = code.data() + count;
codeEnd = code.data() + code.size();
}
size_t AssemblyBuilderA64::allocateData(size_t size, size_t align)
{
LUAU_ASSERT(align > 0 && align <= kMaxAlign && (align & (align - 1)) == 0);
if (dataPos < size)
{
size_t oldSize = data.size();
data.resize(data.size() * 2);
memcpy(&data[oldSize], &data[0], oldSize);
memset(&data[0], 0, oldSize);
dataPos += oldSize;
}
dataPos = (dataPos - size) & ~(align - 1);
return dataPos;
}
void AssemblyBuilderA64::log(const char* opcode)
{
logAppend(" %s\n", opcode);
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src1);
text.append(",");
log(src2);
if (shift > 0)
logAppend(" LSL #%d", shift);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src1);
text.append(",");
logAppend("#%d", src2);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, AddressA64 src)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
log(src);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, int src, int shift)
{
logAppend(" %-12s", opcode);
log(dst);
text.append(",");
logAppend("#%d", src);
if (shift > 0)
logAppend(" LSL #%d", shift);
text.append("\n");
}
void AssemblyBuilderA64::log(const char* opcode, RegisterA64 src, Label label)
{
logAppend(" %-12s", opcode);
log(src);
text.append(",");
logAppend(".L%d\n", label.id);
}
void AssemblyBuilderA64::log(const char* opcode, Label label)
{
logAppend(" %-12s.L%d\n", opcode, label.id);
}
void AssemblyBuilderA64::log(Label label)
{
logAppend(".L%d:\n", label.id);
}
void AssemblyBuilderA64::log(RegisterA64 reg)
{
switch (reg.kind)
{
case KindA64::w:
if (reg.index == 31)
logAppend("wzr");
else
logAppend("w%d", reg.index);
break;
case KindA64::x:
if (reg.index == 31)
logAppend("xzr");
else
logAppend("x%d", reg.index);
break;
case KindA64::none:
LUAU_ASSERT(!"Unexpected register kind");
}
}
void AssemblyBuilderA64::log(AddressA64 addr)
{
text.append("[");
switch (addr.kind)
{
case AddressKindA64::imm:
log(addr.base);
if (addr.data != 0)
logAppend(",#%d", addr.data);
break;
case AddressKindA64::reg:
log(addr.base);
text.append(",");
log(addr.offset);
if (addr.data != 0)
logAppend(" LSL #%d", addr.data);
break;
}
text.append("]");
}
} // namespace CodeGen
} // namespace Luau

View File

@ -13,13 +13,13 @@ namespace CodeGen
{
// TODO: more assertions on operand sizes
const uint8_t codeForCondition[] = {
static const uint8_t codeForCondition[] = {
0x0, 0x1, 0x2, 0x3, 0x2, 0x6, 0x7, 0x3, 0x4, 0xc, 0xe, 0xf, 0xd, 0x3, 0x7, 0x6, 0x2, 0x5, 0xd, 0xf, 0xe, 0xc, 0x4, 0x5, 0xa, 0xb};
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(Condition::Count), "all conditions have to be covered");
static_assert(sizeof(codeForCondition) / sizeof(codeForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
const char* textForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna", "jnae",
"jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(Condition::Count), "all conditions have to be covered");
static const char* textForCondition[] = {"jo", "jno", "jc", "jnc", "jb", "jbe", "ja", "jae", "je", "jl", "jle", "jg", "jge", "jnb", "jnbe", "jna",
"jnae", "jne", "jnl", "jnle", "jng", "jnge", "jz", "jnz", "jp", "jnp"};
static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(ConditionX64::Count), "all conditions have to be covered");
#define OP_PLUS_REG(op, reg) ((op) + (reg & 0x7))
#define OP_PLUS_CC(op, cc) ((op) + uint8_t(cc))
@ -328,7 +328,10 @@ void AssemblyBuilderX64::lea(OperandX64 lhs, OperandX64 rhs)
if (logText)
log("lea", lhs, rhs);
LUAU_ASSERT(rhs.cat == CategoryX64::mem);
LUAU_ASSERT(lhs.cat == CategoryX64::reg && rhs.cat == CategoryX64::mem && rhs.memSize == SizeX64::none);
LUAU_ASSERT(rhs.base == rip || rhs.base.size == lhs.base.size);
LUAU_ASSERT(rhs.index == noreg || rhs.index.size == lhs.base.size);
rhs.memSize = lhs.base.size;
placeBinaryRegAndRegMem(lhs, rhs, 0x8d, 0x8d);
}
@ -363,7 +366,7 @@ void AssemblyBuilderX64::ret()
commit();
}
void AssemblyBuilderX64::jcc(Condition cond, Label& label)
void AssemblyBuilderX64::jcc(ConditionX64 cond, Label& label)
{
placeJcc(textForCondition[size_t(cond)], label, codeForCondition[size_t(cond)]);
}
@ -781,7 +784,7 @@ OperandX64 AssemblyBuilderX64::bytes(const void* ptr, size_t size, size_t align)
{
size_t pos = allocateData(size, align);
memcpy(&data[pos], ptr, size);
return OperandX64(SizeX64::qword, noreg, 1, rip, int32_t(pos - data.size()));
return OperandX64(SizeX64::none, noreg, 1, rip, int32_t(pos - data.size()));
}
void AssemblyBuilderX64::logAppend(const char* fmt, ...)
@ -1072,7 +1075,7 @@ void AssemblyBuilderX64::placeVex(OperandX64 dst, OperandX64 src1, OperandX64 sr
place(AVX_3_3(setW, src1.base, dst.base.size == SizeX64::ymmword, prefix));
}
uint8_t getScaleEncoding(uint8_t scale)
static uint8_t getScaleEncoding(uint8_t scale)
{
static const uint8_t scales[9] = {0xff, 0, 1, 0xff, 2, 0xff, 0xff, 0xff, 3};

View File

@ -51,10 +51,15 @@ static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers)
build.logAppend("; exitNoContinueVm\n");
helpers.exitNoContinueVm = build.setLabel();
emitExit(build, /* continueInVm */ false);
if (build.logText)
build.logAppend("; continueCallInVm\n");
helpers.continueCallInVm = build.setLabel();
emitContinueCallInVm(build);
}
static int emitInst(
AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i, Label* labelarr, Label& fallback)
static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i,
Label* labelarr, Label& fallback)
{
int skip = 0;
@ -86,6 +91,9 @@ static int emitInst(
case LOP_SETGLOBAL:
emitInstSetGlobal(build, pc, i, labelarr, fallback);
break;
case LOP_CALL:
emitInstCall(build, helpers, pc, i, labelarr);
break;
case LOP_RETURN:
emitInstReturn(build, helpers, pc, i, labelarr);
break;
@ -123,19 +131,19 @@ static int emitInst(
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::LessEqual, fallback);
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::LessEqual, fallback);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::Less, fallback);
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::Less, fallback);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::NotLessEqual, fallback);
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLessEqual, fallback);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, pc, i, labelarr, Condition::NotLess, fallback);
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLess, fallback);
break;
case LOP_JUMPX:
emitInstJumpX(build, pc, i, labelarr);
@ -291,19 +299,19 @@ static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauO
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::LessEqual);
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::LessEqual);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::Less);
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::Less);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::NotLessEqual);
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLessEqual);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, Condition::NotLess);
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLess);
break;
case LOP_ADD:
emitInstBinaryFallback(build, pc, i, TM_ADD);
@ -687,6 +695,9 @@ void compile(lua_State* L, int idx)
{
for (int i = 0; i < result->proto->sizecode; i++)
result->instTargets[i] += uintptr_t(codeStart + result->location);
LUAU_ASSERT(result->proto->sizecode);
result->entryTarget = result->instTargets[0];
}
// Link native proto objects to Proto; the memory is now managed by VM and will be freed via onDestroyFunction

View File

@ -72,5 +72,59 @@ void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc)
}
}
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults)
{
// slow-path: not a function call
if (LUAU_UNLIKELY(!ttisfunction(ra)))
{
luaV_tryfuncTM(L, ra);
argtop++; // __call adds an extra self
}
Closure* ccl = clvalue(ra);
CallInfo* ci = incr_ci(L);
ci->func = ra;
ci->base = ra + 1;
ci->top = argtop + ccl->stacksize; // note: technically UB since we haven't reallocated the stack yet
ci->savedpc = NULL;
ci->flags = 0;
ci->nresults = nresults;
L->base = ci->base;
L->top = argtop;
// note: this reallocs stack, but we don't need to VM_PROTECT this
// this is because we're going to modify base/savedpc manually anyhow
// crucially, we can't use ra/argtop after this line
luaD_checkstack(L, ccl->stacksize);
return ccl;
}
void callEpilogC(lua_State* L, int nresults, int n)
{
// ci is our callinfo, cip is our parent
CallInfo* ci = L->ci;
CallInfo* cip = ci - 1;
// copy return values into parent stack (but only up to nresults!), fill the rest with nil
// note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally
StkId res = ci->func;
StkId vali = L->top - n;
StkId valend = L->top;
int i;
for (i = nresults; i != 0 && vali < valend; i--)
setobj2s(L, res++, vali++);
while (i-- > 0)
setnilvalue(res++);
// pop the stack frame
L->ci = cip;
L->base = cip->base;
L->top = (nresults == LUA_MULTRET) ? res : cip->top;
}
} // namespace CodeGen
} // namespace Luau

View File

@ -13,5 +13,8 @@ bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux);
void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc);
Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults);
void callEpilogC(lua_State* L, int nresults, int n);
} // namespace CodeGen
} // namespace Luau

View File

@ -87,7 +87,7 @@ bool initEntryFunction(NativeState& data)
unwind.allocStack(stacksize + localssize);
// Setup frame pointer
build.lea(rbp, qword[rsp + stacksize]);
build.lea(rbp, addr[rsp + stacksize]);
unwind.setupFrameReg(rbp, stacksize);
unwind.finish();
@ -113,7 +113,7 @@ bool initEntryFunction(NativeState& data)
Label returnOff = build.setLabel();
// Cleanup and exit
build.lea(rsp, qword[rbp + localssize]);
build.lea(rsp, addr[rbp + localssize]);
build.pop(r15);
build.pop(r14);
build.pop(r13);

View File

@ -14,7 +14,7 @@ namespace Luau
namespace CodeGen
{
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, Condition cond, Label& label)
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label)
{
// Refresher on comi/ucomi EFLAGS:
// CF only: less
@ -35,53 +35,54 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs,
// And because of NaN, integer check interchangeability like 'not less or equal' <-> 'greater' does not hold
switch (cond)
{
case Condition::NotLessEqual:
case ConditionX64::NotLessEqual:
// (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN
build.jcc(Condition::NotAboveEqual, label);
build.jcc(ConditionX64::NotAboveEqual, label);
break;
case Condition::LessEqual:
case ConditionX64::LessEqual:
// (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN
build.jcc(Condition::AboveEqual, label);
build.jcc(ConditionX64::AboveEqual, label);
break;
case Condition::NotLess:
case ConditionX64::NotLess:
// (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN
build.jcc(Condition::NotAbove, label);
build.jcc(ConditionX64::NotAbove, label);
break;
case Condition::Less:
case ConditionX64::Less:
// (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN
build.jcc(Condition::Above, label);
build.jcc(ConditionX64::Above, label);
break;
case Condition::NotEqual:
case ConditionX64::NotEqual:
// ZF=0 or PF=1 means != or NaN
build.jcc(Condition::NotZero, label);
build.jcc(Condition::Parity, label);
build.jcc(ConditionX64::NotZero, label);
build.jcc(ConditionX64::Parity, label);
break;
default:
LUAU_ASSERT(!"Unsupported condition");
}
}
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos)
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra));
build.lea(rArg3, luauRegValue(rb));
build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegAddress(rb));
if (cond == Condition::NotLessEqual || cond == Condition::LessEqual)
if (cond == ConditionX64::NotLessEqual || cond == ConditionX64::LessEqual)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessequal)]);
else if (cond == Condition::NotLess || cond == Condition::Less)
else if (cond == ConditionX64::NotLess || cond == ConditionX64::Less)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_lessthan)]);
else if (cond == Condition::NotEqual || cond == Condition::Equal)
else if (cond == ConditionX64::NotEqual || cond == ConditionX64::Equal)
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_equalval)]);
else
LUAU_ASSERT(!"Unsupported condition");
emitUpdateBase(build);
build.test(eax, eax);
build.jcc(
cond == Condition::NotLessEqual || cond == Condition::NotLess || cond == Condition::NotEqual ? Condition::Zero : Condition::NotZero, label);
build.jcc(cond == ConditionX64::NotLessEqual || cond == ConditionX64::NotLess || cond == ConditionX64::NotEqual ? ConditionX64::Zero
: ConditionX64::NotZero,
label);
}
RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos)
@ -120,7 +121,7 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi
build.vucomisd(tmp, numd); // Sets ZF=1 if equal or NaN
// We don't need non-integer values
// But to skip the PF=1 check, we proceed with NaN because 0x80000000 index is out of bounds
build.jcc(Condition::NotZero, label);
build.jcc(ConditionX64::NotZero, label);
}
void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, int pcpos, TMS tm)
@ -133,8 +134,8 @@ void callArithHelper(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, in
build.mov(rArg5, tm);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra));
build.lea(rArg3, luauRegValue(rb));
build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegAddress(rb));
build.lea(rArg4, c);
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]);
@ -146,8 +147,8 @@ void callLengthHelper(AssemblyBuilderX64& build, int ra, int rb, int pcpos)
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra));
build.lea(rArg3, luauRegValue(rb));
build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegAddress(rb));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_dolen)]);
emitUpdateBase(build);
@ -158,9 +159,9 @@ void callPrepareForN(AssemblyBuilderX64& build, int limit, int step, int init, i
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(limit));
build.lea(rArg3, luauRegValue(step));
build.lea(rArg4, luauRegValue(init));
build.lea(rArg2, luauRegAddress(limit));
build.lea(rArg3, luauRegAddress(step));
build.lea(rArg4, luauRegAddress(init));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]);
}
@ -169,9 +170,9 @@ void callGetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int p
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(rb));
build.lea(rArg2, luauRegAddress(rb));
build.lea(rArg3, c);
build.lea(rArg4, luauRegValue(ra));
build.lea(rArg4, luauRegAddress(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_gettable)]);
emitUpdateBase(build);
@ -182,9 +183,9 @@ void callSetTable(AssemblyBuilderX64& build, int rb, OperandX64 c, int ra, int p
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(rb));
build.lea(rArg2, luauRegAddress(rb));
build.lea(rArg3, c);
build.lea(rArg4, luauRegValue(ra));
build.lea(rArg4, luauRegAddress(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_settable)]);
emitUpdateBase(build);
@ -197,16 +198,16 @@ static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, Register
// iscollectable(ra)
build.cmp(luauRegTag(ra), LUA_TSTRING);
build.jcc(Condition::Less, skip);
build.jcc(ConditionX64::Less, skip);
// isblack(obj2gco(o))
build.test(byte[object + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(Condition::Zero, skip);
build.jcc(ConditionX64::Zero, skip);
// iswhite(gcvalue(ra))
build.mov(tmp, luauRegValue(ra));
build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT));
build.jcc(Condition::Zero, skip);
build.jcc(ConditionX64::Zero, skip);
LUAU_ASSERT(object != rArg3);
build.mov(rArg3, tmp);
@ -229,12 +230,12 @@ void callBarrierTableFast(AssemblyBuilderX64& build, RegisterX64 table, Label& s
{
// isblack(obj2gco(t))
build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT));
build.jcc(Condition::Zero, skip);
build.jcc(ConditionX64::Zero, skip);
// Argument setup re-ordered to avoid conflicts with table register
if (table != rArg2)
build.mov(rArg2, table);
build.lea(rArg3, qword[rArg2 + offsetof(Table, gclist)]);
build.lea(rArg3, addr[rArg2 + offsetof(Table, gclist)]);
build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]);
}
@ -244,13 +245,13 @@ void callCheckGc(AssemblyBuilderX64& build, int pcpos, bool savepc, Label& skip)
build.mov(rax, qword[rState + offsetof(lua_State, global)]);
build.mov(rdx, qword[rax + offsetof(global_State, totalbytes)]);
build.cmp(rdx, qword[rax + offsetof(global_State, GCthreshold)]);
build.jcc(Condition::Below, skip);
build.jcc(ConditionX64::Below, skip);
if (savepc)
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(rArg2, 1);
build.mov(dwordReg(rArg2), 1);
build.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]);
emitUpdateBase(build);
@ -288,7 +289,7 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos)
build.mov(r8, qword[rState + offsetof(lua_State, global)]);
build.mov(r8, qword[r8 + offsetof(global_State, cb.interrupt)]);
build.test(r8, r8);
build.jcc(Condition::Zero, skip);
build.jcc(ConditionX64::Zero, skip);
emitSetSavedPc(build, pcpos + 1); // uses rax/rdx
@ -301,7 +302,7 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos)
// Check if we need to exit
build.mov(al, byte[rState + offsetof(lua_State, status)]);
build.test(al, al);
build.jcc(Condition::Zero, skip);
build.jcc(ConditionX64::Zero, skip);
build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.sub(qword[rax + offsetof(CallInfo, savedpc)], sizeof(Instruction));
@ -329,17 +330,6 @@ void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpo
build.mov(rArg4, rConstants);
build.call(qword[rNativeContext + offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback)]);
// Some instructions may interrupt the execution
if (opinfo.flags & kFallbackCheckInterrupt)
{
Label skip;
build.test(rax, rax);
build.jcc(Condition::NotZero, skip);
emitExit(build, /* continueInVm */ false);
build.setLabel(skip);
}
emitUpdateBase(build);
// Some instructions may jump to a different instruction or a completely different function
@ -359,51 +349,18 @@ void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpo
build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]);
build.jmp(qword[rax * 2 + rcx]);
}
else if (opinfo.flags & kFallbackUpdateCi)
}
void emitContinueCallInVm(AssemblyBuilderX64& build)
{
// Need to update state of the current function before we jump away
build.mov(rcx, qword[rState + offsetof(lua_State, ci)]); // L->ci
build.mov(rcx, qword[rcx + offsetof(CallInfo, func)]); // L->ci->func
build.mov(rcx, qword[rcx + offsetof(TValue, value.gc)]); // L->ci->func->value.gc aka cl
build.mov(sClosure, rcx);
build.mov(rsi, qword[rcx + offsetof(Closure, l.p)]); // cl->l.p aka proto
build.mov(rConstants, qword[rsi + offsetof(Proto, k)]); // proto->k
build.mov(rcx, qword[rsi + offsetof(Proto, code)]); // proto->code
build.mov(sCode, rcx);
RegisterX64 proto = rcx; // Sync with emitInstCall
// We'll need original instruction pointer later to handle return to interpreter
if (op == LOP_CALL)
build.mov(r9, rax);
// Get instruction index from instruction pointer
// To get instruction index from instruction pointer, we need to divide byte offset by 4
// But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out
build.sub(rax, sCode);
// We need to check if the new function can be executed natively
Label returnToInterpreter;
build.mov(rdx, qword[rsi + offsetofProtoExecData]);
build.test(rdx, rdx);
build.jcc(Condition::Zero, returnToInterpreter);
// Get new instruction location and jump to it
build.mov(rcx, qword[rdx + offsetof(NativeProto, instTargets)]);
build.jmp(qword[rax * 2 + rcx]);
build.setLabel(returnToInterpreter);
// If we are returning to the interpreter to make a call, we need to update the current instruction
if (op == LOP_CALL)
{
build.mov(rdx, qword[proto + offsetof(Proto, code)]);
build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.mov(qword[rax + offsetof(CallInfo, savedpc)], r9);
}
build.mov(qword[rax + offsetof(CallInfo, savedpc)], rdx);
// Continue in the interpreter
emitExit(build, /* continueInVm */ true);
}
}
} // namespace CodeGen
} // namespace Luau

View File

@ -72,6 +72,7 @@ struct ModuleHelpers
{
Label exitContinueVm;
Label exitNoContinueVm;
Label continueCallInVm;
};
inline OperandX64 luauReg(int ri)
@ -79,6 +80,11 @@ inline OperandX64 luauReg(int ri)
return xmmword[rBase + ri * sizeof(TValue)];
}
inline OperandX64 luauRegAddress(int ri)
{
return addr[rBase + ri * sizeof(TValue)];
}
inline OperandX64 luauRegValue(int ri)
{
return qword[rBase + ri * sizeof(TValue) + offsetof(TValue, value)];
@ -99,6 +105,11 @@ inline OperandX64 luauConstant(int ki)
return xmmword[rConstants + ki * sizeof(TValue)];
}
inline OperandX64 luauConstantAddress(int ki)
{
return addr[rConstants + ki * sizeof(TValue)];
}
inline OperandX64 luauConstantTag(int ki)
{
return dword[rConstants + ki * sizeof(TValue) + offsetof(TValue, tt)];
@ -144,13 +155,13 @@ inline void setNodeValue(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64
inline void jumpIfTagIs(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{
build.cmp(luauRegTag(ri), tag);
build.jcc(Condition::Equal, label);
build.jcc(ConditionX64::Equal, label);
}
inline void jumpIfTagIsNot(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{
build.cmp(luauRegTag(ri), tag);
build.jcc(Condition::NotEqual, label);
build.jcc(ConditionX64::NotEqual, label);
}
// Note: fallthrough label should be placed after this condition
@ -160,7 +171,7 @@ inline void jumpIfFalsy(AssemblyBuilderX64& build, int ri, Label& target, Label&
jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, fallthrough); // true if not nil or boolean
build.cmp(luauRegValueBoolean(ri), 0);
build.jcc(Condition::Equal, target); // true if boolean value is 'true'
build.jcc(ConditionX64::Equal, target); // true if boolean value is 'true'
}
// Note: fallthrough label should be placed after this condition
@ -170,13 +181,13 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label
jumpIfTagIsNot(build, ri, LUA_TBOOLEAN, target); // true if not nil or boolean
build.cmp(luauRegValueBoolean(ri), 0);
build.jcc(Condition::NotEqual, target); // true if boolean value is 'true'
build.jcc(ConditionX64::NotEqual, target); // true if boolean value is 'true'
}
inline void jumpIfMetatablePresent(AssemblyBuilderX64& build, RegisterX64 table, Label& target)
{
build.cmp(qword[table + offsetof(Table, metatable)], 0);
build.jcc(Condition::NotEqual, target);
build.jcc(ConditionX64::NotEqual, target);
}
inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& label)
@ -184,13 +195,13 @@ inline void jumpIfUnsafeEnv(AssemblyBuilderX64& build, RegisterX64 tmp, Label& l
build.mov(tmp, sClosure);
build.mov(tmp, qword[tmp + offsetof(Closure, env)]);
build.test(byte[tmp + offsetof(Table, safeenv)], 1);
build.jcc(Condition::Zero, label); // Not a safe environment
build.jcc(ConditionX64::Zero, label); // Not a safe environment
}
inline void jumpIfTableIsReadOnly(AssemblyBuilderX64& build, RegisterX64 table, Label& label)
{
build.cmp(byte[table + offsetof(Table, readonly)], 0);
build.jcc(Condition::NotEqual, label);
build.jcc(ConditionX64::NotEqual, label);
}
inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, lua_Type tag, Label& label)
@ -200,13 +211,13 @@ inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, Re
build.mov(tmp, luauNodeKeyTag(node));
build.and_(tmp, kLuaNodeTagMask);
build.cmp(tmp, tag);
build.jcc(Condition::NotEqual, label);
build.jcc(ConditionX64::NotEqual, label);
}
inline void jumpIfNodeValueTagIs(AssemblyBuilderX64& build, RegisterX64 node, lua_Type tag, Label& label)
{
build.cmp(dword[node + offsetof(LuaNode, val) + offsetof(TValue, tt)], tag);
build.jcc(Condition::Equal, label);
build.jcc(ConditionX64::Equal, label);
}
inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, OperandX64 expectedKey, Label& label)
@ -215,13 +226,13 @@ inline void jumpIfNodeKeyNotInExpectedSlot(AssemblyBuilderX64& build, RegisterX6
build.mov(tmp, expectedKey);
build.cmp(tmp, luauNodeKeyValue(node));
build.jcc(Condition::NotEqual, label);
build.jcc(ConditionX64::NotEqual, label);
jumpIfNodeValueTagIs(build, node, LUA_TNIL, label);
}
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, Condition cond, Label& label);
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, Condition cond, Label& label, int pcpos);
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, ConditionX64 cond, Label& label);
void jumpOnAnyCmpFallback(AssemblyBuilderX64& build, int ra, int rb, ConditionX64 cond, Label& label, int pcpos);
RegisterX64 getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 table, int pcpos);
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, int ri, Label& label);
@ -242,5 +253,8 @@ void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos); // Note: only uses ra
void emitInterrupt(AssemblyBuilderX64& build, int pcpos);
void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos);
void emitContinueCallInVm(AssemblyBuilderX64& build);
void emitExitFromLastReturn(AssemblyBuilderX64& build);
} // namespace CodeGen
} // namespace Luau

View File

@ -11,9 +11,6 @@
#include "lobject.h"
#include "ltm.h"
// TODO: all uses of luauRegValue and luauConstantValue need to be audited; some need to be changed to luauReg/ConstantAddress (doesn't exist yet)
// (the problem with existing use is that it includes additional offsetof(TValue, value) which happens to be 0 but isn't guaranteed to be)
namespace Luau
{
namespace CodeGen
@ -72,6 +69,159 @@ void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc)
build.vmovups(luauReg(ra), xmm0);
}
void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
int nparams = LUAU_INSN_B(*pc) - 1;
int nresults = LUAU_INSN_C(*pc) - 1;
emitInterrupt(build, pcpos);
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.lea(rArg2, luauRegAddress(ra));
if (nparams == LUA_MULTRET)
build.mov(rArg3, qword[rState + offsetof(lua_State, top)]);
else
build.lea(rArg3, luauRegAddress(ra + 1 + nparams));
build.mov(dwordReg(rArg4), nresults);
build.call(qword[rNativeContext + offsetof(NativeContext, callProlog)]);
RegisterX64 ccl = rax; // Returned from callProlog
emitUpdateBase(build);
Label cFuncCall;
build.test(byte[ccl + offsetof(Closure, isC)], 1);
build.jcc(ConditionX64::NotZero, cFuncCall);
{
RegisterX64 proto = rcx; // Sync with emitContinueCallInVm
RegisterX64 ci = rdx;
RegisterX64 argi = rsi;
RegisterX64 argend = rdi;
build.mov(proto, qword[ccl + offsetof(Closure, l.p)]);
// Switch current Closure
build.mov(sClosure, ccl); // Last use of 'ccl'
build.mov(ci, qword[rState + offsetof(lua_State, ci)]);
Label fillnil, exitfillnil;
// argi = L->top
build.mov(argi, qword[rState + offsetof(lua_State, top)]);
// argend = L->base + p->numparams
build.movzx(eax, byte[proto + offsetof(Proto, numparams)]);
build.shl(eax, kTValueSizeLog2);
build.lea(argend, addr[rBase + rax]);
// while (argi < argend) setnilvalue(argi++);
build.setLabel(fillnil);
build.cmp(argi, argend);
build.jcc(ConditionX64::NotBelow, exitfillnil);
build.mov(dword[argi + offsetof(TValue, tt)], LUA_TNIL);
build.add(argi, sizeof(TValue));
build.jmp(fillnil); // This loop rarely runs so it's not worth repeating cmp/jcc
build.setLabel(exitfillnil);
// Set L->top to ci->top as most function expect (no vararg)
build.mov(rax, qword[ci + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax);
build.mov(rax, qword[proto + offsetofProtoExecData]); // We'll need this value later
// But if it is vararg, update it to 'argi'
Label skipVararg;
build.test(byte[proto + offsetof(Proto, is_vararg)], 1);
build.jcc(ConditionX64::Zero, skipVararg);
build.mov(qword[rState + offsetof(lua_State, top)], argi);
build.setLabel(skipVararg);
// Check native function data
build.test(rax, rax);
build.jcc(ConditionX64::Zero, helpers.continueCallInVm);
// Switch current constants
build.mov(rConstants, qword[proto + offsetof(Proto, k)]);
// Switch current code
build.mov(rdx, qword[proto + offsetof(Proto, code)]);
build.mov(sCode, rdx);
build.jmp(qword[rax + offsetof(NativeProto, entryTarget)]);
}
build.setLabel(cFuncCall);
{
// results = ccl->c.f(L);
build.mov(rArg1, rState);
build.call(qword[ccl + offsetof(Closure, c.f)]); // Last use of 'ccl'
RegisterX64 results = eax;
build.test(results, results); // test here will set SF=1 for a negative number and it always sets OF to 0
build.jcc(ConditionX64::Less, helpers.exitNoContinueVm); // jl jumps if SF != OF
// We have special handling for small number of expected results below
if (nresults != 0 && nresults != 1)
{
build.mov(rArg1, rState);
build.mov(dwordReg(rArg2), nresults);
build.mov(dwordReg(rArg3), results);
build.call(qword[rNativeContext + offsetof(NativeContext, callEpilogC)]);
emitUpdateBase(build);
return;
}
RegisterX64 ci = rdx;
RegisterX64 cip = rcx;
RegisterX64 vali = rsi;
build.mov(ci, qword[rState + offsetof(lua_State, ci)]);
build.lea(cip, addr[ci - sizeof(CallInfo)]);
// L->base = cip->base
build.mov(rBase, qword[cip + offsetof(CallInfo, base)]);
build.mov(qword[rState + offsetof(lua_State, base)], rBase);
if (nresults == 1)
{
// Opportunistically copy the result we expected from (L->top - results)
build.mov(vali, qword[rState + offsetof(lua_State, top)]);
build.shl(results, kTValueSizeLog2);
build.sub(vali, qwordReg(results));
build.vmovups(xmm0, xmmword[vali]);
build.vmovups(luauReg(ra), xmm0);
Label skipnil;
// If there was no result, override the value with 'nil'
build.test(results, results);
build.jcc(ConditionX64::NotZero, skipnil);
build.mov(luauRegTag(ra), LUA_TNIL);
build.setLabel(skipnil);
}
// L->ci = cip
build.mov(qword[rState + offsetof(lua_State, ci)], cip);
// L->top = cip->top
build.mov(rax, qword[cip + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
}
void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr)
{
emitInterrupt(build, pcpos);
@ -85,7 +235,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
RegisterX64 nresults = esi;
build.mov(ci, qword[rState + offsetof(lua_State, ci)]);
build.lea(cip, qword[ci - sizeof(CallInfo)]);
build.lea(cip, addr[ci - sizeof(CallInfo)]);
// res = ci->func; note: we assume CALL always puts func+args and expects results to start at func
build.mov(res, qword[ci + offsetof(CallInfo, func)]);
@ -101,7 +251,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
{
// Our instruction doesn't have any results, so just fill results expected in parent with 'nil'
build.test(nresults, nresults); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0
build.jcc(Condition::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
build.jcc(ConditionX64::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
build.mov(counter, nresults);
@ -109,15 +259,15 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL);
build.add(res, sizeof(TValue));
build.dec(counter);
build.jcc(Condition::NotZero, repeatNilLoop);
build.jcc(ConditionX64::NotZero, repeatNilLoop);
}
else if (b == 1)
{
// Try setting our 1 result
build.test(nresults, nresults);
build.jcc(Condition::Zero, skipResultCopy);
build.jcc(ConditionX64::Zero, skipResultCopy);
build.lea(counter, dword[nresults - 1]);
build.lea(counter, addr[nresults - 1]);
build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[res], xmm0);
@ -125,13 +275,13 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
// Fill the rest of the expected results with 'nil'
build.test(counter, counter); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0
build.jcc(Condition::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
build.jcc(ConditionX64::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
Label repeatNilLoop = build.setLabel();
build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL);
build.add(res, sizeof(TValue));
build.dec(counter);
build.jcc(Condition::NotZero, repeatNilLoop);
build.jcc(ConditionX64::NotZero, repeatNilLoop);
}
else
{
@ -140,16 +290,16 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
// Copy return values into parent stack (but only up to nresults!)
build.test(nresults, nresults);
build.jcc(Condition::Zero, skipResultCopy);
build.jcc(ConditionX64::Zero, skipResultCopy);
// vali = ra
build.lea(vali, luauRegValue(ra));
build.lea(vali, luauRegAddress(ra));
// Copy as much as possible for MULTRET calls, and only as much as needed otherwise
if (b == LUA_MULTRET)
build.mov(valend, qword[rState + offsetof(lua_State, top)]); // valend = L->top
else
build.lea(valend, luauRegValue(ra + b)); // valend = ra + b
build.lea(valend, luauRegAddress(ra + b)); // valend = ra + b
build.mov(counter, nresults);
@ -157,26 +307,26 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
build.setLabel(repeatValueLoop);
build.cmp(vali, valend);
build.jcc(Condition::NotBelow, exitValueLoop);
build.jcc(ConditionX64::NotBelow, exitValueLoop);
build.vmovups(xmm0, xmmword[vali]);
build.vmovups(xmmword[res], xmm0);
build.add(vali, sizeof(TValue));
build.add(res, sizeof(TValue));
build.dec(counter);
build.jcc(Condition::NotZero, repeatValueLoop);
build.jcc(ConditionX64::NotZero, repeatValueLoop);
build.setLabel(exitValueLoop);
// Fill the rest of the expected results with 'nil'
build.test(counter, counter); // test here will set SF=1 for a negative number, ZF=1 for zero and OF=0
build.jcc(Condition::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
build.jcc(ConditionX64::LessEqual, skipResultCopy); // jle jumps if SF != OF or ZF == 1
Label repeatNilLoop = build.setLabel();
build.mov(dword[res + offsetof(TValue, tt)], LUA_TNIL);
build.add(res, sizeof(TValue));
build.dec(counter);
build.jcc(Condition::NotZero, repeatNilLoop);
build.jcc(ConditionX64::NotZero, repeatNilLoop);
}
build.setLabel(skipResultCopy);
@ -191,11 +341,11 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
// Unlikely, but this might be the last return from VM
build.test(byte[ci + offsetof(CallInfo, flags)], LUA_CALLINFO_RETURN);
build.jcc(Condition::NotZero, helpers.exitNoContinueVm);
build.jcc(ConditionX64::NotZero, helpers.exitNoContinueVm);
Label skipFixedRetTop;
build.test(nresults, nresults); // test here will set SF=1 for a negative number and it always sets OF to 0
build.jcc(Condition::Less, skipFixedRetTop); // jl jumps if SF != OF
build.jcc(ConditionX64::Less, skipFixedRetTop); // jl jumps if SF != OF
build.mov(rax, qword[cip + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax); // L->top = cip->top
build.setLabel(skipFixedRetTop);
@ -214,7 +364,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
build.mov(execdata, qword[proto + offsetofProtoExecData]);
build.test(execdata, execdata);
build.jcc(Condition::Zero, helpers.exitContinueVm); // Continue in interpreter if function has no native data
build.jcc(ConditionX64::Zero, helpers.exitContinueVm); // Continue in interpreter if function has no native data
// Change constants
build.mov(rConstants, qword[proto + offsetof(Proto, k)]);
@ -269,13 +419,13 @@ void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
build.mov(eax, luauRegTag(ra));
build.cmp(eax, luauRegTag(rb));
build.jcc(Condition::NotEqual, not_ ? target : exit);
build.jcc(ConditionX64::NotEqual, not_ ? target : exit);
// fast-path: number
build.cmp(eax, LUA_TNUMBER);
build.jcc(Condition::NotEqual, fallback);
build.jcc(ConditionX64::NotEqual, fallback);
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), Condition::NotEqual, not_ ? target : exit);
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), ConditionX64::NotEqual, not_ ? target : exit);
if (!not_)
build.jmp(target);
@ -285,10 +435,10 @@ void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc,
{
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? Condition::NotEqual : Condition::Equal, target, pcpos);
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target, pcpos);
}
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond, Label& fallback)
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = pc[1];
@ -302,7 +452,7 @@ void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pc
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), cond, target);
}
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond)
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond)
{
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
@ -324,7 +474,7 @@ void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pc
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
build.cmp(luauRegTag(ra), LUA_TNIL);
build.jcc(not_ ? Condition::NotEqual : Condition::Equal, target);
build.jcc(not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target);
}
void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
@ -339,7 +489,7 @@ void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
jumpIfTagIsNot(build, ra, LUA_TBOOLEAN, not_ ? target : exit);
build.test(luauRegValueBoolean(ra), 1);
build.jcc((aux & 0x1) ^ not_ ? Condition::NotZero : Condition::Zero, target);
build.jcc((aux & 0x1) ^ not_ ? ConditionX64::NotZero : ConditionX64::Zero, target);
}
void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr)
@ -356,15 +506,15 @@ void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TV
if (not_)
{
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), build.f64(kv.value.n), Condition::NotEqual, target);
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), build.f64(kv.value.n), ConditionX64::NotEqual, target);
}
else
{
// Compact equality check requires two labels, so it's not supported in generic 'jumpOnNumberCmp'
build.vmovsd(xmm0, luauRegValue(ra));
build.vucomisd(xmm0, build.f64(kv.value.n));
build.jcc(Condition::Parity, exit); // We first have to check PF=1 for NaN operands, because it also sets ZF=1
build.jcc(Condition::Zero, target); // Now that NaN is out of the way, we can check ZF=1 for equality
build.jcc(ConditionX64::Parity, exit); // We first have to check PF=1 for NaN operands, because it also sets ZF=1
build.jcc(ConditionX64::Zero, target); // Now that NaN is out of the way, we can check ZF=1 for equality
}
}
@ -381,7 +531,7 @@ void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
build.mov(rax, luauRegValue(ra));
build.cmp(rax, luauConstantValue(aux & 0xffffff));
build.jcc(not_ ? Condition::NotEqual : Condition::Equal, target);
build.jcc(not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target);
}
static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int rc, OperandX64 opc, int pcpos, TMS tm, Label& fallback)
@ -437,7 +587,7 @@ void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, int pcpos,
void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm)
{
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_C(*pc)), pcpos, tm);
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), pcpos, tm);
}
void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm, Label& fallback)
@ -447,7 +597,7 @@ void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm)
{
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstantValue(LUAU_INSN_C(*pc)), pcpos, tm);
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstantAddress(LUAU_INSN_C(*pc)), pcpos, tm);
}
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label& fallback)
@ -525,7 +675,7 @@ void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, int pcpos,
void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_B(*pc)), pcpos, TM_UNM);
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_B(*pc)), pcpos, TM_UNM);
}
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
@ -563,8 +713,8 @@ void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(rArg2, aux);
build.mov(rArg3, b == 0 ? 0 : 1 << (b - 1));
build.mov(dwordReg(rArg2), aux);
build.mov(dwordReg(rArg3), 1 << (b - 1));
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]);
build.mov(luauRegValue(ra), rax);
build.mov(luauRegTag(ra), LUA_TTABLE);
@ -610,7 +760,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
// c = L->top - rb
build.mov(cscaled, qword[rState + offsetof(lua_State, top)]);
build.lea(tmp, luauRegValue(rb));
build.lea(tmp, luauRegAddress(rb));
build.sub(cscaled, tmp); // Using byte difference
// L->top = L->ci->top
@ -633,7 +783,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
// Resize if h->sizearray < last
build.cmp(dword[table + offsetof(Table, sizearray)], last);
build.jcc(Condition::NotBelow, skipResize);
build.jcc(ConditionX64::NotBelow, skipResize);
// Argument setup reordered to avoid conflicts
LUAU_ASSERT(rArg3 != table);
@ -676,7 +826,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
if (c == LUA_MULTRET)
{
build.cmp(offset, limit);
build.jcc(Condition::NotBelow, endLoop);
build.jcc(ConditionX64::NotBelow, endLoop);
}
build.setLabel(repeatLoop);
@ -687,7 +837,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, int pcpos
build.add(offset, sizeof(TValue));
build.cmp(offset, limit);
build.jcc(Condition::Below, repeatLoop);
build.jcc(ConditionX64::Below, repeatLoop);
build.setLabel(endLoop);
}
@ -707,7 +857,7 @@ void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
Label skip;
// TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though
build.cmp(dword[rax + offsetof(TValue, tt)], LUA_TUPVAL);
build.jcc(Condition::NotEqual, skip);
build.jcc(ConditionX64::NotEqual, skip);
// UpVal.v points to the value (either on stack, or on heap inside each UpVal, but we can deref it unconditionally)
build.mov(rax, qword[rax + offsetof(TValue, value.gc)]);
@ -746,12 +896,12 @@ void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, int p
// L->openupval != 0
build.mov(rax, qword[rState + offsetof(lua_State, openupval)]);
build.test(rax, rax);
build.jcc(Condition::Zero, skip);
build.jcc(ConditionX64::Zero, skip);
// ra <= L->openuval->v
build.lea(rcx, qword[rBase + ra * sizeof(TValue)]);
build.lea(rcx, addr[rBase + ra * sizeof(TValue)]);
build.cmp(rcx, qword[rax + offsetof(UpVal, v)]);
build.jcc(Condition::Above, skip);
build.jcc(ConditionX64::Above, skip);
build.mov(rArg2, rcx);
build.mov(rArg1, rState);
@ -771,7 +921,7 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b
int nparams = customParams ? customParamCount : LUAU_INSN_B(call) - 1;
int nresults = LUAU_INSN_C(call) - 1;
int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1;
OperandX64 args = customParams ? customArgs : luauRegValue(ra + 2);
OperandX64 args = customParams ? customArgs : luauRegAddress(ra + 2);
Label& exit = labelarr[pcpos + instLen];
@ -784,7 +934,7 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b
if (nresults == LUA_MULTRET)
{
// L->top = ra + n;
build.lea(rax, qword[rBase + (ra + br.actualResultCount) * sizeof(TValue)]);
build.lea(rax, addr[rBase + (ra + br.actualResultCount) * sizeof(TValue)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
else if (nparams == LUA_MULTRET)
@ -822,7 +972,7 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b
// TODO: for SystemV ABI we can compute the result directly into rArg6
// L->top - (ra + 1)
build.mov(rcx, qword[rState + offsetof(lua_State, top)]);
build.lea(rdx, qword[rBase + (ra + 1) * sizeof(TValue)]);
build.lea(rdx, addr[rBase + (ra + 1) * sizeof(TValue)]);
build.sub(rcx, rdx);
build.shr(rcx, kTValueSizeLog2);
@ -840,20 +990,20 @@ static int emitInstFastCallN(AssemblyBuilderX64& build, const Instruction* pc, b
}
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra));
build.lea(rArg3, luauRegValue(arg));
build.mov(rArg4, nresults);
build.lea(rArg2, luauRegAddress(ra));
build.lea(rArg3, luauRegAddress(arg));
build.mov(dwordReg(rArg4), nresults);
build.call(rax);
build.test(eax, eax); // test here will set SF=1 for a negative number and it always sets OF to 0
build.jcc(Condition::Less, exit); // jl jumps if SF != OF
build.jcc(ConditionX64::Less, exit); // jl jumps if SF != OF
if (nresults == LUA_MULTRET)
{
// L->top = ra + n;
build.shl(rax, kTValueSizeLog2);
build.lea(rax, qword[rBase + rax + ra * sizeof(TValue)]);
build.lea(rax, addr[rBase + rax + ra * sizeof(TValue)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
else if (nparams == LUA_MULTRET)
@ -875,13 +1025,13 @@ int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
return emitInstFastCallN(
build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegValue(pc[1]), pcpos, /* instLen */ 2, labelarr);
build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegAddress(pc[1]), pcpos, /* instLen */ 2, labelarr);
}
int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
return emitInstFastCallN(
build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantValue(pc[1]), pcpos, /* instLen */ 2, labelarr);
build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantAddress(pc[1]), pcpos, /* instLen */ 2, labelarr);
}
int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
@ -916,16 +1066,16 @@ void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
Label reverse;
// step <= 0
jumpOnNumberCmp(build, noreg, step, zero, Condition::LessEqual, reverse);
jumpOnNumberCmp(build, noreg, step, zero, ConditionX64::LessEqual, reverse);
// TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation
// false: idx <= limit
jumpOnNumberCmp(build, noreg, idx, limit, Condition::LessEqual, exit);
jumpOnNumberCmp(build, noreg, idx, limit, ConditionX64::LessEqual, exit);
build.jmp(loopExit);
// true: limit <= idx
build.setLabel(reverse);
jumpOnNumberCmp(build, noreg, limit, idx, Condition::LessEqual, exit);
jumpOnNumberCmp(build, noreg, limit, idx, ConditionX64::LessEqual, exit);
build.jmp(loopExit);
// TOOD: place at the end of the function
@ -958,15 +1108,15 @@ void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
Label reverse, exit;
// step <= 0
jumpOnNumberCmp(build, noreg, step, zero, Condition::LessEqual, reverse);
jumpOnNumberCmp(build, noreg, step, zero, ConditionX64::LessEqual, reverse);
// false: idx <= limit
jumpOnNumberCmp(build, noreg, idx, limit, Condition::LessEqual, loopRepeat);
jumpOnNumberCmp(build, noreg, idx, limit, ConditionX64::LessEqual, loopRepeat);
build.jmp(exit);
// true: limit <= idx
build.setLabel(reverse);
jumpOnNumberCmp(build, noreg, limit, idx, Condition::LessEqual, loopRepeat);
jumpOnNumberCmp(build, noreg, limit, idx, ConditionX64::LessEqual, loopRepeat);
build.setLabel(exit);
}
@ -1010,13 +1160,13 @@ void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
// while (unsigned(index) < unsigned(sizearray))
Label arrayLoop = build.setLabel();
build.cmp(dwordReg(index), dword[table + offsetof(Table, sizearray)]);
build.jcc(Condition::NotBelow, isIpairsIter ? exit : skipArray);
build.jcc(ConditionX64::NotBelow, isIpairsIter ? exit : skipArray);
// If element is nil, we increment the index; if it's not, we still need 'index + 1' inside
build.inc(index);
build.cmp(dword[elemPtr + offsetof(TValue, tt)], LUA_TNIL);
build.jcc(Condition::Equal, isIpairsIter ? exit : skipArrayNil);
build.jcc(ConditionX64::Equal, isIpairsIter ? exit : skipArrayNil);
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(index + 1)));
build.mov(luauRegValue(ra + 2), index);
@ -1045,10 +1195,10 @@ void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
// Call helper to assign next node value or to signal loop exit
build.mov(rArg1, rState);
// rArg2 and rArg3 are already set
build.lea(rArg4, luauRegValue(ra));
build.lea(rArg4, luauRegAddress(ra));
build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNodeIter)]);
build.test(al, al);
build.jcc(Condition::NotZero, loopRepeat);
build.jcc(ConditionX64::NotZero, loopRepeat);
}
}
@ -1062,12 +1212,12 @@ void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc,
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(rArg2, ra);
build.mov(rArg3, aux);
build.mov(dwordReg(rArg2), ra);
build.mov(dwordReg(rArg3), aux);
build.call(qword[rNativeContext + offsetof(NativeContext, forgLoopNonTableFallback)]);
emitUpdateBase(build);
build.test(al, al);
build.jcc(Condition::NotZero, loopRepeat);
build.jcc(ConditionX64::NotZero, loopRepeat);
}
void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
@ -1103,7 +1253,7 @@ void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, int
build.vxorpd(xmm0, xmm0, xmm0);
build.vmovsd(xmm1, luauRegValue(ra + 2));
jumpOnNumberCmp(build, noreg, xmm0, xmm1, Condition::NotEqual, fallback);
jumpOnNumberCmp(build, noreg, xmm0, xmm1, ConditionX64::NotEqual, fallback);
build.mov(luauRegTag(ra), LUA_TNIL);
@ -1121,8 +1271,8 @@ void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction*
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
build.mov(rArg1, rState);
build.lea(rArg2, luauRegValue(ra));
build.mov(rArg3, pcpos + 1);
build.lea(rArg2, luauRegAddress(ra));
build.mov(dwordReg(rArg3), pcpos + 1);
build.call(qword[rNativeContext + offsetof(NativeContext, forgPrepXnextFallback)]);
build.jmp(target);
}
@ -1216,7 +1366,7 @@ void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcp
// unsigned(c) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], c);
build.jcc(Condition::BelowEqual, fallback);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
@ -1244,7 +1394,7 @@ void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, int pcp
// unsigned(c) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], c);
build.jcc(Condition::BelowEqual, fallback);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
jumpIfTableIsReadOnly(build, table, fallback);
@ -1284,7 +1434,7 @@ void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
// unsigned(index - 1) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], eax);
build.jcc(Condition::BelowEqual, fallback);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
@ -1296,7 +1446,7 @@ void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
callGetTable(build, LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos);
callGetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos);
}
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Label& fallback)
@ -1319,7 +1469,7 @@ void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
// unsigned(index - 1) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], eax);
build.jcc(Condition::BelowEqual, fallback);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
jumpIfTableIsReadOnly(build, table, fallback);
@ -1335,7 +1485,7 @@ void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpo
void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
callSetTable(build, LUAU_INSN_B(*pc), luauRegValue(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos);
callSetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc), pcpos);
}
void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
@ -1348,7 +1498,7 @@ void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label&
// note: if import failed, k[] is nil; we could check this during codegen, but we instead use runtime fallback
// this allows us to handle ahead-of-time codegen smoothly when an import fails to resolve at runtime
build.cmp(luauConstantTag(k), LUA_TNIL);
build.jcc(Condition::Equal, fallback);
build.jcc(ConditionX64::Equal, fallback);
build.vmovups(xmm0, luauConstant(k));
build.vmovups(luauReg(ra), xmm0);
@ -1367,7 +1517,7 @@ void emitInstGetImportFallback(AssemblyBuilderX64& build, const Instruction* pc,
build.mov(rArg1, rState);
build.mov(rArg2, qword[rax + offsetof(Closure, env)]);
build.mov(rArg3, rConstants);
build.mov(rArg4, aux);
build.mov(dwordReg(rArg4), aux);
if (build.abi == ABIX64::Windows)
build.mov(sArg5, 0);
@ -1470,8 +1620,8 @@ void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos,
// luaV_concat(L, c - b + 1, c)
build.mov(rArg1, rState);
build.mov(rArg2, rc - rb + 1);
build.mov(rArg3, rc);
build.mov(dwordReg(rArg2), rc - rb + 1);
build.mov(dwordReg(rArg3), rc);
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]);
emitUpdateBase(build);

View File

@ -14,7 +14,7 @@ namespace CodeGen
{
class AssemblyBuilderX64;
enum class Condition;
enum class ConditionX64 : uint8_t;
struct Label;
struct ModuleHelpers;
@ -24,14 +24,15 @@ void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback);
void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond, Label& fallback);
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, Condition cond);
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback);
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond);
void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);

File diff suppressed because it is too large Load Diff

View File

@ -10,84 +10,15 @@ typedef uint32_t Instruction;
typedef struct lua_TValue TValue;
typedef TValue* StkId;
const Instruction* execute_LOP_NOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADNIL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MOVE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETUPVAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETUPVAL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CLOSEUPVALS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETIMPORT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETTABLEN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETTABLEN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_RETURN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIF(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFEQ(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTEQ(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFLT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPIFNOTLT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ADD(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SUB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MUL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DIV(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MOD(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_POW(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ADDK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SUBK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MULK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DIVK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MODK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_POWK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_AND(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_OR(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ANDK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_ORK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CONCAT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NOT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_MINUS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LENGTH(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_NEWTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DUPTABLE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORNPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORNLOOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGLOOP(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP_INEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_FORGLOOP_INEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FORGPREP_NEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_FORGLOOP_NEXT(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPBACK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_LOADKX(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPX(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_COVERAGE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_CAPTURE(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_JUMPIFEQK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_DEP_JUMPIFNOTEQK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL1(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL2(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_FASTCALL2K(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKNIL(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKB(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKN(lua_State* L, const Instruction* pc, StkId base, TValue* k);
const Instruction* execute_LOP_JUMPXEQKS(lua_State* L, const Instruction* pc, StkId base, TValue* k);

View File

@ -35,10 +35,9 @@ NativeState::~NativeState() = default;
void initFallbackTable(NativeState& data)
{
// TODO: lvmexecute_split.py could be taught to generate a subset of instructions we actually need
// When fallback is completely removed, remove it from includeInsts list in lvmexecute_split.py
CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE, 0);
CODEGEN_SET_FALLBACK(LOP_NAMECALL, 0);
CODEGEN_SET_FALLBACK(LOP_CALL, kFallbackUpdateCi | kFallbackCheckInterrupt);
CODEGEN_SET_FALLBACK(LOP_FORGPREP, kFallbackUpdatePc);
CODEGEN_SET_FALLBACK(LOP_GETVARARGS, 0);
CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE, 0);
@ -86,6 +85,8 @@ void initHelperFunctions(NativeState& data)
data.context.forgLoopNodeIter = forgLoopNodeIter;
data.context.forgLoopNonTableFallback = forgLoopNonTableFallback;
data.context.forgPrepXnextFallback = forgPrepXnextFallback;
data.context.callProlog = callProlog;
data.context.callEpilogC = callEpilogC;
}
} // namespace CodeGen

View File

@ -26,8 +26,6 @@ class UnwindBuilder;
using FallbackFn = const Instruction*(lua_State* L, const Instruction* pc, StkId base, TValue* k);
constexpr uint8_t kFallbackUpdatePc = 1 << 0;
constexpr uint8_t kFallbackUpdateCi = 1 << 1;
constexpr uint8_t kFallbackCheckInterrupt = 1 << 2;
struct NativeFallback
{
@ -37,7 +35,8 @@ struct NativeFallback
struct NativeProto
{
uintptr_t* instTargets = nullptr;
uintptr_t entryTarget = 0;
uintptr_t* instTargets = nullptr; // TODO: NativeProto should be variable-size with all target embedded
Proto* proto = nullptr;
uint32_t location = 0;
@ -85,6 +84,8 @@ struct NativeContext
bool (*forgLoopNodeIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr;
bool (*forgLoopNonTableFallback)(lua_State* L, int insnA, int aux) = nullptr;
void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr;
Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr;
void (*callEpilogC)(lua_State* L, int nresults, int n) = nullptr;
};
struct NativeState

View File

@ -55,18 +55,23 @@ target_sources(Luau.Compiler PRIVATE
# Luau.CodeGen Sources
target_sources(Luau.CodeGen PRIVATE
CodeGen/include/Luau/AddressA64.h
CodeGen/include/Luau/AssemblyBuilderA64.h
CodeGen/include/Luau/AssemblyBuilderX64.h
CodeGen/include/Luau/CodeAllocator.h
CodeGen/include/Luau/CodeBlockUnwind.h
CodeGen/include/Luau/CodeGen.h
CodeGen/include/Luau/Condition.h
CodeGen/include/Luau/ConditionA64.h
CodeGen/include/Luau/ConditionX64.h
CodeGen/include/Luau/Label.h
CodeGen/include/Luau/OperandX64.h
CodeGen/include/Luau/RegisterA64.h
CodeGen/include/Luau/RegisterX64.h
CodeGen/include/Luau/UnwindBuilder.h
CodeGen/include/Luau/UnwindBuilderDwarf2.h
CodeGen/include/Luau/UnwindBuilderWin.h
CodeGen/src/AssemblyBuilderA64.cpp
CodeGen/src/AssemblyBuilderX64.cpp
CodeGen/src/CodeAllocator.cpp
CodeGen/src/CodeBlockUnwind.cpp
@ -103,6 +108,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/BuiltinDefinitions.h
Analysis/include/Luau/Clone.h
Analysis/include/Luau/Config.h
Analysis/include/Luau/Connective.h
Analysis/include/Luau/Constraint.h
Analysis/include/Luau/ConstraintGraphBuilder.h
Analysis/include/Luau/ConstraintSolver.h
@ -156,6 +162,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/BuiltinDefinitions.cpp
Analysis/src/Clone.cpp
Analysis/src/Config.cpp
Analysis/src/Connective.cpp
Analysis/src/Constraint.cpp
Analysis/src/ConstraintGraphBuilder.cpp
Analysis/src/ConstraintSolver.cpp
@ -299,6 +306,7 @@ if(TARGET Luau.UnitTest)
tests/AstQueryDsl.cpp
tests/ConstraintGraphBuilderFixture.cpp
tests/Fixture.cpp
tests/AssemblyBuilderA64.test.cpp
tests/AssemblyBuilderX64.test.cpp
tests/AstJsonEncoder.test.cpp
tests/AstQuery.test.cpp

View File

@ -20,14 +20,6 @@
#define LUAU_FASTMATH_END
#endif
// Some functions like floor/ceil have SSE4.1 equivalents but we currently support systems without SSE4.1
// On newer GCC and Clang we can use function multi-versioning to generate SSE4.1 code plus CPUID based dispatch.
#if !defined(__APPLE__) && (defined(__x86_64__) || defined(_M_X64)) && ((defined(__clang__) && __clang_major__ >= 14) || (defined(__GNUC__) && __GNUC__ >= 6)) && !defined(__SSE4_1__)
#define LUAU_DISPATCH_SSE41 __attribute__((target_clones("default", "sse4.1")))
#else
#define LUAU_DISPATCH_SSE41
#endif
// Used on functions that have a printf-like interface to validate them statically
#if defined(__GNUC__)
#define LUA_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg)))

View File

@ -96,7 +96,6 @@ static int luauF_atan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId
}
LUAU_FASTMATH_BEGIN
LUAU_DISPATCH_SSE41
static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))
@ -160,7 +159,6 @@ static int luauF_exp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId
}
LUAU_FASTMATH_BEGIN
LUAU_DISPATCH_SSE41
static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))
@ -938,7 +936,6 @@ static int luauF_sign(lua_State* L, StkId res, TValue* arg0, int nresults, StkId
}
LUAU_FASTMATH_BEGIN
LUAU_DISPATCH_SSE41
static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0))

View File

@ -34,7 +34,6 @@ inline bool luai_vecisnan(const float* a)
}
LUAU_FASTMATH_BEGIN
// TODO: LUAU_DISPATCH_SSE41 would be nice here, but clang-14 doesn't support it correctly on inline functions...
inline double luai_nummod(double a, double b)
{
return a - floor(a / b) * b;

View File

@ -0,0 +1,221 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/AssemblyBuilderA64.h"
#include "Luau/StringUtils.h"
#include "doctest.h"
#include <string.h>
using namespace Luau::CodeGen;
static std::string bytecodeAsArray(const std::vector<uint8_t>& bytecode)
{
std::string result = "{";
for (size_t i = 0; i < bytecode.size(); i++)
Luau::formatAppend(result, "%s0x%02x", i == 0 ? "" : ", ", bytecode[i]);
return result.append("}");
}
static std::string bytecodeAsArray(const std::vector<uint32_t>& code)
{
std::string result = "{";
for (size_t i = 0; i < code.size(); i++)
Luau::formatAppend(result, "%s0x%08x", i == 0 ? "" : ", ", code[i]);
return result.append("}");
}
class AssemblyBuilderA64Fixture
{
public:
bool check(void (*f)(AssemblyBuilderA64& build), std::vector<uint32_t> code, std::vector<uint8_t> data = {})
{
AssemblyBuilderA64 build(/* logText= */ false);
f(build);
build.finalize();
if (build.code != code)
{
printf("Expected code: %s\nReceived code: %s\n", bytecodeAsArray(code).c_str(), bytecodeAsArray(build.code).c_str());
return false;
}
if (build.data != data)
{
printf("Expected data: %s\nReceived data: %s\n", bytecodeAsArray(data).c_str(), bytecodeAsArray(build.data).c_str());
return false;
}
return true;
}
};
// armconverter.com can be used to validate instruction sequences
TEST_SUITE_BEGIN("A64Assembly");
#define SINGLE_COMPARE(inst, ...) \
CHECK(check( \
[](AssemblyBuilderA64& build) { \
build.inst; \
}, \
{__VA_ARGS__}))
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Unary")
{
SINGLE_COMPARE(neg(x0, x1), 0xCB0103E0);
SINGLE_COMPARE(neg(w0, w1), 0x4B0103E0);
SINGLE_COMPARE(clz(x0, x1), 0xDAC01020);
SINGLE_COMPARE(clz(w0, w1), 0x5AC01020);
SINGLE_COMPARE(rbit(x0, x1), 0xDAC00020);
SINGLE_COMPARE(rbit(w0, w1), 0x5AC00020);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary")
{
// reg, reg
SINGLE_COMPARE(add(x0, x1, x2), 0x8B020020);
SINGLE_COMPARE(add(w0, w1, w2), 0x0B020020);
SINGLE_COMPARE(add(x0, x1, x2, 7), 0x8B021C20);
SINGLE_COMPARE(sub(x0, x1, x2), 0xCB020020);
SINGLE_COMPARE(and_(x0, x1, x2), 0x8A020020);
SINGLE_COMPARE(orr(x0, x1, x2), 0xAA020020);
SINGLE_COMPARE(eor(x0, x1, x2), 0xCA020020);
SINGLE_COMPARE(lsl(x0, x1, x2), 0x9AC22020);
SINGLE_COMPARE(lsl(w0, w1, w2), 0x1AC22020);
SINGLE_COMPARE(lsr(x0, x1, x2), 0x9AC22420);
SINGLE_COMPARE(asr(x0, x1, x2), 0x9AC22820);
SINGLE_COMPARE(ror(x0, x1, x2), 0x9AC22C20);
// reg, imm
SINGLE_COMPARE(add(x3, x7, 78), 0x910138E3);
SINGLE_COMPARE(add(w3, w7, 78), 0x110138E3);
SINGLE_COMPARE(sub(w3, w7, 78), 0x510138E3);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads")
{
// address forms
SINGLE_COMPARE(ldr(x0, x1), 0xF9400020);
SINGLE_COMPARE(ldr(x0, AddressA64(x1, 8)), 0xF9400420);
SINGLE_COMPARE(ldr(x0, AddressA64(x1, x7)), 0xF8676820);
// load sizes
SINGLE_COMPARE(ldr(x0, x1), 0xF9400020);
SINGLE_COMPARE(ldr(w0, x1), 0xB9400020);
SINGLE_COMPARE(ldrb(w0, x1), 0x39400020);
SINGLE_COMPARE(ldrh(w0, x1), 0x79400020);
SINGLE_COMPARE(ldrsb(x0, x1), 0x39800020);
SINGLE_COMPARE(ldrsb(w0, x1), 0x39C00020);
SINGLE_COMPARE(ldrsh(x0, x1), 0x79800020);
SINGLE_COMPARE(ldrsh(w0, x1), 0x79C00020);
SINGLE_COMPARE(ldrsw(x0, x1), 0xB9800020);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores")
{
// address forms
SINGLE_COMPARE(str(x0, x1), 0xF9000020);
SINGLE_COMPARE(str(x0, AddressA64(x1, 8)), 0xF9000420);
SINGLE_COMPARE(str(x0, AddressA64(x1, x7)), 0xF8276820);
// store sizes
SINGLE_COMPARE(str(x0, x1), 0xF9000020);
SINGLE_COMPARE(str(w0, x1), 0xB9000020);
SINGLE_COMPARE(strb(w0, x1), 0x39000020);
SINGLE_COMPARE(strh(w0, x1), 0x79000020);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Moves")
{
SINGLE_COMPARE(mov(x0, x1), 0xAA0103E0);
SINGLE_COMPARE(mov(w0, w1), 0x2A0103E0);
SINGLE_COMPARE(mov(x0, 42), 0xD2800540);
SINGLE_COMPARE(mov(w0, 42), 0x52800540);
SINGLE_COMPARE(movk(x0, 42, 16), 0xF2A00540);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ControlFlow")
{
// Jump back
CHECK(check(
[](AssemblyBuilderA64& build) {
Label start = build.setLabel();
build.mov(x0, x1);
build.b(ConditionA64::Equal, start);
},
{0xAA0103E0, 0x54FFFFE0}));
// Jump forward
CHECK(check(
[](AssemblyBuilderA64& build) {
Label skip;
build.b(ConditionA64::Equal, skip);
build.mov(x0, x1);
build.setLabel(skip);
},
{0x54000040, 0xAA0103E0}));
// Jumps
CHECK(check(
[](AssemblyBuilderA64& build) {
Label skip;
build.b(ConditionA64::Equal, skip);
build.cbz(x0, skip);
build.cbnz(x0, skip);
build.setLabel(skip);
},
{0x54000060, 0xB4000040, 0xB5000020}));
// Basic control flow
SINGLE_COMPARE(ret(), 0xD65F03C0);
}
TEST_CASE("LogTest")
{
AssemblyBuilderA64 build(/* logText= */ true);
build.add(w0, w1, w2);
build.add(x0, x1, x2, 2);
build.add(w7, w8, 5);
build.add(x7, x8, 5);
build.ldr(x7, x8);
build.ldr(x7, AddressA64(x8, 8));
build.ldr(x7, AddressA64(x8, x9));
build.mov(x1, x2);
build.movk(x1, 42, 16);
Label l;
build.b(ConditionA64::Plus, l);
build.cbz(x7, l);
build.setLabel(l);
build.ret();
build.finalize();
std::string expected = R"(
add w0,w1,w2
add x0,x1,x2 LSL #2
add w7,w8,#5
add x7,x8,#5
ldr x7,[x8]
ldr x7,[x8,#8]
ldr x7,[x8,x9]
mov x1,x2
movk x1,#42 LSL #16
b.pl .L1
cbz x7,.L1
.L1:
ret
)";
CHECK("\n" + build.text == expected);
}
TEST_SUITE_END();

View File

@ -8,7 +8,7 @@
using namespace Luau::CodeGen;
std::string bytecodeAsArray(const std::vector<uint8_t>& bytecode)
static std::string bytecodeAsArray(const std::vector<uint8_t>& bytecode)
{
std::string result = "{";
@ -21,7 +21,7 @@ std::string bytecodeAsArray(const std::vector<uint8_t>& bytecode)
class AssemblyBuilderX64Fixture
{
public:
void check(void (*f)(AssemblyBuilderX64& build), std::vector<uint8_t> code, std::vector<uint8_t> data = {})
bool check(void (*f)(AssemblyBuilderX64& build), std::vector<uint8_t> code, std::vector<uint8_t> data = {})
{
AssemblyBuilderX64 build(/* logText= */ false);
@ -32,25 +32,27 @@ public:
if (build.code != code)
{
printf("Expected code: %s\nReceived code: %s\n", bytecodeAsArray(code).c_str(), bytecodeAsArray(build.code).c_str());
CHECK(false);
return false;
}
if (build.data != data)
{
printf("Expected data: %s\nReceived data: %s\n", bytecodeAsArray(data).c_str(), bytecodeAsArray(build.data).c_str());
CHECK(false);
return false;
}
return true;
}
};
TEST_SUITE_BEGIN("x64Assembly");
#define SINGLE_COMPARE(inst, ...) \
check( \
CHECK(check( \
[](AssemblyBuilderX64& build) { \
build.inst; \
}, \
{__VA_ARGS__})
{__VA_ARGS__}))
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseBinaryInstructionForms")
{
@ -236,9 +238,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfShift")
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea")
{
SINGLE_COMPARE(lea(rax, qword[rdx + rcx]), 0x48, 0x8d, 0x04, 0x0a);
SINGLE_COMPARE(lea(rax, qword[rdx + rax * 4]), 0x48, 0x8d, 0x04, 0x82);
SINGLE_COMPARE(lea(rax, qword[r13 + r12 * 4 + 4]), 0x4b, 0x8d, 0x44, 0xa5, 0x04);
SINGLE_COMPARE(lea(rax, addr[rdx + rcx]), 0x48, 0x8d, 0x04, 0x0a);
SINGLE_COMPARE(lea(rax, addr[rdx + rax * 4]), 0x48, 0x8d, 0x04, 0x82);
SINGLE_COMPARE(lea(rax, addr[r13 + r12 * 4 + 4]), 0x4b, 0x8d, 0x44, 0xa5, 0x04);
}
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps")
@ -280,34 +282,34 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "NopForms")
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentForms")
{
check(
CHECK(check(
[](AssemblyBuilderX64& build) {
build.ret();
build.align(8, AlignmentDataX64::Nop);
},
{0xc3, 0x0f, 0x1f, 0x80, 0x00, 0x00, 0x00, 0x00});
{0xc3, 0x0f, 0x1f, 0x80, 0x00, 0x00, 0x00, 0x00}));
check(
CHECK(check(
[](AssemblyBuilderX64& build) {
build.ret();
build.align(32, AlignmentDataX64::Nop);
},
{0xc3, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84,
0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x1f, 0x40, 0x00});
0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x1f, 0x40, 0x00}));
check(
CHECK(check(
[](AssemblyBuilderX64& build) {
build.ret();
build.align(8, AlignmentDataX64::Int3);
},
{0xc3, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc});
{0xc3, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc}));
check(
CHECK(check(
[](AssemblyBuilderX64& build) {
build.ret();
build.align(8, AlignmentDataX64::Ud2);
},
{0xc3, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0xcc});
{0xc3, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0xcc}));
}
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentOverflow")
@ -349,40 +351,40 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentOverflow")
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow")
{
// Jump back
check(
CHECK(check(
[](AssemblyBuilderX64& build) {
Label start = build.setLabel();
build.add(rsi, 1);
build.cmp(rsi, rdi);
build.jcc(Condition::Equal, start);
build.jcc(ConditionX64::Equal, start);
},
{0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf3, 0xff, 0xff, 0xff});
{0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf3, 0xff, 0xff, 0xff}));
// Jump back, but the label is set before use
check(
CHECK(check(
[](AssemblyBuilderX64& build) {
Label start;
build.add(rsi, 1);
build.setLabel(start);
build.cmp(rsi, rdi);
build.jcc(Condition::Equal, start);
build.jcc(ConditionX64::Equal, start);
},
{0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf7, 0xff, 0xff, 0xff});
{0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf7, 0xff, 0xff, 0xff}));
// Jump forward
check(
CHECK(check(
[](AssemblyBuilderX64& build) {
Label skip;
build.cmp(rsi, rdi);
build.jcc(Condition::Greater, skip);
build.jcc(ConditionX64::Greater, skip);
build.or_(rdi, 0x3e);
build.setLabel(skip);
},
{0x48, 0x3b, 0xf7, 0x0f, 0x8f, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xcf, 0x3e});
{0x48, 0x3b, 0xf7, 0x0f, 0x8f, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xcf, 0x3e}));
// Regular jump
check(
CHECK(check(
[](AssemblyBuilderX64& build) {
Label skip;
@ -390,12 +392,12 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow")
build.and_(rdi, 0x3e);
build.setLabel(skip);
},
{0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e});
{0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e}));
}
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall")
{
check(
CHECK(check(
[](AssemblyBuilderX64& build) {
Label fnB;
@ -404,10 +406,10 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall")
build.ret();
build.setLabel(fnB);
build.lea(rax, qword[rcx + 0x1f]);
build.lea(rax, addr[rcx + 0x1f]);
build.ret();
},
{0x48, 0x83, 0xe1, 0x3e, 0xe8, 0x01, 0x00, 0x00, 0x00, 0xc3, 0x48, 0x8d, 0x41, 0x1f, 0xc3});
{0x48, 0x83, 0xe1, 0x3e, 0xe8, 0x01, 0x00, 0x00, 0x00, 0xc3, 0x48, 0x8d, 0x41, 0x1f, 0xc3}));
}
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms")
@ -518,7 +520,7 @@ TEST_CASE("LogTest")
Label start = build.setLabel();
build.cmp(rsi, rdi);
build.jcc(Condition::Equal, start);
build.jcc(ConditionX64::Equal, start);
build.jmp(qword[rdx]);
build.vaddps(ymm9, ymm12, ymmword[rbp + 0xc]);
@ -548,7 +550,7 @@ TEST_CASE("LogTest")
build.finalize();
bool same = "\n" + build.text == R"(
std::string expected = R"(
push r12
; align 8
nop word ptr[rax+rax] ; 6-byte nop
@ -588,13 +590,14 @@ TEST_CASE("LogTest")
nop dword ptr[rax+rax] ; 8-byte nop
nop word ptr[rax+rax] ; 9-byte nop
)";
CHECK(same);
CHECK("\n" + build.text == expected);
}
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "Constants")
{
// clang-format off
check(
CHECK(check(
[](AssemblyBuilderX64& build) {
build.xor_(rax, rax);
build.add(rax, build.i64(0x1234567887654321));
@ -625,7 +628,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "Constants")
0x00, 0x00, 0x00, 0x00, // padding to align f64
0x00, 0x00, 0x80, 0x3f,
0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12,
});
}));
// clang-format on
}

View File

@ -1,5 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/AssemblyBuilderX64.h"
#include "Luau/AssemblyBuilderA64.h"
#include "Luau/CodeAllocator.h"
#include "Luau/CodeBlockUnwind.h"
#include "Luau/UnwindBuilder.h"
@ -207,7 +208,7 @@ constexpr RegisterX64 rArg3 = rdx;
constexpr RegisterX64 rNonVol1 = r12;
constexpr RegisterX64 rNonVol2 = rbx;
TEST_CASE("GeneratedCodeExecution")
TEST_CASE("GeneratedCodeExecutionX64")
{
AssemblyBuilderX64 build(/* logText= */ false);
@ -241,7 +242,7 @@ void throwing(int64_t arg)
throw std::runtime_error("testing");
}
TEST_CASE("GeneratedCodeExecutionWithThrow")
TEST_CASE("GeneratedCodeExecutionWithThrowX64")
{
AssemblyBuilderX64 build(/* logText= */ false);
@ -267,7 +268,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrow")
build.sub(rsp, stackSize + localsSize);
unwind->allocStack(stackSize + localsSize);
build.lea(rbp, qword[rsp + stackSize]);
build.lea(rbp, addr[rsp + stackSize]);
unwind->setupFrameReg(rbp, stackSize);
unwind->finish();
@ -281,7 +282,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrow")
build.call(rNonVol2);
// Epilogue
build.lea(rsp, qword[rbp + localsSize]);
build.lea(rsp, addr[rbp + localsSize]);
build.pop(rbp);
build.pop(rNonVol2);
build.pop(rNonVol1);
@ -317,7 +318,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrow")
}
}
TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGate")
TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64")
{
AssemblyBuilderX64 build(/* logText= */ false);
@ -351,7 +352,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGate")
build.sub(rsp, stackSize + localsSize);
unwind->allocStack(stackSize + localsSize);
build.lea(rbp, qword[rsp + stackSize]);
build.lea(rbp, addr[rsp + stackSize]);
unwind->setupFrameReg(rbp, stackSize);
unwind->finish();
@ -366,7 +367,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGate")
Label returnOffset = build.setLabel();
// Epilogue
build.lea(rsp, qword[rbp + localsSize]);
build.lea(rsp, addr[rbp + localsSize]);
build.pop(rbp);
build.pop(r15);
build.pop(r14);
@ -432,4 +433,43 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGate")
#endif
#if defined(__aarch64__)
TEST_CASE("GeneratedCodeExecutionA64")
{
AssemblyBuilderA64 build(/* logText= */ false);
Label skip;
build.cbz(x1, skip);
build.ldrsw(x1, x1);
build.cbnz(x1, skip);
build.mov(x1, 0); // doesn't execute due to cbnz above
build.setLabel(skip);
build.add(x1, x1, 1);
build.add(x0, x0, x1, /* LSL */ 1);
build.ret();
build.finalize();
size_t blockSize = 1024 * 1024;
size_t maxTotalSize = 1024 * 1024;
CodeAllocator allocator(blockSize, maxTotalSize);
uint8_t* nativeData;
size_t sizeNativeData;
uint8_t* nativeEntry;
REQUIRE(allocator.allocate(build.data.data(), build.data.size(), reinterpret_cast<uint8_t*>(build.code.data()), build.code.size() * 4, nativeData,
sizeNativeData, nativeEntry));
REQUIRE(nativeEntry);
using FunctionType = int64_t(int64_t, int*);
FunctionType* f = (FunctionType*)nativeEntry;
int input = 10;
int64_t result = f(20, &input);
CHECK(result == 42);
}
#endif
TEST_SUITE_END();

View File

@ -539,8 +539,14 @@ TEST_CASE("Debugger")
static bool singlestep = false;
static int stephits = 0;
SUBCASE("") { singlestep = false; }
SUBCASE("SingleStep") { singlestep = true; }
SUBCASE("")
{
singlestep = false;
}
SUBCASE("SingleStep")
{
singlestep = true;
}
breakhits = 0;
interruptedthread = nullptr;

View File

@ -506,13 +506,14 @@ std::optional<TypeId> linearSearchForBinding(Scope* scope, const char* name)
return std::nullopt;
}
void registerNotType(Fixture& fixture, TypeArena& arena)
void registerHiddenTypes(Fixture& fixture, TypeArena& arena)
{
TypeId t = arena.addType(GenericTypeVar{"T"});
GenericTypeDefinition genericT{t};
ScopePtr moduleScope = fixture.frontend.getGlobalScope();
moduleScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, arena.addType(NegationTypeVar{t})};
moduleScope->exportedTypeBindings["fun"] = TypeFun{{}, fixture.singletonTypes->functionType};
}
void dump(const std::vector<Constraint>& constraints)

View File

@ -186,7 +186,7 @@ std::optional<TypeId> lookupName(ScopePtr scope, const std::string& name); // Wa
std::optional<TypeId> linearSearchForBinding(Scope* scope, const char* name);
void registerNotType(Fixture& fixture, TypeArena& arena);
void registerHiddenTypes(Fixture& fixture, TypeArena& arena);
} // namespace Luau

View File

@ -3,6 +3,7 @@
#include "Fixture.h"
#include "Luau/Common.h"
#include "Luau/TypeVar.h"
#include "doctest.h"
#include "Luau/Normalize.h"
@ -19,7 +20,7 @@ struct IsSubtypeFixture : Fixture
return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice);
}
};
}
} // namespace
void createSomeClasses(Frontend& frontend)
{
@ -109,12 +110,10 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_and_any")
TypeId a = requireType("a");
TypeId b = requireType("b");
// Intuition:
// We cannot use b where a is required because we cannot rely on b to return a string.
// We cannot use a where b is required because we cannot rely on a to accept non-number arguments.
// any makes things work even when it makes no sense.
CHECK(!isSubtype(b, a));
CHECK(!isSubtype(a, b));
CHECK(isSubtype(b, a));
CHECK(isSubtype(a, b));
}
TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_functions_with_no_head")
@ -199,7 +198,7 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop")
TypeId b = requireType("b");
CHECK(isSubtype(a, b));
CHECK(!isSubtype(b, a));
CHECK(isSubtype(b, a));
}
TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection")
@ -261,7 +260,7 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "tables")
TypeId d = requireType("d");
CHECK(isSubtype(a, b));
CHECK(!isSubtype(b, a));
CHECK(isSubtype(b, a));
CHECK(!isSubtype(c, a));
CHECK(!isSubtype(a, c));
@ -394,7 +393,8 @@ TEST_SUITE_END();
struct NormalizeFixture : Fixture
{
ScopedFastFlag sff{"LuauNegatedStringSingletons", true};
ScopedFastFlag sff0{"LuauNegatedStringSingletons", true};
ScopedFastFlag sff1{"LuauNegatedFunctionTypes", true};
TypeArena arena;
InternalErrorReporter iceHandler;
@ -403,16 +403,21 @@ struct NormalizeFixture : Fixture
NormalizeFixture()
{
registerNotType(*this, arena);
registerHiddenTypes(*this, arena);
}
TypeId normal(const std::string& annotation)
const NormalizedType* toNormalizedType(const std::string& annotation)
{
CheckResult result = check("type _Res = " + annotation);
LUAU_REQUIRE_NO_ERRORS(result);
std::optional<TypeId> ty = lookupType("_Res");
REQUIRE(ty);
const NormalizedType* norm = normalizer.normalize(*ty);
return normalizer.normalize(*ty);
}
TypeId normal(const std::string& annotation)
{
const NormalizedType* norm = toNormalizedType(annotation);
REQUIRE(norm);
return normalizer.typeFromNormal(*norm);
}
@ -476,6 +481,13 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_of_negations")
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "disjoint_negations_normalize_to_string")
{
CHECK(R"(string)" == toString(normal(R"(
(string & Not<"hello"> & Not<"world">) | (string & Not<"goodbye">)
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean")
{
CHECK("true" == toString(normal(R"(
@ -490,10 +502,43 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean_2")
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "bare_negation")
TEST_CASE_FIXTURE(NormalizeFixture, "intersect_function_and_top_function")
{
CHECK("() -> ()" == toString(normal(R"(
fun & (() -> ())
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "intersect_function_and_top_function_reverse")
{
CHECK("() -> ()" == toString(normal(R"(
(() -> ()) & fun
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "union_function_and_top_function")
{
CHECK("function" == toString(normal(R"(
fun | (() -> ())
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "negated_function_is_anything_except_a_function")
{
CHECK("(boolean | number | string | thread)?" == toString(normal(R"(
Not<fun>
)")));
}
TEST_CASE_FIXTURE(NormalizeFixture, "specific_functions_cannot_be_negated")
{
CHECK(nullptr == toNormalizedType("Not<(boolean) -> boolean>"));
}
TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean")
{
// TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function
CHECK("(number | string | thread)?" == toString(normal(R"(
CHECK("(function | number | string | thread)?" == toString(normal(R"(
Not<boolean>
)")));
}

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Parser.h"
#include "AstQueryDsl.h"
#include "Fixture.h"
#include "ScopedFlags.h"
@ -2775,4 +2776,51 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the
CHECK(2 == f->generics.size);
}
TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_between_table_members")
{
ScopedFastFlag luauTableConstructorRecovery{"LuauTableConstructorRecovery", true};
ParseResult result = tryParse(R"(
local t = {
first = 1
second = 2,
third = 3,
fouth = 4,
}
)");
REQUIRE(1 == result.errors.size());
CHECK(Location({3, 12}, {3, 18}) == result.errors[0].getLocation());
CHECK("Expected ',' after table constructor element" == result.errors[0].getMessage());
REQUIRE(1 == result.root->body.size);
AstExprTable* table = Luau::query<AstExprTable>(result.root);
REQUIRE(table);
CHECK(table->items.size == 4);
}
TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_after_last_table_member")
{
ParseResult result = tryParse(R"(
local t = {
first = 1
local ok = true
local good = ok == true
)");
REQUIRE(1 == result.errors.size());
CHECK(Location({4, 8}, {4, 13}) == result.errors[0].getLocation());
CHECK("Expected '}' (to close '{' at line 2), got 'local'" == result.errors[0].getMessage());
REQUIRE(3 == result.root->body.size);
AstExprTable* table = Luau::query<AstExprTable>(result.root);
REQUIRE(table);
CHECK(table->items.size == 1);
}
TEST_SUITE_END();

View File

@ -10,7 +10,6 @@
using namespace Luau;
LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction);
LUAU_FASTFLAG(LuauSpecialTypesAsterisked);
LUAU_FASTFLAG(LuauFixNameMaps);
LUAU_FASTFLAG(LuauFunctionReturnStringificationFixup);
@ -270,17 +269,9 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded")
o.maxTypeLength = 40;
CHECK_EQ(toString(requireType("f0"), o), "() -> ()");
CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()");
if (FFlag::LuauSpecialTypesAsterisked)
{
CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
}
else
{
CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... <TRUNCATED>");
CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... <TRUNCATED>");
}
}
TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive")
{
@ -297,17 +288,9 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive")
o.maxTypeLength = 40;
CHECK_EQ(toString(requireType("f0"), o), "() -> ()");
CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()");
if (FFlag::LuauSpecialTypesAsterisked)
{
CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*");
}
else
{
CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... <TRUNCATED>");
CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... <TRUNCATED>");
}
}
TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table_state_braces")
{
@ -535,10 +518,7 @@ local function target(callback: nil) return callback(4, "hello") end
)");
LUAU_REQUIRE_ERRORS(result);
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("(nil) -> (*error-type*)", toString(requireType("target")));
else
CHECK_EQ("(nil) -> (<error-type>)", toString(requireType("target")));
}
TEST_CASE_FIXTURE(Fixture, "toStringGenericPack")

View File

@ -13,8 +13,6 @@
using namespace Luau;
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
TEST_SUITE_BEGIN("TypeInferAnyError");
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any")
@ -96,10 +94,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error")
LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(requireType("a")));
else
CHECK_EQ("<error-type>", toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2")
@ -115,10 +110,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2")
LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(requireType("a")));
else
CHECK_EQ("<error-type>", toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error")
@ -233,10 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error")
CHECK_EQ("unknown", err->name);
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(requireType("a")));
else
CHECK_EQ("<error-type>", toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error")
@ -245,10 +234,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error")
local a = Utility.Create "Foo" {}
)");
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(requireType("a")));
else
CHECK_EQ("<error-type>", toString(requireType("a")));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any")

View File

@ -8,7 +8,6 @@
using namespace Luau;
LUAU_FASTFLAG(LuauSpecialTypesAsterisked);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
TEST_SUITE_BEGIN("BuiltinTests");
@ -685,7 +684,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution && FFlag::LuauSpecialTypesAsterisked)
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("string", toString(requireType("foo")));
CHECK_EQ("*error-type*", toString(requireType("bar")));
@ -714,7 +713,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_strin
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution && FFlag::LuauSpecialTypesAsterisked)
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("string", toString(requireType("foo")));
CHECK_EQ("string", toString(requireType("bar")));
@ -1016,10 +1015,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic")
CHECK_EQ("number", toString(requireType("a")));
CHECK_EQ("string", toString(requireType("b")));
CHECK_EQ("boolean", toString(requireType("c")));
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(requireType("d")));
else
CHECK_EQ("<error-type>", toString(requireType("d")));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments")

View File

@ -15,7 +15,6 @@
using namespace Luau;
LUAU_FASTFLAG(LuauInstantiateInSubtyping);
LUAU_FASTFLAG(LuauSpecialTypesAsterisked);
TEST_SUITE_BEGIN("TypeInferFunctions");
@ -980,19 +979,13 @@ TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language")
REQUIRE(tm1);
CHECK_EQ("(string) -> number", toString(tm1->wantedType));
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("(string, *error-type*) -> number", toString(tm1->givenType));
else
CHECK_EQ("(string, <error-type>) -> number", toString(tm1->givenType));
auto tm2 = get<TypeMismatch>(result.errors[1]);
REQUIRE(tm2);
CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType));
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("(string, *error-type*) -> number", toString(tm2->givenType));
else
CHECK_EQ("(string, <error-type>) -> number", toString(tm2->givenType));
}
TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type")
@ -1538,21 +1531,11 @@ function t:b() return 2 end -- not OK
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauSpecialTypesAsterisked)
{
CHECK_EQ(R"(Type '(*error-type*) -> number' could not be converted into '() -> number'
caused by:
Argument count mismatch. Function expects 1 argument, but none are specified)",
toString(result.errors[0]));
}
else
{
CHECK_EQ(R"(Type '(<error-type>) -> number' could not be converted into '() -> number'
caused by:
Argument count mismatch. Function expects 1 argument, but none are specified)",
toString(result.errors[0]));
}
}
TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic")
{
@ -1779,7 +1762,72 @@ z = y -- Not OK, so the line is colorable
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(toString(result.errors[0]), "Type '((\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> boolean) & ((\"blue\" | \"red\") -> (\"blue\") -> (\"blue\") -> false) & ((\"blue\" | \"red\") -> (\"red\") -> (\"red\") -> false) & ((\"blue\") -> (\"blue\") -> (\"blue\" | \"red\") -> false) & ((\"red\") -> (\"red\") -> (\"blue\" | \"red\") -> false)' could not be converted into '(\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> false'; none of the intersection parts are compatible");
CHECK_EQ(toString(result.errors[0]),
"Type '((\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> boolean) & ((\"blue\" | \"red\") -> (\"blue\") -> (\"blue\") "
"-> false) & ((\"blue\" | \"red\") -> (\"red\") -> (\"red\") -> false) & ((\"blue\") -> (\"blue\") -> (\"blue\" | \"red\") -> false) & "
"((\"red\") -> (\"red\") -> (\"blue\" | \"red\") -> false)' could not be converted into '(\"blue\" | \"red\") -> (\"blue\" | \"red\") -> "
"(\"blue\" | \"red\") -> false'; none of the intersection parts are compatible");
}
TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions")
{
ScopedFastFlag sff{"LuauNegatedFunctionTypes", true};
registerHiddenTypes(*this, frontend.globalTypes);
CheckResult result = check(R"(
function foo(f: fun) end
function a() end
function id(x) return x end
foo(a)
foo(id)
foo(foo)
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function")
{
ScopedFastFlag sff{"LuauNegatedFunctionTypes", true};
registerHiddenTypes(*this, frontend.globalTypes);
CheckResult result = check(R"(
local a: fun = function() end
function one(arg: () -> ()) end
function two(arg: <T>(T) -> T) end
one(a)
two(a)
)");
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK(6 == result.errors[0].location.begin.line);
CHECK(7 == result.errors[1].location.begin.line);
}
TEST_CASE_FIXTURE(Fixture, "other_things_are_not_related_to_function")
{
ScopedFastFlag sff{"LuauNegatedFunctionTypes", true};
registerHiddenTypes(*this, frontend.globalTypes);
CheckResult result = check(R"(
local a: fun = function() end
local b: {} = a
local c: boolean = a
local d: fun = true
local e: fun = {}
)");
LUAU_REQUIRE_ERROR_COUNT(4, result);
CHECK(2 == result.errors[0].location.begin.line);
CHECK(3 == result.errors[1].location.begin.line);
CHECK(4 == result.errors[2].location.begin.line);
CHECK(5 == result.errors[3].location.begin.line);
}
TEST_SUITE_END();

View File

@ -10,7 +10,6 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
using namespace Luau;
@ -1011,10 +1010,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying")
std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0");
REQUIRE(t0);
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(t0->type));
else
CHECK_EQ("<error-type>", toString(t0->type));
auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) {
return get<OccursCheckFailed>(err);

View File

@ -13,7 +13,6 @@
using namespace Luau;
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
TEST_SUITE_BEGIN("TypeInferLoops");
@ -157,10 +156,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error")
LUAU_REQUIRE_ERROR_COUNT(2, result);
TypeId p = requireType("p");
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(p));
else
CHECK_EQ("<error-type>", toString(p));
}
TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function")

View File

@ -14,8 +14,6 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping)
using namespace Luau;
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
TEST_SUITE_BEGIN("TypeInferModules");
TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_require_basic")
@ -176,10 +174,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export")
auto hootyType = requireType(bModule, "Hooty");
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(hootyType));
else
CHECK_EQ("<error-type>", toString(hootyType));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript")
@ -282,10 +277,7 @@ local ModuleA = require(game.A)
std::optional<TypeId> oty = requireType("ModuleA");
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(*oty));
else
CHECK_EQ("<error-type>", toString(*oty));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types")

View File

@ -20,10 +20,10 @@ struct NegationFixture : Fixture
NegationFixture()
{
registerNotType(*this, arena);
registerHiddenTypes(*this, arena);
}
};
}
} // namespace
TEST_SUITE_BEGIN("Negations");

View File

@ -11,8 +11,6 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
using namespace Luau;
TEST_SUITE_BEGIN("TypeInferPrimitives");
@ -49,10 +47,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index")
REQUIRE(nat);
CHECK_EQ("string", toString(nat->ty));
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(requireType("t")));
else
CHECK_EQ("<error-type>", toString(requireType("t")));
}
TEST_CASE_FIXTURE(Fixture, "string_method")

View File

@ -633,7 +633,7 @@ struct IsSubtypeFixture : Fixture
return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice);
}
};
}
} // namespace
TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection_of_functions_of_different_arities")
{

View File

@ -7,7 +7,6 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
using namespace Luau;
@ -127,7 +126,7 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint")
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("(string?) & ~~~(false?)", toString(requireTypeAtPosition({5, 26})));
CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({5, 26})));
}
else
{
@ -153,7 +152,7 @@ TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through")
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("(string?) & ~~~(false?)", toString(requireTypeAtPosition({5, 26})));
CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({5, 26})));
}
else
{
@ -178,8 +177,16 @@ TEST_CASE_FIXTURE(Fixture, "and_constraint")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("(number?) & ~(false?)", toString(requireTypeAtPosition({4, 26})));
}
else
{
CHECK_EQ("string", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("number", toString(requireTypeAtPosition({4, 26})));
}
CHECK_EQ("string?", toString(requireTypeAtPosition({6, 26})));
CHECK_EQ("number?", toString(requireTypeAtPosition({7, 26})));
@ -204,9 +211,17 @@ TEST_CASE_FIXTURE(Fixture, "not_and_constraint")
CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26})));
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({6, 26})));
CHECK_EQ("(number?) & ~(false?)", toString(requireTypeAtPosition({7, 26})));
}
else
{
CHECK_EQ("string", toString(requireTypeAtPosition({6, 26})));
CHECK_EQ("number", toString(requireTypeAtPosition({7, 26})));
}
}
TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates")
{
@ -227,9 +242,57 @@ TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates")
CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26})));
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({6, 26})));
CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({7, 26})));
}
else
{
CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26})));
CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26})));
}
}
TEST_CASE_FIXTURE(Fixture, "a_and_b_or_a_and_c")
{
CheckResult result = check(R"(
function f(a: string?, b: number?, c: boolean)
if (a and b) or (a and c) then
local foo = a
local bar = b
local baz = c
else
local foo = a
local bar = b
local baz = c
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("(string?) & (~(false?) | ~(false?))", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28})));
CHECK_EQ("boolean", toString(requireTypeAtPosition({5, 28})));
CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28})));
CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28})));
CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28})));
}
else
{
CHECK_EQ("string", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28})));
CHECK_EQ("true", toString(requireTypeAtPosition({5, 28}))); // oh no! :(
CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28})));
CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28})));
CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28})));
}
}
TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints")
{
@ -244,9 +307,18 @@ TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("number?", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("string?", toString(requireTypeAtPosition({4, 26})));
}
else
{
// We're going to drop support for type refinements through type assertions.
CHECK_EQ("number", toString(requireTypeAtPosition({3, 26})));
CHECK_EQ("string", toString(requireTypeAtPosition({4, 26})));
}
}
TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position")
{
@ -381,12 +453,23 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "((number | string)?) & (boolean?)"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "((number | string)?) & (boolean?)"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "((number | string)?) & unknown"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "(boolean?) & unknown"); // a ~= b
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b
}
}
TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term")
{
@ -402,9 +485,17 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term")
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & number"); // a == 1
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & unknown"); // a ~= 1
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1;
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1
}
}
TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue")
{
@ -420,8 +511,16 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello" & ((number | string)?))"); // a == "hello"
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"(((number | string)?) & ~"hello")"); // a ~= "hello"
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello"
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello"
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"((number | string)?)"); // a ~= "hello"
}
}
TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil")
@ -438,9 +537,17 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & ~nil"); // a ~= nil
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & nil"); // a == nil
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil
}
}
TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue")
{
@ -454,9 +561,18 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
ToStringOptions opts;
CHECK_EQ(toString(requireTypeAtPosition({3, 33}), opts), "(string?) & a"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36}), opts), "(string?) & a"); // a == b
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b
}
}
TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal")
{
@ -470,9 +586,17 @@ TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_e
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any & unknown"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "({| x: number |}?) & unknown"); // a ~= b
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b
}
}
TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil")
{
@ -490,12 +614,23 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string & unknown"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "(string?) & unknown"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "(string?) & string"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "(string?) & string"); // a == b
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b
CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b
}
}
TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable")
{
@ -525,10 +660,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28})));
else
CHECK_EQ("<error-type>", toString(requireTypeAtPosition({3, 28})));
}
TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true")
@ -715,9 +847,17 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({4, 28})));
}
else
{
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28})));
}
}
TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2")
{
@ -732,9 +872,17 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::DebugLuauDeferredConstraintResolution)
{
CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({4, 28})));
}
else
{
CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28})));
}
}
TEST_CASE_FIXTURE(Fixture, "either_number_or_string")
{

View File

@ -79,6 +79,16 @@ TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype")
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype_multi_assignment")
{
CheckResult result = check(R"(
local a: "foo" = "foo"
local b: string, c: number = a, 10
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons")
{
CheckResult result = check(R"(

View File

@ -17,7 +17,6 @@ using namespace Luau;
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
TEST_SUITE_BEGIN("TableTests");
@ -3275,7 +3274,8 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table
// LUAU_REQUIRE_ERRORS(result);
// CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}'
// caused by:
// Property 'm' is not compatible. Type '<a>(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)");
// Property 'm' is not compatible. Type '<a>(a) -> a' could not be converted into '(number) -> number'; different number of generic type
// parameters)");
// // this error message is not great since the underlying issue is that the context is invariant,
// and `(number) -> number` cannot be a subtype of `<a>(a) -> a`.
}
@ -3335,10 +3335,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated")
ToStringOptions opts;
opts.exhaustive = true;
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts));
else
CHECK_EQ("{ x: <error-type>, y: number }", toString(requireType("t"), opts));
}
TEST_SUITE_END();

View File

@ -17,7 +17,6 @@
LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauInstantiateInSubtyping);
LUAU_FASTFLAG(LuauSpecialTypesAsterisked);
using namespace Luau;
@ -237,22 +236,12 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types")
// TODO: Should we assert anything about these tests when DCR is being used?
if (!FFlag::DebugLuauDeferredConstraintResolution)
{
if (FFlag::LuauSpecialTypesAsterisked)
{
CHECK_EQ("*error-type*", toString(requireType("c")));
CHECK_EQ("*error-type*", toString(requireType("d")));
CHECK_EQ("*error-type*", toString(requireType("e")));
CHECK_EQ("*error-type*", toString(requireType("f")));
}
else
{
CHECK_EQ("<error-type>", toString(requireType("c")));
CHECK_EQ("<error-type>", toString(requireType("d")));
CHECK_EQ("<error-type>", toString(requireType("e")));
CHECK_EQ("<error-type>", toString(requireType("f")));
}
}
}
TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowing")
@ -662,10 +651,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional")
std::optional<TypeFun> t0 = getMainModule()->getModuleScope()->lookupType("t0");
REQUIRE(t0);
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(t0->type));
else
CHECK_EQ("<error-type>", toString(t0->type));
auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) {
return get<OccursCheckFailed>(err);

View File

@ -9,8 +9,6 @@
using namespace Luau;
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
struct TryUnifyFixture : Fixture
{
TypeArena arena;
@ -124,10 +122,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("a", toString(requireType("a")));
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(requireType("b")));
else
CHECK_EQ("<error-type>", toString(requireType("b")));
}
TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained")
@ -142,10 +137,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_con
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("a", toString(requireType("a")));
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(requireType("b")));
else
CHECK_EQ("<error-type>", toString(requireType("b")));
CHECK_EQ("number", toString(requireType("c")));
}

View File

@ -6,8 +6,6 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauSpecialTypesAsterisked)
using namespace Luau;
TEST_SUITE_BEGIN("UnionTypes");
@ -199,10 +197,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property")
CHECK_EQ(mup->missing[0], *bTy);
CHECK_EQ(mup->key, "x");
if (FFlag::LuauSpecialTypesAsterisked)
CHECK_EQ("*error-type*", toString(requireType("r")));
else
CHECK_EQ("<error-type>", toString(requireType("r")));
}
TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any")

View File

@ -217,4 +217,35 @@ TEST_CASE("Visit")
CHECK(r3 == "1231147");
}
struct MoveOnly
{
MoveOnly() = default;
MoveOnly(const MoveOnly&) = delete;
MoveOnly& operator=(const MoveOnly&) = delete;
MoveOnly(MoveOnly&&) = default;
MoveOnly& operator=(MoveOnly&&) = default;
};
TEST_CASE("Move")
{
Variant<MoveOnly> v1 = MoveOnly{};
Variant<MoveOnly> v2 = std::move(v1);
}
TEST_CASE("MoveWithCopyableAlternative")
{
Variant<std::string, MoveOnly> v1 = std::string{"Hello, world! I am longer than a normal hello world string to avoid SSO."};
Variant<std::string, MoveOnly> v2 = std::move(v1);
std::string* s1 = get_if<std::string>(&v1);
REQUIRE(s1);
CHECK(*s1 == "");
std::string* s2 = get_if<std::string>(&v2);
REQUIRE(s2);
CHECK(*s2 == "Hello, world! I am longer than a normal hello world string to avoid SSO.");
}
TEST_SUITE_END();

View File

@ -10,7 +10,9 @@ AnnotationTests.too_many_type_params
AnnotationTests.two_type_params
AnnotationTests.unknown_type_reference_generates_error
AstQuery.last_argument_function_call_type
AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method
AstQuery::getDocumentationSymbolAtPosition.overloaded_fn
AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop
AutocompleteTest.autocomplete_first_function_arg_expected_type
AutocompleteTest.autocomplete_interpolated_string
AutocompleteTest.autocomplete_oop_implicit_self
@ -106,17 +108,14 @@ GenericsTests.correctly_instantiate_polymorphic_member_functions
GenericsTests.do_not_infer_generic_functions
GenericsTests.duplicate_generic_type_packs
GenericsTests.duplicate_generic_types
GenericsTests.factories_of_generics
GenericsTests.generic_argument_count_too_few
GenericsTests.generic_argument_count_too_many
GenericsTests.generic_factories
GenericsTests.generic_functions_in_types
GenericsTests.generic_functions_should_be_memory_safe
GenericsTests.generic_table_method
GenericsTests.generic_type_pack_parentheses
GenericsTests.generic_type_pack_unification1
GenericsTests.generic_type_pack_unification2
GenericsTests.generic_type_pack_unification3
GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments
GenericsTests.infer_generic_function_function_argument
GenericsTests.infer_generic_function_function_argument_overloaded
@ -172,7 +171,6 @@ ProvisionalTests.table_insert_with_a_singleton_argument
ProvisionalTests.typeguard_inference_incomplete
ProvisionalTests.weirditer_should_not_loop_forever
ProvisionalTests.while_body_are_also_refined
RefinementTest.and_constraint
RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string
RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number
RefinementTest.assert_non_binary_expressions_actually_resolve_constraints
@ -187,28 +185,17 @@ RefinementTest.either_number_or_string
RefinementTest.eliminate_subclasses_of_instance
RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil
RefinementTest.index_on_a_refined_property
RefinementTest.invert_is_truthy_constraint
RefinementTest.invert_is_truthy_constraint_ifelse_expression
RefinementTest.is_truthy_constraint
RefinementTest.is_truthy_constraint_ifelse_expression
RefinementTest.lvalue_is_not_nil
RefinementTest.merge_should_be_fully_agnostic_of_hashmap_ordering
RefinementTest.narrow_boolean_to_true_or_false
RefinementTest.narrow_property_of_a_bounded_variable
RefinementTest.narrow_this_large_union
RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true
RefinementTest.not_a_and_not_b
RefinementTest.not_a_and_not_b2
RefinementTest.not_and_constraint
RefinementTest.not_t_or_some_prop_of_t
RefinementTest.or_predicate_with_truthy_predicates
RefinementTest.parenthesized_expressions_are_followed_through
RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table
RefinementTest.refine_the_correct_types_opposite_of_when_a_is_not_number_or_string
RefinementTest.refine_unknowns
RefinementTest.term_is_equal_to_an_lvalue
RefinementTest.truthy_constraint_on_properties
RefinementTest.type_assertion_expr_carry_its_constraints
RefinementTest.type_comparison_ifelse_expression
RefinementTest.type_guard_can_filter_for_intersection_of_tables
RefinementTest.type_guard_can_filter_for_overloaded_function
@ -271,6 +258,7 @@ TableTests.infer_indexer_from_value_property_in_literal
TableTests.inferred_return_type_of_free_table
TableTests.inferring_crazy_table_should_also_be_quick
TableTests.instantiate_table_cloning_3
TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound
TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound
TableTests.leaking_bad_metatable_errors
TableTests.less_exponential_blowup_please
@ -315,7 +303,6 @@ TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors
TableTests.tables_get_names_from_their_locals
TableTests.tc_member_function
TableTests.tc_member_function_2
TableTests.top_table_type
TableTests.type_mismatch_on_massive_table_is_cut_short
TableTests.unification_of_unions_in_a_self_referential_type
TableTests.unifying_tables_shouldnt_uaf2
@ -417,12 +404,12 @@ TypeInferFunctions.too_few_arguments_variadic
TypeInferFunctions.too_few_arguments_variadic_generic
TypeInferFunctions.too_few_arguments_variadic_generic2
TypeInferFunctions.too_many_arguments
TypeInferFunctions.too_many_arguments_error_location
TypeInferFunctions.too_many_return_values
TypeInferFunctions.too_many_return_values_in_parentheses
TypeInferFunctions.too_many_return_values_no_function
TypeInferFunctions.vararg_function_is_quantified
TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values
TypeInferLoops.for_in_loop_with_custom_iterator
TypeInferLoops.for_in_loop_with_next
TypeInferLoops.for_in_with_generic_next
TypeInferLoops.for_in_with_just_one_iterator_is_ok
@ -430,7 +417,6 @@ TypeInferLoops.loop_iter_no_indexer_nonstrict
TypeInferLoops.loop_iter_trailing_nil
TypeInferLoops.unreachable_code_after_infinite_loop
TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free
TypeInferModules.bound_free_table_export_is_ok
TypeInferModules.custom_require_global
TypeInferModules.do_not_modify_imported_types
TypeInferModules.module_type_conflict
@ -443,6 +429,7 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2
TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon
TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory
TypeInferOOP.methods_are_topologically_sorted
TypeInferOOP.object_constructor_can_refer_to_method_of_self
TypeInferOperators.and_or_ternary
TypeInferOperators.CallAndOrOfFunctions
TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable

View File

@ -32,6 +32,9 @@ source = """// This file is part of the Luau programming language and is license
"""
function = ""
signature = ""
includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_COVERAGE", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS"]
state = 0
@ -44,7 +47,6 @@ for line in input:
if match:
inst = match[1]
signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)"
header += signature + ";\n"
function = signature + "\n"
function += "{\n"
function += " [[maybe_unused]] Closure* cl = clvalue(L->ci->func);\n"
@ -84,7 +86,10 @@ for line in input:
function = function[:-len(finalline)]
function += " return pc;\n}\n"
if inst in includeInsts:
header += signature + ";\n"
source += function + "\n"
state = 0
# skip LUA_CUSTOM_EXECUTION code blocks

View File

@ -0,0 +1,173 @@
#!/usr/bin/python
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
# The purpose of this script is to analyze disassembly generated by objdump or
# dumpbin to print (or to compare) the stack usage of functions/methods.
# This is a quickly written script, so it is quite possible it may not handle
# all code properly.
#
# The script expects the user to create a text assembly dump to be passed to
# the script.
#
# objdump Example
# objdump --demangle --disassemble objfile.o > objfile.s
#
# dumpbin Example
# dumpbin /disasm objfile.obj > objfile.s
#
# If the script is passed a single file, then all stack size information that
# is found it printed. If two files are passed, then the script compares the
# stack usage of the two files (useful for A/B comparisons).
# Currently more than two input files are not supported. (But adding support shouldn't
# be very difficult.)
#
# Note: The script only handles x64 disassembly. Supporting x86 is likely
# trivial, but ARM support could be difficult.
# Thus far the script has been tested with MSVC on Win64 and clang on OSX.
import argparse
import re
blank_re = re.compile('\s*')
class LineReader:
def __init__(self, lines):
self.lines = list(reversed(lines))
def get_line(self):
return self.lines.pop(-1)
def peek_line(self):
return self.lines[-1]
def consume_blank_lines(self):
while blank_re.fullmatch(self.peek_line()):
self.get_line()
def is_empty(self):
return len(self.lines) == 0
def parse_objdump_assembly(in_file):
results = {}
text_section_re = re.compile('Disassembly of section __TEXT,__text:\s*')
symbol_re = re.compile('[^<]*<(.*)>:\s*')
stack_alloc = re.compile('.*subq\s*\$(\d*), %rsp\s*')
lr = LineReader(in_file.readlines())
def find_stack_alloc_size():
while True:
if lr.is_empty():
return None
if blank_re.fullmatch(lr.peek_line()):
return None
line = lr.get_line()
mo = stack_alloc.fullmatch(line)
if mo:
lr.consume_blank_lines()
return int(mo.group(1))
# Find beginning of disassembly
while not text_section_re.fullmatch(lr.get_line()):
pass
# Scan for symbols
while not lr.is_empty():
lr.consume_blank_lines()
if lr.is_empty():
break
line = lr.get_line()
mo = symbol_re.fullmatch(line)
# Found a symbol
if mo:
symbol = mo.group(1)
stack_size = find_stack_alloc_size()
if stack_size != None:
results[symbol] = stack_size
return results
def parse_dumpbin_assembly(in_file):
results = {}
file_type_re = re.compile('File Type: COFF OBJECT\s*')
symbol_re = re.compile('[^(]*\((.*)\):\s*')
summary_re = re.compile('\s*Summary\s*')
stack_alloc = re.compile('.*sub\s*rsp,([A-Z0-9]*)h\s*')
lr = LineReader(in_file.readlines())
def find_stack_alloc_size():
while True:
if lr.is_empty():
return None
if blank_re.fullmatch(lr.peek_line()):
return None
line = lr.get_line()
mo = stack_alloc.fullmatch(line)
if mo:
lr.consume_blank_lines()
return int(mo.group(1), 16) # return value in decimal
# Find beginning of disassembly
while not file_type_re.fullmatch(lr.get_line()):
pass
# Scan for symbols
while not lr.is_empty():
lr.consume_blank_lines()
if lr.is_empty():
break
line = lr.get_line()
if summary_re.fullmatch(line):
break
mo = symbol_re.fullmatch(line)
# Found a symbol
if mo:
symbol = mo.group(1)
stack_size = find_stack_alloc_size()
if stack_size != None:
results[symbol] = stack_size
return results
def main():
parser = argparse.ArgumentParser(description='Tool used for reporting or comparing the stack usage of functions/methods')
parser.add_argument('--format', choices=['dumpbin', 'objdump'], required=True, help='Specifies the program used to generate the input files')
parser.add_argument('--input', action='append', required=True, help='Input assembly file. This option may be specified multiple times.')
parser.add_argument('--md-output', action='store_true', help='Show table output in markdown format')
parser.add_argument('--only-diffs', action='store_true', help='Only show stack info when it differs between the input files')
args = parser.parse_args()
parsers = {'dumpbin': parse_dumpbin_assembly, 'objdump' : parse_objdump_assembly}
parse_func = parsers[args.format]
input_results = []
for input_name in args.input:
with open(input_name) as in_file:
results = parse_func(in_file)
input_results.append(results)
if len(input_results) == 1:
# Print out the results sorted by size
size_sorted = sorted([(size, symbol) for symbol, size in results.items()], reverse=True)
print(input_name)
for size, symbol in size_sorted:
print(f'{size:10}\t{symbol}')
print()
elif len(input_results) == 2:
common_symbols = set(input_results[0].keys()).intersection(set(input_results[1].keys()))
print(f'Found {len(common_symbols)} common symbols')
stack_sizes = sorted([(input_results[0][sym], input_results[1][sym], sym) for sym in common_symbols], reverse=True)
if args.md_output:
print('Before | After | Symbol')
print('-- | -- | --')
for size0, size1, symbol in stack_sizes:
if args.only_diffs and size0 == size1:
continue
if args.md_output:
print(f'{size0} | {size1} | {symbol}')
else:
print(f'{size0:10}\t{size1:10}\t{symbol}')
else:
print("TODO support more than 2 inputs")
if __name__ == '__main__':
main()