Sync to upstream/release/600 (#1076)

### What's Changed

- Improve readability of unions and intersections by limiting the number
of elements of those types that can be presented on a single line (gated
under `FFlag::LuauToStringSimpleCompositeTypesSingleLine`)
- Adds a new option to the compiler `--record-stats` to record and
output compilation statistics
- `if...then...else` expressions are now optimized into `AND/OR` form
when possible.

### VM

- Add a new `buffer` type to Luau based on the [buffer
RFC](https://github.com/Roblox/luau/pull/739) and additional C API
functions to work with it; this release does not include the library.
- Internal C API to work with string buffers has been updated to align
with Lua version more closely

### Native Codegen

- Added support for new X64 instruction (rev) and new A64 instruction
(bswap) in the assembler
- Simplified the way numerical loop condition is translated to IR

### New Type Solver

- Operator inference now handled by type families
- Created a new system called `Type Paths` to explain why subtyping
tests fail in order to improve the quality of error messages.
- Systematic changes to implement Data Flow analysis in the new solver
(`Breadcrumb` removed and replaced with `RefinementKey`)

---
Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Alexander McCord <amccord@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Lily Brown <lbrown@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

---------

Co-authored-by: Arseny Kapoulkine <arseny.kapoulkine@gmail.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Lily Brown <lbrown@roblox.com>
Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Alexander McCord <amccord@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
This commit is contained in:
Vighnesh-V 2023-10-20 18:10:30 -07:00 committed by GitHub
parent 380bb7095d
commit fd6250cf9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
128 changed files with 4722 additions and 1476 deletions

View File

@ -4,7 +4,7 @@
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Substitution.h" #include "Luau/Substitution.h"
#include "Luau/Type.h" #include "Luau/TypeFwd.h"
#include <memory> #include <memory>
@ -39,4 +39,4 @@ struct Anyification : Substitution
bool ignoreChildren(TypePackId ty) override; bool ignoreChildren(TypePackId ty) override;
}; };
} // namespace Luau } // namespace Luau

View File

@ -3,7 +3,7 @@
#include "Luau/Substitution.h" #include "Luau/Substitution.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/Type.h" #include "Luau/TypeFwd.h"
namespace Luau namespace Luau
{ {

View File

@ -3,6 +3,7 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Documentation.h" #include "Luau/Documentation.h"
#include "Luau/TypeFwd.h"
#include <memory> #include <memory>
@ -13,9 +14,6 @@ struct Binding;
struct SourceModule; struct SourceModule;
struct Module; struct Module;
struct Type;
using TypeId = const Type*;
using ScopePtr = std::shared_ptr<struct Scope>; using ScopePtr = std::shared_ptr<struct Scope>;
struct ExprOrLocal struct ExprOrLocal

View File

@ -1,75 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Variant.h"
#include <string>
#include <optional>
namespace Luau
{
using NullableBreadcrumbId = const struct Breadcrumb*;
using BreadcrumbId = NotNull<const struct Breadcrumb>;
struct FieldMetadata
{
std::string prop;
};
struct SubscriptMetadata
{
BreadcrumbId key;
};
using Metadata = Variant<FieldMetadata, SubscriptMetadata>;
struct Breadcrumb
{
NullableBreadcrumbId previous;
DefId def;
std::optional<Metadata> metadata;
std::vector<BreadcrumbId> children;
};
inline Breadcrumb* asMutable(NullableBreadcrumbId breadcrumb)
{
LUAU_ASSERT(breadcrumb);
return const_cast<Breadcrumb*>(breadcrumb);
}
template<typename T>
const T* getMetadata(NullableBreadcrumbId breadcrumb)
{
if (!breadcrumb || !breadcrumb->metadata)
return nullptr;
return get_if<T>(&*breadcrumb->metadata);
}
struct BreadcrumbArena
{
TypedAllocator<Breadcrumb> allocator;
template<typename... Args>
BreadcrumbId add(NullableBreadcrumbId previous, DefId def, Args&&... args)
{
Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, std::forward<Args>(args)...});
if (previous)
asMutable(previous)->children.push_back(NotNull{bc});
return NotNull{bc};
}
template<typename T, typename... Args>
BreadcrumbId emplace(NullableBreadcrumbId previous, DefId def, Args&&... args)
{
Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, Metadata{T{std::forward<Args>(args)...}}});
if (previous)
asMutable(previous)->children.push_back(NotNull{bc});
return NotNull{bc};
}
};
} // namespace Luau

View File

@ -4,8 +4,8 @@
#include "Luau/Ast.h" // Used for some of the enumerations #include "Luau/Ast.h" // Used for some of the enumerations
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Type.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include <string> #include <string>
#include <memory> #include <memory>
@ -16,12 +16,6 @@ namespace Luau
struct Scope; struct Scope;
struct Type;
using TypeId = const Type*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
// subType <: superType // subType <: superType
struct SubtypeConstraint struct SubtypeConstraint
{ {
@ -55,31 +49,6 @@ struct InstantiationConstraint
TypeId superType; TypeId superType;
}; };
struct UnaryConstraint
{
AstExprUnary::Op op;
TypeId operandType;
TypeId resultType;
};
// let L : leftType
// let R : rightType
// in
// L op R : resultType
struct BinaryConstraint
{
AstExprBinary::Op op;
TypeId leftType;
TypeId rightType;
TypeId resultType;
// When we dispatch this constraint, we update the key at this map to record
// the overload that we selected.
const AstNode* astFragment;
DenseHashMap<const AstNode*, TypeId>* astOriginalCallTypes;
DenseHashMap<const AstNode*, TypeId>* astOverloadResolvedTypes;
};
// iteratee is iterable // iteratee is iterable
// iterators is the iteration types. // iterators is the iteration types.
struct IterableConstraint struct IterableConstraint
@ -241,6 +210,22 @@ struct RefineConstraint
TypeId discriminant; TypeId discriminant;
}; };
// resultType ~ T0 op T1 op ... op TN
//
// op is either union or intersection. If any of the input types are blocked,
// this constraint will block unless forced.
struct SetOpConstraint
{
enum
{
Intersection,
Union
} mode;
TypeId resultType;
std::vector<TypeId> types;
};
// ty ~ reduce ty // ty ~ reduce ty
// //
// Try to reduce ty, if it is a TypeFamilyInstanceType. Otherwise, do nothing. // Try to reduce ty, if it is a TypeFamilyInstanceType. Otherwise, do nothing.
@ -257,10 +242,9 @@ struct ReducePackConstraint
TypePackId tp; TypePackId tp;
}; };
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint, using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, IterableConstraint,
BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint, HasPropConstraint, SetPropConstraint,
HasPropConstraint, SetPropConstraint, SetIndexerConstraint, SingletonOrTopTypeConstraint, UnpackConstraint, RefineConstraint, ReduceConstraint, SetIndexerConstraint, SingletonOrTopTypeConstraint, UnpackConstraint, RefineConstraint, SetOpConstraint, ReduceConstraint, ReducePackConstraint>;
ReducePackConstraint>;
struct Constraint struct Constraint
{ {

View File

@ -5,15 +5,17 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/ControlFlow.h" #include "Luau/ControlFlow.h"
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/Normalize.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Refinement.h" #include "Luau/Refinement.h"
#include "Luau/Symbol.h" #include "Luau/Symbol.h"
#include "Luau/Type.h" #include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Normalize.h" #include "Luau/Normalize.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -69,11 +71,18 @@ struct ConstraintGraphBuilder
// This is null when the CGB is initially constructed. // This is null when the CGB is initially constructed.
Scope* rootScope; Scope* rootScope;
struct InferredBinding
{
Scope* scope;
Location location;
TypeIds types;
};
// During constraint generation, we only populate the Scope::bindings // During constraint generation, we only populate the Scope::bindings
// property for annotated symbols. Unannotated symbols must be handled in a // property for annotated symbols. Unannotated symbols must be handled in a
// postprocessing step because we do not yet have the full breadcrumb graph. // postprocessing step because we have not yet allocated the types that will
// We queue them up here. // be assigned to those unannotated symbols, so we queue them up here.
std::vector<std::tuple<Scope*, Symbol, BreadcrumbId>> inferredBindings; std::map<Symbol, InferredBinding> inferredBindings;
// Constraints that go straight to the solver. // Constraints that go straight to the solver.
std::vector<ConstraintPtr> constraints; std::vector<ConstraintPtr> constraints;
@ -155,6 +164,18 @@ private:
*/ */
NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c); NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
struct RefinementPartition
{
// Types that we want to intersect against the type of the expression.
std::vector<TypeId> discriminantTypes;
// Sometimes the type we're discriminating against is implicitly nil.
bool shouldAppendNilType = false;
};
using RefinementContext = InsertionOrderedMap<DefId, RefinementPartition>;
void unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector<ConstraintV>* constraints);
void computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector<ConstraintV>* constraints);
void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement); void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement);
ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
@ -211,13 +232,15 @@ private:
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType); std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
TypeId checkLValue(const ScopePtr& scope, AstExpr* expr); std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExpr* expr);
TypeId checkLValue(const ScopePtr& scope, AstExprLocal* local); std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprLocal* local);
TypeId checkLValue(const ScopePtr& scope, AstExprGlobal* global); std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprGlobal* global);
TypeId checkLValue(const ScopePtr& scope, AstExprIndexName* indexName); std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprIndexName* indexName);
TypeId checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr); std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
TypeId updateProperty(const ScopePtr& scope, AstExpr* expr); TypeId updateProperty(const ScopePtr& scope, AstExpr* expr);
void updateLValueType(AstExpr* lvalue, TypeId ty);
struct FunctionSignature struct FunctionSignature
{ {
// The type of the function. // The type of the function.

View File

@ -111,8 +111,6 @@ struct ConstraintSolver
bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const InstantiationConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const InstantiationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const UnaryConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const BinaryConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint);
@ -124,6 +122,7 @@ struct ConstraintSolver
bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const RefineConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const RefineConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const SetOpConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force);

View File

@ -3,29 +3,46 @@
// Do not include LValue. It should never be used here. // Do not include LValue. It should never be used here.
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Breadcrumb.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Def.h" #include "Luau/Def.h"
#include "Luau/Symbol.h" #include "Luau/Symbol.h"
#include "Luau/TypedAllocator.h"
#include <unordered_map> #include <unordered_map>
namespace Luau namespace Luau
{ {
struct RefinementKey
{
const RefinementKey* parent = nullptr;
DefId def;
std::optional<std::string> propName;
};
struct RefinementKeyArena
{
TypedAllocator<RefinementKey> allocator;
const RefinementKey* leaf(DefId def);
const RefinementKey* node(const RefinementKey* parent, DefId def, const std::string& propName);
};
struct DataFlowGraph struct DataFlowGraph
{ {
DataFlowGraph(DataFlowGraph&&) = default; DataFlowGraph(DataFlowGraph&&) = default;
DataFlowGraph& operator=(DataFlowGraph&&) = default; DataFlowGraph& operator=(DataFlowGraph&&) = default;
NullableBreadcrumbId getBreadcrumb(const AstExpr* expr) const; DefId getDef(const AstExpr* expr) const;
// Look up for the rvalue breadcrumb for a compound assignment.
std::optional<DefId> getRValueDefForCompoundAssign(const AstExpr* expr) const;
BreadcrumbId getBreadcrumb(const AstLocal* local) const; DefId getDef(const AstLocal* local) const;
BreadcrumbId getBreadcrumb(const AstExprLocal* local) const;
BreadcrumbId getBreadcrumb(const AstExprGlobal* global) const;
BreadcrumbId getBreadcrumb(const AstStatDeclareGlobal* global) const; DefId getDef(const AstStatDeclareGlobal* global) const;
BreadcrumbId getBreadcrumb(const AstStatDeclareFunction* func) const; DefId getDef(const AstStatDeclareFunction* func) const;
const RefinementKey* getRefinementKey(const AstExpr* expr) const;
private: private:
DataFlowGraph() = default; DataFlowGraph() = default;
@ -33,17 +50,23 @@ private:
DataFlowGraph(const DataFlowGraph&) = delete; DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete; DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena defs; DefArena defArena;
BreadcrumbArena breadcrumbs; RefinementKeyArena keyArena;
DenseHashMap<const AstExpr*, NullableBreadcrumbId> astBreadcrumbs{nullptr}; DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
// Sometimes we don't have the AstExprLocal* but we have AstLocal*, and sometimes we need to extract that DefId. // Sometimes we don't have the AstExprLocal* but we have AstLocal*, and sometimes we need to extract that DefId.
DenseHashMap<const AstLocal*, NullableBreadcrumbId> localBreadcrumbs{nullptr}; DenseHashMap<const AstLocal*, const Def*> localDefs{nullptr};
// There's no AstStatDeclaration, and it feels useless to introduce it just to enforce an invariant in one place. // There's no AstStatDeclaration, and it feels useless to introduce it just to enforce an invariant in one place.
// All keys in this maps are really only statements that ambiently declares a symbol. // All keys in this maps are really only statements that ambiently declares a symbol.
DenseHashMap<const AstStat*, NullableBreadcrumbId> declaredBreadcrumbs{nullptr}; DenseHashMap<const AstStat*, const Def*> declaredDefs{nullptr};
// Compound assignments are in a weird situation where the local being assigned to is also being used at its
// previous type implicitly in an rvalue position. This map provides the previous binding.
DenseHashMap<const AstExpr*, const Def*> compoundAssignBreadcrumbs{nullptr};
DenseHashMap<const AstExpr*, const RefinementKey*> astRefinementKeys{nullptr};
friend struct DataFlowGraphBuilder; friend struct DataFlowGraphBuilder;
}; };
@ -51,15 +74,19 @@ private:
struct DfgScope struct DfgScope
{ {
DfgScope* parent; DfgScope* parent;
DenseHashMap<Symbol, NullableBreadcrumbId> bindings{Symbol{}}; DenseHashMap<Symbol, const Def*> bindings{Symbol{}};
DenseHashMap<const Def*, std::unordered_map<std::string, NullableBreadcrumbId>> props{nullptr}; DenseHashMap<const Def*, std::unordered_map<std::string, const Def*>> props{nullptr};
NullableBreadcrumbId lookup(Symbol symbol) const; std::optional<DefId> lookup(Symbol symbol) const;
NullableBreadcrumbId lookup(DefId def, const std::string& key) const; std::optional<DefId> lookup(DefId def, const std::string& key) const;
};
struct DataFlowResult
{
DefId def;
const RefinementKey* parent = nullptr;
}; };
// Currently unsound. We do not presently track the control flow of the program.
// Additionally, we do not presently track assignments.
struct DataFlowGraphBuilder struct DataFlowGraphBuilder
{ {
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle); static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle);
@ -71,8 +98,8 @@ private:
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph; DataFlowGraph graph;
NotNull<DefArena> defs{&graph.defs}; NotNull<DefArena> defArena{&graph.defArena};
NotNull<BreadcrumbArena> breadcrumbs{&graph.breadcrumbs}; NotNull<RefinementKeyArena> keyArena{&graph.keyArena};
struct InternalErrorReporter* handle = nullptr; struct InternalErrorReporter* handle = nullptr;
DfgScope* moduleScope = nullptr; DfgScope* moduleScope = nullptr;
@ -105,27 +132,28 @@ private:
void visit(DfgScope* scope, AstStatDeclareClass* d); void visit(DfgScope* scope, AstStatDeclareClass* d);
void visit(DfgScope* scope, AstStatError* error); void visit(DfgScope* scope, AstStatError* error);
BreadcrumbId visitExpr(DfgScope* scope, AstExpr* e); DataFlowResult visitExpr(DfgScope* scope, AstExpr* e);
BreadcrumbId visitExpr(DfgScope* scope, AstExprLocal* l); DataFlowResult visitExpr(DfgScope* scope, AstExprGroup* group);
BreadcrumbId visitExpr(DfgScope* scope, AstExprGlobal* g); DataFlowResult visitExpr(DfgScope* scope, AstExprLocal* l);
BreadcrumbId visitExpr(DfgScope* scope, AstExprCall* c); DataFlowResult visitExpr(DfgScope* scope, AstExprGlobal* g);
BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexName* i); DataFlowResult visitExpr(DfgScope* scope, AstExprCall* c);
BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexExpr* i); DataFlowResult visitExpr(DfgScope* scope, AstExprIndexName* i);
BreadcrumbId visitExpr(DfgScope* scope, AstExprFunction* f); DataFlowResult visitExpr(DfgScope* scope, AstExprIndexExpr* i);
BreadcrumbId visitExpr(DfgScope* scope, AstExprTable* t); DataFlowResult visitExpr(DfgScope* scope, AstExprFunction* f);
BreadcrumbId visitExpr(DfgScope* scope, AstExprUnary* u); DataFlowResult visitExpr(DfgScope* scope, AstExprTable* t);
BreadcrumbId visitExpr(DfgScope* scope, AstExprBinary* b); DataFlowResult visitExpr(DfgScope* scope, AstExprUnary* u);
BreadcrumbId visitExpr(DfgScope* scope, AstExprTypeAssertion* t); DataFlowResult visitExpr(DfgScope* scope, AstExprBinary* b);
BreadcrumbId visitExpr(DfgScope* scope, AstExprIfElse* i); DataFlowResult visitExpr(DfgScope* scope, AstExprTypeAssertion* t);
BreadcrumbId visitExpr(DfgScope* scope, AstExprInterpString* i); DataFlowResult visitExpr(DfgScope* scope, AstExprIfElse* i);
BreadcrumbId visitExpr(DfgScope* scope, AstExprError* error); DataFlowResult visitExpr(DfgScope* scope, AstExprInterpString* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprError* error);
void visitLValue(DfgScope* scope, AstExpr* e, BreadcrumbId bc); void visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef, bool isCompoundAssignment = false);
void visitLValue(DfgScope* scope, AstExprLocal* l, BreadcrumbId bc); void visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef, bool isCompoundAssignment);
void visitLValue(DfgScope* scope, AstExprGlobal* g, BreadcrumbId bc); void visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef, bool isCompoundAssignment);
void visitLValue(DfgScope* scope, AstExprIndexName* i, BreadcrumbId bc); void visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef);
void visitLValue(DfgScope* scope, AstExprIndexExpr* i, BreadcrumbId bc); void visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef);
void visitLValue(DfgScope* scope, AstExprError* e, BreadcrumbId bc); void visitLValue(DfgScope* scope, AstExprError* e, DefId incomingDef);
void visitType(DfgScope* scope, AstType* t); void visitType(DfgScope* scope, AstType* t);
void visitType(DfgScope* scope, AstTypeReference* r); void visitType(DfgScope* scope, AstTypeReference* r);

View File

@ -23,6 +23,7 @@ using DefId = NotNull<const Def>;
*/ */
struct Cell struct Cell
{ {
bool subscripted = false;
}; };
/** /**
@ -71,11 +72,13 @@ const T* get(DefId def)
return get_if<T>(&def->v); return get_if<T>(&def->v);
} }
bool containsSubscriptedDefinition(DefId def);
struct DefArena struct DefArena
{ {
TypedAllocator<Def> allocator; TypedAllocator<Def> allocator;
DefId freshCell(); DefId freshCell(bool subscripted = false);
// TODO: implement once we have cases where we need to merge in definitions // TODO: implement once we have cases where we need to merge in definitions
// DefId phi(const std::vector<DefId>& defs); // DefId phi(const std::vector<DefId>& defs);
}; };

View File

@ -2,11 +2,12 @@
#pragma once #pragma once
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Type.h" #include "Luau/TypeFwd.h"
#include "Luau/UnifierSharedState.h" #include "Luau/UnifierSharedState.h"
#include <optional> #include <optional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_set>
namespace Luau namespace Luau
{ {

View File

@ -5,6 +5,9 @@
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/Ast.h"
#include <set>
namespace Luau namespace Luau
{ {

View File

@ -6,12 +6,11 @@
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/TypeFwd.h"
namespace Luau namespace Luau
{ {
struct BuiltinTypes;
struct GlobalTypes struct GlobalTypes
{ {
explicit GlobalTypes(NotNull<BuiltinTypes> builtinTypes); explicit GlobalTypes(NotNull<BuiltinTypes> builtinTypes);

View File

@ -3,13 +3,12 @@
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Substitution.h" #include "Luau/Substitution.h"
#include "Luau/Type.h" #include "Luau/TypeFwd.h"
#include "Luau/Unifiable.h" #include "Luau/Unifiable.h"
namespace Luau namespace Luau
{ {
struct BuiltinTypes;
struct TxnLog; struct TxnLog;
struct TypeArena; struct TypeArena;
struct TypeCheckLimits; struct TypeCheckLimits;

View File

@ -5,6 +5,7 @@
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/TypePath.h"
#include <ostream> #include <ostream>
@ -48,4 +49,14 @@ std::ostream& operator<<(std::ostream& lhs, const TypePackVar& tv);
std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted); std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted);
std::ostream& operator<<(std::ostream& lhs, TypeId ty);
std::ostream& operator<<(std::ostream& lhs, TypePackId tp);
namespace TypePath
{
std::ostream& operator<<(std::ostream& lhs, const Path& path);
}; // namespace TypePath
} // namespace Luau } // namespace Luau

View File

@ -3,6 +3,7 @@
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/Symbol.h" #include "Luau/Symbol.h"
#include "Luau/TypeFwd.h"
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
@ -10,9 +11,6 @@
namespace Luau namespace Luau
{ {
struct Type;
using TypeId = const Type*;
struct Field; struct Field;
// Deprecated. Do not use in new work. // Deprecated. Do not use in new work.

View File

@ -2,10 +2,15 @@
#pragma once #pragma once
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Type.h" #include "Luau/TypeFwd.h"
#include "Luau/UnifierSharedState.h" #include "Luau/UnifierSharedState.h"
#include <initializer_list>
#include <map>
#include <memory> #include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace Luau namespace Luau
{ {
@ -13,7 +18,6 @@ namespace Luau
struct InternalErrorReporter; struct InternalErrorReporter;
struct Module; struct Module;
struct Scope; struct Scope;
struct BuiltinTypes;
using ModulePtr = std::shared_ptr<Module>; using ModulePtr = std::shared_ptr<Module>;
@ -33,10 +37,15 @@ public:
using iterator = std::vector<TypeId>::iterator; using iterator = std::vector<TypeId>::iterator;
using const_iterator = std::vector<TypeId>::const_iterator; using const_iterator = std::vector<TypeId>::const_iterator;
TypeIds(const TypeIds&) = default;
TypeIds(TypeIds&&) = default;
TypeIds() = default; TypeIds() = default;
~TypeIds() = default; ~TypeIds() = default;
TypeIds(std::initializer_list<TypeId> tys);
TypeIds(const TypeIds&) = default;
TypeIds& operator=(const TypeIds&) = default;
TypeIds(TypeIds&&) = default;
TypeIds& operator=(TypeIds&&) = default; TypeIds& operator=(TypeIds&&) = default;
void insert(TypeId ty); void insert(TypeId ty);

View File

@ -4,15 +4,13 @@
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/LValue.h" #include "Luau/LValue.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include <vector> #include <vector>
namespace Luau namespace Luau
{ {
struct Type;
using TypeId = const Type*;
struct TruthyPredicate; struct TruthyPredicate;
struct IsAPredicate; struct IsAPredicate;
struct TypeGuardPredicate; struct TypeGuardPredicate;

View File

@ -1,10 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Type.h" #include "Luau/TypeFwd.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Unifiable.h"
#include <vector> #include <vector>
#include <optional>
namespace Luau namespace Luau
{ {

View File

@ -4,14 +4,13 @@
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/TypedAllocator.h" #include "Luau/TypedAllocator.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
namespace Luau namespace Luau
{ {
using BreadcrumbId = NotNull<const struct Breadcrumb>; struct RefinementKey;
using DefId = NotNull<const struct Def>;
struct Type;
using TypeId = const Type*;
struct Variadic; struct Variadic;
struct Negation; struct Negation;
@ -52,7 +51,7 @@ struct Equivalence
struct Proposition struct Proposition
{ {
BreadcrumbId breadcrumb; const RefinementKey* key;
TypeId discriminantTy; TypeId discriminantTy;
}; };
@ -69,7 +68,7 @@ struct RefinementArena
RefinementId conjunction(RefinementId lhs, RefinementId rhs); RefinementId conjunction(RefinementId lhs, RefinementId rhs);
RefinementId disjunction(RefinementId lhs, RefinementId rhs); RefinementId disjunction(RefinementId lhs, RefinementId rhs);
RefinementId equivalence(RefinementId lhs, RefinementId rhs); RefinementId equivalence(RefinementId lhs, RefinementId rhs);
RefinementId proposition(BreadcrumbId breadcrumb, TypeId discriminantTy); RefinementId proposition(const RefinementKey* key, TypeId discriminantTy);
private: private:
TypedAllocator<Refinement> allocator; TypedAllocator<Refinement> allocator;

View File

@ -2,9 +2,13 @@
#pragma once #pragma once
#include "Luau/Def.h" #include "Luau/Def.h"
#include "Luau/LValue.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/DenseHash.h"
#include "Luau/Symbol.h"
#include "Luau/Unifiable.h"
#include <unordered_map> #include <unordered_map>
#include <optional> #include <optional>
@ -54,6 +58,7 @@ struct Scope
std::optional<TypeId> lookup(Symbol sym) const; std::optional<TypeId> lookup(Symbol sym) const;
std::optional<TypeId> lookupLValue(DefId def) const; std::optional<TypeId> lookupLValue(DefId def) const;
std::optional<TypeId> lookup(DefId def) const; std::optional<TypeId> lookup(DefId def) const;
std::optional<std::pair<TypeId, Scope*>> lookupEx(DefId def);
std::optional<std::pair<Binding*, Scope*>> lookupEx(Symbol sym); std::optional<std::pair<Binding*, Scope*>> lookupEx(Symbol sym);
std::optional<TypeFun> lookupType(const Name& name) const; std::optional<TypeFun> lookupType(const Name& name) const;

View File

@ -2,7 +2,8 @@
#pragma once #pragma once
#include "Luau/Type.h" #include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
#include <set> #include <set>
@ -10,7 +11,6 @@ namespace Luau
{ {
struct TypeArena; struct TypeArena;
struct BuiltinTypes;
struct SimplifyResult struct SimplifyResult
{ {

View File

@ -2,8 +2,7 @@
#pragma once #pragma once
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/TypePack.h" #include "Luau/TypeFwd.h"
#include "Luau/Type.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
// We provide an implementation of substitution on types, // We provide an implementation of substitution on types,

View File

@ -1,10 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Type.h" #include "Luau/TypeFwd.h"
#include "Luau/TypePack.h"
#include "Luau/TypePairHash.h" #include "Luau/TypePairHash.h"
#include "Luau/UnifierSharedState.h" #include "Luau/UnifierSharedState.h"
#include "Luau/TypePath.h"
#include <vector> #include <vector>
#include <optional> #include <optional>
@ -23,6 +23,13 @@ struct NormalizedClassType;
struct NormalizedStringType; struct NormalizedStringType;
struct NormalizedFunctionType; struct NormalizedFunctionType;
struct SubtypingReasoning
{
Path subPath;
Path superPath;
bool operator==(const SubtypingReasoning& other) const;
};
struct SubtypingResult struct SubtypingResult
{ {
@ -31,8 +38,18 @@ struct SubtypingResult
bool normalizationTooComplex = false; bool normalizationTooComplex = false;
bool isCacheable = true; bool isCacheable = true;
/// The reason for isSubtype to be false. May not be present even if
/// isSubtype is false, depending on the input types.
std::optional<SubtypingReasoning> reasoning;
SubtypingResult& andAlso(const SubtypingResult& other); SubtypingResult& andAlso(const SubtypingResult& other);
SubtypingResult& orElse(const SubtypingResult& other); SubtypingResult& orElse(const SubtypingResult& other);
SubtypingResult& withBothComponent(TypePath::Component component);
SubtypingResult& withSuperComponent(TypePath::Component component);
SubtypingResult& withSubComponent(TypePath::Component component);
SubtypingResult& withBothPath(TypePath::Path path);
SubtypingResult& withSubPath(TypePath::Path path);
SubtypingResult& withSuperPath(TypePath::Path path);
// Only negates the `isSubtype`. // Only negates the `isSubtype`.
static SubtypingResult negate(const SubtypingResult& result); static SubtypingResult negate(const SubtypingResult& result);

View File

@ -2,16 +2,12 @@
#pragma once #pragma once
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/TypeFwd.h"
#include <string> #include <string>
namespace Luau namespace Luau
{ {
struct Type;
using TypeId = const Type*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct ToDotOptions struct ToDotOptions
{ {

View File

@ -2,6 +2,7 @@
#pragma once #pragma once
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/TypeFwd.h"
#include <memory> #include <memory>
#include <optional> #include <optional>
@ -20,13 +21,6 @@ class AstExpr;
struct Scope; struct Scope;
struct Type;
using TypeId = const Type*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct FunctionType;
struct Constraint; struct Constraint;
struct Position; struct Position;
@ -149,4 +143,14 @@ std::string generateName(size_t n);
std::string toString(const Position& position); std::string toString(const Position& position);
std::string toString(const Location& location, int offset = 0, bool useBegin = true); std::string toString(const Location& location, int offset = 0, bool useBegin = true);
std::string toString(const TypeOrPack& tyOrTp, ToStringOptions& opts);
inline std::string toString(const TypeOrPack& tyOrTp)
{
ToStringOptions opts{};
return toString(tyOrTp, opts);
}
std::string dump(const TypeOrPack& tyOrTp);
} // namespace Luau } // namespace Luau

View File

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/TypeFwd.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Refinement.h" #include "Luau/Refinement.h"
@ -9,6 +11,7 @@
#include "Luau/Predicate.h" #include "Luau/Predicate.h"
#include "Luau/Unifiable.h" #include "Luau/Unifiable.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include <atomic> #include <atomic>
#include <deque> #include <deque>
@ -59,22 +62,6 @@ struct TypeFamily;
* ``` * ```
*/ */
// So... why `const T*` here rather than `T*`?
// It's because we've had problems caused by the type graph being mutated
// in ways it shouldn't be, for example mutating types from other modules.
// To try to control this, we make the use of types immutable by default,
// then provide explicit mutable access via getMutable and asMutable.
// This means we can grep for all the places we're mutating the type graph,
// and it makes it possible to provide other APIs (e.g. the txn log)
// which control mutable access to the type graph.
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct Type;
// Should never be null
using TypeId = const Type*;
using Name = std::string; using Name = std::string;
// A free type is one whose exact shape has yet to be fully determined. // A free type is one whose exact shape has yet to be fully determined.
@ -244,22 +231,6 @@ const T* get(const SingletonType* stv)
return nullptr; return nullptr;
} }
struct GenericTypeDefinition
{
TypeId ty;
std::optional<TypeId> defaultValue;
bool operator==(const GenericTypeDefinition& rhs) const;
};
struct GenericTypePackDefinition
{
TypePackId tp;
std::optional<TypePackId> defaultValue;
bool operator==(const GenericTypePackDefinition& rhs) const;
};
struct FunctionArgument struct FunctionArgument
{ {
Name name; Name name;
@ -549,42 +520,6 @@ struct TypeFamilyInstanceType
std::vector<TypePackId> packArguments; std::vector<TypePackId> packArguments;
}; };
struct TypeFun
{
// These should all be generic
std::vector<GenericTypeDefinition> typeParams;
std::vector<GenericTypePackDefinition> typePackParams;
/** The underlying type.
*
* WARNING! This is not safe to use as a type if typeParams is not empty!!
* You must first use TypeChecker::instantiateTypeFun to turn it into a real type.
*/
TypeId type;
TypeFun() = default;
explicit TypeFun(TypeId ty)
: type(ty)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, TypeId type)
: typeParams(std::move(typeParams))
, type(type)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, std::vector<GenericTypePackDefinition> typePackParams, TypeId type)
: typeParams(std::move(typeParams))
, typePackParams(std::move(typePackParams))
, type(type)
{
}
bool operator==(const TypeFun& rhs) const;
};
/** Represents a pending type alias instantiation. /** Represents a pending type alias instantiation.
* *
* In order to afford (co)recursive type aliases, we need to reason about a * In order to afford (co)recursive type aliases, we need to reason about a
@ -729,6 +664,58 @@ struct Type final
Type& operator=(const Type& rhs); Type& operator=(const Type& rhs);
}; };
struct GenericTypeDefinition
{
TypeId ty;
std::optional<TypeId> defaultValue;
bool operator==(const GenericTypeDefinition& rhs) const;
};
struct GenericTypePackDefinition
{
TypePackId tp;
std::optional<TypePackId> defaultValue;
bool operator==(const GenericTypePackDefinition& rhs) const;
};
struct TypeFun
{
// These should all be generic
std::vector<GenericTypeDefinition> typeParams;
std::vector<GenericTypePackDefinition> typePackParams;
/** The underlying type.
*
* WARNING! This is not safe to use as a type if typeParams is not empty!!
* You must first use TypeChecker::instantiateTypeFun to turn it into a real type.
*/
TypeId type;
TypeFun() = default;
explicit TypeFun(TypeId ty)
: type(ty)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, TypeId type)
: typeParams(std::move(typeParams))
, type(type)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, std::vector<GenericTypePackDefinition> typePackParams, TypeId type)
: typeParams(std::move(typeParams))
, typePackParams(std::move(typePackParams))
, type(type)
{
}
bool operator==(const TypeFun& rhs) const;
};
using SeenSet = std::set<std::pair<const void*, const void*>>; using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs); bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs);

View File

@ -1,13 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Error.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFwd.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "NotNull.h"
#include "TypeCheckLimits.h"
#include <functional> #include <functional>
#include <string> #include <string>
@ -16,14 +15,7 @@
namespace Luau namespace Luau
{ {
struct Type;
using TypeId = const Type*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct TypeArena; struct TypeArena;
struct BuiltinTypes;
struct TxnLog; struct TxnLog;
class Normalizer; class Normalizer;
@ -150,6 +142,8 @@ struct BuiltinTypeFamilies
BuiltinTypeFamilies(); BuiltinTypeFamilies();
TypeFamily notFamily; TypeFamily notFamily;
TypeFamily lenFamily;
TypeFamily unmFamily;
TypeFamily addFamily; TypeFamily addFamily;
TypeFamily subFamily; TypeFamily subFamily;

View File

@ -0,0 +1,59 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Variant.h"
#include <string>
namespace Luau
{
// So... why `const T*` here rather than `T*`?
// It's because we've had problems caused by the type graph being mutated
// in ways it shouldn't be, for example mutating types from other modules.
// To try to control this, we make the use of types immutable by default,
// then provide explicit mutable access via getMutable and asMutable.
// This means we can grep for all the places we're mutating the type graph,
// and it makes it possible to provide other APIs (e.g. the txn log)
// which control mutable access to the type graph.
struct Type;
using TypeId = const Type*;
struct FreeType;
struct GenericType;
struct PrimitiveType;
struct BlockedType;
struct PendingExpansionType;
struct SingletonType;
struct FunctionType;
struct TableType;
struct MetatableType;
struct ClassType;
struct AnyType;
struct UnionType;
struct IntersectionType;
struct LazyType;
struct UnknownType;
struct NeverType;
struct NegationType;
struct TypeFamilyInstanceType;
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct FreeTypePack;
struct GenericTypePack;
struct TypePack;
struct VariadicTypePack;
struct BlockedTypePack;
struct TypeFamilyInstanceTypePack;
using Name = std::string;
using ModuleName = std::string;
struct BuiltinTypes;
using TypeOrPack = Variant<TypeId, TypePackId>;
} // namespace Luau

View File

@ -9,9 +9,8 @@
#include "Luau/Substitution.h" #include "Luau/Substitution.h"
#include "Luau/Symbol.h" #include "Luau/Symbol.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/Type.h" #include "Luau/TypeFwd.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/Unifier.h" #include "Luau/Unifier.h"
#include "Luau/UnifierSharedState.h" #include "Luau/UnifierSharedState.h"

View File

@ -0,0 +1,45 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include "Luau/TypePack.h"
#include "Luau/Variant.h"
#include <type_traits>
namespace Luau
{
const void* ptr(TypeOrPack ty);
template<typename T>
const T* get(TypeOrPack ty)
{
if constexpr (std::is_same_v<T, TypeId>)
return ty.get_if<TypeId>();
else if constexpr (std::is_same_v<T, TypePackId>)
return ty.get_if<TypePackId>();
else if constexpr (TypeVariant::is_part_of_v<T>)
{
if (auto innerTy = ty.get_if<TypeId>())
return get<T>(*innerTy);
else
return nullptr;
}
else if constexpr (TypePackVariant::is_part_of_v<T>)
{
if (auto innerTp = ty.get_if<TypePackId>())
return get<T>(*innerTp);
else
return nullptr;
}
else
{
static_assert(always_false_v<T>, "invalid T to get from TypeOrPack");
LUAU_UNREACHABLE();
}
}
TypeOrPack follow(TypeOrPack ty);
} // namespace Luau

View File

@ -1,12 +1,15 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Type.h"
#include "Luau/Unifiable.h" #include "Luau/Unifiable.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include "Luau/NotNull.h"
#include "Luau/Common.h"
#include <optional> #include <optional>
#include <set> #include <set>
#include <vector>
namespace Luau namespace Luau
{ {
@ -20,9 +23,6 @@ struct VariadicTypePack;
struct BlockedTypePack; struct BlockedTypePack;
struct TypeFamilyInstanceTypePack; struct TypeFamilyInstanceTypePack;
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct FreeTypePack struct FreeTypePack
{ {
explicit FreeTypePack(TypeLevel level); explicit FreeTypePack(TypeLevel level);

View File

@ -1,7 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Type.h" #include "Luau/TypeFwd.h"
#include <utility> #include <utility>

View File

@ -0,0 +1,220 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeFwd.h"
#include "Luau/Variant.h"
#include "Luau/NotNull.h"
#include "Luau/TypeOrPack.h"
#include <optional>
#include <string>
#include <vector>
namespace Luau
{
namespace TypePath
{
/// Represents a property of a class, table, or anything else with a concept of
/// a named property.
struct Property
{
/// The name of the property.
std::string name;
/// Whether to look at the read or the write type.
bool isRead = true;
explicit Property(std::string name);
Property(std::string name, bool read)
: name(std::move(name))
, isRead(read)
{
}
static Property read(std::string name);
static Property write(std::string name);
bool operator==(const Property& other) const;
};
/// Represents an index into a type or a pack. For a type, this indexes into a
/// union or intersection's list. For a pack, this indexes into the pack's nth
/// element.
struct Index
{
/// The 0-based index to use for the lookup.
size_t index;
bool operator==(const Index& other) const;
};
/// Represents fields of a type or pack that contain a type.
enum class TypeField
{
/// The metatable of a type. This could be a metatable type, a primitive
/// type, a class type, or perhaps even a string singleton type.
Metatable,
/// The lower bound of this type, if one is present.
LowerBound,
/// The upper bound of this type, if present.
UpperBound,
/// The index type.
IndexLookup,
/// The indexer result type.
IndexResult,
/// The negated type, for negations.
Negated,
/// The variadic type for a type pack.
Variadic,
};
/// Represents fields of a type or type pack that contain a type pack.
enum class PackField
{
/// What arguments this type accepts.
Arguments,
/// What this type returns when called.
Returns,
/// The tail of a type pack.
Tail,
};
/// A single component of a path, representing one inner type or type pack to
/// traverse into.
using Component = Luau::Variant<Property, Index, TypeField, PackField>;
/// A path through a type or type pack accessing a particular type or type pack
/// contained within.
///
/// Paths are always relative; to make use of a Path, you need to specify an
/// entry point. They are not canonicalized; two Paths may not compare equal but
/// may point to the same result, depending on the layout of the entry point.
///
/// Paths always descend through an entry point. This doesn't mean that they
/// cannot reach "upwards" in the actual type hierarchy in some cases, but it
/// does mean that there is no equivalent to `../` in file system paths. This is
/// intentional and unavoidable, because types and type packs don't have a
/// concept of a parent - they are a directed cyclic graph, with no hierarchy
/// that actually holds in all cases.
struct Path
{
/// The Components of this Path.
std::vector<Component> components;
/// Creates a new empty Path.
Path()
{
}
/// Creates a new Path from a list of components.
explicit Path(std::vector<Component> components)
: components(std::move(components))
{
}
/// Creates a new single-component Path.
explicit Path(Component component)
: components({component})
{
}
/// Creates a new Path by appending another Path to this one.
/// @param suffix the Path to append
/// @return a new Path representing `this + suffix`
Path append(const Path& suffix) const;
/// Creates a new Path by appending a Component to this Path.
/// @param component the Component to append
/// @return a new Path with `component` appended to it.
Path push(Component component) const;
/// Creates a new Path by prepending a Component to this Path.
/// @param component the Component to prepend
/// @return a new Path with `component` prepended to it.
Path push_front(Component component) const;
/// Creates a new Path by removing the last Component of this Path.
/// If the Path is empty, this is a no-op.
/// @return a Path with the last component removed.
Path pop() const;
/// Returns the last Component of this Path, if present.
std::optional<Component> last() const;
/// Returns whether this Path is empty, meaning it has no components at all.
/// Traversing an empty Path results in the type you started with.
bool empty() const;
bool operator==(const Path& other) const;
bool operator!=(const Path& other) const
{
return !(*this == other);
}
};
/// The canonical "empty" Path, meaning a Path with no components.
static const Path kEmpty{};
struct PathBuilder
{
std::vector<Component> components;
Path build();
PathBuilder& readProp(std::string name);
PathBuilder& writeProp(std::string name);
PathBuilder& prop(std::string name);
PathBuilder& index(size_t i);
PathBuilder& mt();
PathBuilder& lb();
PathBuilder& ub();
PathBuilder& indexKey();
PathBuilder& indexValue();
PathBuilder& negated();
PathBuilder& variadic();
PathBuilder& args();
PathBuilder& rets();
PathBuilder& tail();
};
} // namespace TypePath
using Path = TypePath::Path;
/// Converts a Path to a string for debugging purposes. This output may not be
/// terribly clear to end users of the Luau type system.
std::string toString(const TypePath::Path& path);
std::optional<TypeOrPack> traverse(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
std::optional<TypeOrPack> traverse(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
/// Traverses a path from a type to its end point, which must be a type.
/// @param root the entry point of the traversal
/// @param path the path to traverse
/// @param builtinTypes the built-in types in use (used to acquire the string metatable)
/// @returns the TypeId at the end of the path, or nullopt if the traversal failed.
std::optional<TypeId> traverseForType(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
/// Traverses a path from a type pack to its end point, which must be a type.
/// @param root the entry point of the traversal
/// @param path the path to traverse
/// @param builtinTypes the built-in types in use (used to acquire the string metatable)
/// @returns the TypeId at the end of the path, or nullopt if the traversal failed.
std::optional<TypeId> traverseForType(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
/// Traverses a path from a type to its end point, which must be a type pack.
/// @param root the entry point of the traversal
/// @param path the path to traverse
/// @param builtinTypes the built-in types in use (used to acquire the string metatable)
/// @returns the TypePackId at the end of the path, or nullopt if the traversal failed.
std::optional<TypePackId> traverseForPack(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
/// Traverses a path from a type pack to its end point, which must be a type pack.
/// @param root the entry point of the traversal
/// @param path the path to traverse
/// @param builtinTypes the built-in types in use (used to acquire the string metatable)
/// @returns the TypePackId at the end of the path, or nullopt if the traversal failed.
std::optional<TypePackId> traverseForPack(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
} // namespace Luau

View File

@ -9,7 +9,7 @@
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/UnifierSharedState.h" #include "Luau/UnifierSharedState.h"
#include "Normalize.h" #include "Luau/Normalize.h"
#include <unordered_set> #include <unordered_set>

View File

@ -4,10 +4,10 @@
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Type.h" #include "Luau/TypePairHash.h"
#include "TypePairHash.h" #include "Luau/TypeCheckLimits.h"
#include "TypeCheckLimits.h" #include "Luau/TypeChecker2.h"
#include "TypeChecker2.h" #include "Luau/TypeFwd.h"
#include <optional> #include <optional>
#include <vector> #include <vector>
@ -16,10 +16,6 @@
namespace Luau namespace Luau
{ {
using TypeId = const struct Type*;
using TypePackId = const struct TypePackVar*;
struct BuiltinTypes;
struct InternalErrorReporter; struct InternalErrorReporter;
struct Scope; struct Scope;
struct TypeArena; struct TypeArena;

View File

@ -3,8 +3,7 @@
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Type.h" #include "Luau/TypeFwd.h"
#include "Luau/TypePack.h"
#include <utility> #include <utility>

View File

@ -44,6 +44,9 @@ private:
public: public:
using first_alternative = typename First<Ts...>::type; using first_alternative = typename First<Ts...>::type;
template<typename T>
static constexpr bool is_part_of_v = std::disjunction_v<typename std::is_same<std::decay_t<Ts>, T>...>;
Variant() Variant()
{ {
static_assert(std::is_default_constructible_v<first_alternative>, "first alternative type must be default constructible"); static_assert(std::is_default_constructible_v<first_alternative>, "first alternative type must be default constructible");

View File

@ -701,6 +701,7 @@ void TypeCloner::operator()(const FunctionType& t)
ftv->argNames = t.argNames; ftv->argNames = t.argNames;
ftv->retTypes = clone(t.retTypes, dest, cloneState); ftv->retTypes = clone(t.retTypes, dest, cloneState);
ftv->hasNoFreeOrGenericTypes = t.hasNoFreeOrGenericTypes; ftv->hasNoFreeOrGenericTypes = t.hasNoFreeOrGenericTypes;
ftv->isCheckedFunction = t.isCheckedFunction;
} }
void TypeCloner::operator()(const TableType& t) void TypeCloner::operator()(const TableType& t)

View File

@ -2,7 +2,7 @@
#include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintGraphBuilder.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Breadcrumb.h" #include "Luau/Def.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/ControlFlow.h" #include "Luau/ControlFlow.h"
@ -216,19 +216,7 @@ NotNull<Constraint> ConstraintGraphBuilder::addConstraint(const ScopePtr& scope,
return NotNull{constraints.emplace_back(std::move(c)).get()}; return NotNull{constraints.emplace_back(std::move(c)).get()};
} }
struct RefinementPartition void ConstraintGraphBuilder::unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector<ConstraintV>* constraints)
{
// Types that we want to intersect against the type of the expression.
std::vector<TypeId> discriminantTypes;
// Sometimes the type we're discriminating against is implicitly nil.
bool shouldAppendNilType = false;
};
using RefinementContext = InsertionOrderedMap<DefId, RefinementPartition>;
static void unionRefinements(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, const RefinementContext& lhs, const RefinementContext& rhs,
RefinementContext& dest, std::vector<ConstraintV>* constraints)
{ {
const auto intersect = [&](const std::vector<TypeId>& types) { const auto intersect = [&](const std::vector<TypeId>& types) {
if (1 == types.size()) if (1 == types.size())
@ -264,44 +252,43 @@ static void unionRefinements(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeAre
} }
} }
static void computeRefinement(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, const ScopePtr& scope, RefinementId refinement, void ConstraintGraphBuilder::computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector<ConstraintV>* constraints)
RefinementContext* refis, bool sense, bool eq, std::vector<ConstraintV>* constraints)
{ {
if (!refinement) if (!refinement)
return; return;
else if (auto variadic = get<Variadic>(refinement)) else if (auto variadic = get<Variadic>(refinement))
{ {
for (RefinementId refi : variadic->refinements) for (RefinementId refi : variadic->refinements)
computeRefinement(builtinTypes, arena, scope, refi, refis, sense, eq, constraints); computeRefinement(scope, refi, refis, sense, eq, constraints);
} }
else if (auto negation = get<Negation>(refinement)) else if (auto negation = get<Negation>(refinement))
return computeRefinement(builtinTypes, arena, scope, negation->refinement, refis, !sense, eq, constraints); return computeRefinement(scope, negation->refinement, refis, !sense, eq, constraints);
else if (auto conjunction = get<Conjunction>(refinement)) else if (auto conjunction = get<Conjunction>(refinement))
{ {
RefinementContext lhsRefis; RefinementContext lhsRefis;
RefinementContext rhsRefis; RefinementContext rhsRefis;
computeRefinement(builtinTypes, arena, scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, eq, constraints); computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, eq, constraints);
computeRefinement(builtinTypes, arena, scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, eq, constraints); computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, eq, constraints);
if (!sense) if (!sense)
unionRefinements(builtinTypes, arena, lhsRefis, rhsRefis, *refis, constraints); unionRefinements(lhsRefis, rhsRefis, *refis, constraints);
} }
else if (auto disjunction = get<Disjunction>(refinement)) else if (auto disjunction = get<Disjunction>(refinement))
{ {
RefinementContext lhsRefis; RefinementContext lhsRefis;
RefinementContext rhsRefis; RefinementContext rhsRefis;
computeRefinement(builtinTypes, arena, scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, eq, constraints); computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, eq, constraints);
computeRefinement(builtinTypes, arena, scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, eq, constraints); computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, eq, constraints);
if (sense) if (sense)
unionRefinements(builtinTypes, arena, lhsRefis, rhsRefis, *refis, constraints); unionRefinements(lhsRefis, rhsRefis, *refis, constraints);
} }
else if (auto equivalence = get<Equivalence>(refinement)) else if (auto equivalence = get<Equivalence>(refinement))
{ {
computeRefinement(builtinTypes, arena, scope, equivalence->lhs, refis, sense, true, constraints); computeRefinement(scope, equivalence->lhs, refis, sense, true, constraints);
computeRefinement(builtinTypes, arena, scope, equivalence->rhs, refis, sense, true, constraints); computeRefinement(scope, equivalence->rhs, refis, sense, true, constraints);
} }
else if (auto proposition = get<Proposition>(refinement)) else if (auto proposition = get<Proposition>(refinement))
{ {
@ -314,40 +301,27 @@ static void computeRefinement(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeAr
constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy, !sense}); constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy, !sense});
} }
RefinementContext uncommittedRefis; for (const RefinementKey* key = proposition->key; key; key = key->parent)
uncommittedRefis.insert(proposition->breadcrumb->def, {}); {
uncommittedRefis.get(proposition->breadcrumb->def)->discriminantTypes.push_back(discriminantTy); refis->insert(key->def, {});
refis->get(key->def)->discriminantTypes.push_back(discriminantTy);
// Reached leaf node
if (!key->propName)
break;
TypeId nextDiscriminantTy = arena->addType(TableType{});
NotNull<TableType> table{getMutable<TableType>(nextDiscriminantTy)};
table->props[*key->propName] = {discriminantTy};
table->scope = scope.get();
table->state = TableState::Sealed;
discriminantTy = nextDiscriminantTy;
}
// When the top-level expression is `t[x]`, we want to refine it into `nil`, not `never`. // When the top-level expression is `t[x]`, we want to refine it into `nil`, not `never`.
if ((sense || !eq) && getMetadata<SubscriptMetadata>(proposition->breadcrumb)) LUAU_ASSERT(refis->get(proposition->key->def));
uncommittedRefis.get(proposition->breadcrumb->def)->shouldAppendNilType = true; refis->get(proposition->key->def)->shouldAppendNilType = (sense || !eq) && containsSubscriptedDefinition(proposition->key->def);
for (NullableBreadcrumbId current = proposition->breadcrumb; current && current->previous; current = current->previous)
{
LUAU_ASSERT(get<Cell>(current->def));
// If this current breadcrumb has no metadata, it's no-op for the purpose of building a discriminant type.
if (!current->metadata)
continue;
else if (auto field = getMetadata<FieldMetadata>(current))
{
TableType::Props props{{field->prop, Property{discriminantTy}}};
discriminantTy = arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed});
uncommittedRefis.insert(current->previous->def, {});
uncommittedRefis.get(current->previous->def)->discriminantTypes.push_back(discriminantTy);
}
}
// And now it's time to commit it.
for (auto& [def, partition] : uncommittedRefis)
{
(*refis).insert(def, {});
for (TypeId discriminantTy : partition.discriminantTypes)
(*refis).get(def)->discriminantTypes.push_back(discriminantTy);
(*refis).get(def)->shouldAppendNilType |= partition.shouldAppendNilType;
}
} }
} }
@ -415,7 +389,7 @@ void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location lo
RefinementContext refinements; RefinementContext refinements;
std::vector<ConstraintV> constraints; std::vector<ConstraintV> constraints;
computeRefinement(builtinTypes, arena, scope, refinement, &refinements, /*sense*/ true, /*eq*/ false, &constraints); computeRefinement(scope, refinement, &refinements, /*sense*/ true, /*eq*/ false, &constraints);
for (auto& [def, partition] : refinements) for (auto& [def, partition] : refinements)
{ {
@ -586,20 +560,22 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat)
} }
} }
ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* statLocal)
{ {
std::vector<std::optional<TypeId>> varTypes; std::vector<std::optional<TypeId>> varTypes;
varTypes.reserve(local->vars.size); varTypes.reserve(statLocal->vars.size);
std::vector<TypeId> assignees; std::vector<TypeId> assignees;
assignees.reserve(local->vars.size); assignees.reserve(statLocal->vars.size);
// Used to name the first value type, even if it's not placed in varTypes, // Used to name the first value type, even if it's not placed in varTypes,
// for the purpose of synthetic name attribution. // for the purpose of synthetic name attribution.
std::optional<TypeId> firstValueType; std::optional<TypeId> firstValueType;
for (AstLocal* local : local->vars) for (AstLocal* local : statLocal->vars)
{ {
const Location location = local->location;
TypeId assignee = arena->addType(BlockedType{}); TypeId assignee = arena->addType(BlockedType{});
assignees.push_back(assignee); assignees.push_back(assignee);
@ -612,21 +588,27 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* l
varTypes.push_back(annotationTy); varTypes.push_back(annotationTy);
addConstraint(scope, local->location, SubtypeConstraint{assignee, annotationTy}); addConstraint(scope, local->location, SubtypeConstraint{assignee, annotationTy});
scope->bindings[local] = Binding{annotationTy, location};
} }
else else
{
varTypes.push_back(std::nullopt); varTypes.push_back(std::nullopt);
BreadcrumbId bc = dfg->getBreadcrumb(local); inferredBindings[local] = {scope.get(), location, {assignee}};
scope->lvalueTypes[bc->def] = assignee; }
DefId def = dfg->getDef(local);
scope->lvalueTypes[def] = assignee;
} }
TypePackId resultPack = checkPack(scope, local->values, varTypes).tp; TypePackId resultPack = checkPack(scope, statLocal->values, varTypes).tp;
addConstraint(scope, local->location, UnpackConstraint{arena->addTypePack(std::move(assignees)), resultPack}); addConstraint(scope, statLocal->location, UnpackConstraint{arena->addTypePack(std::move(assignees)), resultPack});
if (local->vars.size == 1 && local->values.size == 1 && firstValueType && scope.get() == rootScope) if (statLocal->vars.size == 1 && statLocal->values.size == 1 && firstValueType && scope.get() == rootScope)
{ {
AstLocal* var = local->vars.data[0]; AstLocal* var = statLocal->vars.data[0];
AstExpr* value = local->values.data[0]; AstExpr* value = statLocal->values.data[0];
if (value->is<AstExprTable>()) if (value->is<AstExprTable>())
addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true});
@ -639,29 +621,12 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* l
} }
} }
for (size_t i = 0; i < local->vars.size; ++i) if (statLocal->values.size > 0)
{
AstLocal* l = local->vars.data[i];
Location location = l->location;
std::optional<TypeId> annotation = varTypes[i];
BreadcrumbId bc = dfg->getBreadcrumb(l);
if (annotation)
scope->bindings[l] = Binding{*annotation, location};
else
{
scope->bindings[l] = Binding{builtinTypes->neverType, location};
inferredBindings.emplace_back(scope.get(), l, bc);
}
}
if (local->values.size > 0)
{ {
// To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'. // To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'.
for (size_t i = 0; i < local->values.size && i < local->vars.size; ++i) for (size_t i = 0; i < statLocal->values.size && i < statLocal->vars.size; ++i)
{ {
const AstExprCall* call = local->values.data[i]->as<AstExprCall>(); const AstExprCall* call = statLocal->values.data[i]->as<AstExprCall>();
if (!call) if (!call)
continue; continue;
@ -679,7 +644,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* l
if (!module) if (!module)
continue; continue;
const Name name{local->vars.data[i]->name.value}; const Name name{statLocal->vars.data[i]->name.value};
scope->importedTypeBindings[name] = module->exportedTypeBindings; scope->importedTypeBindings[name] = module->exportedTypeBindings;
scope->importedModules[name] = moduleInfo->name; scope->importedModules[name] = moduleInfo->name;
@ -719,9 +684,9 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for
ScopePtr forScope = childScope(for_, scope); ScopePtr forScope = childScope(for_, scope);
forScope->bindings[for_->var] = Binding{annotationTy, for_->var->location}; forScope->bindings[for_->var] = Binding{annotationTy, for_->var->location};
BreadcrumbId bc = dfg->getBreadcrumb(for_->var); DefId def = dfg->getDef(for_->var);
forScope->lvalueTypes[bc->def] = annotationTy; forScope->lvalueTypes[def] = annotationTy;
forScope->rvalueRefinements[bc->def] = annotationTy; forScope->rvalueRefinements[def] = annotationTy;
visit(forScope, for_->body); visit(forScope, for_->body);
@ -750,8 +715,8 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* f
TypeId assignee = arena->addType(BlockedType{}); TypeId assignee = arena->addType(BlockedType{});
variableTypes.push_back(assignee); variableTypes.push_back(assignee);
BreadcrumbId bc = dfg->getBreadcrumb(var); DefId def = dfg->getDef(var);
loopScope->lvalueTypes[bc->def] = assignee; loopScope->lvalueTypes[def] = assignee;
} }
TypePackId variablePack = arena->addTypePack(std::move(variableTypes)); TypePackId variablePack = arena->addTypePack(std::move(variableTypes));
@ -803,11 +768,11 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFun
FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location);
sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location};
BreadcrumbId bc = dfg->getBreadcrumb(function->name); DefId def = dfg->getDef(function->name);
scope->lvalueTypes[bc->def] = functionType; scope->lvalueTypes[def] = functionType;
scope->rvalueRefinements[bc->def] = functionType; scope->rvalueRefinements[def] = functionType;
sig.bodyScope->lvalueTypes[bc->def] = sig.signature; sig.bodyScope->lvalueTypes[def] = sig.signature;
sig.bodyScope->rvalueRefinements[bc->def] = sig.signature; sig.bodyScope->rvalueRefinements[def] = sig.signature;
Checkpoint start = checkpoint(this); Checkpoint start = checkpoint(this);
checkFunctionBody(sig.bodyScope, function->func); checkFunctionBody(sig.bodyScope, function->func);
@ -848,11 +813,8 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction
std::unordered_set<Constraint*> excludeList; std::unordered_set<Constraint*> excludeList;
const NullableBreadcrumbId functionBreadcrumb = dfg->getBreadcrumb(function->name); DefId def = dfg->getDef(function->name);
std::optional<TypeId> existingFunctionTy = scope->lookupLValue(def);
std::optional<TypeId> existingFunctionTy;
if (functionBreadcrumb)
existingFunctionTy = scope->lookupLValue(functionBreadcrumb->def);
if (AstExprLocal* localName = function->name->as<AstExprLocal>()) if (AstExprLocal* localName = function->name->as<AstExprLocal>())
{ {
@ -867,12 +829,8 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction
scope->bindings[localName->local] = Binding{generalizedType, localName->location}; scope->bindings[localName->local] = Binding{generalizedType, localName->location};
sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location};
sig.bodyScope->lvalueTypes[def] = sig.signature;
if (functionBreadcrumb) sig.bodyScope->rvalueRefinements[def] = sig.signature;
{
sig.bodyScope->lvalueTypes[functionBreadcrumb->def] = sig.signature;
sig.bodyScope->rvalueRefinements[functionBreadcrumb->def] = sig.signature;
}
} }
else if (AstExprGlobal* globalName = function->name->as<AstExprGlobal>()) else if (AstExprGlobal* globalName = function->name->as<AstExprGlobal>())
{ {
@ -882,17 +840,14 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction
generalizedType = *existingFunctionTy; generalizedType = *existingFunctionTy;
sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location};
sig.bodyScope->lvalueTypes[def] = sig.signature;
if (functionBreadcrumb) sig.bodyScope->rvalueRefinements[def] = sig.signature;
{
sig.bodyScope->lvalueTypes[functionBreadcrumb->def] = sig.signature;
sig.bodyScope->rvalueRefinements[functionBreadcrumb->def] = sig.signature;
}
} }
else if (AstExprIndexName* indexName = function->name->as<AstExprIndexName>()) else if (AstExprIndexName* indexName = function->name->as<AstExprIndexName>())
{ {
Checkpoint check1 = checkpoint(this); Checkpoint check1 = checkpoint(this);
TypeId lvalueType = checkLValue(scope, indexName); std::optional<TypeId> lvalueType = checkLValue(scope, indexName);
LUAU_ASSERT(lvalueType);
Checkpoint check2 = checkpoint(this); Checkpoint check2 = checkpoint(this);
forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) { forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) {
@ -901,10 +856,13 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction
// TODO figure out how to populate the location field of the table Property. // TODO figure out how to populate the location field of the table Property.
if (get<FreeType>(lvalueType)) if (lvalueType)
asMutable(lvalueType)->ty.emplace<BoundType>(generalizedType); {
else if (get<FreeType>(*lvalueType))
addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); asMutable(*lvalueType)->ty.emplace<BoundType>(generalizedType);
else
addConstraint(scope, indexName->location, SubtypeConstraint{*lvalueType, generalizedType});
}
} }
else if (AstExprError* err = function->name->as<AstExprError>()) else if (AstExprError* err = function->name->as<AstExprError>())
{ {
@ -914,8 +872,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction
if (generalizedType == nullptr) if (generalizedType == nullptr)
ice->ice("generalizedType == nullptr", function->location); ice->ice("generalizedType == nullptr", function->location);
if (functionBreadcrumb) scope->rvalueRefinements[def] = generalizedType;
scope->rvalueRefinements[functionBreadcrumb->def] = generalizedType;
checkFunctionBody(sig.bodyScope, function->func); checkFunctionBody(sig.bodyScope, function->func);
Checkpoint end = checkpoint(this); Checkpoint end = checkpoint(this);
@ -997,18 +954,23 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign*
for (AstExpr* lvalue : assign->vars) for (AstExpr* lvalue : assign->vars)
{ {
TypeId upperBound = follow(checkLValue(scope, lvalue));
if (get<FreeType>(upperBound))
expectedTypes.push_back(std::nullopt);
else
expectedTypes.push_back(upperBound);
TypeId assignee = arena->addType(BlockedType{}); TypeId assignee = arena->addType(BlockedType{});
assignees.push_back(assignee); assignees.push_back(assignee);
addConstraint(scope, lvalue->location, SubtypeConstraint{assignee, upperBound});
if (NullableBreadcrumbId bc = dfg->getBreadcrumb(lvalue)) std::optional<TypeId> upperBound = follow(checkLValue(scope, lvalue));
scope->lvalueTypes[bc->def] = assignee; if (upperBound)
{
if (get<FreeType>(*upperBound))
expectedTypes.push_back(std::nullopt);
else
expectedTypes.push_back(*upperBound);
addConstraint(scope, lvalue->location, SubtypeConstraint{assignee, *upperBound});
}
DefId def = dfg->getDef(lvalue);
scope->lvalueTypes[def] = assignee;
updateLValueType(lvalue, assignee);
} }
TypePackId resultPack = checkPack(scope, assign->values, expectedTypes).tp; TypePackId resultPack = checkPack(scope, assign->values, expectedTypes).tp;
@ -1019,15 +981,15 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign*
ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign)
{ {
// We need to tweak the BinaryConstraint that we emit, so we cannot use the std::optional<TypeId> varTy = checkLValue(scope, assign->var);
// strategy of falsifying an AST fragment.
TypeId varTy = checkLValue(scope, assign->var);
TypeId valueTy = check(scope, assign->value).ty;
TypeId resultType = arena->addType(BlockedType{}); AstExprBinary binop = AstExprBinary{assign->location, assign->op, assign->var, assign->value};
addConstraint(scope, assign->location, TypeId resultTy = check(scope, &binop).ty;
BinaryConstraint{assign->op, varTy, valueTy, resultType, assign, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); if (varTy)
addConstraint(scope, assign->location, SubtypeConstraint{resultType, varTy}); addConstraint(scope, assign->location, SubtypeConstraint{resultTy, *varTy});
DefId def = dfg->getDef(assign->var);
scope->lvalueTypes[def] = resultTy;
return ControlFlow::None; return ControlFlow::None;
} }
@ -1138,9 +1100,9 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareG
module->declaredGlobals[globalName] = globalTy; module->declaredGlobals[globalName] = globalTy;
rootScope->bindings[global->name] = Binding{globalTy, global->location}; rootScope->bindings[global->name] = Binding{globalTy, global->location};
BreadcrumbId bc = dfg->getBreadcrumb(global); DefId def = dfg->getDef(global);
rootScope->lvalueTypes[bc->def] = globalTy; rootScope->lvalueTypes[def] = globalTy;
rootScope->rvalueRefinements[bc->def] = globalTy; rootScope->rvalueRefinements[def] = globalTy;
return ControlFlow::None; return ControlFlow::None;
} }
@ -1310,9 +1272,9 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareF
module->declaredGlobals[fnName] = fnType; module->declaredGlobals[fnName] = fnType;
scope->bindings[global->name] = Binding{fnType, global->location}; scope->bindings[global->name] = Binding{fnType, global->location};
BreadcrumbId bc = dfg->getBreadcrumb(global); DefId def = dfg->getDef(global);
rootScope->lvalueTypes[bc->def] = fnType; rootScope->lvalueTypes[def] = fnType;
rootScope->rvalueRefinements[bc->def] = fnType; rootScope->rvalueRefinements[def] = fnType;
return ControlFlow::None; return ControlFlow::None;
} }
@ -1409,10 +1371,10 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
exprArgs.push_back(indexExpr->expr); exprArgs.push_back(indexExpr->expr);
if (auto bc = dfg->getBreadcrumb(indexExpr->expr)) if (auto key = dfg->getRefinementKey(indexExpr->expr))
{ {
TypeId discriminantTy = arena->addType(BlockedType{}); TypeId discriminantTy = arena->addType(BlockedType{});
returnRefinements.push_back(refinementArena.proposition(NotNull{bc}, discriminantTy)); returnRefinements.push_back(refinementArena.proposition(key, discriminantTy));
discriminantTypes.push_back(discriminantTy); discriminantTypes.push_back(discriminantTy);
} }
else else
@ -1423,10 +1385,10 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
{ {
exprArgs.push_back(arg); exprArgs.push_back(arg);
if (auto bc = dfg->getBreadcrumb(arg)) if (auto key = dfg->getRefinementKey(arg))
{ {
TypeId discriminantTy = arena->addType(BlockedType{}); TypeId discriminantTy = arena->addType(BlockedType{});
returnRefinements.push_back(refinementArena.proposition(NotNull{bc}, discriminantTy)); returnRefinements.push_back(refinementArena.proposition(key, discriminantTy));
discriminantTypes.push_back(discriminantTy); discriminantTypes.push_back(discriminantTy);
} }
else else
@ -1525,9 +1487,12 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
{ {
scope->bindings[targetLocal->local].typeId = resultTy; scope->bindings[targetLocal->local].typeId = resultTy;
BreadcrumbId bc = dfg->getBreadcrumb(targetLocal); DefId def = dfg->getDef(targetLocal);
scope->lvalueTypes[bc->def] = resultTy; // TODO: typestates: track this as an assignment scope->lvalueTypes[def] = resultTy; // TODO: typestates: track this as an assignment
scope->rvalueRefinements[bc->def] = resultTy; // TODO: typestates: track this as an assignment scope->rvalueRefinements[def] = resultTy; // TODO: typestates: track this as an assignment
if (auto it = inferredBindings.find(targetLocal->local); it != inferredBindings.end())
it->second.types.insert(resultTy);
} }
return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}};
@ -1686,27 +1651,51 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local)
{ {
BreadcrumbId bc = dfg->getBreadcrumb(local); const RefinementKey* key = dfg->getRefinementKey(local);
std::optional<DefId> rvalueDef = dfg->getRValueDefForCompoundAssign(local);
LUAU_ASSERT(key || rvalueDef);
if (auto ty = scope->lookup(bc->def)) std::optional<TypeId> maybeTy;
return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)};
// if we have a refinement key, we can look up its type.
if (key)
maybeTy = scope->lookup(key->def);
// if the current def doesn't have a type, we might be doing a compound assignment
// and therefore might need to look at the rvalue def instead.
if (!maybeTy && rvalueDef)
maybeTy = scope->lookup(*rvalueDef);
if (maybeTy)
{
TypeId ty = follow(*maybeTy);
if (auto it = inferredBindings.find(local->local); it != inferredBindings.end())
it->second.types.insert(ty);
return Inference{ty, refinementArena.proposition(key, builtinTypes->truthyType)};
}
else else
ice->ice("CGB: AstExprLocal came before its declaration?"); ice->ice("CGB: AstExprLocal came before its declaration?");
} }
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global)
{ {
BreadcrumbId bc = dfg->getBreadcrumb(global); const RefinementKey* key = dfg->getRefinementKey(global);
std::optional<DefId> rvalueDef = dfg->getRValueDefForCompoundAssign(global);
LUAU_ASSERT(key || rvalueDef);
// we'll use whichever of the two definitions we have here.
DefId def = key ? key->def : *rvalueDef;
/* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any
* global that is not already in-scope is definitely an unknown symbol. * global that is not already in-scope is definitely an unknown symbol.
*/ */
if (auto ty = scope->lookup(bc->def)) if (auto ty = scope->lookup(def))
return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)};
else if (auto ty = scope->lookup(global->name)) else if (auto ty = scope->lookup(global->name))
{ {
rootScope->rvalueRefinements[bc->def] = *ty; rootScope->rvalueRefinements[key->def] = *ty;
return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)};
} }
else else
{ {
@ -1720,19 +1709,19 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName*
TypeId obj = check(scope, indexName->expr).ty; TypeId obj = check(scope, indexName->expr).ty;
TypeId result = arena->addType(BlockedType{}); TypeId result = arena->addType(BlockedType{});
NullableBreadcrumbId bc = dfg->getBreadcrumb(indexName); const RefinementKey* key = dfg->getRefinementKey(indexName);
if (bc) if (key)
{ {
if (auto ty = scope->lookup(bc->def)) if (auto ty = scope->lookup(key->def))
return Inference{*ty, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)};
scope->rvalueRefinements[bc->def] = result; scope->rvalueRefinements[key->def] = result;
} }
addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value}); addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value});
if (bc) if (key)
return Inference{result, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)};
else else
return Inference{result}; return Inference{result};
} }
@ -1743,13 +1732,13 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr*
TypeId indexType = check(scope, indexExpr->index).ty; TypeId indexType = check(scope, indexExpr->index).ty;
TypeId result = freshType(scope); TypeId result = freshType(scope);
NullableBreadcrumbId bc = dfg->getBreadcrumb(indexExpr); const RefinementKey* key = dfg->getRefinementKey(indexExpr);
if (bc) if (key)
{ {
if (auto ty = scope->lookup(bc->def)) if (auto ty = scope->lookup(key->def))
return Inference{*ty, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)};
scope->rvalueRefinements[bc->def] = result; scope->rvalueRefinements[key->def] = result;
} }
TableIndexer indexer{indexType, result}; TableIndexer indexer{indexType, result};
@ -1757,8 +1746,8 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr*
addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType});
if (bc) if (key)
return Inference{result, refinementArena.proposition(NotNull{bc}, builtinTypes->truthyType)}; return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)};
else else
return Inference{result}; return Inference{result};
} }
@ -1812,12 +1801,28 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* una
addConstraint(scope, unary->location, ReduceConstraint{resultType}); addConstraint(scope, unary->location, ReduceConstraint{resultType});
return Inference{resultType, refinementArena.negation(refinement)}; return Inference{resultType, refinementArena.negation(refinement)};
} }
default: case AstExprUnary::Op::Len:
{ {
TypeId resultType = arena->addType(BlockedType{}); TypeId resultType = arena->addType(TypeFamilyInstanceType{
addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); NotNull{&kBuiltinTypeFamilies.lenFamily},
return Inference{resultType}; {operandType},
{},
});
addConstraint(scope, unary->location, ReduceConstraint{resultType});
return Inference{resultType, refinementArena.negation(refinement)};
} }
case AstExprUnary::Op::Minus:
{
TypeId resultType = arena->addType(TypeFamilyInstanceType{
NotNull{&kBuiltinTypeFamilies.unmFamily},
{operandType},
{},
});
addConstraint(scope, unary->location, ReduceConstraint{resultType});
return Inference{resultType, refinementArena.negation(refinement)};
}
default: // msvc can't prove that this is exhaustive.
LUAU_UNREACHABLE();
} }
} }
@ -1978,13 +1983,10 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* bi
addConstraint(scope, binary->location, ReduceConstraint{resultType}); addConstraint(scope, binary->location, ReduceConstraint{resultType});
return Inference{resultType, std::move(refinement)}; return Inference{resultType, std::move(refinement)};
} }
default: case AstExprBinary::Op::Op__Count:
{ ice->ice("Op__Count should never be generated in an AST.");
TypeId resultType = arena->addType(BlockedType{}); default: // msvc can't prove that this is exhaustive.
addConstraint(scope, binary->location, LUAU_UNREACHABLE();
BinaryConstraint{binary->op, leftType, rightType, resultType, binary, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes});
return Inference{resultType, std::move(refinement)};
}
} }
} }
@ -2056,9 +2058,10 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
TypeId leftType = check(scope, binary->left).ty; TypeId leftType = check(scope, binary->left).ty;
TypeId rightType = check(scope, binary->right).ty; TypeId rightType = check(scope, binary->right).ty;
NullableBreadcrumbId bc = dfg->getBreadcrumb(typeguard->target); const RefinementKey* key = dfg->getRefinementKey(typeguard->target);
if (!bc) if (!key)
return {leftType, rightType, nullptr}; return {leftType, rightType, nullptr};
auto augmentForErrorSupression = [&](TypeId ty) -> TypeId { auto augmentForErrorSupression = [&](TypeId ty) -> TypeId {
return arena->addType(UnionType{{ty, builtinTypes->errorType}}); return arena->addType(UnionType{{ty, builtinTypes->errorType}});
}; };
@ -2096,7 +2099,7 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
discriminantTy = ty; discriminantTy = ty;
} }
RefinementId proposition = refinementArena.proposition(NotNull{bc}, discriminantTy); RefinementId proposition = refinementArena.proposition(key, discriminantTy);
if (binary->op == AstExprBinary::CompareEq) if (binary->op == AstExprBinary::CompareEq)
return {leftType, rightType, proposition}; return {leftType, rightType, proposition};
else if (binary->op == AstExprBinary::CompareNe) else if (binary->op == AstExprBinary::CompareNe)
@ -2111,13 +2114,8 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
TypeId leftType = check(scope, binary->left, {}, true).ty; TypeId leftType = check(scope, binary->left, {}, true).ty;
TypeId rightType = check(scope, binary->right, {}, true).ty; TypeId rightType = check(scope, binary->right, {}, true).ty;
RefinementId leftRefinement = nullptr; RefinementId leftRefinement = refinementArena.proposition(dfg->getRefinementKey(binary->left), rightType);
if (auto bc = dfg->getBreadcrumb(binary->left)) RefinementId rightRefinement = refinementArena.proposition(dfg->getRefinementKey(binary->right), leftType);
leftRefinement = refinementArena.proposition(NotNull{bc}, rightType);
RefinementId rightRefinement = nullptr;
if (auto bc = dfg->getBreadcrumb(binary->right))
rightRefinement = refinementArena.proposition(NotNull{bc}, leftType);
if (binary->op == AstExprBinary::CompareNe) if (binary->op == AstExprBinary::CompareNe)
{ {
@ -2135,7 +2133,7 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
} }
} }
TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) std::optional<TypeId> ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
{ {
if (auto local = expr->as<AstExprLocal>()) if (auto local = expr->as<AstExprLocal>())
return checkLValue(scope, local); return checkLValue(scope, local);
@ -2154,24 +2152,42 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
ice->ice("checkLValue is inexhaustive"); ice->ice("checkLValue is inexhaustive");
} }
TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExprLocal* local) std::optional<TypeId> ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExprLocal* local)
{ {
std::optional<TypeId> upperBound = scope->lookup(Symbol{local->local}); /*
LUAU_ASSERT(upperBound); * The caller of this method uses the returned type to emit the proper
return *upperBound; * SubtypeConstraint.
*
* At this point during constraint generation, the binding table is only
* populated by symbols that have type annotations.
*
* If this local has an interesting type annotation, it is important that we
* return that.
*/
std::optional<TypeId> annotatedTy = scope->lookup(local->local);
if (annotatedTy)
return annotatedTy;
/*
* As a safety measure, we'll assert that no type has yet been ascribed to
* the corresponding def. We'll populate this when we generate
* constraints for assignment and compound assignment statements.
*/
LUAU_ASSERT(!scope->lookupLValue(dfg->getDef(local)));
return std::nullopt;
} }
TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExprGlobal* global) std::optional<TypeId> ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExprGlobal* global)
{ {
return scope->lookup(Symbol{global->name}).value_or(builtinTypes->errorRecoveryType()); return scope->lookup(Symbol{global->name});
} }
TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExprIndexName* indexName) std::optional<TypeId> ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExprIndexName* indexName)
{ {
return updateProperty(scope, indexName); return updateProperty(scope, indexName);
} }
TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr) std::optional<TypeId> ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr)
{ {
return updateProperty(scope, indexExpr); return updateProperty(scope, indexExpr);
} }
@ -2226,7 +2242,7 @@ TypeId ConstraintGraphBuilder::updateProperty(const ScopePtr& scope, AstExpr* ex
return check(scope, expr).ty; return check(scope, expr).ty;
Symbol sym; Symbol sym;
NullableBreadcrumbId bc = nullptr; const Def* def = nullptr;
std::vector<std::string> segments; std::vector<std::string> segments;
std::vector<AstExpr*> exprs; std::vector<AstExpr*> exprs;
@ -2236,13 +2252,13 @@ TypeId ConstraintGraphBuilder::updateProperty(const ScopePtr& scope, AstExpr* ex
if (auto global = e->as<AstExprGlobal>()) if (auto global = e->as<AstExprGlobal>())
{ {
sym = global->name; sym = global->name;
bc = dfg->getBreadcrumb(global); def = dfg->getDef(global);
break; break;
} }
else if (auto local = e->as<AstExprLocal>()) else if (auto local = e->as<AstExprLocal>())
{ {
sym = local->local; sym = local->local;
bc = dfg->getBreadcrumb(local); def = dfg->getDef(local);
break; break;
} }
else if (auto indexName = e->as<AstExprIndexName>()) else if (auto indexName = e->as<AstExprIndexName>())
@ -2275,20 +2291,12 @@ TypeId ConstraintGraphBuilder::updateProperty(const ScopePtr& scope, AstExpr* ex
std::reverse(begin(segments), end(segments)); std::reverse(begin(segments), end(segments));
std::reverse(begin(exprs), end(exprs)); std::reverse(begin(exprs), end(exprs));
auto lookupResult = scope->lookupEx(sym); LUAU_ASSERT(def);
std::optional<std::pair<TypeId, Scope*>> lookupResult = scope->lookupEx(NotNull{def});
if (!lookupResult) if (!lookupResult)
return check(scope, expr).ty; return check(scope, expr).ty;
const auto [subjectBinding, symbolScope] = std::move(*lookupResult);
LUAU_ASSERT(bc); const auto [subjectType, subjectScope] = *lookupResult;
std::optional<TypeId> subjectTy = scope->lookup(bc->def);
/* If we have a breadcrumb but no type, it can only mean that we're setting
* a property of some builtin table. This isn't legal, but we still want to
* wire up the constraints properly so that we can report why it is not
* legal.
*/
TypeId subjectType = subjectTy.value_or(subjectBinding->typeId);
TypeId propTy = freshType(scope); TypeId propTy = freshType(scope);
@ -2311,20 +2319,29 @@ TypeId ConstraintGraphBuilder::updateProperty(const ScopePtr& scope, AstExpr* ex
if (!subjectType->persistent) if (!subjectType->persistent)
{ {
symbolScope->bindings[sym].typeId = updatedType; subjectScope->bindings[sym].typeId = updatedType;
// This can fail if the user is erroneously trying to augment a builtin // This can fail if the user is erroneously trying to augment a builtin
// table like os or string. // table like os or string.
if (auto bc = dfg->getBreadcrumb(e)) if (auto key = dfg->getRefinementKey(e))
{ {
symbolScope->lvalueTypes[bc->def] = updatedType; subjectScope->lvalueTypes[key->def] = updatedType;
symbolScope->rvalueRefinements[bc->def] = updatedType; subjectScope->rvalueRefinements[key->def] = updatedType;
} }
} }
return propTy; return propTy;
} }
void ConstraintGraphBuilder::updateLValueType(AstExpr* lvalue, TypeId ty)
{
if (auto local = lvalue->as<AstExprLocal>())
{
if (auto it = inferredBindings.find(local->local); it != inferredBindings.end())
it->second.types.insert(ty);
}
}
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType) Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType)
{ {
const bool expectedTypeIsFree = expectedType && get<FreeType>(follow(*expectedType)); const bool expectedTypeIsFree = expectedType && get<FreeType>(follow(*expectedType));
@ -2525,9 +2542,9 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
argNames.emplace_back(FunctionArgument{fn->self->name.value, fn->self->location}); argNames.emplace_back(FunctionArgument{fn->self->name.value, fn->self->location});
signatureScope->bindings[fn->self] = Binding{selfType, fn->self->location}; signatureScope->bindings[fn->self] = Binding{selfType, fn->self->location};
BreadcrumbId bc = dfg->getBreadcrumb(fn->self); DefId def = dfg->getDef(fn->self);
signatureScope->lvalueTypes[bc->def] = selfType; signatureScope->lvalueTypes[def] = selfType;
signatureScope->rvalueRefinements[bc->def] = selfType; signatureScope->rvalueRefinements[def] = selfType;
} }
for (size_t i = 0; i < fn->args.size; ++i) for (size_t i = 0; i < fn->args.size; ++i)
@ -2552,14 +2569,13 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
signatureScope->bindings[local] = Binding{argTy, local->location}; signatureScope->bindings[local] = Binding{argTy, local->location};
else else
{ {
BreadcrumbId bc = dfg->getBreadcrumb(local);
signatureScope->bindings[local] = Binding{builtinTypes->neverType, local->location}; signatureScope->bindings[local] = Binding{builtinTypes->neverType, local->location};
inferredBindings.emplace_back(signatureScope.get(), local, bc); inferredBindings[local] = {signatureScope.get(), {}};
} }
BreadcrumbId bc = dfg->getBreadcrumb(local); DefId def = dfg->getDef(local);
signatureScope->lvalueTypes[bc->def] = argTy; signatureScope->lvalueTypes[def] = argTy;
signatureScope->rvalueRefinements[bc->def] = argTy; signatureScope->rvalueRefinements[def] = argTy;
} }
TypePackId varargPack = nullptr; TypePackId varargPack = nullptr;
@ -2854,7 +2870,10 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b
} }
else if (auto boolAnnotation = ty->as<AstTypeSingletonBool>()) else if (auto boolAnnotation = ty->as<AstTypeSingletonBool>())
{ {
result = arena->addType(SingletonType(BooleanSingleton{boolAnnotation->value})); if (boolAnnotation->value)
result = builtinTypes->trueType;
else
result = builtinTypes->falseType;
} }
else if (auto stringAnnotation = ty->as<AstTypeSingletonString>()) else if (auto stringAnnotation = ty->as<AstTypeSingletonString>())
{ {
@ -3042,10 +3061,8 @@ struct GlobalPrepopulator : AstVisitor
TypeId bt = arena->addType(BlockedType{}); TypeId bt = arena->addType(BlockedType{});
globalScope->bindings[g->name] = Binding{bt}; globalScope->bindings[g->name] = Binding{bt};
NullableBreadcrumbId bc = dfg->getBreadcrumb(function->name); DefId def = dfg->getDef(function->name);
LUAU_ASSERT(bc); globalScope->lvalueTypes[def] = bt;
globalScope->lvalueTypes[bc->def] = bt;
} }
return true; return true;
@ -3064,32 +3081,16 @@ void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope,
void ConstraintGraphBuilder::fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block) void ConstraintGraphBuilder::fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block)
{ {
std::deque<BreadcrumbId> queue; for (const auto& [symbol, p] : inferredBindings)
for (const auto& [scope, symbol, breadcrumb] : inferredBindings)
{ {
LUAU_ASSERT(queue.empty()); const auto& [scope, location, types] = p;
queue.push_back(breadcrumb); std::vector<TypeId> tys(types.begin(), types.end());
TypeId ty = builtinTypes->neverType; TypeId ty = arena->addType(BlockedType{});
addConstraint(globalScope, Location{}, SetOpConstraint{SetOpConstraint::Union, ty, std::move(tys)});
while (!queue.empty()) scope->bindings[symbol] = Binding{ty, location};
{
const BreadcrumbId bc = queue.front();
queue.pop_front();
TypeId* lvalueType = scope->lvalueTypes.find(bc->def);
if (!lvalueType)
continue;
ty = simplifyUnion(builtinTypes, arena, ty, *lvalueType).result;
for (BreadcrumbId child : bc->children)
queue.push_back(child);
}
scope->bindings[symbol].typeId = ty;
} }
} }

View File

@ -516,10 +516,6 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*gc, constraint, force); success = tryDispatch(*gc, constraint, force);
else if (auto ic = get<InstantiationConstraint>(*constraint)) else if (auto ic = get<InstantiationConstraint>(*constraint))
success = tryDispatch(*ic, constraint, force); success = tryDispatch(*ic, constraint, force);
else if (auto uc = get<UnaryConstraint>(*constraint))
success = tryDispatch(*uc, constraint, force);
else if (auto bc = get<BinaryConstraint>(*constraint))
success = tryDispatch(*bc, constraint, force);
else if (auto ic = get<IterableConstraint>(*constraint)) else if (auto ic = get<IterableConstraint>(*constraint))
success = tryDispatch(*ic, constraint, force); success = tryDispatch(*ic, constraint, force);
else if (auto nc = get<NameConstraint>(*constraint)) else if (auto nc = get<NameConstraint>(*constraint))
@ -542,6 +538,8 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*uc, constraint); success = tryDispatch(*uc, constraint);
else if (auto rc = get<RefineConstraint>(*constraint)) else if (auto rc = get<RefineConstraint>(*constraint))
success = tryDispatch(*rc, constraint, force); success = tryDispatch(*rc, constraint, force);
else if (auto soc = get<SetOpConstraint>(*constraint))
success = tryDispatch(*soc, constraint, force);
else if (auto rc = get<ReduceConstraint>(*constraint)) else if (auto rc = get<ReduceConstraint>(*constraint))
success = tryDispatch(*rc, constraint, force); success = tryDispatch(*rc, constraint, force);
else if (auto rpc = get<ReducePackConstraint>(*constraint)) else if (auto rpc = get<ReducePackConstraint>(*constraint))
@ -652,335 +650,6 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull<con
return true; return true;
} }
bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull<const Constraint> constraint, bool force)
{
TypeId operandType = follow(c.operandType);
if (isBlocked(operandType))
return block(operandType, constraint);
if (!force && get<FreeType>(operandType))
return block(operandType, constraint);
LUAU_ASSERT(get<BlockedType>(c.resultType));
switch (c.op)
{
case AstExprUnary::Not:
{
asMutable(c.resultType)->ty.emplace<BoundType>(builtinTypes->booleanType);
unblock(c.resultType, constraint->location);
return true;
}
case AstExprUnary::Len:
{
// __len must return a number.
asMutable(c.resultType)->ty.emplace<BoundType>(builtinTypes->numberType);
unblock(c.resultType, constraint->location);
return true;
}
case AstExprUnary::Minus:
{
if (isNumber(operandType) || get<AnyType>(operandType) || get<ErrorType>(operandType) || get<NeverType>(operandType))
{
asMutable(c.resultType)->ty.emplace<BoundType>(c.operandType);
}
else if (std::optional<TypeId> mm = findMetatableEntry(builtinTypes, errors, operandType, "__unm", constraint->location))
{
TypeId mmTy = follow(*mm);
if (get<FreeType>(mmTy) && !force)
return block(mmTy, constraint);
TypePackId argPack = arena->addTypePack(TypePack{{operandType}, {}});
TypePackId retPack = arena->addTypePack(BlockedTypePack{});
TypeId res = freshType(arena, builtinTypes, constraint->scope);
asMutable(c.resultType)->ty.emplace<BoundType>(res);
pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{retPack, arena->addTypePack(TypePack{{c.resultType}})});
pushConstraint(constraint->scope, constraint->location, FunctionCallConstraint{mmTy, argPack, retPack, nullptr});
}
else
{
asMutable(c.resultType)->ty.emplace<BoundType>(builtinTypes->errorRecoveryType());
}
unblock(c.resultType, constraint->location);
return true;
}
}
LUAU_ASSERT(false);
return false;
}
bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull<const Constraint> constraint, bool force)
{
TypeId leftType = follow(c.leftType);
TypeId rightType = follow(c.rightType);
TypeId resultType = follow(c.resultType);
LUAU_ASSERT(get<BlockedType>(resultType));
bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or;
/* Compound assignments create constraints of the form
*
* A <: Binary<op, A, B>
*
* This constraint is the one that is meant to unblock A, so it doesn't
* make any sense to stop and wait for someone else to do it.
*/
// If any is present, the expression must evaluate to any as well.
bool leftAny = get<AnyType>(leftType) || get<ErrorType>(leftType);
bool rightAny = get<AnyType>(rightType) || get<ErrorType>(rightType);
bool anyPresent = leftAny || rightAny;
if (isBlocked(leftType) && leftType != resultType)
return block(c.leftType, constraint);
if (isBlocked(rightType) && rightType != resultType)
return block(c.rightType, constraint);
if (!force)
{
// Logical expressions may proceed if the LHS is free.
if (hasTypeInIntersection<FreeType>(leftType) && !isLogical)
return block(leftType, constraint);
}
// Logical expressions may proceed if the LHS is free.
if (isBlocked(leftType) || (hasTypeInIntersection<FreeType>(leftType) && !isLogical))
{
asMutable(resultType)->ty.emplace<BoundType>(errorRecoveryType());
unblock(resultType, constraint->location);
return true;
}
// Metatables go first, even if there is primitive behavior.
if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end())
{
LUAU_ASSERT(FFlag::LuauFloorDivision || c.op != AstExprBinary::Op::FloorDiv);
// Metatables are not the same. The metamethod will not be invoked.
if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) &&
getMetatable(leftType, builtinTypes) != getMetatable(rightType, builtinTypes))
{
// TODO: Boolean singleton false? The result is _always_ boolean false.
asMutable(resultType)->ty.emplace<BoundType>(builtinTypes->booleanType);
unblock(resultType, constraint->location);
return true;
}
std::optional<TypeId> mm;
// The LHS metatable takes priority over the RHS metatable, where
// present.
if (std::optional<TypeId> leftMm = findMetatableEntry(builtinTypes, errors, leftType, it->second, constraint->location))
mm = leftMm;
else if (std::optional<TypeId> rightMm = findMetatableEntry(builtinTypes, errors, rightType, it->second, constraint->location))
mm = rightMm;
if (mm)
{
std::optional<TypeId> instantiatedMm = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, *mm);
if (!instantiatedMm)
{
reportError(CodeTooComplex{}, constraint->location);
return true;
}
// TODO: Is a table with __call legal here?
// TODO: Overloads
if (const FunctionType* ftv = get<FunctionType>(follow(*instantiatedMm)))
{
TypePackId inferredArgs;
// For >= and > we invoke __lt and __le respectively with
// swapped argument ordering.
if (c.op == AstExprBinary::Op::CompareGe || c.op == AstExprBinary::Op::CompareGt)
{
inferredArgs = arena->addTypePack({rightType, leftType});
}
else
{
inferredArgs = arena->addTypePack({leftType, rightType});
}
unify(constraint->scope, constraint->location, inferredArgs, ftv->argTypes);
TypeId mmResult;
// Comparison operations always evaluate to a boolean,
// regardless of what the metamethod returns.
switch (c.op)
{
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
mmResult = builtinTypes->booleanType;
break;
default:
if (get<NeverType>(leftType) || get<NeverType>(rightType))
mmResult = builtinTypes->neverType;
else
mmResult = first(ftv->retTypes).value_or(errorRecoveryType());
}
asMutable(resultType)->ty.emplace<BoundType>(mmResult);
unblock(resultType, constraint->location);
(*c.astOriginalCallTypes)[c.astFragment] = *mm;
(*c.astOverloadResolvedTypes)[c.astFragment] = *instantiatedMm;
return true;
}
}
// If there's no metamethod available, fall back to primitive behavior.
}
switch (c.op)
{
// For arithmetic operators, if the LHS is a number, the RHS must be a
// number as well. The result will also be a number.
case AstExprBinary::Op::Add:
case AstExprBinary::Op::Sub:
case AstExprBinary::Op::Mul:
case AstExprBinary::Op::Div:
case AstExprBinary::Op::FloorDiv:
case AstExprBinary::Op::Pow:
case AstExprBinary::Op::Mod:
{
LUAU_ASSERT(FFlag::LuauFloorDivision || c.op != AstExprBinary::Op::FloorDiv);
const NormalizedType* normLeftTy = normalizer->normalize(leftType);
if (hasTypeInIntersection<FreeType>(leftType) && force)
asMutable(leftType)->ty.emplace<BoundType>(anyPresent ? builtinTypes->anyType : builtinTypes->numberType);
// We want to check if the left type has tops because `any` is a valid type for the lhs
if (normLeftTy && (normLeftTy->isExactlyNumber() || get<AnyType>(normLeftTy->tops)))
{
unify(constraint->scope, constraint->location, leftType, rightType);
asMutable(resultType)->ty.emplace<BoundType>(anyPresent ? builtinTypes->anyType : leftType);
unblock(resultType, constraint->location);
return true;
}
else if (get<NeverType>(leftType) || get<NeverType>(rightType))
{
unify(constraint->scope, constraint->location, leftType, rightType);
asMutable(resultType)->ty.emplace<BoundType>(builtinTypes->neverType);
unblock(resultType, constraint->location);
return true;
}
break;
}
// For concatenation, if the LHS is a string, the RHS must be a string as
// well. The result will also be a string.
case AstExprBinary::Op::Concat:
{
if (hasTypeInIntersection<FreeType>(leftType) && force)
asMutable(leftType)->ty.emplace<BoundType>(anyPresent ? builtinTypes->anyType : builtinTypes->stringType);
const NormalizedType* leftNormTy = normalizer->normalize(leftType);
if (leftNormTy && leftNormTy->isSubtypeOfString())
{
unify(constraint->scope, constraint->location, leftType, rightType);
asMutable(resultType)->ty.emplace<BoundType>(anyPresent ? builtinTypes->anyType : leftType);
unblock(resultType, constraint->location);
return true;
}
else if (get<NeverType>(leftType) || get<NeverType>(rightType))
{
unify(constraint->scope, constraint->location, leftType, rightType);
asMutable(resultType)->ty.emplace<BoundType>(builtinTypes->neverType);
unblock(resultType, constraint->location);
return true;
}
break;
}
// Inexact comparisons require that the types be both numbers or both
// strings, and evaluate to a boolean.
case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe:
case AstExprBinary::Op::CompareLt:
{
const NormalizedType* lt = normalizer->normalize(leftType);
const NormalizedType* rt = normalizer->normalize(rightType);
// If the lhs is any, comparisons should be valid.
if (lt && rt && (lt->isExactlyNumber() || get<AnyType>(lt->tops)) && rt->isExactlyNumber())
{
asMutable(resultType)->ty.emplace<BoundType>(builtinTypes->booleanType);
unblock(resultType, constraint->location);
return true;
}
if (lt && rt && (lt->isSubtypeOfString() || get<AnyType>(lt->tops)) && rt->isSubtypeOfString())
{
asMutable(resultType)->ty.emplace<BoundType>(builtinTypes->booleanType);
unblock(resultType, constraint->location);
return true;
}
if (get<NeverType>(leftType) || get<NeverType>(rightType))
{
asMutable(resultType)->ty.emplace<BoundType>(builtinTypes->booleanType);
unblock(resultType, constraint->location);
return true;
}
break;
}
// == and ~= always evaluate to a boolean, and impose no other constraints
// on their parameters.
case AstExprBinary::Op::CompareEq:
case AstExprBinary::Op::CompareNe:
asMutable(resultType)->ty.emplace<BoundType>(builtinTypes->booleanType);
unblock(resultType, constraint->location);
return true;
// And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is
// truthy.
case AstExprBinary::Op::And:
{
TypeId leftFilteredTy = simplifyIntersection(builtinTypes, arena, leftType, builtinTypes->falsyType).result;
asMutable(resultType)->ty.emplace<BoundType>(simplifyUnion(builtinTypes, arena, rightType, leftFilteredTy).result);
unblock(resultType, constraint->location);
return true;
}
// Or evaluates to the LHS type if the LHS is truthy, and the RHS type if
// LHS is falsey.
case AstExprBinary::Op::Or:
{
TypeId leftFilteredTy = simplifyIntersection(builtinTypes, arena, leftType, builtinTypes->truthyType).result;
asMutable(resultType)->ty.emplace<BoundType>(simplifyUnion(builtinTypes, arena, rightType, leftFilteredTy).result);
unblock(resultType, constraint->location);
return true;
}
default:
iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location);
break;
}
// We failed to either evaluate a metamethod or invoke primitive behavior.
unify(constraint->scope, constraint->location, leftType, errorRecoveryType());
unify(constraint->scope, constraint->location, rightType, errorRecoveryType());
asMutable(resultType)->ty.emplace<BoundType>(errorRecoveryType());
unblock(resultType, constraint->location);
return true;
}
bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force) bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force)
{ {
/* /*
@ -1946,6 +1615,32 @@ bool ConstraintSolver::tryDispatch(const RefineConstraint& c, NotNull<const Cons
return true; return true;
} }
bool ConstraintSolver::tryDispatch(const SetOpConstraint& c, NotNull<const Constraint> constraint, bool force)
{
bool blocked = false;
for (TypeId ty : c.types)
{
if (isBlocked(ty))
{
blocked = true;
block(ty, constraint);
}
}
if (blocked && !force)
return false;
LUAU_ASSERT(SetOpConstraint::Union == c.mode);
TypeId res = builtinTypes->neverType;
for (TypeId ty : c.types)
res = simplifyUnion(builtinTypes, arena, res, ty).result;
asMutable(c.resultType)->ty.emplace<BoundType>(res);
return true;
}
bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force) bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force)
{ {
TypeId ty = follow(c.ty); TypeId ty = follow(c.ty);

View File

@ -2,11 +2,12 @@
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Breadcrumb.h" #include "Luau/Def.h"
#include "Luau/Common.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Refinement.h"
#include <algorithm> #include <algorithm>
#include <optional>
LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
@ -14,74 +15,81 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau namespace Luau
{ {
NullableBreadcrumbId DataFlowGraph::getBreadcrumb(const AstExpr* expr) const const RefinementKey* RefinementKeyArena::leaf(DefId def)
{ {
// We need to skip through AstExprGroup because DFG doesn't try its best to transitively return allocator.allocate(RefinementKey{nullptr, def, std::nullopt});
while (auto group = expr->as<AstExprGroup>()) }
expr = group->expr;
if (auto bc = astBreadcrumbs.find(expr)) const RefinementKey* RefinementKeyArena::node(const RefinementKey* parent, DefId def, const std::string& propName)
return *bc; {
return allocator.allocate(RefinementKey{parent, def, propName});
}
DefId DataFlowGraph::getDef(const AstExpr* expr) const
{
auto def = astDefs.find(expr);
LUAU_ASSERT(def);
return NotNull{*def};
}
std::optional<DefId> DataFlowGraph::getRValueDefForCompoundAssign(const AstExpr* expr) const
{
auto def = compoundAssignBreadcrumbs.find(expr);
return def ? std::optional<DefId>(*def) : std::nullopt;
}
DefId DataFlowGraph::getDef(const AstLocal* local) const
{
auto def = localDefs.find(local);
LUAU_ASSERT(def);
return NotNull{*def};
}
DefId DataFlowGraph::getDef(const AstStatDeclareGlobal* global) const
{
auto def = declaredDefs.find(global);
LUAU_ASSERT(def);
return NotNull{*def};
}
DefId DataFlowGraph::getDef(const AstStatDeclareFunction* func) const
{
auto def = declaredDefs.find(func);
LUAU_ASSERT(def);
return NotNull{*def};
}
const RefinementKey* DataFlowGraph::getRefinementKey(const AstExpr* expr) const
{
if (auto key = astRefinementKeys.find(expr))
return *key;
return nullptr; return nullptr;
} }
BreadcrumbId DataFlowGraph::getBreadcrumb(const AstLocal* local) const std::optional<DefId> DfgScope::lookup(Symbol symbol) const
{
auto bc = localBreadcrumbs.find(local);
LUAU_ASSERT(bc);
return NotNull{*bc};
}
BreadcrumbId DataFlowGraph::getBreadcrumb(const AstExprLocal* local) const
{
auto bc = astBreadcrumbs.find(local);
LUAU_ASSERT(bc);
return NotNull{*bc};
}
BreadcrumbId DataFlowGraph::getBreadcrumb(const AstExprGlobal* global) const
{
auto bc = astBreadcrumbs.find(global);
LUAU_ASSERT(bc);
return NotNull{*bc};
}
BreadcrumbId DataFlowGraph::getBreadcrumb(const AstStatDeclareGlobal* global) const
{
auto bc = declaredBreadcrumbs.find(global);
LUAU_ASSERT(bc);
return NotNull{*bc};
}
BreadcrumbId DataFlowGraph::getBreadcrumb(const AstStatDeclareFunction* func) const
{
auto bc = declaredBreadcrumbs.find(func);
LUAU_ASSERT(bc);
return NotNull{*bc};
}
NullableBreadcrumbId DfgScope::lookup(Symbol symbol) const
{ {
for (const DfgScope* current = this; current; current = current->parent) for (const DfgScope* current = this; current; current = current->parent)
{ {
if (auto breadcrumb = current->bindings.find(symbol)) if (auto def = current->bindings.find(symbol))
return *breadcrumb; return NotNull{*def};
} }
return nullptr; return std::nullopt;
} }
NullableBreadcrumbId DfgScope::lookup(DefId def, const std::string& key) const std::optional<DefId> DfgScope::lookup(DefId def, const std::string& key) const
{ {
for (const DfgScope* current = this; current; current = current->parent) for (const DfgScope* current = this; current; current = current->parent)
{ {
if (auto map = props.find(def)) if (auto map = props.find(def))
{ {
if (auto it = map->find(key); it != map->end()) if (auto it = map->find(key); it != map->end())
return it->second; return NotNull{it->second};
} }
} }
return nullptr; return std::nullopt;
} }
DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalErrorReporter> handle) DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalErrorReporter> handle)
@ -95,8 +103,8 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalE
if (FFlag::DebugLuauFreezeArena) if (FFlag::DebugLuauFreezeArena)
{ {
builder.defs->allocator.freeze(); builder.defArena->allocator.freeze();
builder.breadcrumbs->allocator.freeze(); builder.keyArena->allocator.freeze();
} }
return std::move(builder.graph); return std::move(builder.graph);
@ -217,10 +225,10 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e)
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l)
{ {
// We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`)
std::vector<BreadcrumbId> bcs; std::vector<DefId> defs;
bcs.reserve(l->values.size); defs.reserve(l->values.size);
for (AstExpr* e : l->values) for (AstExpr* e : l->values)
bcs.push_back(visitExpr(scope, e)); defs.push_back(visitExpr(scope, e).def);
for (size_t i = 0; i < l->vars.size; ++i) for (size_t i = 0; i < l->vars.size; ++i)
{ {
@ -228,10 +236,12 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l)
if (local->annotation) if (local->annotation)
visitType(scope, local->annotation); visitType(scope, local->annotation);
// We need to create a new breadcrumb with new defs to intentionally avoid alias tracking. // We need to create a new def to intentionally avoid alias tracking, but we'd like to
BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell(), i < bcs.size() ? bcs[i]->metadata : std::nullopt); // make sure that the non-aliased defs are also marked as a subscript for refinements.
graph.localBreadcrumbs[local] = bc; bool subscripted = i < defs.size() && containsSubscriptedDefinition(defs[i]);
scope->bindings[local] = bc; DefId def = defArena->freshCell(subscripted);
graph.localDefs[local] = def;
scope->bindings[local] = def;
} }
} }
@ -247,10 +257,9 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f)
if (f->var->annotation) if (f->var->annotation)
visitType(forScope, f->var->annotation); visitType(forScope, f->var->annotation);
// TODO: RangeMetadata. DefId def = defArena->freshCell();
BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); graph.localDefs[f->var] = def;
graph.localBreadcrumbs[f->var] = bc; scope->bindings[f->var] = def;
scope->bindings[f->var] = bc;
// TODO(controlflow): entry point has a back edge from exit point // TODO(controlflow): entry point has a back edge from exit point
visit(forScope, f->body); visit(forScope, f->body);
@ -265,10 +274,9 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f)
if (local->annotation) if (local->annotation)
visitType(forScope, local->annotation); visitType(forScope, local->annotation);
// TODO: IterMetadata (different from RangeMetadata) DefId def = defArena->freshCell();
BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); graph.localDefs[local] = def;
graph.localBreadcrumbs[local] = bc; forScope->bindings[local] = def;
forScope->bindings[local] = bc;
} }
// TODO(controlflow): entry point has a back edge from exit point // TODO(controlflow): entry point has a back edge from exit point
@ -281,11 +289,15 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f)
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a)
{ {
for (size_t i = 0; i < std::max(a->vars.size, a->values.size); ++i) std::vector<DefId> defs;
defs.reserve(a->values.size);
for (AstExpr* e : a->values)
defs.push_back(visitExpr(scope, e).def);
for (size_t i = 0; i < a->vars.size; ++i)
{ {
BreadcrumbId bc = i < a->values.size ? visitExpr(scope, a->values.data[i]) : breadcrumbs->add(nullptr, defs->freshCell()); AstExpr* v = a->vars.data[i];
if (i < a->vars.size) visitLValue(scope, v, i < defs.size() ? defs[i] : defArena->freshCell());
visitLValue(scope, a->vars.data[i], bc);
} }
} }
@ -297,9 +309,9 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c)
// //
// local a = 5 -- a-1 // local a = 5 -- a-1
// a += 5 -- a-2 = a-1 + 5 // a += 5 -- a-2 = a-1 + 5
//
// We can't just visit `c->var` as a rvalue and then separately traverse `c->var` as an lvalue, since that's O(n^2). // We can't just visit `c->var` as a rvalue and then separately traverse `c->var` as an lvalue, since that's O(n^2).
visitLValue(scope, c->var, visitExpr(scope, c->value)); DefId def = visitExpr(scope, c->value).def;
visitLValue(scope, c->var, def, /* isCompoundAssignment */ true);
} }
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f)
@ -314,14 +326,17 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f)
// //
// which is evidence that references to variables must be a phi node of all possible definitions, // which is evidence that references to variables must be a phi node of all possible definitions,
// but for bug compatibility, we'll assume the same thing here. // but for bug compatibility, we'll assume the same thing here.
visitLValue(scope, f->name, visitExpr(scope, f->func)); DefId prototype = defArena->freshCell();
visitLValue(scope, f->name, prototype);
visitExpr(scope, f->func);
} }
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l)
{ {
BreadcrumbId bc = visitExpr(scope, l->func); DefId def = defArena->freshCell();
graph.localBreadcrumbs[l->name] = bc; graph.localDefs[l->name] = def;
scope->bindings[l->name] = bc; scope->bindings[l->name] = def;
visitExpr(scope, l->func);
} }
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t)
@ -334,20 +349,18 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t)
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d)
{ {
// TODO: AmbientDeclarationMetadata. DefId def = defArena->freshCell();
BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); graph.declaredDefs[d] = def;
graph.declaredBreadcrumbs[d] = bc; scope->bindings[d->name] = def;
scope->bindings[d->name] = bc;
visitType(scope, d->type); visitType(scope, d->type);
} }
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d) void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d)
{ {
// TODO: AmbientDeclarationMetadata. DefId def = defArena->freshCell();
BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); graph.declaredDefs[d] = def;
graph.declaredBreadcrumbs[d] = bc; scope->bindings[d->name] = def;
scope->bindings[d->name] = bc;
DfgScope* unreachable = childScope(scope); DfgScope* unreachable = childScope(scope);
visitGenerics(unreachable, d->generics); visitGenerics(unreachable, d->generics);
@ -375,116 +388,125 @@ void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatError* error)
visitExpr(unreachable, e); visitExpr(unreachable, e);
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e)
{ {
if (auto g = e->as<AstExprGroup>()) auto go = [&]() -> DataFlowResult {
return visitExpr(scope, g->expr); if (auto g = e->as<AstExprGroup>())
else if (auto c = e->as<AstExprConstantNil>()) return visitExpr(scope, g);
return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as<AstExprConstantNil>())
else if (auto c = e->as<AstExprConstantBool>()) return {defArena->freshCell(), nullptr}; // ok
return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as<AstExprConstantBool>())
else if (auto c = e->as<AstExprConstantNumber>()) return {defArena->freshCell(), nullptr}; // ok
return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as<AstExprConstantNumber>())
else if (auto c = e->as<AstExprConstantString>()) return {defArena->freshCell(), nullptr}; // ok
return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto c = e->as<AstExprConstantString>())
else if (auto l = e->as<AstExprLocal>()) return {defArena->freshCell(), nullptr}; // ok
return visitExpr(scope, l); else if (auto l = e->as<AstExprLocal>())
else if (auto g = e->as<AstExprGlobal>()) return visitExpr(scope, l);
return visitExpr(scope, g); else if (auto g = e->as<AstExprGlobal>())
else if (auto v = e->as<AstExprVarargs>()) return visitExpr(scope, g);
return breadcrumbs->add(nullptr, defs->freshCell()); // ok else if (auto v = e->as<AstExprVarargs>())
else if (auto c = e->as<AstExprCall>()) return {defArena->freshCell(), nullptr}; // ok
return visitExpr(scope, c); else if (auto c = e->as<AstExprCall>())
else if (auto i = e->as<AstExprIndexName>()) return visitExpr(scope, c);
return visitExpr(scope, i); else if (auto i = e->as<AstExprIndexName>())
else if (auto i = e->as<AstExprIndexExpr>()) return visitExpr(scope, i);
return visitExpr(scope, i); else if (auto i = e->as<AstExprIndexExpr>())
else if (auto f = e->as<AstExprFunction>()) return visitExpr(scope, i);
return visitExpr(scope, f); else if (auto f = e->as<AstExprFunction>())
else if (auto t = e->as<AstExprTable>()) return visitExpr(scope, f);
return visitExpr(scope, t); else if (auto t = e->as<AstExprTable>())
else if (auto u = e->as<AstExprUnary>()) return visitExpr(scope, t);
return visitExpr(scope, u); else if (auto u = e->as<AstExprUnary>())
else if (auto b = e->as<AstExprBinary>()) return visitExpr(scope, u);
return visitExpr(scope, b); else if (auto b = e->as<AstExprBinary>())
else if (auto t = e->as<AstExprTypeAssertion>()) return visitExpr(scope, b);
return visitExpr(scope, t); else if (auto t = e->as<AstExprTypeAssertion>())
else if (auto i = e->as<AstExprIfElse>()) return visitExpr(scope, t);
return visitExpr(scope, i); else if (auto i = e->as<AstExprIfElse>())
else if (auto i = e->as<AstExprInterpString>()) return visitExpr(scope, i);
return visitExpr(scope, i); else if (auto i = e->as<AstExprInterpString>())
else if (auto error = e->as<AstExprError>()) return visitExpr(scope, i);
return visitExpr(scope, error); else if (auto error = e->as<AstExprError>())
else return visitExpr(scope, error);
handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitExpr"); else
handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitExpr");
};
auto [def, key] = go();
graph.astDefs[e] = def;
if (key)
graph.astRefinementKeys[e] = key;
return {def, key};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGroup* group)
{ {
NullableBreadcrumbId breadcrumb = scope->lookup(l->local); return visitExpr(scope, group->expr);
if (!breadcrumb)
handle->ice("DFG: AstExprLocal came before its declaration?");
graph.astBreadcrumbs[l] = breadcrumb;
return NotNull{breadcrumb};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l)
{ {
NullableBreadcrumbId bc = scope->lookup(g->name); if (auto def = scope->lookup(l->local))
if (!bc)
{ {
bc = breadcrumbs->add(nullptr, defs->freshCell()); const RefinementKey* key = keyArena->leaf(*def);
moduleScope->bindings[g->name] = bc; return {*def, key};
} }
graph.astBreadcrumbs[g] = bc; handle->ice("DFG: AstExprLocal came before its declaration?");
return NotNull{bc};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g)
{
if (auto def = scope->lookup(g->name))
return {*def, keyArena->leaf(*def)};
DefId def = defArena->freshCell();
moduleScope->bindings[g->name] = def;
return {def, keyArena->leaf(def)};
}
DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c)
{ {
visitExpr(scope, c->func); visitExpr(scope, c->func);
for (AstExpr* arg : c->args) for (AstExpr* arg : c->args)
visitExpr(scope, arg); visitExpr(scope, arg);
return breadcrumbs->add(nullptr, defs->freshCell()); return {defArena->freshCell(), nullptr};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i)
{ {
BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); auto [parentDef, parentKey] = visitExpr(scope, i->expr);
std::string key = i->index.value; std::string index = i->index.value;
NullableBreadcrumbId& propBreadcrumb = moduleScope->props[parentBreadcrumb->def][key]; auto& propDef = moduleScope->props[parentDef][index];
if (!propBreadcrumb) if (!propDef)
propBreadcrumb = breadcrumbs->emplace<FieldMetadata>(parentBreadcrumb, defs->freshCell(), key); propDef = defArena->freshCell();
graph.astBreadcrumbs[i] = propBreadcrumb; return {NotNull{propDef}, keyArena->node(parentKey, NotNull{propDef}, index)};
return NotNull{propBreadcrumb};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i)
{ {
BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); auto [parentDef, parentKey] = visitExpr(scope, i->expr);
BreadcrumbId key = visitExpr(scope, i->index); visitExpr(scope, i->index);
if (auto string = i->index->as<AstExprConstantString>()) if (auto string = i->index->as<AstExprConstantString>())
{ {
std::string key{string->value.data, string->value.size}; std::string index{string->value.data, string->value.size};
NullableBreadcrumbId& propBreadcrumb = moduleScope->props[parentBreadcrumb->def][key]; auto& propDef = moduleScope->props[parentDef][index];
if (!propBreadcrumb) if (!propDef)
propBreadcrumb = breadcrumbs->emplace<FieldMetadata>(parentBreadcrumb, defs->freshCell(), key); propDef = defArena->freshCell();
graph.astBreadcrumbs[i] = NotNull{propBreadcrumb}; return {NotNull{propDef}, keyArena->node(parentKey, NotNull{propDef}, index)};
return NotNull{propBreadcrumb};
} }
return breadcrumbs->emplace<SubscriptMetadata>(nullptr, defs->freshCell(), key); return {defArena->freshCell(/* subscripted= */true), nullptr};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f)
{ {
DfgScope* signatureScope = childScope(scope); DfgScope* signatureScope = childScope(scope);
@ -493,10 +515,9 @@ BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f
// There's no syntax for `self` to have an annotation if using `function t:m()` // There's no syntax for `self` to have an annotation if using `function t:m()`
LUAU_ASSERT(!self->annotation); LUAU_ASSERT(!self->annotation);
// TODO: ParameterMetadata. DefId def = defArena->freshCell();
BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); graph.localDefs[self] = def;
graph.localBreadcrumbs[self] = bc; signatureScope->bindings[self] = def;
signatureScope->bindings[self] = bc;
} }
for (AstLocal* param : f->args) for (AstLocal* param : f->args)
@ -504,10 +525,9 @@ BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f
if (param->annotation) if (param->annotation)
visitType(signatureScope, param->annotation); visitType(signatureScope, param->annotation);
// TODO: ParameterMetadata. DefId def = defArena->freshCell();
BreadcrumbId bc = breadcrumbs->add(nullptr, defs->freshCell()); graph.localDefs[param] = def;
graph.localBreadcrumbs[param] = bc; signatureScope->bindings[param] = def;
signatureScope->bindings[param] = bc;
} }
if (f->varargAnnotation) if (f->varargAnnotation)
@ -526,10 +546,10 @@ BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f
// g() --> 5 // g() --> 5
visit(signatureScope, f->body); visit(signatureScope, f->body);
return breadcrumbs->add(nullptr, defs->freshCell()); return {defArena->freshCell(), nullptr};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t)
{ {
for (AstExprTable::Item item : t->items) for (AstExprTable::Item item : t->items)
{ {
@ -538,120 +558,132 @@ BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t)
visitExpr(scope, item.value); visitExpr(scope, item.value);
} }
return breadcrumbs->add(nullptr, defs->freshCell()); return {defArena->freshCell(), nullptr};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u)
{ {
visitExpr(scope, u->expr); visitExpr(scope, u->expr);
return breadcrumbs->add(nullptr, defs->freshCell()); return {defArena->freshCell(), nullptr};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b)
{ {
visitExpr(scope, b->left); visitExpr(scope, b->left);
visitExpr(scope, b->right); visitExpr(scope, b->right);
return breadcrumbs->add(nullptr, defs->freshCell()); return {defArena->freshCell(), nullptr};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t)
{ {
// TODO: TypeAssertionMetadata? auto [def, key] = visitExpr(scope, t->expr);
BreadcrumbId bc = visitExpr(scope, t->expr);
visitType(scope, t->annotation); visitType(scope, t->annotation);
return bc; return {def, key};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i)
{ {
visitExpr(scope, i->condition); visitExpr(scope, i->condition);
visitExpr(scope, i->trueExpr); visitExpr(scope, i->trueExpr);
visitExpr(scope, i->falseExpr); visitExpr(scope, i->falseExpr);
return breadcrumbs->add(nullptr, defs->freshCell()); return {defArena->freshCell(), nullptr};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i)
{ {
for (AstExpr* e : i->expressions) for (AstExpr* e : i->expressions)
visitExpr(scope, e); visitExpr(scope, e);
return breadcrumbs->add(nullptr, defs->freshCell()); return {defArena->freshCell(), nullptr};
} }
BreadcrumbId DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* error) DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* error)
{ {
DfgScope* unreachable = childScope(scope); DfgScope* unreachable = childScope(scope);
for (AstExpr* e : error->expressions) for (AstExpr* e : error->expressions)
visitExpr(unreachable, e); visitExpr(unreachable, e);
return breadcrumbs->add(nullptr, defs->freshCell()); return {defArena->freshCell(), nullptr};
} }
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, BreadcrumbId bc) void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef, bool isCompoundAssignment)
{ {
if (auto l = e->as<AstExprLocal>()) if (auto l = e->as<AstExprLocal>())
return visitLValue(scope, l, bc); return visitLValue(scope, l, incomingDef, isCompoundAssignment);
else if (auto g = e->as<AstExprGlobal>()) else if (auto g = e->as<AstExprGlobal>())
return visitLValue(scope, g, bc); return visitLValue(scope, g, incomingDef, isCompoundAssignment);
else if (auto i = e->as<AstExprIndexName>()) else if (auto i = e->as<AstExprIndexName>())
return visitLValue(scope, i, bc); return visitLValue(scope, i, incomingDef);
else if (auto i = e->as<AstExprIndexExpr>()) else if (auto i = e->as<AstExprIndexExpr>())
return visitLValue(scope, i, bc); return visitLValue(scope, i, incomingDef);
else if (auto error = e->as<AstExprError>()) else if (auto error = e->as<AstExprError>())
return visitLValue(scope, error, bc); return visitLValue(scope, error, incomingDef);
else else
handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue");
} }
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l, BreadcrumbId bc) void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef, bool isCompoundAssignment)
{ {
// In order to avoid alias tracking, we need to clip the reference to the parent breadcrumb // We need to keep the previous breadcrumb around for a compound assignment.
// as well as the def that was about to be assigned onto this lvalue. However, we want to if (isCompoundAssignment)
// copy the metadata so that refinements can be consistent. {
BreadcrumbId updated = breadcrumbs->add(scope->lookup(l->local), defs->freshCell(), bc->metadata); if (auto def = scope->lookup(l->local))
graph.astBreadcrumbs[l] = updated; graph.compoundAssignBreadcrumbs[l] = *def;
}
// In order to avoid alias tracking, we need to clip the reference to the parent def.
DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astDefs[l] = updated;
scope->bindings[l->local] = updated; scope->bindings[l->local] = updated;
} }
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, BreadcrumbId bc) void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef, bool isCompoundAssignment)
{ {
// In order to avoid alias tracking, we need to clip the reference to the parent breadcrumb // We need to keep the previous breadcrumb around for a compound assignment.
// as well as the def that was about to be assigned onto this lvalue. However, we want to if (isCompoundAssignment)
// copy the metadata so that refinements can be consistent. {
BreadcrumbId updated = breadcrumbs->add(scope->lookup(g->name), defs->freshCell(), bc->metadata); if (auto def = scope->lookup(g->name))
graph.astBreadcrumbs[g] = updated; graph.compoundAssignBreadcrumbs[g] = *def;
}
// In order to avoid alias tracking, we need to clip the reference to the parent def.
DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astDefs[g] = updated;
scope->bindings[g->name] = updated; scope->bindings[g->name] = updated;
} }
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i, BreadcrumbId bc) void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef)
{ {
BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); DefId parentDef = visitExpr(scope, i->expr).def;
BreadcrumbId updated = breadcrumbs->add(scope->props[parentBreadcrumb->def][i->index.value], defs->freshCell(), bc->metadata); DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astBreadcrumbs[i] = updated; graph.astDefs[i] = updated;
scope->props[parentBreadcrumb->def][i->index.value] = updated; scope->props[parentDef][i->index.value] = updated;
} }
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, BreadcrumbId bc) void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef)
{ {
BreadcrumbId parentBreadcrumb = visitExpr(scope, i->expr); DefId parentDef = visitExpr(scope, i->expr).def;
visitExpr(scope, i->index); visitExpr(scope, i->index);
if (auto string = i->index->as<AstExprConstantString>()) if (auto string = i->index->as<AstExprConstantString>())
{ {
BreadcrumbId updated = breadcrumbs->add(scope->props[parentBreadcrumb->def][string->value.data], defs->freshCell(), bc->metadata); DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef));
graph.astBreadcrumbs[i] = updated; graph.astDefs[i] = updated;
scope->props[parentBreadcrumb->def][string->value.data] = updated; scope->props[parentDef][string->value.data] = updated;
} }
graph.astDefs[i] = defArena->freshCell();
} }
void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprError* error, BreadcrumbId bc) void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprError* error, DefId incomingDef)
{ {
visitExpr(scope, error); DefId def = visitExpr(scope, error).def;
graph.astDefs[error] = def;
} }
void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t) void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t)

View File

@ -1,12 +1,22 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Def.h" #include "Luau/Def.h"
#include "Luau/Common.h"
namespace Luau namespace Luau
{ {
DefId DefArena::freshCell() bool containsSubscriptedDefinition(DefId def)
{ {
return NotNull{allocator.allocate(Def{Cell{}})}; if (auto cell = get<Cell>(def))
return cell->subscripted;
LUAU_ASSERT(!"Phi nodes not implemented yet");
return false;
}
DefId DefArena::freshCell(bool subscripted)
{
return NotNull{allocator.allocate(Def{Cell{subscripted}})};
} }
} // namespace Luau } // namespace Luau

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/IostreamHelpers.h" #include "Luau/IostreamHelpers.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypePath.h"
namespace Luau namespace Luau
{ {
@ -236,4 +237,34 @@ std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv)
return stream << toString(tv); return stream << toString(tv);
} }
std::ostream& operator<<(std::ostream& stream, TypeId ty)
{
// we commonly use a null pointer when a type may not be present; we need to
// account for that here.
if (!ty)
return stream << "<nullptr>";
return stream << toString(ty);
}
std::ostream& operator<<(std::ostream& stream, TypePackId tp)
{
// we commonly use a null pointer when a type may not be present; we need to
// account for that here.
if (!tp)
return stream << "<nullptr>";
return stream << toString(tp);
}
namespace TypePath
{
std::ostream& operator<<(std::ostream& stream, const Path& path)
{
return stream << toString(path);
}
} // namespace TypePath
} // namespace Luau } // namespace Luau

View File

@ -325,6 +325,7 @@ struct NonStrictTypeChecker
return; return;
TypeId fnTy = *originalCallTy; TypeId fnTy = *originalCallTy;
// TODO: how should we link this to the passed in context here
NonStrictContext fresh{}; NonStrictContext fresh{};
if (auto fn = get<FunctionType>(follow(fnTy))) if (auto fn = get<FunctionType>(follow(fnTy)))
{ {
@ -351,28 +352,20 @@ struct NonStrictTypeChecker
// We will compare arg and ~number // We will compare arg and ~number
AstExpr* arg = call->args.data[i]; AstExpr* arg = call->args.data[i];
TypeId expectedArgType = argTypes[i]; TypeId expectedArgType = argTypes[i];
NullableBreadcrumbId bc = dfg->getBreadcrumb(arg); DefId def = dfg->getDef(arg);
// TODO: Cache negations created here!!! // TODO: Cache negations created here!!!
// See Jira Ticket: https://roblox.atlassian.net/browse/CLI-87539 // See Jira Ticket: https://roblox.atlassian.net/browse/CLI-87539
if (bc) TypeId runTimeErrorTy = arena.addType(NegationType{expectedArgType});
{ fresh.context[def.get()] = runTimeErrorTy;
TypeId runTimeErrorTy = arena.addType(NegationType{expectedArgType});
DefId def = bc->def;
fresh.context[def.get()] = runTimeErrorTy;
}
else
{
std::cout << "bad" << std::endl;
}
} }
// Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types // Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types
AstName name = getIdentifier(call->func);
for (size_t i = 0; i < call->args.size; i++) for (size_t i = 0; i < call->args.size; i++)
{ {
AstExpr* arg = call->args.data[i]; AstExpr* arg = call->args.data[i];
// TODO: pipe in name of checked function to report Error
if (auto runTimeFailureType = willRunTimeError(arg, fresh)) if (auto runTimeFailureType = willRunTimeError(arg, fresh))
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, "", i}, arg->location); reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, name.value, i}, arg->location);
} }
} }
} }
@ -401,25 +394,22 @@ struct NonStrictTypeChecker
// If this fragment of the ast will run time error, return the type that causes this // If this fragment of the ast will run time error, return the type that causes this
std::optional<TypeId> willRunTimeError(AstExpr* fragment, const NonStrictContext& context) std::optional<TypeId> willRunTimeError(AstExpr* fragment, const NonStrictContext& context)
{ {
DefId def = dfg->getDef(fragment);
if (NullableBreadcrumbId bc = dfg->getBreadcrumb(fragment)) if (std::optional<TypeId> contextTy = context.find(def))
{ {
std::optional<TypeId> contextTy = context.find(bc->def);
if (contextTy)
{
TypeId actualType = lookupType(fragment); TypeId actualType = lookupType(fragment);
SubtypingResult r = subtyping.isSubtype(actualType, *contextTy); SubtypingResult r = subtyping.isSubtype(actualType, *contextTy);
if (r.normalizationTooComplex) if (r.normalizationTooComplex)
reportError(NormalizationTooComplex{}, fragment->location); reportError(NormalizationTooComplex{}, fragment->location);
if (!r.isSubtype && !r.isErrorSuppressing) if (!r.isSubtype && !r.isErrorSuppressing)
reportError(TypeMismatch{actualType, *contextTy}, fragment->location); reportError(TypeMismatch{actualType, *contextTy}, fragment->location);
if (r.isSubtype) if (r.isSubtype)
return {actualType}; return {actualType};
}
} }
return {}; return {};
} }
}; };

View File

@ -22,6 +22,13 @@ LUAU_FASTFLAG(DebugLuauReadWriteProperties)
namespace Luau namespace Luau
{ {
TypeIds::TypeIds(std::initializer_list<TypeId> tys)
{
for (TypeId ty : tys)
insert(ty);
}
void TypeIds::insert(TypeId ty) void TypeIds::insert(TypeId ty)
{ {
ty = follow(ty); ty = follow(ty);

View File

@ -1,37 +1,60 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Refinement.h" #include "Luau/Refinement.h"
#include <algorithm>
namespace Luau namespace Luau
{ {
RefinementId RefinementArena::variadic(const std::vector<RefinementId>& refis) RefinementId RefinementArena::variadic(const std::vector<RefinementId>& refis)
{ {
bool hasRefinements = false;
for (RefinementId r : refis)
hasRefinements |= bool(r);
if (!hasRefinements)
return nullptr;
return NotNull{allocator.allocate(Variadic{refis})}; return NotNull{allocator.allocate(Variadic{refis})};
} }
RefinementId RefinementArena::negation(RefinementId refinement) RefinementId RefinementArena::negation(RefinementId refinement)
{ {
if (!refinement)
return nullptr;
return NotNull{allocator.allocate(Negation{refinement})}; return NotNull{allocator.allocate(Negation{refinement})};
} }
RefinementId RefinementArena::conjunction(RefinementId lhs, RefinementId rhs) RefinementId RefinementArena::conjunction(RefinementId lhs, RefinementId rhs)
{ {
if (!lhs && !rhs)
return nullptr;
return NotNull{allocator.allocate(Conjunction{lhs, rhs})}; return NotNull{allocator.allocate(Conjunction{lhs, rhs})};
} }
RefinementId RefinementArena::disjunction(RefinementId lhs, RefinementId rhs) RefinementId RefinementArena::disjunction(RefinementId lhs, RefinementId rhs)
{ {
if (!lhs && !rhs)
return nullptr;
return NotNull{allocator.allocate(Disjunction{lhs, rhs})}; return NotNull{allocator.allocate(Disjunction{lhs, rhs})};
} }
RefinementId RefinementArena::equivalence(RefinementId lhs, RefinementId rhs) RefinementId RefinementArena::equivalence(RefinementId lhs, RefinementId rhs)
{ {
if (!lhs && !rhs)
return nullptr;
return NotNull{allocator.allocate(Equivalence{lhs, rhs})}; return NotNull{allocator.allocate(Equivalence{lhs, rhs})};
} }
RefinementId RefinementArena::proposition(BreadcrumbId breadcrumb, TypeId discriminantTy) RefinementId RefinementArena::proposition(const RefinementKey* key, TypeId discriminantTy)
{ {
return NotNull{allocator.allocate(Proposition{breadcrumb, discriminantTy})}; if (!key)
return nullptr;
return NotNull{allocator.allocate(Proposition{key, discriminantTy})};
} }
} // namespace Luau } // namespace Luau

View File

@ -38,6 +38,23 @@ std::optional<TypeId> Scope::lookup(Symbol sym) const
return std::nullopt; return std::nullopt;
} }
std::optional<std::pair<TypeId, Scope*>> Scope::lookupEx(DefId def)
{
Scope* s = this;
while (true)
{
TypeId* it = s->lvalueTypes.find(def);
if (it)
return std::pair{*it, s};
if (s->parent)
s = s->parent.get();
else
return std::nullopt;
}
}
std::optional<std::pair<Binding*, Scope*>> Scope::lookupEx(Symbol sym) std::optional<std::pair<Binding*, Scope*>> Scope::lookupEx(Symbol sym)
{ {
Scope* s = this; Scope* s = this;

View File

@ -75,6 +75,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
clone.dcrMagicRefinement = a.dcrMagicRefinement; clone.dcrMagicRefinement = a.dcrMagicRefinement;
clone.tags = a.tags; clone.tags = a.tags;
clone.argNames = a.argNames; clone.argNames = a.argNames;
clone.isCheckedFunction = a.isCheckedFunction;
return dest.addType(std::move(clone)); return dest.addType(std::move(clone));
} }
else if constexpr (std::is_same_v<T, TableType>) else if constexpr (std::is_same_v<T, TableType>)

View File

@ -11,6 +11,7 @@
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypePath.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include <algorithm> #include <algorithm>
@ -44,8 +45,19 @@ struct VarianceFlipper
} }
}; };
bool SubtypingReasoning::operator==(const SubtypingReasoning& other) const
{
return subPath == other.subPath && superPath == other.superPath;
}
SubtypingResult& SubtypingResult::andAlso(const SubtypingResult& other) SubtypingResult& SubtypingResult::andAlso(const SubtypingResult& other)
{ {
// If this result is a subtype, we take the other result's reasoning. If
// this result is not a subtype, we keep the current reasoning, even if the
// other isn't a subtype.
if (isSubtype)
reasoning = other.reasoning;
isSubtype &= other.isSubtype; isSubtype &= other.isSubtype;
// `|=` is intentional here, we want to preserve error related flags. // `|=` is intentional here, we want to preserve error related flags.
isErrorSuppressing |= other.isErrorSuppressing; isErrorSuppressing |= other.isErrorSuppressing;
@ -57,6 +69,11 @@ SubtypingResult& SubtypingResult::andAlso(const SubtypingResult& other)
SubtypingResult& SubtypingResult::orElse(const SubtypingResult& other) SubtypingResult& SubtypingResult::orElse(const SubtypingResult& other)
{ {
// If the other result is not a subtype, we take the other result's
// reasoning.
if (!other.isSubtype)
reasoning = other.reasoning;
isSubtype |= other.isSubtype; isSubtype |= other.isSubtype;
isErrorSuppressing |= other.isErrorSuppressing; isErrorSuppressing |= other.isErrorSuppressing;
normalizationTooComplex |= other.normalizationTooComplex; normalizationTooComplex |= other.normalizationTooComplex;
@ -65,6 +82,56 @@ SubtypingResult& SubtypingResult::orElse(const SubtypingResult& other)
return *this; return *this;
} }
SubtypingResult& SubtypingResult::withBothComponent(TypePath::Component component)
{
return withSubComponent(component).withSuperComponent(component);
}
SubtypingResult& SubtypingResult::withSubComponent(TypePath::Component component)
{
if (!reasoning)
reasoning = SubtypingReasoning{Path(), Path()};
reasoning->subPath = reasoning->subPath.push_front(component);
return *this;
}
SubtypingResult& SubtypingResult::withSuperComponent(TypePath::Component component)
{
if (!reasoning)
reasoning = SubtypingReasoning{Path(), Path()};
reasoning->superPath = reasoning->superPath.push_front(component);
return *this;
}
SubtypingResult& SubtypingResult::withBothPath(TypePath::Path path)
{
return withSubPath(path).withSuperPath(path);
}
SubtypingResult& SubtypingResult::withSubPath(TypePath::Path path)
{
if (!reasoning)
reasoning = SubtypingReasoning{Path(), Path()};
reasoning->subPath = path.append(reasoning->subPath);
return *this;
}
SubtypingResult& SubtypingResult::withSuperPath(TypePath::Path path)
{
if (!reasoning)
reasoning = SubtypingReasoning{Path(), Path()};
reasoning->superPath = path.append(reasoning->superPath);
return *this;
}
SubtypingResult SubtypingResult::negate(const SubtypingResult& result) SubtypingResult SubtypingResult::negate(const SubtypingResult& result)
{ {
return SubtypingResult{ return SubtypingResult{
@ -287,7 +354,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
else if (get<ErrorType>(subTy)) else if (get<ErrorType>(subTy))
result = {false, true}; result = {false, true};
else if (auto p = get2<NegationType, NegationType>(subTy, superTy)) else if (auto p = get2<NegationType, NegationType>(subTy, superTy))
result = isCovariantWith(env, p.first->ty, p.second->ty); result = isCovariantWith(env, p.first->ty, p.second->ty).withBothComponent(TypePath::TypeField::Negated);
else if (auto subNegation = get<NegationType>(subTy)) else if (auto subNegation = get<NegationType>(subTy))
result = isCovariantWith(env, subNegation, superTy); result = isCovariantWith(env, subNegation, superTy);
else if (auto superNegation = get<NegationType>(superTy)) else if (auto superNegation = get<NegationType>(superTy))
@ -350,9 +417,9 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
for (size_t i = 0; i < headSize; ++i) for (size_t i = 0; i < headSize; ++i)
{ {
results.push_back(isCovariantWith(env, subHead[i], superHead[i])); results.push_back(isCovariantWith(env, subHead[i], superHead[i]).withBothComponent(TypePath::Index{i}));
if (!results.back().isSubtype) if (!results.back().isSubtype)
return {false}; return results.back();
} }
// Handle mismatched head sizes // Handle mismatched head sizes
@ -364,7 +431,9 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
if (auto vt = get<VariadicTypePack>(*subTail)) if (auto vt = get<VariadicTypePack>(*subTail))
{ {
for (size_t i = headSize; i < superHead.size(); ++i) for (size_t i = headSize; i < superHead.size(); ++i)
results.push_back(isCovariantWith(env, vt->ty, superHead[i])); results.push_back(isCovariantWith(env, vt->ty, superHead[i])
.withSubComponent(TypePath::TypeField::Variadic)
.withSuperComponent(TypePath::Index{i}));
} }
else if (auto gt = get<GenericTypePack>(*subTail)) else if (auto gt = get<GenericTypePack>(*subTail))
{ {
@ -379,7 +448,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail); TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail);
if (TypePackId* other = env.mappedGenericPacks.find(*subTail)) if (TypePackId* other = env.mappedGenericPacks.find(*subTail))
results.push_back(isCovariantWith(env, *other, superTailPack)); // TODO: TypePath can't express "slice of a pack + its tail".
results.push_back(isCovariantWith(env, *other, superTailPack).withSubComponent(TypePath::PackField::Tail));
else else
env.mappedGenericPacks.try_insert(*subTail, superTailPack); env.mappedGenericPacks.try_insert(*subTail, superTailPack);
@ -393,7 +463,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
// //
// (T) -> () </: <X>(X) -> () // (T) -> () </: <X>(X) -> ()
// //
return {false}; return SubtypingResult{false}.withSubComponent(TypePath::PackField::Tail);
} }
} }
else else
@ -409,7 +479,9 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
if (auto vt = get<VariadicTypePack>(*superTail)) if (auto vt = get<VariadicTypePack>(*superTail))
{ {
for (size_t i = headSize; i < subHead.size(); ++i) for (size_t i = headSize; i < subHead.size(); ++i)
results.push_back(isCovariantWith(env, subHead[i], vt->ty)); results.push_back(isCovariantWith(env, subHead[i], vt->ty)
.withSubComponent(TypePath::Index{i})
.withSuperComponent(TypePath::TypeField::Variadic));
} }
else if (auto gt = get<GenericTypePack>(*superTail)) else if (auto gt = get<GenericTypePack>(*superTail))
{ {
@ -424,7 +496,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail); TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail);
if (TypePackId* other = env.mappedGenericPacks.find(*superTail)) if (TypePackId* other = env.mappedGenericPacks.find(*superTail))
results.push_back(isCovariantWith(env, *other, subTailPack)); // TODO: TypePath can't express "slice of a pack + its tail".
results.push_back(isCovariantWith(env, *other, subTailPack).withSuperComponent(TypePath::PackField::Tail));
else else
env.mappedGenericPacks.try_insert(*superTail, subTailPack); env.mappedGenericPacks.try_insert(*superTail, subTailPack);
@ -437,7 +510,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
// For any non-generic type T: // For any non-generic type T:
// //
// () -> T </: <X...>() -> X... // () -> T </: <X...>() -> X...
return {false}; return SubtypingResult{false}.withSuperComponent(TypePath::PackField::Tail);
} }
} }
else else
@ -453,12 +526,14 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
{ {
if (auto p = get2<VariadicTypePack, VariadicTypePack>(*subTail, *superTail)) if (auto p = get2<VariadicTypePack, VariadicTypePack>(*subTail, *superTail))
{ {
results.push_back(isCovariantWith(env, p)); // Variadic component is added by the isCovariantWith
// implementation; no need to add it here.
results.push_back(isCovariantWith(env, p).withBothComponent(TypePath::PackField::Tail));
} }
else if (auto p = get2<GenericTypePack, GenericTypePack>(*subTail, *superTail)) else if (auto p = get2<GenericTypePack, GenericTypePack>(*subTail, *superTail))
{ {
bool ok = bindGeneric(env, *subTail, *superTail); bool ok = bindGeneric(env, *subTail, *superTail);
results.push_back({ok}); results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail));
} }
else if (get2<VariadicTypePack, GenericTypePack>(*subTail, *superTail)) else if (get2<VariadicTypePack, GenericTypePack>(*subTail, *superTail))
{ {
@ -466,12 +541,12 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
{ {
// <A...>(A...) -> number <: (...number) -> number // <A...>(A...) -> number <: (...number) -> number
bool ok = bindGeneric(env, *subTail, *superTail); bool ok = bindGeneric(env, *subTail, *superTail);
results.push_back({ok}); results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail));
} }
else else
{ {
// (number) -> ...number </: <A...>(number) -> A... // (number) -> ...number </: <A...>(number) -> A...
results.push_back({false}); results.push_back(SubtypingResult{false}.withBothComponent(TypePath::PackField::Tail));
} }
} }
else if (get2<GenericTypePack, VariadicTypePack>(*subTail, *superTail)) else if (get2<GenericTypePack, VariadicTypePack>(*subTail, *superTail))
@ -479,13 +554,13 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
if (variance == Variance::Contravariant) if (variance == Variance::Contravariant)
{ {
// (...number) -> number </: <A...>(A...) -> number // (...number) -> number </: <A...>(A...) -> number
results.push_back({false}); results.push_back(SubtypingResult{false}.withBothComponent(TypePath::PackField::Tail));
} }
else else
{ {
// <A...>() -> A... <: () -> ...number // <A...>() -> A... <: () -> ...number
bool ok = bindGeneric(env, *subTail, *superTail); bool ok = bindGeneric(env, *subTail, *superTail);
results.push_back({ok}); results.push_back(SubtypingResult{ok}.withBothComponent(TypePath::PackField::Tail));
} }
} }
else else
@ -496,12 +571,12 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
{ {
if (get<VariadicTypePack>(*subTail)) if (get<VariadicTypePack>(*subTail))
{ {
return {false}; return SubtypingResult{false}.withSubComponent(TypePath::PackField::Tail);
} }
else if (get<GenericTypePack>(*subTail)) else if (get<GenericTypePack>(*subTail))
{ {
bool ok = bindGeneric(env, *subTail, builtinTypes->emptyTypePack); bool ok = bindGeneric(env, *subTail, builtinTypes->emptyTypePack);
return {ok}; return SubtypingResult{ok}.withSubComponent(TypePath::PackField::Tail);
} }
else else
unexpected(*subTail); unexpected(*subTail);
@ -525,10 +600,10 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
if (variance == Variance::Contravariant) if (variance == Variance::Contravariant)
{ {
bool ok = bindGeneric(env, builtinTypes->emptyTypePack, *superTail); bool ok = bindGeneric(env, builtinTypes->emptyTypePack, *superTail);
results.push_back({ok}); results.push_back(SubtypingResult{ok}.withSuperComponent(TypePath::PackField::Tail));
} }
else else
results.push_back({false}); results.push_back(SubtypingResult{false}.withSuperComponent(TypePath::PackField::Tail));
} }
else else
iceReporter->ice("Subtyping test encountered the unexpected type pack: " + toString(*superTail)); iceReporter->ice("Subtyping test encountered the unexpected type pack: " + toString(*superTail));
@ -540,7 +615,15 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
template<typename SubTy, typename SuperTy> template<typename SubTy, typename SuperTy>
SubtypingResult Subtyping::isContravariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy) SubtypingResult Subtyping::isContravariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy)
{ {
return isCovariantWith(env, superTy, subTy); SubtypingResult result = isCovariantWith(env, superTy, subTy);
// If we don't swap the paths here, we will end up producing an invalid path
// whenever we involve contravariance. We'll end up appending path
// components that should belong to the supertype to the subtype, and vice
// versa.
if (result.reasoning)
std::swap(result.reasoning->subPath, result.reasoning->superPath);
return result;
} }
template<typename SubTy, typename SuperTy> template<typename SubTy, typename SuperTy>
@ -602,8 +685,9 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
{ {
// As per TAPL: T <: A | B iff T <: A || T <: B // As per TAPL: T <: A | B iff T <: A || T <: B
std::vector<SubtypingResult> subtypings; std::vector<SubtypingResult> subtypings;
size_t i = 0;
for (TypeId ty : superUnion) for (TypeId ty : superUnion)
subtypings.push_back(isCovariantWith(env, subTy, ty)); subtypings.push_back(isCovariantWith(env, subTy, ty).withSuperComponent(TypePath::Index{i++}));
return SubtypingResult::any(subtypings); return SubtypingResult::any(subtypings);
} }
@ -611,8 +695,9 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Unio
{ {
// As per TAPL: A | B <: T iff A <: T && B <: T // As per TAPL: A | B <: T iff A <: T && B <: T
std::vector<SubtypingResult> subtypings; std::vector<SubtypingResult> subtypings;
size_t i = 0;
for (TypeId ty : subUnion) for (TypeId ty : subUnion)
subtypings.push_back(isCovariantWith(env, ty, superTy)); subtypings.push_back(isCovariantWith(env, ty, superTy).withSubComponent(TypePath::Index{i++}));
return SubtypingResult::all(subtypings); return SubtypingResult::all(subtypings);
} }
@ -620,8 +705,9 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
{ {
// As per TAPL: T <: A & B iff T <: A && T <: B // As per TAPL: T <: A & B iff T <: A && T <: B
std::vector<SubtypingResult> subtypings; std::vector<SubtypingResult> subtypings;
size_t i = 0;
for (TypeId ty : superIntersection) for (TypeId ty : superIntersection)
subtypings.push_back(isCovariantWith(env, subTy, ty)); subtypings.push_back(isCovariantWith(env, subTy, ty).withSuperComponent(TypePath::Index{i++}));
return SubtypingResult::all(subtypings); return SubtypingResult::all(subtypings);
} }
@ -629,8 +715,9 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Inte
{ {
// As per TAPL: A & B <: T iff A <: T || B <: T // As per TAPL: A & B <: T iff A <: T || B <: T
std::vector<SubtypingResult> subtypings; std::vector<SubtypingResult> subtypings;
size_t i = 0;
for (TypeId ty : subIntersection) for (TypeId ty : subIntersection)
subtypings.push_back(isCovariantWith(env, ty, superTy)); subtypings.push_back(isCovariantWith(env, ty, superTy).withSubComponent(TypePath::Index{i++}));
return SubtypingResult::any(subtypings); return SubtypingResult::any(subtypings);
} }
@ -638,23 +725,25 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Nega
{ {
TypeId negatedTy = follow(subNegation->ty); TypeId negatedTy = follow(subNegation->ty);
SubtypingResult result;
// In order to follow a consistent codepath, rather than folding the // In order to follow a consistent codepath, rather than folding the
// isCovariantWith test down to its conclusion here, we test the subtyping test // isCovariantWith test down to its conclusion here, we test the subtyping test
// of the result of negating the type for never, unknown, any, and error. // of the result of negating the type for never, unknown, any, and error.
if (is<NeverType>(negatedTy)) if (is<NeverType>(negatedTy))
{ {
// ¬never ~ unknown // ¬never ~ unknown
return isCovariantWith(env, builtinTypes->unknownType, superTy); result = isCovariantWith(env, builtinTypes->unknownType, superTy);
} }
else if (is<UnknownType>(negatedTy)) else if (is<UnknownType>(negatedTy))
{ {
// ¬unknown ~ never // ¬unknown ~ never
return isCovariantWith(env, builtinTypes->neverType, superTy); result = isCovariantWith(env, builtinTypes->neverType, superTy);
} }
else if (is<AnyType>(negatedTy)) else if (is<AnyType>(negatedTy))
{ {
// ¬any ~ any // ¬any ~ any
return isCovariantWith(env, negatedTy, superTy); result = isCovariantWith(env, negatedTy, superTy);
} }
else if (auto u = get<UnionType>(negatedTy)) else if (auto u = get<UnionType>(negatedTy))
{ {
@ -668,7 +757,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Nega
subtypings.push_back(isCovariantWith(env, &negatedTmp, superTy)); subtypings.push_back(isCovariantWith(env, &negatedTmp, superTy));
} }
return SubtypingResult::all(subtypings); result = SubtypingResult::all(subtypings);
} }
else if (auto i = get<IntersectionType>(negatedTy)) else if (auto i = get<IntersectionType>(negatedTy))
{ {
@ -687,7 +776,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Nega
} }
} }
return SubtypingResult::any(subtypings); result = SubtypingResult::any(subtypings);
} }
else if (is<ErrorType, FunctionType, TableType, MetatableType>(negatedTy)) else if (is<ErrorType, FunctionType, TableType, MetatableType>(negatedTy))
{ {
@ -697,28 +786,32 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Nega
// subtype of other stuff. // subtype of other stuff.
else else
{ {
return {false}; result = {false};
} }
return result.withSubComponent(TypePath::TypeField::Negated);
} }
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TypeId subTy, const NegationType* superNegation) SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TypeId subTy, const NegationType* superNegation)
{ {
TypeId negatedTy = follow(superNegation->ty); TypeId negatedTy = follow(superNegation->ty);
SubtypingResult result;
if (is<NeverType>(negatedTy)) if (is<NeverType>(negatedTy))
{ {
// ¬never ~ unknown // ¬never ~ unknown
return isCovariantWith(env, subTy, builtinTypes->unknownType); result = isCovariantWith(env, subTy, builtinTypes->unknownType);
} }
else if (is<UnknownType>(negatedTy)) else if (is<UnknownType>(negatedTy))
{ {
// ¬unknown ~ never // ¬unknown ~ never
return isCovariantWith(env, subTy, builtinTypes->neverType); result = isCovariantWith(env, subTy, builtinTypes->neverType);
} }
else if (is<AnyType>(negatedTy)) else if (is<AnyType>(negatedTy))
{ {
// ¬any ~ any // ¬any ~ any
return isSubtype(subTy, negatedTy); result = isSubtype(subTy, negatedTy);
} }
else if (auto u = get<UnionType>(negatedTy)) else if (auto u = get<UnionType>(negatedTy))
{ {
@ -737,7 +830,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Type
} }
} }
return SubtypingResult::all(subtypings); result = SubtypingResult::all(subtypings);
} }
else if (auto i = get<IntersectionType>(negatedTy)) else if (auto i = get<IntersectionType>(negatedTy))
{ {
@ -756,53 +849,55 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Type
} }
} }
return SubtypingResult::any(subtypings); result = SubtypingResult::any(subtypings);
} }
else if (auto p = get2<PrimitiveType, PrimitiveType>(subTy, negatedTy)) else if (auto p = get2<PrimitiveType, PrimitiveType>(subTy, negatedTy))
{ {
// number <: ¬boolean // number <: ¬boolean
// number </: ¬number // number </: ¬number
return {p.first->type != p.second->type}; result = {p.first->type != p.second->type};
} }
else if (auto p = get2<SingletonType, PrimitiveType>(subTy, negatedTy)) else if (auto p = get2<SingletonType, PrimitiveType>(subTy, negatedTy))
{ {
// "foo" </: ¬string // "foo" </: ¬string
if (get<StringSingleton>(p.first) && p.second->type == PrimitiveType::String) if (get<StringSingleton>(p.first) && p.second->type == PrimitiveType::String)
return {false}; result = {false};
// false </: ¬boolean // false </: ¬boolean
else if (get<BooleanSingleton>(p.first) && p.second->type == PrimitiveType::Boolean) else if (get<BooleanSingleton>(p.first) && p.second->type == PrimitiveType::Boolean)
return {false}; result = {false};
// other cases are true // other cases are true
else else
return {true}; result = {true};
} }
else if (auto p = get2<PrimitiveType, SingletonType>(subTy, negatedTy)) else if (auto p = get2<PrimitiveType, SingletonType>(subTy, negatedTy))
{ {
if (p.first->type == PrimitiveType::String && get<StringSingleton>(p.second)) if (p.first->type == PrimitiveType::String && get<StringSingleton>(p.second))
return {false}; result = {false};
else if (p.first->type == PrimitiveType::Boolean && get<BooleanSingleton>(p.second)) else if (p.first->type == PrimitiveType::Boolean && get<BooleanSingleton>(p.second))
return {false}; result = {false};
else else
return {true}; result = {true};
} }
// the top class type is not actually a primitive type, so the negation of // the top class type is not actually a primitive type, so the negation of
// any one of them includes the top class type. // any one of them includes the top class type.
else if (auto p = get2<ClassType, PrimitiveType>(subTy, negatedTy)) else if (auto p = get2<ClassType, PrimitiveType>(subTy, negatedTy))
return {true}; result = {true};
else if (auto p = get<PrimitiveType>(negatedTy); p && is<TableType, MetatableType>(subTy)) else if (auto p = get<PrimitiveType>(negatedTy); p && is<TableType, MetatableType>(subTy))
return {p->type != PrimitiveType::Table}; result = {p->type != PrimitiveType::Table};
else if (auto p = get2<FunctionType, PrimitiveType>(subTy, negatedTy)) else if (auto p = get2<FunctionType, PrimitiveType>(subTy, negatedTy))
return {p.second->type != PrimitiveType::Function}; result = {p.second->type != PrimitiveType::Function};
else if (auto p = get2<SingletonType, SingletonType>(subTy, negatedTy)) else if (auto p = get2<SingletonType, SingletonType>(subTy, negatedTy))
return {*p.first != *p.second}; result = {*p.first != *p.second};
else if (auto p = get2<ClassType, ClassType>(subTy, negatedTy)) else if (auto p = get2<ClassType, ClassType>(subTy, negatedTy))
return SubtypingResult::negate(isCovariantWith(env, p.first, p.second)); result = SubtypingResult::negate(isCovariantWith(env, p.first, p.second));
else if (get2<FunctionType, ClassType>(subTy, negatedTy)) else if (get2<FunctionType, ClassType>(subTy, negatedTy))
return {true}; result = {true};
else if (is<ErrorType, FunctionType, TableType, MetatableType>(negatedTy)) else if (is<ErrorType, FunctionType, TableType, MetatableType>(negatedTy))
iceReporter->ice("attempting to negate a non-testable type"); iceReporter->ice("attempting to negate a non-testable type");
else
result = {false};
return {false}; return result.withSuperComponent(TypePath::TypeField::Negated);
} }
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const PrimitiveType* superPrim) SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const PrimitiveType* superPrim)
@ -836,16 +931,19 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Tabl
{ {
std::vector<SubtypingResult> results; std::vector<SubtypingResult> results;
if (auto it = subTable->props.find(name); it != subTable->props.end()) if (auto it = subTable->props.find(name); it != subTable->props.end())
results.push_back(isInvariantWith(env, it->second.type(), prop.type())); results.push_back(isInvariantWith(env, it->second.type(), prop.type())
.withBothComponent(TypePath::Property(name)));
if (subTable->indexer) if (subTable->indexer)
{ {
if (isInvariantWith(env, subTable->indexer->indexType, builtinTypes->stringType).isSubtype) if (isInvariantWith(env, subTable->indexer->indexType, builtinTypes->stringType).isSubtype)
results.push_back(isInvariantWith(env, subTable->indexer->indexResultType, prop.type())); results.push_back(isInvariantWith(env, subTable->indexer->indexResultType, prop.type())
.withSubComponent(TypePath::TypeField::IndexResult)
.withSuperComponent(TypePath::Property(name)));
} }
if (results.empty()) if (results.empty())
return {false}; return SubtypingResult{false};
result.andAlso(SubtypingResult::all(results)); result.andAlso(SubtypingResult::all(results));
} }
@ -863,7 +961,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Tabl
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt) SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt)
{ {
return isCovariantWith(env, subMt->table, superMt->table).andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable)); return isCovariantWith(env, subMt->table, superMt->table)
.andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable).withBothComponent(TypePath::TypeField::Metatable));
} }
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const TableType* superTable) SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const TableType* superTable)
@ -900,7 +999,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Clas
for (const auto& [name, prop] : superTable->props) for (const auto& [name, prop] : superTable->props)
{ {
if (auto classProp = lookupClassProp(subClass, name)) if (auto classProp = lookupClassProp(subClass, name))
result.andAlso(isInvariantWith(env, prop.type(), classProp->type())); result.andAlso(isInvariantWith(env, prop.type(), classProp->type()).withBothComponent(TypePath::Property(name)));
else else
return SubtypingResult{false}; return SubtypingResult{false};
} }
@ -913,10 +1012,10 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Func
SubtypingResult result; SubtypingResult result;
{ {
VarianceFlipper vf{&variance}; VarianceFlipper vf{&variance};
result.orElse(isContravariantWith(env, subFunction->argTypes, superFunction->argTypes)); result.orElse(isContravariantWith(env, subFunction->argTypes, superFunction->argTypes).withBothComponent(TypePath::PackField::Arguments));
} }
result.andAlso(isCovariantWith(env, subFunction->retTypes, superFunction->retTypes)); result.andAlso(isCovariantWith(env, subFunction->retTypes, superFunction->retTypes).withBothComponent(TypePath::PackField::Returns));
return result; return result;
} }
@ -933,7 +1032,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Prim
if (auto it = mttv->props.find("__index"); it != mttv->props.end()) if (auto it = mttv->props.find("__index"); it != mttv->props.end())
{ {
if (auto stringTable = get<TableType>(it->second.type())) if (auto stringTable = get<TableType>(it->second.type()))
result.orElse(isCovariantWith(env, stringTable, superTable)); result.orElse(
isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().prop("__index").build()));
} }
} }
} }
@ -954,7 +1054,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Sing
if (auto it = mttv->props.find("__index"); it != mttv->props.end()) if (auto it = mttv->props.find("__index"); it != mttv->props.end())
{ {
if (auto stringTable = get<TableType>(it->second.type())) if (auto stringTable = get<TableType>(it->second.type()))
result.orElse(isCovariantWith(env, stringTable, superTable)); result.orElse(
isCovariantWith(env, stringTable, superTable).withSubPath(TypePath::PathBuilder().mt().prop("__index").build()));
} }
} }
} }
@ -965,7 +1066,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Sing
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TableIndexer& subIndexer, const TableIndexer& superIndexer) SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const TableIndexer& subIndexer, const TableIndexer& superIndexer)
{ {
return isInvariantWith(env, subIndexer.indexType, superIndexer.indexType) return isInvariantWith(env, subIndexer.indexType, superIndexer.indexType)
.andAlso(isInvariantWith(env, superIndexer.indexResultType, subIndexer.indexResultType)); .withBothComponent(TypePath::TypeField::IndexLookup)
.andAlso(isInvariantWith(env, superIndexer.indexResultType, subIndexer.indexResultType).withBothComponent(TypePath::TypeField::IndexResult));
} }
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const NormalizedType* subNorm, const NormalizedType* superNorm) SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const NormalizedType* subNorm, const NormalizedType* superNorm)
@ -1092,11 +1194,12 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Type
{ {
std::vector<SubtypingResult> results; std::vector<SubtypingResult> results;
size_t i = 0;
for (TypeId subTy : subTypes) for (TypeId subTy : subTypes)
{ {
results.emplace_back(); results.emplace_back();
for (TypeId superTy : superTypes) for (TypeId superTy : superTypes)
results.back().orElse(isCovariantWith(env, subTy, superTy)); results.back().orElse(isCovariantWith(env, subTy, superTy).withBothComponent(TypePath::Index{i++}));
} }
return SubtypingResult::all(results); return SubtypingResult::all(results);
@ -1104,7 +1207,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Type
SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic) SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic)
{ {
return isCovariantWith(env, subVariadic->ty, superVariadic->ty); return isCovariantWith(env, subVariadic->ty, superVariadic->ty).withBothComponent(TypePath::TypeField::Variadic);
} }
bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId superTy) bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId superTy)

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/Common.h"
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
@ -10,6 +11,7 @@
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeFamily.h" #include "Luau/TypeFamily.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
#include "Luau/TypeOrPack.h"
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
@ -620,6 +622,12 @@ struct TypeStringifier
state.emit(">"); state.emit(">");
} }
if (FFlag::DebugLuauDeferredConstraintResolution)
{
if (ftv.isCheckedFunction)
state.emit("@checked ");
}
state.emit("("); state.emit("(");
if (state.opts.functionTypeArguments) if (state.opts.functionTypeArguments)
@ -1686,21 +1694,6 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
std::string superStr = tos(c.superType); std::string superStr = tos(c.superType);
return subStr + " ~ inst " + superStr; return subStr + " ~ inst " + superStr;
} }
else if constexpr (std::is_same_v<T, UnaryConstraint>)
{
std::string resultStr = tos(c.resultType);
std::string operandStr = tos(c.operandType);
return resultStr + " ~ Unary<" + toString(c.op) + ", " + operandStr + ">";
}
else if constexpr (std::is_same_v<T, BinaryConstraint>)
{
std::string resultStr = tos(c.resultType);
std::string leftStr = tos(c.leftType);
std::string rightStr = tos(c.rightType);
return resultStr + " ~ Binary<" + toString(c.op) + ", " + leftStr + ", " + rightStr + ">";
}
else if constexpr (std::is_same_v<T, IterableConstraint>) else if constexpr (std::is_same_v<T, IterableConstraint>)
{ {
std::string iteratorStr = tos(c.iterator); std::string iteratorStr = tos(c.iterator);
@ -1756,6 +1749,23 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
const char* op = c.mode == RefineConstraint::Union ? "union" : "intersect"; const char* op = c.mode == RefineConstraint::Union ? "union" : "intersect";
return tos(c.resultType) + " ~ refine " + tos(c.type) + " " + op + " " + tos(c.discriminant); return tos(c.resultType) + " ~ refine " + tos(c.type) + " " + op + " " + tos(c.discriminant);
} }
else if constexpr (std::is_same_v<T, SetOpConstraint>)
{
const char* op = c.mode == SetOpConstraint::Union ? " | " : " & ";
std::string res = tos(c.resultType) + " ~ ";
bool first = true;
for (TypeId t : c.types)
{
if (first)
first = false;
else
res += op;
res += tos(t);
}
return res;
}
else if constexpr (std::is_same_v<T, ReduceConstraint>) else if constexpr (std::is_same_v<T, ReduceConstraint>)
return "reduce " + tos(c.ty); return "reduce " + tos(c.ty);
else if constexpr (std::is_same_v<T, ReducePackConstraint>) else if constexpr (std::is_same_v<T, ReducePackConstraint>)
@ -1834,4 +1844,24 @@ std::string toString(const Location& location, int offset, bool useBegin)
} }
} }
std::string toString(const TypeOrPack& tyOrTp, ToStringOptions& opts)
{
if (const TypeId* ty = get<TypeId>(tyOrTp))
return toString(*ty, opts);
else if (const TypePackId* tp = get<TypePackId>(tyOrTp))
return toString(*tp, opts);
else
LUAU_UNREACHABLE();
}
std::string dump(const TypeOrPack& tyOrTp)
{
ToStringOptions opts;
opts.exhaustive = true;
opts.functionTypeArguments = true;
std::string s = toString(tyOrTp, opts);
printf("%s\n", s.c_str());
return s;
}
} // namespace Luau } // namespace Luau

View File

@ -15,9 +15,11 @@
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include "Luau/TypeFamily.h" #include "Luau/TypeFamily.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypePack.h"
#include "Luau/TypePath.h"
#include "Luau/TypeUtils.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
#include <algorithm> #include <algorithm>
@ -2395,7 +2397,27 @@ struct TypeChecker2
reportError(NormalizationTooComplex{}, location); reportError(NormalizationTooComplex{}, location);
if (!r.isSubtype && !r.isErrorSuppressing) if (!r.isSubtype && !r.isErrorSuppressing)
reportError(TypeMismatch{superTy, subTy}, location); {
if (r.reasoning)
{
std::optional<TypeOrPack> subLeaf = traverse(subTy, r.reasoning->subPath, builtinTypes);
std::optional<TypeOrPack> superLeaf = traverse(superTy, r.reasoning->superPath, builtinTypes);
if (!subLeaf || !superLeaf)
ice->ice("Subtyping test returned a reasoning with an invalid path", location);
if (!get2<TypeId, TypeId>(*subLeaf, *superLeaf) && !get2<TypePackId, TypePackId>(*subLeaf, *superLeaf))
ice->ice("Subtyping test returned a reasoning where one path ends at a type and the other ends at a pack.", location);
std::string reason = "type " + toString(subTy) + toString(r.reasoning->subPath) + " (" + toString(*subLeaf) +
") is not a subtype of " + toString(superTy) + toString(r.reasoning->superPath) + " (" + toString(*superLeaf) +
")";
reportError(TypeMismatch{superTy, subTy, reason}, location);
}
else
reportError(TypeMismatch{superTy, subTy}, location);
}
return r.isSubtype; return r.isSubtype;
} }

View File

@ -338,6 +338,150 @@ TypeFamilyReductionResult<TypeId> notFamilyFn(const std::vector<TypeId>& typePar
return {ctx->builtins->booleanType, false, {}, {}}; return {ctx->builtins->booleanType, false, {}, {}};
} }
TypeFamilyReductionResult<TypeId> lenFamilyFn(const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFamilyContext> ctx)
{
if (typeParams.size() != 1 || !packParams.empty())
{
ctx->ice->ice("len type family: encountered a type family instance without the required argument structure");
LUAU_ASSERT(false);
}
TypeId operandTy = follow(typeParams.at(0));
const NormalizedType* normTy = ctx->normalizer->normalize(operandTy);
// if the type failed to normalize, we can't reduce, but know nothing about inhabitance.
if (!normTy)
return {std::nullopt, false, {}, {}};
// if the operand type is error suppressing, we can immediately reduce to `number`.
if (normTy->shouldSuppressErrors())
return {ctx->builtins->numberType, false, {}, {}};
// if we have a `never`, we can never observe that the operator didn't work.
if (is<NeverType>(operandTy))
return {ctx->builtins->neverType, false, {}, {}};
// if we're checking the length of a string, that works!
if (normTy->isSubtypeOfString())
return {ctx->builtins->numberType, false, {}, {}};
// we use the normalized operand here in case there was an intersection or union.
TypeId normalizedOperand = ctx->normalizer->typeFromNormal(*normTy);
if (normTy->hasTopTable() || get<TableType>(normalizedOperand))
return {ctx->builtins->numberType, false, {}, {}};
// otherwise, we wait to see if the operand type is resolved
if (isPending(operandTy, ctx->solver))
return {std::nullopt, false, {operandTy}, {}};
// findMetatableEntry demands the ability to emit errors, so we must give it
// the necessary state to do that, even if we intend to just eat the errors.
ErrorVec dummy;
std::optional<TypeId> mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__len", Location{});
if (!mmType)
return {std::nullopt, true, {}, {}};
mmType = follow(*mmType);
if (isPending(*mmType, ctx->solver))
return {std::nullopt, false, {*mmType}, {}};
const FunctionType* mmFtv = get<FunctionType>(*mmType);
if (!mmFtv)
return {std::nullopt, true, {}, {}};
std::optional<TypeId> instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType);
if (!instantiatedMmType)
return {std::nullopt, true, {}, {}};
const FunctionType* instantiatedMmFtv = get<FunctionType>(*instantiatedMmType);
if (!instantiatedMmFtv)
return {ctx->builtins->errorRecoveryType(), false, {}, {}};
TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy});
Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice};
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice, ctx->scope};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}};
// `len` must return a `number`.
return {ctx->builtins->numberType, false, {}, {}};
}
TypeFamilyReductionResult<TypeId> unmFamilyFn(
const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFamilyContext> ctx)
{
if (typeParams.size() != 1 || !packParams.empty())
{
ctx->ice->ice("unm type family: encountered a type family instance without the required argument structure");
LUAU_ASSERT(false);
}
TypeId operandTy = follow(typeParams.at(0));
const NormalizedType* normTy = ctx->normalizer->normalize(operandTy);
// if the operand failed to normalize, we can't reduce, but know nothing about inhabitance.
if (!normTy)
return {std::nullopt, false, {}, {}};
// if the operand is error suppressing, we can just go ahead and reduce.
if (normTy->shouldSuppressErrors())
return {operandTy, false, {}, {}};
// if we have a `never`, we can never observe that the operation didn't work.
if (is<NeverType>(operandTy))
return {ctx->builtins->neverType, false, {}, {}};
// If the type is exactly `number`, we can reduce now.
if (normTy->isExactlyNumber())
return {ctx->builtins->numberType, false, {}, {}};
// otherwise, check if we need to wait on the type to be further resolved
if (isPending(operandTy, ctx->solver))
return {std::nullopt, false, {operandTy}, {}};
// findMetatableEntry demands the ability to emit errors, so we must give it
// the necessary state to do that, even if we intend to just eat the errors.
ErrorVec dummy;
std::optional<TypeId> mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__unm", Location{});
if (!mmType)
return {std::nullopt, true, {}, {}};
mmType = follow(*mmType);
if (isPending(*mmType, ctx->solver))
return {std::nullopt, false, {*mmType}, {}};
const FunctionType* mmFtv = get<FunctionType>(*mmType);
if (!mmFtv)
return {std::nullopt, true, {}, {}};
std::optional<TypeId> instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType);
if (!instantiatedMmType)
return {std::nullopt, true, {}, {}};
const FunctionType* instantiatedMmFtv = get<FunctionType>(*instantiatedMmType);
if (!instantiatedMmFtv)
return {ctx->builtins->errorRecoveryType(), false, {}, {}};
TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy});
Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice};
if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes))
return {std::nullopt, true, {}, {}}; // occurs check failed
Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice, ctx->scope};
if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes).isSubtype) // TODO: is this the right variance?
return {std::nullopt, true, {}, {}};
if (std::optional<TypeId> ret = first(instantiatedMmFtv->retTypes))
return {*ret, false, {}, {}};
else
return {std::nullopt, true, {}, {}};
}
TypeFamilyReductionResult<TypeId> numericBinopFamilyFn( TypeFamilyReductionResult<TypeId> numericBinopFamilyFn(
const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFamilyContext> ctx, const std::string metamethod) const std::vector<TypeId>& typeParams, const std::vector<TypePackId>& packParams, NotNull<TypeFamilyContext> ctx, const std::string metamethod)
{ {
@ -816,6 +960,8 @@ TypeFamilyReductionResult<TypeId> eqFamilyFn(const std::vector<TypeId>& typePara
BuiltinTypeFamilies::BuiltinTypeFamilies() BuiltinTypeFamilies::BuiltinTypeFamilies()
: notFamily{"not", notFamilyFn} : notFamily{"not", notFamilyFn}
, lenFamily{"len", lenFamilyFn}
, unmFamily{"unm", unmFamilyFn}
, addFamily{"add", addFamilyFn} , addFamily{"add", addFamilyFn}
, subFamily{"sub", subFamilyFn} , subFamily{"sub", subFamilyFn}
, mulFamily{"mul", mulFamilyFn} , mulFamily{"mul", mulFamilyFn}
@ -834,6 +980,14 @@ BuiltinTypeFamilies::BuiltinTypeFamilies()
void BuiltinTypeFamilies::addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const void BuiltinTypeFamilies::addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const
{ {
// make a type function for a one-argument type family
auto mkUnaryTypeFamily = [&](const TypeFamily* family) {
TypeId t = arena->addType(GenericType{"T"});
GenericTypeDefinition genericT{t};
return TypeFun{{genericT}, arena->addType(TypeFamilyInstanceType{NotNull{family}, {t}, {}})};
};
// make a type function for a two-argument type family // make a type function for a two-argument type family
auto mkBinaryTypeFamily = [&](const TypeFamily* family) { auto mkBinaryTypeFamily = [&](const TypeFamily* family) {
TypeId t = arena->addType(GenericType{"T"}); TypeId t = arena->addType(GenericType{"T"});
@ -844,6 +998,9 @@ void BuiltinTypeFamilies::addToScope(NotNull<TypeArena> arena, NotNull<Scope> sc
return TypeFun{{genericT, genericU}, arena->addType(TypeFamilyInstanceType{NotNull{family}, {t, u}, {}})}; return TypeFun{{genericT, genericU}, arena->addType(TypeFamilyInstanceType{NotNull{family}, {t, u}, {}})};
}; };
scope->exportedTypeBindings[lenFamily.name] = mkUnaryTypeFamily(&lenFamily);
scope->exportedTypeBindings[unmFamily.name] = mkUnaryTypeFamily(&unmFamily);
scope->exportedTypeBindings[addFamily.name] = mkBinaryTypeFamily(&addFamily); scope->exportedTypeBindings[addFamily.name] = mkBinaryTypeFamily(&addFamily);
scope->exportedTypeBindings[subFamily.name] = mkBinaryTypeFamily(&subFamily); scope->exportedTypeBindings[subFamily.name] = mkBinaryTypeFamily(&subFamily);
scope->exportedTypeBindings[mulFamily.name] = mkBinaryTypeFamily(&mulFamily); scope->exportedTypeBindings[mulFamily.name] = mkBinaryTypeFamily(&mulFamily);

View File

@ -0,0 +1,29 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeOrPack.h"
#include "Luau/Common.h"
namespace Luau
{
const void* ptr(TypeOrPack tyOrTp)
{
if (auto ty = get<TypeId>(tyOrTp))
return static_cast<const void*>(*ty);
else if (auto tp = get<TypePackId>(tyOrTp))
return static_cast<const void*>(*tp);
else
LUAU_UNREACHABLE();
}
TypeOrPack follow(TypeOrPack tyOrTp)
{
if (auto ty = get<TypeId>(tyOrTp))
return follow(*ty);
else if (auto tp = get<TypePackId>(tyOrTp))
return follow(*tp);
else
LUAU_UNREACHABLE();
}
} // namespace Luau

633
Analysis/src/TypePath.cpp Normal file
View File

@ -0,0 +1,633 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypePath.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include <optional>
#include <sstream>
#include <type_traits>
#include <unordered_set>
LUAU_FASTFLAG(DebugLuauReadWriteProperties);
// Maximum number of steps to follow when traversing a path. May not always
// equate to the number of components in a path, depending on the traversal
// logic.
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypePathMaximumTraverseSteps, 100);
namespace Luau
{
namespace TypePath
{
Property::Property(std::string name)
: name(std::move(name))
{
LUAU_ASSERT(!FFlag::DebugLuauReadWriteProperties);
}
Property Property::read(std::string name)
{
return Property(std::move(name), true);
}
Property Property::write(std::string name)
{
return Property(std::move(name), false);
}
bool Property::operator==(const Property& other) const
{
return name == other.name && isRead == other.isRead;
}
bool Index::operator==(const Index& other) const
{
return index == other.index;
}
Path Path::append(const Path& suffix) const
{
std::vector<Component> joined(components);
joined.reserve(suffix.components.size());
joined.insert(joined.end(), suffix.components.begin(), suffix.components.end());
return Path(std::move(joined));
}
Path Path::push(Component component) const
{
std::vector<Component> joined(components);
joined.push_back(component);
return Path(std::move(joined));
}
Path Path::push_front(Component component) const
{
std::vector<Component> joined{};
joined.reserve(components.size() + 1);
joined.push_back(std::move(component));
joined.insert(joined.end(), components.begin(), components.end());
return Path(std::move(joined));
}
Path Path::pop() const
{
if (empty())
return kEmpty;
std::vector<Component> popped(components);
popped.pop_back();
return Path(std::move(popped));
}
std::optional<Component> Path::last() const
{
if (empty())
return std::nullopt;
return components.back();
}
bool Path::empty() const
{
return components.empty();
}
bool Path::operator==(const Path& other) const
{
return components == other.components;
}
Path PathBuilder::build()
{
return Path(std::move(components));
}
PathBuilder& PathBuilder::readProp(std::string name)
{
LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties);
components.push_back(Property{std::move(name), true});
return *this;
}
PathBuilder& PathBuilder::writeProp(std::string name)
{
LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties);
components.push_back(Property{std::move(name), false});
return *this;
}
PathBuilder& PathBuilder::prop(std::string name)
{
LUAU_ASSERT(!FFlag::DebugLuauReadWriteProperties);
components.push_back(Property{std::move(name)});
return *this;
}
PathBuilder& PathBuilder::index(size_t i)
{
components.push_back(Index{i});
return *this;
}
PathBuilder& PathBuilder::mt()
{
components.push_back(TypeField::Metatable);
return *this;
}
PathBuilder& PathBuilder::lb()
{
components.push_back(TypeField::LowerBound);
return *this;
}
PathBuilder& PathBuilder::ub()
{
components.push_back(TypeField::UpperBound);
return *this;
}
PathBuilder& PathBuilder::indexKey()
{
components.push_back(TypeField::IndexLookup);
return *this;
}
PathBuilder& PathBuilder::indexValue()
{
components.push_back(TypeField::IndexResult);
return *this;
}
PathBuilder& PathBuilder::negated()
{
components.push_back(TypeField::Negated);
return *this;
}
PathBuilder& PathBuilder::variadic()
{
components.push_back(TypeField::Variadic);
return *this;
}
PathBuilder& PathBuilder::args()
{
components.push_back(PackField::Arguments);
return *this;
}
PathBuilder& PathBuilder::rets()
{
components.push_back(PackField::Returns);
return *this;
}
PathBuilder& PathBuilder::tail()
{
components.push_back(PackField::Tail);
return *this;
}
} // namespace TypePath
namespace
{
struct TraversalState
{
TraversalState(TypeId root, NotNull<BuiltinTypes> builtinTypes)
: current(root)
, builtinTypes(builtinTypes)
{
}
TraversalState(TypePackId root, NotNull<BuiltinTypes> builtinTypes)
: current(root)
, builtinTypes(builtinTypes)
{
}
TypeOrPack current;
NotNull<BuiltinTypes> builtinTypes;
DenseHashSet<const void*> seen{nullptr};
int steps = 0;
void updateCurrent(TypeId ty)
{
LUAU_ASSERT(ty);
current = follow(ty);
}
void updateCurrent(TypePackId tp)
{
LUAU_ASSERT(tp);
current = follow(tp);
}
bool haveCycle()
{
const void* currentPtr = ptr(current);
if (seen.contains(currentPtr))
return true;
else
seen.insert(currentPtr);
return false;
}
bool tooLong()
{
return ++steps > DFInt::LuauTypePathMaximumTraverseSteps;
}
bool checkInvariants()
{
return haveCycle() || tooLong();
}
bool traverse(const TypePath::Property& property)
{
auto currentType = get<TypeId>(current);
if (!currentType)
return false;
if (checkInvariants())
return false;
const Property* prop = nullptr;
if (auto t = get<TableType>(*currentType))
{
auto it = t->props.find(property.name);
if (it != t->props.end())
{
prop = &it->second;
}
}
else if (auto c = get<ClassType>(*currentType))
{
prop = lookupClassProp(c, property.name);
}
else if (auto m = getMetatable(*currentType, builtinTypes))
{
// Weird: rather than use findMetatableEntry, which requires a lot
// of stuff that we don't have and don't want to pull in, we use the
// path traversal logic to grab __index and then re-enter the lookup
// logic there.
updateCurrent(*m);
if (!traverse(TypePath::Property{"__index"}))
return false;
return traverse(property);
}
if (prop)
{
std::optional<TypeId> maybeType;
if (FFlag::DebugLuauReadWriteProperties)
maybeType = property.isRead ? prop->readType() : prop->writeType();
else
maybeType = prop->type();
if (maybeType)
{
updateCurrent(*maybeType);
return true;
}
}
return false;
}
bool traverse(const TypePath::Index& index)
{
if (checkInvariants())
return false;
if (auto currentType = get<TypeId>(current))
{
if (auto u = get<UnionType>(*currentType))
{
auto it = begin(u);
std::advance(it, index.index);
if (it != end(u))
{
updateCurrent(*it);
return true;
}
}
else if (auto i = get<IntersectionType>(*currentType))
{
auto it = begin(i);
std::advance(it, index.index);
if (it != end(i))
{
updateCurrent(*it);
return true;
}
}
}
else
{
auto currentPack = get<TypePackId>(current);
LUAU_ASSERT(currentPack);
if (get<TypePack>(*currentPack))
{
auto it = begin(*currentPack);
for (size_t i = 0; i < index.index && it != end(*currentPack); ++i)
++it;
if (it != end(*currentPack))
{
updateCurrent(*it);
return true;
}
}
}
return false;
}
bool traverse(TypePath::TypeField field)
{
if (checkInvariants())
return false;
switch (field)
{
case TypePath::TypeField::Metatable:
if (auto currentType = get<TypeId>(current))
{
if (std::optional<TypeId> mt = getMetatable(*currentType, builtinTypes))
{
updateCurrent(*mt);
return true;
}
}
return false;
case TypePath::TypeField::LowerBound:
case TypePath::TypeField::UpperBound:
if (auto ft = get<FreeType>(current))
{
updateCurrent(field == TypePath::TypeField::LowerBound ? ft->lowerBound : ft->upperBound);
return true;
}
return false;
case TypePath::TypeField::IndexLookup:
case TypePath::TypeField::IndexResult:
{
const TableIndexer* indexer = nullptr;
if (auto tt = get<TableType>(current); tt && tt->indexer)
indexer = &(*tt->indexer);
// Note: we don't appear to walk the class hierarchy for indexers
else if (auto ct = get<ClassType>(current); ct && ct->indexer)
indexer = &(*ct->indexer);
if (indexer)
{
updateCurrent(field == TypePath::TypeField::IndexLookup ? indexer->indexType : indexer->indexResultType);
return true;
}
return false;
}
case TypePath::TypeField::Negated:
if (auto nt = get<NegationType>(current))
{
updateCurrent(nt->ty);
return true;
}
return false;
case TypePath::TypeField::Variadic:
if (auto vtp = get<VariadicTypePack>(current))
{
updateCurrent(vtp->ty);
return true;
}
return false;
}
return false;
}
bool traverse(TypePath::PackField field)
{
if (checkInvariants())
return false;
switch (field)
{
case TypePath::PackField::Arguments:
case TypePath::PackField::Returns:
if (auto ft = get<FunctionType>(current))
{
updateCurrent(field == TypePath::PackField::Arguments ? ft->argTypes : ft->retTypes);
return true;
}
return false;
case TypePath::PackField::Tail:
if (auto currentPack = get<TypePackId>(current))
{
auto it = begin(*currentPack);
while (it != end(*currentPack))
++it;
if (auto tail = it.tail())
{
updateCurrent(*tail);
return true;
}
}
return false;
}
return false;
}
};
} // namespace
std::string toString(const TypePath::Path& path)
{
std::stringstream result;
bool first = true;
auto strComponent = [&](auto&& c) {
using T = std::decay_t<decltype(c)>;
if constexpr (std::is_same_v<T, TypePath::Property>)
{
result << '[';
if (FFlag::DebugLuauReadWriteProperties)
{
if (c.isRead)
result << "read ";
else
result << "write ";
}
result << '"' << c.name << '"' << ']';
}
else if constexpr (std::is_same_v<T, TypePath::Index>)
{
result << '[' << std::to_string(c.index) << ']';
}
else if constexpr (std::is_same_v<T, TypePath::TypeField>)
{
if (!first)
result << '.';
switch (c)
{
case TypePath::TypeField::Metatable:
result << "metatable";
break;
case TypePath::TypeField::LowerBound:
result << "lowerBound";
break;
case TypePath::TypeField::UpperBound:
result << "upperBound";
break;
case TypePath::TypeField::IndexLookup:
result << "indexer";
break;
case TypePath::TypeField::IndexResult:
result << "indexResult";
break;
case TypePath::TypeField::Negated:
result << "negated";
break;
case TypePath::TypeField::Variadic:
result << "variadic";
break;
}
result << "()";
}
else if constexpr (std::is_same_v<T, TypePath::PackField>)
{
if (!first)
result << '.';
switch (c)
{
case TypePath::PackField::Arguments:
result << "arguments";
break;
case TypePath::PackField::Returns:
result << "returns";
break;
case TypePath::PackField::Tail:
result << "tail";
break;
}
result << "()";
}
else
{
static_assert(always_false_v<T>, "Unhandled Component variant");
}
first = false;
};
for (const TypePath::Component& component : path.components)
Luau::visit(strComponent, component);
return result.str();
}
static bool traverse(TraversalState& state, const Path& path)
{
auto step = [&state](auto&& c) {
return state.traverse(c);
};
for (const TypePath::Component& component : path.components)
{
bool stepSuccess = visit(step, component);
if (!stepSuccess)
return false;
}
return true;
}
std::optional<TypeOrPack> traverse(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes)
{
TraversalState state(follow(root), builtinTypes);
if (traverse(state, path))
return state.current;
else
return std::nullopt;
}
std::optional<TypeOrPack> traverse(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
std::optional<TypeId> traverseForType(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes)
{
TraversalState state(follow(root), builtinTypes);
if (traverse(state, path))
{
auto ty = get<TypeId>(state.current);
return ty ? std::make_optional(*ty) : std::nullopt;
}
else
return std::nullopt;
}
std::optional<TypeId> traverseForType(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes)
{
TraversalState state(follow(root), builtinTypes);
if (traverse(state, path))
{
auto ty = get<TypeId>(state.current);
return ty ? std::make_optional(*ty) : std::nullopt;
}
else
return std::nullopt;
}
std::optional<TypePackId> traverseForPack(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes)
{
TraversalState state(follow(root), builtinTypes);
if (traverse(state, path))
{
auto ty = get<TypePackId>(state.current);
return ty ? std::make_optional(*ty) : std::nullopt;
}
else
return std::nullopt;
}
std::optional<TypePackId> traverseForPack(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes)
{
TraversalState state(follow(root), builtinTypes);
if (traverse(state, path))
{
auto ty = get<TypePackId>(state.current);
return ty ? std::make_optional(*ty) : std::nullopt;
}
else
return std::nullopt;
}
} // namespace Luau

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/Common.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"

View File

@ -33,6 +33,13 @@ enum class CompileFormat
Null Null
}; };
enum class RecordStats
{
None,
Total,
Split
};
struct GlobalOptions struct GlobalOptions
{ {
int optimizationLevel = 1; int optimizationLevel = 1;
@ -122,6 +129,57 @@ struct CompileStats
double codegenTime; double codegenTime;
Luau::CodeGen::LoweringStats lowerStats; Luau::CodeGen::LoweringStats lowerStats;
void serializeToJson(FILE* fp)
{
// use compact one-line formatting to reduce file length
fprintf(fp, "{\
\"lines\": %zu, \
\"bytecode\": %zu, \
\"codegen\": %zu, \
\"readTime\": %f, \
\"miscTime\": %f, \
\"parseTime\": %f, \
\"compileTime\": %f, \
\"codegenTime\": %f, \
\"lowerStats\": {\
\"totalFunctions\": %u, \
\"skippedFunctions\": %u, \
\"spillsToSlot\": %d, \
\"spillsToRestore\": %d, \
\"maxSpillSlotsUsed\": %u, \
\"blocksPreOpt\": %u, \
\"blocksPostOpt\": %u, \
\"maxBlockInstructions\": %u, \
\"regAllocErrors\": %d, \
\"loweringErrors\": %d\
}}",
lines, bytecode, codegen, readTime, miscTime, parseTime, compileTime, codegenTime, lowerStats.totalFunctions, lowerStats.skippedFunctions,
lowerStats.spillsToSlot, lowerStats.spillsToRestore, lowerStats.maxSpillSlotsUsed, lowerStats.blocksPreOpt, lowerStats.blocksPostOpt,
lowerStats.maxBlockInstructions, lowerStats.regAllocErrors, lowerStats.loweringErrors);
}
CompileStats& operator+=(const CompileStats& that)
{
this->lines += that.lines;
this->bytecode += that.bytecode;
this->codegen += that.codegen;
this->readTime += that.readTime;
this->miscTime += that.miscTime;
this->parseTime += that.parseTime;
this->compileTime += that.compileTime;
this->codegenTime += that.codegenTime;
this->lowerStats += that.lowerStats;
return *this;
}
CompileStats operator+(const CompileStats& other) const
{
CompileStats result(*this);
result += other;
return result;
}
}; };
static double recordDeltaTime(double& timer) static double recordDeltaTime(double& timer)
@ -254,6 +312,7 @@ static void displayHelp(const char* argv0)
printf(" -g<n>: compile with debug level n (default 1, n should be between 0 and 2).\n"); printf(" -g<n>: compile with debug level n (default 1, n should be between 0 and 2).\n");
printf(" --target=<target>: compile code for specific architecture (a64, x64, a64_nf, x64_ms).\n"); printf(" --target=<target>: compile code for specific architecture (a64, x64, a64_nf, x64_ms).\n");
printf(" --timetrace: record compiler time tracing information into trace.json\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n");
printf(" --record-stats=<style>: records compilation stats in stats.json (total, split).\n");
} }
static int assertionHandler(const char* expr, const char* file, int line, const char* function) static int assertionHandler(const char* expr, const char* file, int line, const char* function)
@ -270,6 +329,7 @@ int main(int argc, char** argv)
CompileFormat compileFormat = CompileFormat::Text; CompileFormat compileFormat = CompileFormat::Text;
Luau::CodeGen::AssemblyOptions::Target assemblyTarget = Luau::CodeGen::AssemblyOptions::Host; Luau::CodeGen::AssemblyOptions::Target assemblyTarget = Luau::CodeGen::AssemblyOptions::Host;
RecordStats recordStats = RecordStats::None;
for (int i = 1; i < argc; i++) for (int i = 1; i < argc; i++)
{ {
@ -320,6 +380,20 @@ int main(int argc, char** argv)
{ {
FFlag::DebugLuauTimeTracing.value = true; FFlag::DebugLuauTimeTracing.value = true;
} }
else if (strncmp(argv[i], "--record-stats=", 15) == 0)
{
const char* value = argv[i] + 15;
if (strcmp(value, "total") == 0)
recordStats = RecordStats::Total;
else if (strcmp(value, "split") == 0)
recordStats = RecordStats::Split;
else
{
fprintf(stderr, "Error: unknown 'style' for '--record-stats'\n");
return 1;
}
}
else if (strncmp(argv[i], "--fflags=", 9) == 0) else if (strncmp(argv[i], "--fflags=", 9) == 0)
{ {
setLuauFlags(argv[i] + 9); setLuauFlags(argv[i] + 9);
@ -351,11 +425,23 @@ int main(int argc, char** argv)
_setmode(_fileno(stdout), _O_BINARY); _setmode(_fileno(stdout), _O_BINARY);
#endif #endif
const size_t fileCount = files.size();
CompileStats stats = {}; CompileStats stats = {};
std::vector<CompileStats> fileStats;
if (recordStats == RecordStats::Split)
fileStats.reserve(fileCount);
int failed = 0; int failed = 0;
for (const std::string& path : files) for (const std::string& path : files)
failed += !compileFile(path.c_str(), compileFormat, assemblyTarget, stats); {
CompileStats fileStat = {};
failed += !compileFile(path.c_str(), compileFormat, assemblyTarget, fileStat);
stats += fileStat;
if (recordStats == RecordStats::Split)
fileStats.push_back(fileStat);
}
if (compileFormat == CompileFormat::Null) if (compileFormat == CompileFormat::Null)
{ {
@ -374,5 +460,35 @@ int main(int argc, char** argv)
stats.lowerStats.maxSpillSlotsUsed); stats.lowerStats.maxSpillSlotsUsed);
} }
if (recordStats != RecordStats::None)
{
FILE* fp = fopen("stats.json", "w");
if (!fp)
{
fprintf(stderr, "Unable to open 'stats.json'\n");
return 1;
}
if (recordStats == RecordStats::Total)
{
stats.serializeToJson(fp);
}
else if (recordStats == RecordStats::Split)
{
fprintf(fp, "{\n");
for (size_t i = 0; i < fileCount; ++i)
{
fprintf(fp, "\"%s\": ", files[i].c_str());
fileStats[i].serializeToJson(fp);
fprintf(fp, i == (fileCount - 1) ? "\n" : ",\n");
}
fprintf(fp, "}");
}
fclose(fp);
}
return failed ? 1 : 0; return failed ? 1 : 0;
} }

View File

@ -72,6 +72,7 @@ public:
void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void clz(RegisterA64 dst, RegisterA64 src); void clz(RegisterA64 dst, RegisterA64 src);
void rbit(RegisterA64 dst, RegisterA64 src); void rbit(RegisterA64 dst, RegisterA64 src);
void rev(RegisterA64 dst, RegisterA64 src);
// Shifts with immediates // Shifts with immediates
// Note: immediate value must be in [0, 31] or [0, 63] range based on register type // Note: immediate value must be in [0, 31] or [0, 63] range based on register type

View File

@ -106,6 +106,7 @@ public:
void bsr(RegisterX64 dst, OperandX64 src); void bsr(RegisterX64 dst, OperandX64 src);
void bsf(RegisterX64 dst, OperandX64 src); void bsf(RegisterX64 dst, OperandX64 src);
void bswap(RegisterX64 dst);
// Code alignment // Code alignment
void nop(uint32_t length = 1); void nop(uint32_t length = 1);

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include <algorithm>
#include <string> #include <string>
#include <stddef.h> #include <stddef.h>
@ -66,6 +67,8 @@ struct AssemblyOptions
Target target = Host; Target target = Host;
unsigned int flags = 0;
bool outputBinary = false; bool outputBinary = false;
bool includeAssembly = false; bool includeAssembly = false;
@ -79,12 +82,39 @@ struct AssemblyOptions
struct LoweringStats struct LoweringStats
{ {
unsigned totalFunctions = 0;
unsigned skippedFunctions = 0;
int spillsToSlot = 0; int spillsToSlot = 0;
int spillsToRestore = 0; int spillsToRestore = 0;
unsigned maxSpillSlotsUsed = 0; unsigned maxSpillSlotsUsed = 0;
unsigned blocksPreOpt = 0;
unsigned blocksPostOpt = 0;
unsigned maxBlockInstructions = 0;
int regAllocErrors = 0; int regAllocErrors = 0;
int loweringErrors = 0; int loweringErrors = 0;
LoweringStats operator+(const LoweringStats& other) const
{
LoweringStats result(*this);
result += other;
return result;
}
LoweringStats& operator+=(const LoweringStats& that)
{
this->totalFunctions += that.totalFunctions;
this->skippedFunctions += that.skippedFunctions;
this->spillsToSlot += that.spillsToSlot;
this->spillsToRestore += that.spillsToRestore;
this->maxSpillSlotsUsed = std::max(this->maxSpillSlotsUsed, that.maxSpillSlotsUsed);
this->blocksPreOpt += that.blocksPreOpt;
this->blocksPostOpt += that.blocksPostOpt;
this->maxBlockInstructions = std::max(this->maxBlockInstructions, that.maxBlockInstructions);
this->regAllocErrors += that.regAllocErrors;
this->loweringErrors += that.loweringErrors;
return *this;
}
}; };
// Generates assembly for target function and all inner functions // Generates assembly for target function and all inner functions

View File

@ -219,6 +219,14 @@ enum class IrCmd : uint8_t
// E: block (if false) // E: block (if false)
JUMP_CMP_NUM, JUMP_CMP_NUM,
// Perform jump based on a numerical loop condition (step > 0 ? idx <= limit : limit <= idx)
// A: double (index)
// B: double (limit)
// C: double (step)
// D: block (if true)
// E: block (if false)
JUMP_FORN_LOOP_COND,
// Perform a conditional jump based on cached table node slot matching the actual table node slot for a key // Perform a conditional jump based on cached table node slot matching the actual table node slot for a key
// A: pointer (LuaNode) // A: pointer (LuaNode)
// B: Kn // B: Kn

View File

@ -97,6 +97,7 @@ inline bool isBlockTerminator(IrCmd cmd)
case IrCmd::JUMP_CMP_INT: case IrCmd::JUMP_CMP_INT:
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:
case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_NUM:
case IrCmd::JUMP_FORN_LOOP_COND:
case IrCmd::JUMP_SLOT_MATCH: case IrCmd::JUMP_SLOT_MATCH:
case IrCmd::RETURN: case IrCmd::RETURN:
case IrCmd::FORGLOOP: case IrCmd::FORGLOOP:

View File

@ -254,6 +254,14 @@ void AssemblyBuilderA64::rbit(RegisterA64 dst, RegisterA64 src)
placeR1("rbit", dst, src, 0b10'11010110'00000'0000'00); placeR1("rbit", dst, src, 0b10'11010110'00000'0000'00);
} }
void AssemblyBuilderA64::rev(RegisterA64 dst, RegisterA64 src)
{
LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x);
LUAU_ASSERT(dst.kind == src.kind);
placeR1("rev", dst, src, 0b10'11010110'00000'0000'10 | int(dst.kind == KindA64::x));
}
void AssemblyBuilderA64::lsl(RegisterA64 dst, RegisterA64 src1, uint8_t src2) void AssemblyBuilderA64::lsl(RegisterA64 dst, RegisterA64 src1, uint8_t src2)
{ {
int size = dst.kind == KindA64::x ? 64 : 32; int size = dst.kind == KindA64::x ? 64 : 32;

View File

@ -541,6 +541,19 @@ void AssemblyBuilderX64::bsf(RegisterX64 dst, OperandX64 src)
commit(); commit();
} }
void AssemblyBuilderX64::bswap(RegisterX64 dst)
{
if (logText)
log("bswap", dst);
LUAU_ASSERT(dst.size == SizeX64::dword || dst.size == SizeX64::qword);
placeRex(dst);
place(0x0f);
place(OP_PLUS_REG(0xc8, dst.index));
commit();
}
void AssemblyBuilderX64::nop(uint32_t length) void AssemblyBuilderX64::nop(uint32_t length)
{ {
while (length != 0) while (length != 0)

View File

@ -42,6 +42,9 @@
LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false) LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false)
LUAU_FASTFLAGVARIABLE(DebugCodegenOptSize, false) LUAU_FASTFLAGVARIABLE(DebugCodegenOptSize, false)
LUAU_FASTFLAGVARIABLE(DebugCodegenSkipNumbering, false) LUAU_FASTFLAGVARIABLE(DebugCodegenSkipNumbering, false)
LUAU_FASTINTVARIABLE(CodegenHeuristicsInstructionLimit, 1'048'576) // 1 M
LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockLimit, 65'536) // 64 K
LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockInstructionLimit, 65'536) // 64 K
namespace Luau namespace Luau
{ {

View File

@ -45,11 +45,19 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto)
template<typename AssemblyBuilder> template<typename AssemblyBuilder>
static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, AssemblyOptions options, LoweringStats* stats) static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, AssemblyOptions options, LoweringStats* stats)
{ {
Proto* root = clvalue(func)->l.p;
if ((options.flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
return std::string();
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, clvalue(func)->l.p, /* flags= */ 0); gatherFunctions(protos, root, options.flags);
protos.erase(std::remove_if(protos.begin(), protos.end(), [](Proto* p) { return p == nullptr; }), protos.end()); protos.erase(std::remove_if(protos.begin(), protos.end(), [](Proto* p) { return p == nullptr; }), protos.end());
if (stats)
stats->totalFunctions += unsigned(protos.size());
if (protos.empty()) if (protos.empty())
{ {
build.finalize(); // to avoid assertion in AssemblyBuilder dtor build.finalize(); // to avoid assertion in AssemblyBuilder dtor
@ -77,6 +85,9 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
{ {
if (build.logText) if (build.logText)
build.logAppend("; skipping (can't lower)\n"); build.logAppend("; skipping (can't lower)\n");
if (stats)
stats->skippedFunctions += 1;
} }
if (build.logText) if (build.logText)

View File

@ -23,6 +23,9 @@
LUAU_FASTFLAG(DebugCodegenNoOpt) LUAU_FASTFLAG(DebugCodegenNoOpt)
LUAU_FASTFLAG(DebugCodegenOptSize) LUAU_FASTFLAG(DebugCodegenOptSize)
LUAU_FASTFLAG(DebugCodegenSkipNumbering) LUAU_FASTFLAG(DebugCodegenSkipNumbering)
LUAU_FASTINT(CodegenHeuristicsInstructionLimit)
LUAU_FASTINT(CodegenHeuristicsBlockLimit)
LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit)
namespace Luau namespace Luau
{ {
@ -222,8 +225,41 @@ inline bool lowerIr(A64::AssemblyBuilderA64& build, IrBuilder& ir, const std::ve
template<typename AssemblyBuilder> template<typename AssemblyBuilder>
inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats) inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats)
{ {
helpers.bytecodeInstructionCount += unsigned(ir.function.instructions.size());
if (helpers.bytecodeInstructionCount >= unsigned(FInt::CodegenHeuristicsInstructionLimit.value))
return false;
killUnusedBlocks(ir.function); killUnusedBlocks(ir.function);
unsigned preOptBlockCount = 0;
unsigned maxBlockInstructions = 0;
for (const IrBlock& block : ir.function.blocks)
{
preOptBlockCount += (block.kind != IrBlockKind::Dead);
unsigned blockInstructions = block.finish - block.start;
maxBlockInstructions = std::max(maxBlockInstructions, blockInstructions);
};
helpers.preOptBlockCount += preOptBlockCount;
// we update stats before checking the heuristic so that even if we bail out
// our stats include information about the limit that was exceeded.
if (stats)
{
stats->blocksPreOpt += preOptBlockCount;
stats->maxBlockInstructions = maxBlockInstructions;
}
// we use helpers.blocksPreOpt instead of stats.blocksPreOpt since
// stats can be null across some code paths.
if (helpers.preOptBlockCount >= unsigned(FInt::CodegenHeuristicsBlockLimit.value))
return false;
if (maxBlockInstructions >= unsigned(FInt::CodegenHeuristicsBlockInstructionLimit.value))
return false;
computeCfgInfo(ir.function); computeCfgInfo(ir.function);
if (!FFlag::DebugCodegenNoOpt) if (!FFlag::DebugCodegenNoOpt)
@ -241,6 +277,15 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers&
// In order to allocate registers during lowering, we need to know where instruction results are last used // In order to allocate registers during lowering, we need to know where instruction results are last used
updateLastUseLocations(ir.function, sortedBlocks); updateLastUseLocations(ir.function, sortedBlocks);
if (stats)
{
for (const IrBlock& block : ir.function.blocks)
{
if (block.kind != IrBlockKind::Dead)
++stats->blocksPostOpt;
}
}
return lowerIr(build, ir, sortedBlocks, helpers, proto, options, stats); return lowerIr(build, ir, sortedBlocks, helpers, proto, options, stats);
} }

View File

@ -31,6 +31,9 @@ struct ModuleHelpers
// A64 // A64
Label continueCall; // x0: closure Label continueCall; // x0: closure
unsigned bytecodeInstructionCount = 0;
unsigned preOptBlockCount = 0;
}; };
} // namespace CodeGen } // namespace CodeGen

View File

@ -163,6 +163,8 @@ const char* getCmdName(IrCmd cmd)
return "JUMP_EQ_POINTER"; return "JUMP_EQ_POINTER";
case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_NUM:
return "JUMP_CMP_NUM"; return "JUMP_CMP_NUM";
case IrCmd::JUMP_FORN_LOOP_COND:
return "JUMP_FORN_LOOP_COND";
case IrCmd::JUMP_SLOT_MATCH: case IrCmd::JUMP_SLOT_MATCH:
return "JUMP_SLOT_MATCH"; return "JUMP_SLOT_MATCH";
case IrCmd::TABLE_LEN: case IrCmd::TABLE_LEN:

View File

@ -812,6 +812,30 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
jumpOrFallthrough(blockOp(inst.e), next); jumpOrFallthrough(blockOp(inst.e), next);
break; break;
} }
case IrCmd::JUMP_FORN_LOOP_COND:
{
RegisterA64 index = tempDouble(inst.a);
RegisterA64 limit = tempDouble(inst.b);
Label direct;
// step > 0
build.fcmpz(tempDouble(inst.c));
build.b(getConditionFP(IrCondition::Greater), direct);
// !(limit <= index)
build.fcmp(limit, index);
build.b(getConditionFP(IrCondition::NotLessEqual), labelOp(inst.e));
build.b(labelOp(inst.d));
// !(index <= limit)
build.setLabel(direct);
build.fcmp(index, limit);
build.b(getConditionFP(IrCondition::NotLessEqual), labelOp(inst.e));
jumpOrFallthrough(blockOp(inst.d), next);
break;
}
// IrCmd::JUMP_SLOT_MATCH implemented below // IrCmd::JUMP_SLOT_MATCH implemented below
case IrCmd::TABLE_LEN: case IrCmd::TABLE_LEN:
{ {

View File

@ -699,6 +699,36 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
jumpOrFallthrough(blockOp(inst.e), next); jumpOrFallthrough(blockOp(inst.e), next);
break; break;
} }
case IrCmd::JUMP_FORN_LOOP_COND:
{
ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
ScopedRegX64 tmp3{regs, SizeX64::xmmword};
RegisterX64 index = inst.a.kind == IrOpKind::Inst ? regOp(inst.a) : tmp1.reg;
RegisterX64 limit = inst.b.kind == IrOpKind::Inst ? regOp(inst.b) : tmp2.reg;
if (inst.a.kind != IrOpKind::Inst)
build.vmovsd(tmp1.reg, memRegDoubleOp(inst.a));
if (inst.b.kind != IrOpKind::Inst)
build.vmovsd(tmp2.reg, memRegDoubleOp(inst.b));
Label direct;
// step > 0
jumpOnNumberCmp(build, tmp3.reg, memRegDoubleOp(inst.c), build.f64(0.0), IrCondition::Greater, direct);
// !(limit <= index)
jumpOnNumberCmp(build, noreg, limit, index, IrCondition::NotLessEqual, labelOp(inst.e));
build.jmp(labelOp(inst.d));
// !(index <= limit)
build.setLabel(direct);
jumpOnNumberCmp(build, noreg, index, limit, IrCondition::NotLessEqual, labelOp(inst.e));
jumpOrFallthrough(blockOp(inst.d), next);
break;
}
case IrCmd::TABLE_LEN: case IrCmd::TABLE_LEN:
{ {
IrCallWrapperX64 callWrap(regs, build, index); IrCallWrapperX64 callWrap(regs, build, index);

View File

@ -15,6 +15,7 @@
LUAU_FASTFLAGVARIABLE(LuauImproveForN2, false) LUAU_FASTFLAGVARIABLE(LuauImproveForN2, false)
LUAU_FASTFLAG(LuauReduceStackSpills) LUAU_FASTFLAG(LuauReduceStackSpills)
LUAU_FASTFLAGVARIABLE(LuauInlineArrConstOffset, false) LUAU_FASTFLAGVARIABLE(LuauInlineArrConstOffset, false)
LUAU_FASTFLAGVARIABLE(LuauLowerAltLoopForn, false)
namespace Luau namespace Luau
{ {
@ -680,26 +681,35 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1));
build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos));
IrOp direct = build.block(IrBlockKind::Internal); if (FFlag::LuauLowerAltLoopForn)
IrOp reverse = build.block(IrBlockKind::Internal); {
IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1));
IrOp zero = build.constDouble(0.0); build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopStart, loopExit);
IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); }
else
{
IrOp direct = build.block(IrBlockKind::Internal);
IrOp reverse = build.block(IrBlockKind::Internal);
// step > 0 IrOp zero = build.constDouble(0.0);
// note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1));
build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse);
// Condition to start the loop: step > 0 ? idx <= limit : limit <= idx // step > 0
// We invert the condition so that loopStart is the fallthrough (false) label // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64
build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse);
// step > 0 is false, check limit <= idx // Condition to start the loop: step > 0 ? idx <= limit : limit <= idx
build.beginBlock(reverse); // We invert the condition so that loopStart is the fallthrough (false) label
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart);
// step > 0 is true, check idx <= limit // step > 0 is false, check limit <= idx
build.beginBlock(direct); build.beginBlock(reverse);
build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart);
// step > 0 is true, check idx <= limit
build.beginBlock(direct);
build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart);
}
} }
else else
{ {
@ -713,6 +723,24 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart);
} }
} }
else if (FFlag::LuauLowerAltLoopForn)
{
// When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails
// Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant
// To avoid that overhead for an extreemely rare case (that doesn't even typecheck), we exit to VM to handle it
IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0));
build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos));
IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1));
build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos));
IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2));
build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos));
IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0));
IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1));
IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2));
build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopStart, loopExit);
}
else else
{ {
IrOp direct = build.block(IrBlockKind::Internal); IrOp direct = build.block(IrBlockKind::Internal);
@ -770,7 +798,6 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos)
LUAU_ASSERT(!build.loopStepStack.empty()); LUAU_ASSERT(!build.loopStepStack.empty());
IrOp stepK = build.loopStepStack.back(); IrOp stepK = build.loopStepStack.back();
IrOp zero = build.constDouble(0.0);
IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0));
IrOp step = stepK.kind == IrOpKind::Undef ? build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)) : stepK; IrOp step = stepK.kind == IrOpKind::Undef ? build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)) : stepK;
@ -780,22 +807,31 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos)
if (stepK.kind == IrOpKind::Undef) if (stepK.kind == IrOpKind::Undef)
{ {
IrOp direct = build.block(IrBlockKind::Internal); if (FFlag::LuauLowerAltLoopForn)
IrOp reverse = build.block(IrBlockKind::Internal); {
build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopRepeat, loopExit);
}
else
{
IrOp direct = build.block(IrBlockKind::Internal);
IrOp reverse = build.block(IrBlockKind::Internal);
// step > 0 IrOp zero = build.constDouble(0.0);
// note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64
build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse);
// Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx // step > 0
// note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64
build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse);
// step > 0 is false, check limit <= idx // Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx
build.beginBlock(reverse);
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
// step > 0 is true, check idx <= limit // step > 0 is false, check limit <= idx
build.beginBlock(direct); build.beginBlock(reverse);
build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
// step > 0 is true, check idx <= limit
build.beginBlock(direct);
build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
}
} }
else else
{ {
@ -808,6 +844,19 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit);
} }
} }
else if (FFlag::LuauLowerAltLoopForn)
{
build.inst(IrCmd::INTERRUPT, build.constUint(pcpos));
IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0));
IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1));
IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2));
idx = build.inst(IrCmd::ADD_NUM, idx, step);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx);
build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopRepeat, loopExit);
}
else else
{ {
build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); build.inst(IrCmd::INTERRUPT, build.constUint(pcpos));

View File

@ -75,6 +75,7 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::JUMP_CMP_INT: case IrCmd::JUMP_CMP_INT:
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:
case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_NUM:
case IrCmd::JUMP_FORN_LOOP_COND:
case IrCmd::JUMP_SLOT_MATCH: case IrCmd::JUMP_SLOT_MATCH:
return IrValueKind::None; return IrValueKind::None;
case IrCmd::TABLE_LEN: case IrCmd::TABLE_LEN:

View File

@ -19,6 +19,7 @@ LUAU_FASTFLAGVARIABLE(LuauReuseHashSlots2, false)
LUAU_FASTFLAGVARIABLE(LuauKeepVmapLinear, false) LUAU_FASTFLAGVARIABLE(LuauKeepVmapLinear, false)
LUAU_FASTFLAGVARIABLE(LuauMergeTagLoads, false) LUAU_FASTFLAGVARIABLE(LuauMergeTagLoads, false)
LUAU_FASTFLAGVARIABLE(LuauReuseArrSlots, false) LUAU_FASTFLAGVARIABLE(LuauReuseArrSlots, false)
LUAU_FASTFLAG(LuauLowerAltLoopForn)
namespace Luau namespace Luau
{ {
@ -782,6 +783,46 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
} }
break; break;
} }
case IrCmd::JUMP_FORN_LOOP_COND:
{
std::optional<double> step = function.asDoubleOp(inst.c.kind == IrOpKind::Constant ? inst.c : state.tryGetValue(inst.c));
if (!step)
break;
std::optional<double> idx = function.asDoubleOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a));
std::optional<double> limit = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b));
if (*step > 0)
{
if (idx && limit)
{
if (compare(*idx, *limit, IrCondition::NotLessEqual))
replace(function, block, index, {IrCmd::JUMP, inst.e});
else
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
else
{
replace(function, block, index, IrInst{IrCmd::JUMP_CMP_NUM, inst.a, inst.b, build.cond(IrCondition::NotLessEqual), inst.e, inst.d});
}
}
else
{
if (idx && limit)
{
if (compare(*limit, *idx, IrCondition::NotLessEqual))
replace(function, block, index, {IrCmd::JUMP, inst.e});
else
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
else
{
replace(function, block, index, IrInst{IrCmd::JUMP_CMP_NUM, inst.b, inst.a, build.cond(IrCondition::NotLessEqual), inst.e, inst.d});
}
}
break;
}
case IrCmd::GET_UPVALUE: case IrCmd::GET_UPVALUE:
state.invalidate(inst.a); state.invalidate(inst.a);
break; break;
@ -1282,6 +1323,9 @@ static void constPropInBlockChain(IrBuilder& build, std::vector<uint8_t>& visite
if (target.useCount == 1 && !visited[targetIdx] && target.kind != IrBlockKind::Fallback) if (target.useCount == 1 && !visited[targetIdx] && target.kind != IrBlockKind::Fallback)
{ {
if (FFlag::LuauLowerAltLoopForn && getLiveOutValueCount(function, target) != 0)
break;
// Make sure block ordering guarantee is checked at lowering time // Make sure block ordering guarantee is checked at lowering time
block->expectedNextBlock = function.getBlockIndex(target); block->expectedNextBlock = function.getBlockIndex(target);

View File

@ -33,6 +33,7 @@ LUAU_FASTFLAG(LuauFloorDivision)
LUAU_FASTFLAGVARIABLE(LuauCompileFixContinueValidation2, false) LUAU_FASTFLAGVARIABLE(LuauCompileFixContinueValidation2, false)
LUAU_FASTFLAGVARIABLE(LuauCompileContinueCloseUpvals, false) LUAU_FASTFLAGVARIABLE(LuauCompileContinueCloseUpvals, false)
LUAU_FASTFLAGVARIABLE(LuauCompileIfElseAndOr, false)
namespace Luau namespace Luau
{ {
@ -1569,6 +1570,23 @@ struct Compiler
} }
} }
void compileExprIfElseAndOr(bool and_, uint8_t creg, AstExpr* other, uint8_t target)
{
int32_t cid = getConstantIndex(other);
if (cid >= 0 && cid <= 255)
{
bytecode.emitABC(and_ ? LOP_ANDK : LOP_ORK, target, creg, uint8_t(cid));
}
else
{
RegScope rs(this);
uint8_t oreg = compileExprAuto(other, rs);
bytecode.emitABC(and_ ? LOP_AND : LOP_OR, target, creg, oreg);
}
}
void compileExprIfElse(AstExprIfElse* expr, uint8_t target, bool targetTemp) void compileExprIfElse(AstExprIfElse* expr, uint8_t target, bool targetTemp)
{ {
if (isConstant(expr->condition)) if (isConstant(expr->condition))
@ -1584,6 +1602,20 @@ struct Compiler
} }
else else
{ {
if (FFlag::LuauCompileIfElseAndOr)
{
// Optimization: convert some if..then..else expressions into and/or when the other side has no side effects and is very cheap to compute
// if v then v else e => v or e
// if v then e else v => v and e
if (int creg = getExprLocalReg(expr->condition); creg >= 0)
{
if (creg == getExprLocalReg(expr->trueExpr) && (getExprLocalReg(expr->falseExpr) >= 0 || isConstant(expr->falseExpr)))
return compileExprIfElseAndOr(/* and_= */ false, uint8_t(creg), expr->falseExpr, target);
else if (creg == getExprLocalReg(expr->falseExpr) && (getExprLocalReg(expr->trueExpr) >= 0 || isConstant(expr->trueExpr)))
return compileExprIfElseAndOr(/* and_= */ true, uint8_t(creg), expr->trueExpr, target);
}
}
std::vector<size_t> elseJump; std::vector<size_t> elseJump;
compileConditionValue(expr->condition, nullptr, elseJump, false); compileConditionValue(expr->condition, nullptr, elseJump, false);
compileExpr(expr->trueExpr, target, targetTemp); compileExpr(expr->trueExpr, target, targetTemp);

View File

@ -152,7 +152,6 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstJsonEncoder.h
Analysis/include/Luau/AstQuery.h Analysis/include/Luau/AstQuery.h
Analysis/include/Luau/Autocomplete.h Analysis/include/Luau/Autocomplete.h
Analysis/include/Luau/Breadcrumb.h
Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/BuiltinDefinitions.h
Analysis/include/Luau/Cancellation.h Analysis/include/Luau/Cancellation.h
Analysis/include/Luau/Clone.h Analysis/include/Luau/Clone.h
@ -196,15 +195,18 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Transpiler.h Analysis/include/Luau/Transpiler.h
Analysis/include/Luau/TxnLog.h Analysis/include/Luau/TxnLog.h
Analysis/include/Luau/Type.h Analysis/include/Luau/Type.h
Analysis/include/Luau/TypePairHash.h
Analysis/include/Luau/TypeArena.h Analysis/include/Luau/TypeArena.h
Analysis/include/Luau/TypeAttach.h Analysis/include/Luau/TypeAttach.h
Analysis/include/Luau/TypeChecker2.h Analysis/include/Luau/TypeChecker2.h
Analysis/include/Luau/TypeCheckLimits.h Analysis/include/Luau/TypeCheckLimits.h
Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypedAllocator.h
Analysis/include/Luau/TypeFamily.h Analysis/include/Luau/TypeFamily.h
Analysis/include/Luau/TypeFwd.h
Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypeInfer.h
Analysis/include/Luau/TypeOrPack.h
Analysis/include/Luau/TypePack.h Analysis/include/Luau/TypePack.h
Analysis/include/Luau/TypePairHash.h
Analysis/include/Luau/TypePath.h
Analysis/include/Luau/TypeUtils.h Analysis/include/Luau/TypeUtils.h
Analysis/include/Luau/Unifiable.h Analysis/include/Luau/Unifiable.h
Analysis/include/Luau/Unifier.h Analysis/include/Luau/Unifier.h
@ -259,7 +261,9 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/TypedAllocator.cpp Analysis/src/TypedAllocator.cpp
Analysis/src/TypeFamily.cpp Analysis/src/TypeFamily.cpp
Analysis/src/TypeInfer.cpp Analysis/src/TypeInfer.cpp
Analysis/src/TypeOrPack.cpp
Analysis/src/TypePack.cpp Analysis/src/TypePack.cpp
Analysis/src/TypePath.cpp
Analysis/src/TypeUtils.cpp Analysis/src/TypeUtils.cpp
Analysis/src/Unifiable.cpp Analysis/src/Unifiable.cpp
Analysis/src/Unifier.cpp Analysis/src/Unifier.cpp
@ -276,6 +280,7 @@ target_sources(Luau.VM PRIVATE
VM/src/laux.cpp VM/src/laux.cpp
VM/src/lbaselib.cpp VM/src/lbaselib.cpp
VM/src/lbitlib.cpp VM/src/lbitlib.cpp
VM/src/lbuffer.cpp
VM/src/lbuiltins.cpp VM/src/lbuiltins.cpp
VM/src/lcorolib.cpp VM/src/lcorolib.cpp
VM/src/ldblib.cpp VM/src/ldblib.cpp
@ -304,6 +309,7 @@ target_sources(Luau.VM PRIVATE
VM/src/lvmutils.cpp VM/src/lvmutils.cpp
VM/src/lapi.h VM/src/lapi.h
VM/src/lbuffer.h
VM/src/lbuiltins.h VM/src/lbuiltins.h
VM/src/lbytecode.h VM/src/lbytecode.h
VM/src/lcommon.h VM/src/lcommon.h
@ -371,8 +377,6 @@ if(TARGET Luau.UnitTest)
tests/AstQueryDsl.cpp tests/AstQueryDsl.cpp
tests/AstQueryDsl.h tests/AstQueryDsl.h
tests/AstVisitor.test.cpp tests/AstVisitor.test.cpp
tests/RegisterCallbacks.h
tests/RegisterCallbacks.cpp
tests/Autocomplete.test.cpp tests/Autocomplete.test.cpp
tests/BuiltinDefinitions.test.cpp tests/BuiltinDefinitions.test.cpp
tests/ClassFixture.cpp tests/ClassFixture.cpp
@ -386,11 +390,14 @@ if(TARGET Luau.UnitTest)
tests/CostModel.test.cpp tests/CostModel.test.cpp
tests/DataFlowGraph.test.cpp tests/DataFlowGraph.test.cpp
tests/DenseHash.test.cpp tests/DenseHash.test.cpp
tests/DiffAsserts.cpp
tests/DiffAsserts.h
tests/Differ.test.cpp tests/Differ.test.cpp
tests/Error.test.cpp tests/Error.test.cpp
tests/Fixture.cpp tests/Fixture.cpp
tests/Fixture.h tests/Fixture.h
tests/Frontend.test.cpp tests/Frontend.test.cpp
tests/InsertionOrderedMap.test.cpp
tests/IostreamOptional.h tests/IostreamOptional.h
tests/IrBuilder.test.cpp tests/IrBuilder.test.cpp
tests/IrCallWrapperX64.test.cpp tests/IrCallWrapperX64.test.cpp
@ -405,6 +412,8 @@ if(TARGET Luau.UnitTest)
tests/Normalize.test.cpp tests/Normalize.test.cpp
tests/NotNull.test.cpp tests/NotNull.test.cpp
tests/Parser.test.cpp tests/Parser.test.cpp
tests/RegisterCallbacks.cpp
tests/RegisterCallbacks.h
tests/RequireTracer.test.cpp tests/RequireTracer.test.cpp
tests/RuntimeLimits.test.cpp tests/RuntimeLimits.test.cpp
tests/ScopedFlags.h tests/ScopedFlags.h
@ -446,11 +455,11 @@ if(TARGET Luau.UnitTest)
tests/TypeInfer.unionTypes.test.cpp tests/TypeInfer.unionTypes.test.cpp
tests/TypeInfer.unknownnever.test.cpp tests/TypeInfer.unknownnever.test.cpp
tests/TypePack.test.cpp tests/TypePack.test.cpp
tests/TypePath.test.cpp
tests/TypeVar.test.cpp tests/TypeVar.test.cpp
tests/Unifier2.test.cpp tests/Unifier2.test.cpp
tests/Variant.test.cpp tests/Variant.test.cpp
tests/VisitType.test.cpp tests/VisitType.test.cpp
tests/InsertionOrderedMap.test.cpp
tests/main.cpp) tests/main.cpp)
endif() endif()

View File

@ -85,6 +85,7 @@ enum lua_Type
LUA_TFUNCTION, LUA_TFUNCTION,
LUA_TUSERDATA, LUA_TUSERDATA,
LUA_TTHREAD, LUA_TTHREAD,
LUA_TBUFFER,
// values below this line are used in GCObject tags but may never show up in TValue type tags // values below this line are used in GCObject tags but may never show up in TValue type tags
LUA_TPROTO, LUA_TPROTO,
@ -162,6 +163,7 @@ LUA_API void* lua_touserdata(lua_State* L, int idx);
LUA_API void* lua_touserdatatagged(lua_State* L, int idx, int tag); LUA_API void* lua_touserdatatagged(lua_State* L, int idx, int tag);
LUA_API int lua_userdatatag(lua_State* L, int idx); LUA_API int lua_userdatatag(lua_State* L, int idx);
LUA_API lua_State* lua_tothread(lua_State* L, int idx); LUA_API lua_State* lua_tothread(lua_State* L, int idx);
LUA_API void* lua_tobuffer(lua_State* L, int idx, size_t* len);
LUA_API const void* lua_topointer(lua_State* L, int idx); LUA_API const void* lua_topointer(lua_State* L, int idx);
/* /*
@ -188,6 +190,8 @@ LUA_API void lua_pushlightuserdata(lua_State* L, void* p);
LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag);
LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*));
LUA_API void* lua_newbuffer(lua_State* L, size_t sz);
/* /*
** get functions (Lua -> stack) ** get functions (Lua -> stack)
*/ */
@ -359,6 +363,7 @@ LUA_API void lua_unref(lua_State* L, int ref);
#define lua_isboolean(L, n) (lua_type(L, (n)) == LUA_TBOOLEAN) #define lua_isboolean(L, n) (lua_type(L, (n)) == LUA_TBOOLEAN)
#define lua_isvector(L, n) (lua_type(L, (n)) == LUA_TVECTOR) #define lua_isvector(L, n) (lua_type(L, (n)) == LUA_TVECTOR)
#define lua_isthread(L, n) (lua_type(L, (n)) == LUA_TTHREAD) #define lua_isthread(L, n) (lua_type(L, (n)) == LUA_TTHREAD)
#define lua_isbuffer(L, n) (lua_type(L, (n)) == LUA_TBUFFER)
#define lua_isnone(L, n) (lua_type(L, (n)) == LUA_TNONE) #define lua_isnone(L, n) (lua_type(L, (n)) == LUA_TNONE)
#define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL)

View File

@ -43,6 +43,8 @@ LUALIB_API void luaL_checkany(lua_State* L, int narg);
LUALIB_API int luaL_newmetatable(lua_State* L, const char* tname); LUALIB_API int luaL_newmetatable(lua_State* L, const char* tname);
LUALIB_API void* luaL_checkudata(lua_State* L, int ud, const char* tname); LUALIB_API void* luaL_checkudata(lua_State* L, int ud, const char* tname);
LUALIB_API void* luaL_checkbuffer(lua_State* L, int narg, size_t* len);
LUALIB_API void luaL_where(lua_State* L, int lvl); LUALIB_API void luaL_where(lua_State* L, int lvl);
LUALIB_API LUA_PRINTF_ATTR(2, 3) l_noret luaL_errorL(lua_State* L, const char* fmt, ...); LUALIB_API LUA_PRINTF_ATTR(2, 3) l_noret luaL_errorL(lua_State* L, const char* fmt, ...);
@ -74,7 +76,7 @@ LUALIB_API const char* luaL_typename(lua_State* L, int idx);
// generic buffer manipulation // generic buffer manipulation
struct luaL_Buffer struct luaL_Strbuf
{ {
char* p; // current position in buffer char* p; // current position in buffer
char* end; // end of the current buffer char* end; // end of the current buffer
@ -82,26 +84,27 @@ struct luaL_Buffer
struct TString* storage; struct TString* storage;
char buffer[LUA_BUFFERSIZE]; char buffer[LUA_BUFFERSIZE];
}; };
typedef struct luaL_Buffer luaL_Buffer; typedef struct luaL_Strbuf luaL_Strbuf;
// compatibility typedef: this type is called luaL_Buffer in Lua headers
// renamed to luaL_Strbuf to reduce confusion with internal VM buffer type
typedef struct luaL_Strbuf luaL_Buffer;
// when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack // when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack
// in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) // in general, functions expect the mutable string buffer to be placed on top of the stack (top-1)
// with the exception of luaL_addvalue that expects the value at the top and string buffer further away (top-2) // with the exception of luaL_addvalue that expects the value at the top and string buffer further away (top-2)
// functions that accept a 'boxloc' support string buffer placement at any location in the stack
// all the buffer users we have in Luau match this pattern, but it's something to keep in mind for new uses of buffers
#define luaL_addchar(B, c) ((void)((B)->p < (B)->end || luaL_extendbuffer(B, 1, -1)), (*(B)->p++ = (char)(c))) #define luaL_addchar(B, c) ((void)((B)->p < (B)->end || luaL_prepbuffsize(B, 1)), (*(B)->p++ = (char)(c)))
#define luaL_addstring(B, s) luaL_addlstring(B, s, strlen(s), -1) #define luaL_addstring(B, s) luaL_addlstring(B, s, strlen(s))
LUALIB_API void luaL_buffinit(lua_State* L, luaL_Buffer* B); LUALIB_API void luaL_buffinit(lua_State* L, luaL_Strbuf* B);
LUALIB_API char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size); LUALIB_API char* luaL_buffinitsize(lua_State* L, luaL_Strbuf* B, size_t size);
LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc); LUALIB_API char* luaL_prepbuffsize(luaL_Buffer* B, size_t size);
LUALIB_API void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc); LUALIB_API void luaL_addlstring(luaL_Strbuf* B, const char* s, size_t l);
LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t l, int boxloc); LUALIB_API void luaL_addvalue(luaL_Strbuf* B);
LUALIB_API void luaL_addvalue(luaL_Buffer* B); LUALIB_API void luaL_addvalueany(luaL_Strbuf* B, int idx);
LUALIB_API void luaL_addvalueany(luaL_Buffer* B, int idx); LUALIB_API void luaL_pushresult(luaL_Strbuf* B);
LUALIB_API void luaL_pushresult(luaL_Buffer* B); LUALIB_API void luaL_pushresultsize(luaL_Strbuf* B, size_t size);
LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size);
// builtin libraries // builtin libraries
LUALIB_API int luaopen_base(lua_State* L); LUALIB_API int luaopen_base(lua_State* L);

View File

@ -11,6 +11,7 @@
#include "ludata.h" #include "ludata.h"
#include "lvm.h" #include "lvm.h"
#include "lnumutils.h" #include "lnumutils.h"
#include "lbuffer.h"
#include <string.h> #include <string.h>
@ -483,6 +484,8 @@ int lua_objlen(lua_State* L, int idx)
return tsvalue(o)->len; return tsvalue(o)->len;
case LUA_TUSERDATA: case LUA_TUSERDATA:
return uvalue(o)->len; return uvalue(o)->len;
case LUA_TBUFFER:
return bufvalue(o)->len;
case LUA_TTABLE: case LUA_TTABLE:
return luaH_getn(hvalue(o)); return luaH_getn(hvalue(o));
default: default:
@ -533,25 +536,32 @@ lua_State* lua_tothread(lua_State* L, int idx)
return (!ttisthread(o)) ? NULL : thvalue(o); return (!ttisthread(o)) ? NULL : thvalue(o);
} }
void* lua_tobuffer(lua_State* L, int idx, size_t* len)
{
StkId o = index2addr(L, idx);
if (!ttisbuffer(o))
return NULL;
Buffer* b = bufvalue(o);
if (len)
*len = b->len;
return b->data;
}
const void* lua_topointer(lua_State* L, int idx) const void* lua_topointer(lua_State* L, int idx)
{ {
StkId o = index2addr(L, idx); StkId o = index2addr(L, idx);
switch (ttype(o)) switch (ttype(o))
{ {
case LUA_TSTRING:
return tsvalue(o);
case LUA_TTABLE:
return hvalue(o);
case LUA_TFUNCTION:
return clvalue(o);
case LUA_TTHREAD:
return thvalue(o);
case LUA_TUSERDATA: case LUA_TUSERDATA:
return uvalue(o)->data; return uvalue(o)->data;
case LUA_TLIGHTUSERDATA: case LUA_TLIGHTUSERDATA:
return pvalue(o); return pvalue(o);
default: default:
return NULL; return iscollectable(o) ? gcvalue(o) : NULL;
} }
} }
@ -1271,6 +1281,16 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*))
return u->data; return u->data;
} }
void* lua_newbuffer(lua_State* L, size_t sz)
{
luaC_checkGC(L);
luaC_threadbarrier(L);
Buffer* b = luaB_newbuffer(L, sz);
setbufvalue(L, L->top, b);
api_incr_top(L);
return b->data;
}
static const char* aux_upvalue(StkId fi, int n, TValue** val) static const char* aux_upvalue(StkId fi, int n, TValue** val)
{ {
Closure* f; Closure* f;

View File

@ -132,6 +132,14 @@ void* luaL_checkudata(lua_State* L, int ud, const char* tname)
luaL_typeerrorL(L, ud, tname); // else error luaL_typeerrorL(L, ud, tname); // else error
} }
void* luaL_checkbuffer(lua_State* L, int narg, size_t* len)
{
void* b = lua_tobuffer(L, narg, len);
if (!b)
tag_error(L, narg, LUA_TBUFFER);
return b;
}
void luaL_checkstack(lua_State* L, int space, const char* mes) void luaL_checkstack(lua_State* L, int space, const char* mes)
{ {
if (!lua_checkstack(L, space)) if (!lua_checkstack(L, space))
@ -360,24 +368,7 @@ static size_t getnextbuffersize(lua_State* L, size_t currentsize, size_t desired
return newsize; return newsize;
} }
void luaL_buffinit(lua_State* L, luaL_Buffer* B) static char* extendstrbuf(luaL_Strbuf* B, size_t additionalsize, int boxloc)
{
// start with an internal buffer
B->p = B->buffer;
B->end = B->p + LUA_BUFFERSIZE;
B->L = L;
B->storage = nullptr;
}
char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size)
{
luaL_buffinit(L, B);
luaL_reservebuffer(B, size, -1);
return B->p;
}
char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc)
{ {
lua_State* L = B->L; lua_State* L = B->L;
@ -408,22 +399,39 @@ char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc)
return B->p; return B->p;
} }
void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc) void luaL_buffinit(lua_State* L, luaL_Strbuf* B)
{ {
if (size_t(B->end - B->p) < size) // start with an internal buffer
luaL_extendbuffer(B, size - (B->end - B->p), boxloc); B->p = B->buffer;
B->end = B->p + LUA_BUFFERSIZE;
B->L = L;
B->storage = nullptr;
} }
void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len, int boxloc) char* luaL_buffinitsize(lua_State* L, luaL_Strbuf* B, size_t size)
{
luaL_buffinit(L, B);
return luaL_prepbuffsize(B, size);
}
char* luaL_prepbuffsize(luaL_Strbuf* B, size_t size)
{
if (size_t(B->end - B->p) < size)
return extendstrbuf(B, size - (B->end - B->p), -1);
return B->p;
}
void luaL_addlstring(luaL_Strbuf* B, const char* s, size_t len)
{ {
if (size_t(B->end - B->p) < len) if (size_t(B->end - B->p) < len)
luaL_extendbuffer(B, len - (B->end - B->p), boxloc); extendstrbuf(B, len - (B->end - B->p), -1);
memcpy(B->p, s, len); memcpy(B->p, s, len);
B->p += len; B->p += len;
} }
void luaL_addvalue(luaL_Buffer* B) void luaL_addvalue(luaL_Strbuf* B)
{ {
lua_State* L = B->L; lua_State* L = B->L;
@ -431,7 +439,7 @@ void luaL_addvalue(luaL_Buffer* B)
if (const char* s = lua_tolstring(L, -1, &vl)) if (const char* s = lua_tolstring(L, -1, &vl))
{ {
if (size_t(B->end - B->p) < vl) if (size_t(B->end - B->p) < vl)
luaL_extendbuffer(B, vl - (B->end - B->p), -2); extendstrbuf(B, vl - (B->end - B->p), -2);
memcpy(B->p, s, vl); memcpy(B->p, s, vl);
B->p += vl; B->p += vl;
@ -440,7 +448,7 @@ void luaL_addvalue(luaL_Buffer* B)
} }
} }
void luaL_addvalueany(luaL_Buffer* B, int idx) void luaL_addvalueany(luaL_Strbuf* B, int idx)
{ {
lua_State* L = B->L; lua_State* L = B->L;
@ -465,28 +473,29 @@ void luaL_addvalueany(luaL_Buffer* B, int idx)
double n = lua_tonumber(L, idx); double n = lua_tonumber(L, idx);
char s[LUAI_MAXNUM2STR]; char s[LUAI_MAXNUM2STR];
char* e = luai_num2str(s, n); char* e = luai_num2str(s, n);
luaL_addlstring(B, s, e - s, -1); luaL_addlstring(B, s, e - s);
break; break;
} }
case LUA_TSTRING: case LUA_TSTRING:
{ {
size_t len; size_t len;
const char* s = lua_tolstring(L, idx, &len); const char* s = lua_tolstring(L, idx, &len);
luaL_addlstring(B, s, len, -1); luaL_addlstring(B, s, len);
break; break;
} }
default: default:
{ {
size_t len; size_t len;
const char* s = luaL_tolstring(L, idx, &len); luaL_tolstring(L, idx, &len);
luaL_addlstring(B, s, len, -2); // note: luaL_addlstring assumes box is stored at top of stack, so we can't call it here
lua_pop(L, 1); // instead we use luaL_addvalue which will take the string from the top of the stack and add that
luaL_addvalue(B);
} }
} }
} }
void luaL_pushresult(luaL_Buffer* B) void luaL_pushresult(luaL_Strbuf* B)
{ {
lua_State* L = B->L; lua_State* L = B->L;
@ -510,7 +519,7 @@ void luaL_pushresult(luaL_Buffer* B)
} }
} }
void luaL_pushresultsize(luaL_Buffer* B, size_t size) void luaL_pushresultsize(luaL_Strbuf* B, size_t size)
{ {
B->p += size; B->p += size;
luaL_pushresult(B); luaL_pushresult(B);

24
VM/src/lbuffer.cpp Normal file
View File

@ -0,0 +1,24 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "lbuffer.h"
#include "lgc.h"
#include "lmem.h"
#include <string.h>
Buffer* luaB_newbuffer(lua_State* L, size_t s)
{
if (s > MAX_BUFFER_SIZE)
luaM_toobig(L);
Buffer* b = luaM_newgco(L, Buffer, sizebuffer(s), L->activememcat);
luaC_init(L, b, LUA_TBUFFER);
b->len = unsigned(s);
memset(b->data, 0, b->len);
return b;
}
void luaB_freebuffer(lua_State* L, Buffer* b, lua_Page* page)
{
luaM_freegco(L, b, sizebuffer(b->len), b->memcat, page);
}

13
VM/src/lbuffer.h Normal file
View File

@ -0,0 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "lobject.h"
// buffer size limit
#define MAX_BUFFER_SIZE (1 << 30)
// GCObject size has to be at least 16 bytes, so a minimum of 8 bytes is always reserved
#define sizebuffer(len) (offsetof(Buffer, data) + ((len) < 8 ? 8 : (len)))
LUAI_FUNC Buffer* luaB_newbuffer(lua_State* L, size_t s);
LUAI_FUNC void luaB_freebuffer(lua_State* L, Buffer* u, struct lua_Page* page);

View File

@ -110,7 +110,7 @@ static int db_traceback(lua_State* L)
int level = luaL_optinteger(L, arg + 2, (L == L1) ? 1 : 0); int level = luaL_optinteger(L, arg + 2, (L == L1) ? 1 : 0);
luaL_argcheck(L, level >= 0, arg + 2, "level can't be negative"); luaL_argcheck(L, level >= 0, arg + 2, "level can't be negative");
luaL_Buffer buf; luaL_Strbuf buf;
luaL_buffinit(L, &buf); luaL_buffinit(L, &buf);
if (msg) if (msg)
@ -137,7 +137,7 @@ static int db_traceback(lua_State* L)
*--lineptr = '0' + (r % 10); *--lineptr = '0' + (r % 10);
luaL_addchar(&buf, ':'); luaL_addchar(&buf, ':');
luaL_addlstring(&buf, lineptr, lineend - lineptr, -1); luaL_addlstring(&buf, lineptr, lineend - lineptr);
} }
if (ar.name) if (ar.name)

View File

@ -17,6 +17,8 @@
#include <string.h> #include <string.h>
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauHandlerClose, false)
/* /*
** {====================================================== ** {======================================================
** Error-recovery functions ** Error-recovery functions
@ -407,7 +409,7 @@ static void resume_handle(lua_State* L, void* ud)
L->ci = restoreci(L, old_ci); L->ci = restoreci(L, old_ci);
// close eventual pending closures; this means it's now safe to restore stack // close eventual pending closures; this means it's now safe to restore stack
luaF_close(L, L->base); luaF_close(L, DFFlag::LuauHandlerClose ? L->ci->base : L->base);
// finish cont call and restore stack to previous ci top // finish cont call and restore stack to previous ci top
luau_poscall(L, L->top - n); luau_poscall(L, L->top - n);

View File

@ -10,6 +10,7 @@
#include "ldo.h" #include "ldo.h"
#include "lmem.h" #include "lmem.h"
#include "ludata.h" #include "ludata.h"
#include "lbuffer.h"
#include <string.h> #include <string.h>
@ -275,6 +276,11 @@ static void reallymarkobject(global_State* g, GCObject* o)
g->gray = o; g->gray = o;
break; break;
} }
case LUA_TBUFFER:
{
gray2black(o); // buffers are never gray
return;
}
case LUA_TPROTO: case LUA_TPROTO:
{ {
gco2p(o)->gclist = g->gray; gco2p(o)->gclist = g->gray;
@ -618,6 +624,9 @@ static void freeobj(lua_State* L, GCObject* o, lua_Page* page)
case LUA_TUSERDATA: case LUA_TUSERDATA:
luaU_freeudata(L, gco2u(o), page); luaU_freeudata(L, gco2u(o), page);
break; break;
case LUA_TBUFFER:
luaB_freebuffer(L, gco2buf(o), page);
break;
default: default:
LUAU_ASSERT(0); LUAU_ASSERT(0);
} }

View File

@ -136,7 +136,8 @@ LUAI_FUNC void luaC_barriertable(lua_State* L, Table* t, GCObject* v);
LUAI_FUNC void luaC_barrierback(lua_State* L, GCObject* o, GCObject** gclist); LUAI_FUNC void luaC_barrierback(lua_State* L, GCObject* o, GCObject** gclist);
LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_validate(lua_State* L);
LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat));
LUAI_FUNC void luaC_enumheap(lua_State* L, void* context, void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, const char* name), LUAI_FUNC void luaC_enumheap(lua_State* L, void* context,
void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, size_t size, const char* name),
void (*edge)(void* context, void* from, void* to, const char* name)); void (*edge)(void* context, void* from, void* to, const char* name));
LUAI_FUNC int64_t luaC_allocationrate(lua_State* L); LUAI_FUNC int64_t luaC_allocationrate(lua_State* L);
LUAI_FUNC const char* luaC_statename(int state); LUAI_FUNC const char* luaC_statename(int state);

View File

@ -9,6 +9,7 @@
#include "lstring.h" #include "lstring.h"
#include "ltable.h" #include "ltable.h"
#include "ludata.h" #include "ludata.h"
#include "lbuffer.h"
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
@ -166,6 +167,9 @@ static void validateobj(global_State* g, GCObject* o)
validatestack(g, gco2th(o)); validatestack(g, gco2th(o));
break; break;
case LUA_TBUFFER:
break;
case LUA_TPROTO: case LUA_TPROTO:
validateproto(g, gco2p(o)); validateproto(g, gco2p(o));
break; break;
@ -473,6 +477,11 @@ static void dumpthread(FILE* f, lua_State* th)
fprintf(f, "}"); fprintf(f, "}");
} }
static void dumpbuffer(FILE* f, Buffer* b)
{
fprintf(f, "{\"type\":\"buffer\",\"cat\":%d,\"size\":%d}", b->memcat, int(sizebuffer(b->len)));
}
static void dumpproto(FILE* f, Proto* p) static void dumpproto(FILE* f, Proto* p)
{ {
size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo +
@ -541,6 +550,9 @@ static void dumpobj(FILE* f, GCObject* o)
case LUA_TTHREAD: case LUA_TTHREAD:
return dumpthread(f, gco2th(o)); return dumpthread(f, gco2th(o));
case LUA_TBUFFER:
return dumpbuffer(f, gco2buf(o));
case LUA_TPROTO: case LUA_TPROTO:
return dumpproto(f, gco2p(o)); return dumpproto(f, gco2p(o));
@ -607,7 +619,7 @@ struct EnumContext
{ {
lua_State* L; lua_State* L;
void* context; void* context;
void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, const char* name); void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, size_t size, const char* name);
void (*edge)(void* context, void* from, void* to, const char* name); void (*edge)(void* context, void* from, void* to, const char* name);
}; };
@ -617,9 +629,9 @@ static void* enumtopointer(GCObject* gco)
return gco->gch.tt == LUA_TUSERDATA ? (void*)gco2u(gco)->data : (void*)gco; return gco->gch.tt == LUA_TUSERDATA ? (void*)gco2u(gco)->data : (void*)gco;
} }
static void enumnode(EnumContext* ctx, GCObject* gco, const char* objname) static void enumnode(EnumContext* ctx, GCObject* gco, size_t size, const char* objname)
{ {
ctx->node(ctx->context, enumtopointer(gco), gco->gch.tt, gco->gch.memcat, objname); ctx->node(ctx->context, enumtopointer(gco), gco->gch.tt, gco->gch.memcat, size, objname);
} }
static void enumedge(EnumContext* ctx, GCObject* from, GCObject* to, const char* edgename) static void enumedge(EnumContext* ctx, GCObject* from, GCObject* to, const char* edgename)
@ -638,13 +650,15 @@ static void enumedges(EnumContext* ctx, GCObject* from, TValue* data, size_t siz
static void enumstring(EnumContext* ctx, TString* ts) static void enumstring(EnumContext* ctx, TString* ts)
{ {
enumnode(ctx, obj2gco(ts), NULL); enumnode(ctx, obj2gco(ts), ts->len, NULL);
} }
static void enumtable(EnumContext* ctx, Table* h) static void enumtable(EnumContext* ctx, Table* h)
{ {
size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue);
// Provide a name for a special registry table // Provide a name for a special registry table
enumnode(ctx, obj2gco(h), h == hvalue(registry(ctx->L)) ? "registry" : NULL); enumnode(ctx, obj2gco(h), size, h == hvalue(registry(ctx->L)) ? "registry" : NULL);
if (h->node != &luaH_dummynode) if (h->node != &luaH_dummynode)
{ {
@ -703,7 +717,7 @@ static void enumclosure(EnumContext* ctx, Closure* cl)
{ {
if (cl->isC) if (cl->isC)
{ {
enumnode(ctx, obj2gco(cl), cl->c.debugname); enumnode(ctx, obj2gco(cl), sizeCclosure(cl->nupvalues), cl->c.debugname);
} }
else else
{ {
@ -716,7 +730,7 @@ static void enumclosure(EnumContext* ctx, Closure* cl)
else else
snprintf(buf, sizeof(buf), "%s:%d", p->debugname ? getstr(p->debugname) : "", p->linedefined); snprintf(buf, sizeof(buf), "%s:%d", p->debugname ? getstr(p->debugname) : "", p->linedefined);
enumnode(ctx, obj2gco(cl), buf); enumnode(ctx, obj2gco(cl), sizeLclosure(cl->nupvalues), buf);
} }
enumedge(ctx, obj2gco(cl), obj2gco(cl->env), "env"); enumedge(ctx, obj2gco(cl), obj2gco(cl->env), "env");
@ -737,7 +751,26 @@ static void enumclosure(EnumContext* ctx, Closure* cl)
static void enumudata(EnumContext* ctx, Udata* u) static void enumudata(EnumContext* ctx, Udata* u)
{ {
enumnode(ctx, obj2gco(u), NULL); const char* name = NULL;
if (Table* h = u->metatable)
{
if (h->node != &luaH_dummynode)
{
for (int i = 0; i < sizenode(h); ++i)
{
const LuaNode& n = h->node[i];
if (ttisstring(&n.key) && ttisstring(&n.val) && strcmp(svalue(&n.key), "__type") == 0)
{
name = svalue(&n.val);
break;
}
}
}
}
enumnode(ctx, obj2gco(u), sizeudata(u->len), name);
if (u->metatable) if (u->metatable)
enumedge(ctx, obj2gco(u), obj2gco(u->metatable), "metatable"); enumedge(ctx, obj2gco(u), obj2gco(u->metatable), "metatable");
@ -745,6 +778,8 @@ static void enumudata(EnumContext* ctx, Udata* u)
static void enumthread(EnumContext* ctx, lua_State* th) static void enumthread(EnumContext* ctx, lua_State* th)
{ {
size_t size = sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci;
Closure* tcl = NULL; Closure* tcl = NULL;
for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci)
{ {
@ -766,11 +801,11 @@ static void enumthread(EnumContext* ctx, lua_State* th)
else else
snprintf(buf, sizeof(buf), "%s:%d", p->debugname ? getstr(p->debugname) : "", p->linedefined); snprintf(buf, sizeof(buf), "%s:%d", p->debugname ? getstr(p->debugname) : "", p->linedefined);
enumnode(ctx, obj2gco(th), buf); enumnode(ctx, obj2gco(th), size, buf);
} }
else else
{ {
enumnode(ctx, obj2gco(th), NULL); enumnode(ctx, obj2gco(th), size, NULL);
} }
enumedge(ctx, obj2gco(th), obj2gco(th->gt), "globals"); enumedge(ctx, obj2gco(th), obj2gco(th->gt), "globals");
@ -779,9 +814,17 @@ static void enumthread(EnumContext* ctx, lua_State* th)
enumedges(ctx, obj2gco(th), th->stack, th->top - th->stack, "stack"); enumedges(ctx, obj2gco(th), th->stack, th->top - th->stack, "stack");
} }
static void enumbuffer(EnumContext* ctx, Buffer* b)
{
enumnode(ctx, obj2gco(b), sizebuffer(b->len), NULL);
}
static void enumproto(EnumContext* ctx, Proto* p) static void enumproto(EnumContext* ctx, Proto* p)
{ {
enumnode(ctx, obj2gco(p), p->source ? getstr(p->source) : NULL); size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo +
sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues;
enumnode(ctx, obj2gco(p), size, p->source ? getstr(p->source) : NULL);
if (p->sizek) if (p->sizek)
enumedges(ctx, obj2gco(p), p->k, p->sizek, "constants"); enumedges(ctx, obj2gco(p), p->k, p->sizek, "constants");
@ -792,7 +835,7 @@ static void enumproto(EnumContext* ctx, Proto* p)
static void enumupval(EnumContext* ctx, UpVal* uv) static void enumupval(EnumContext* ctx, UpVal* uv)
{ {
enumnode(ctx, obj2gco(uv), NULL); enumnode(ctx, obj2gco(uv), sizeof(UpVal), NULL);
if (iscollectable(uv->v)) if (iscollectable(uv->v))
enumedge(ctx, obj2gco(uv), gcvalue(uv->v), "value"); enumedge(ctx, obj2gco(uv), gcvalue(uv->v), "value");
@ -817,6 +860,9 @@ static void enumobj(EnumContext* ctx, GCObject* o)
case LUA_TTHREAD: case LUA_TTHREAD:
return enumthread(ctx, gco2th(o)); return enumthread(ctx, gco2th(o));
case LUA_TBUFFER:
return enumbuffer(ctx, gco2buf(o));
case LUA_TPROTO: case LUA_TPROTO:
return enumproto(ctx, gco2p(o)); return enumproto(ctx, gco2p(o));
@ -834,7 +880,7 @@ static bool enumgco(void* context, lua_Page* page, GCObject* gco)
return false; return false;
} }
void luaC_enumheap(lua_State* L, void* context, void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, const char* name), void luaC_enumheap(lua_State* L, void* context, void (*node)(void* context, void* ptr, uint8_t tt, uint8_t memcat, size_t size, const char* name),
void (*edge)(void* context, void* from, void* to, const char* name)) void (*edge)(void* context, void* from, void* to, const char* name))
{ {
global_State* g = L->global; global_State* g = L->global;

View File

@ -118,6 +118,7 @@ static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table
static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header");
static_assert(offsetof(Udata, data) == ABISWITCH(16, 16, 12), "size mismatch for userdata header"); static_assert(offsetof(Udata, data) == ABISWITCH(16, 16, 12), "size mismatch for userdata header");
static_assert(sizeof(Table) == ABISWITCH(48, 32, 32), "size mismatch for table header"); static_assert(sizeof(Table) == ABISWITCH(48, 32, 32), "size mismatch for table header");
static_assert(offsetof(Buffer, data) == ABISWITCH(8, 8, 8), "size mismatch for buffer header");
const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kSizeClasses = LUA_SIZECLASSES;
const size_t kMaxSmallSize = 512; const size_t kMaxSmallSize = 512;

View File

@ -58,6 +58,7 @@ typedef struct lua_TValue
#define ttisboolean(o) (ttype(o) == LUA_TBOOLEAN) #define ttisboolean(o) (ttype(o) == LUA_TBOOLEAN)
#define ttisuserdata(o) (ttype(o) == LUA_TUSERDATA) #define ttisuserdata(o) (ttype(o) == LUA_TUSERDATA)
#define ttisthread(o) (ttype(o) == LUA_TTHREAD) #define ttisthread(o) (ttype(o) == LUA_TTHREAD)
#define ttisbuffer(o) (ttype(o) == LUA_TBUFFER)
#define ttislightuserdata(o) (ttype(o) == LUA_TLIGHTUSERDATA) #define ttislightuserdata(o) (ttype(o) == LUA_TLIGHTUSERDATA)
#define ttisvector(o) (ttype(o) == LUA_TVECTOR) #define ttisvector(o) (ttype(o) == LUA_TVECTOR)
#define ttisupval(o) (ttype(o) == LUA_TUPVAL) #define ttisupval(o) (ttype(o) == LUA_TUPVAL)
@ -74,6 +75,7 @@ typedef struct lua_TValue
#define hvalue(o) check_exp(ttistable(o), &(o)->value.gc->h) #define hvalue(o) check_exp(ttistable(o), &(o)->value.gc->h)
#define bvalue(o) check_exp(ttisboolean(o), (o)->value.b) #define bvalue(o) check_exp(ttisboolean(o), (o)->value.b)
#define thvalue(o) check_exp(ttisthread(o), &(o)->value.gc->th) #define thvalue(o) check_exp(ttisthread(o), &(o)->value.gc->th)
#define bufvalue(o) check_exp(ttisbuffer(o), &(o)->value.gc->buf)
#define upvalue(o) check_exp(ttisupval(o), &(o)->value.gc->uv) #define upvalue(o) check_exp(ttisupval(o), &(o)->value.gc->uv)
#define l_isfalse(o) (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) #define l_isfalse(o) (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0))
@ -156,6 +158,14 @@ typedef struct lua_TValue
checkliveness(L->global, i_o); \ checkliveness(L->global, i_o); \
} }
#define setbufvalue(L, obj, x) \
{ \
TValue* i_o = (obj); \
i_o->value.gc = cast_to(GCObject*, (x)); \
i_o->tt = LUA_TBUFFER; \
checkliveness(L->global, i_o); \
}
#define setclvalue(L, obj, x) \ #define setclvalue(L, obj, x) \
{ \ { \
TValue* i_o = (obj); \ TValue* i_o = (obj); \
@ -254,6 +264,19 @@ typedef struct Udata
}; };
} Udata; } Udata;
typedef struct Buffer
{
CommonHeader;
unsigned int len;
union
{
char data[1]; // buffer is allocated right after the header
L_Umaxalign dummy; // ensures maximum alignment for data
};
} Buffer;
/* /*
** Function Prototypes ** Function Prototypes
*/ */

View File

@ -161,7 +161,7 @@ static int os_date(lua_State* L)
cc[0] = '%'; cc[0] = '%';
cc[2] = '\0'; cc[2] = '\0';
luaL_Buffer b; luaL_Strbuf b;
luaL_buffinit(L, &b); luaL_buffinit(L, &b);
for (; *s; s++) for (; *s; s++)
{ {
@ -179,7 +179,7 @@ static int os_date(lua_State* L)
char buff[200]; // should be big enough for any conversion result char buff[200]; // should be big enough for any conversion result
cc[1] = *(++s); cc[1] = *(++s);
reslen = strftime(buff, sizeof(buff), cc, stm); reslen = strftime(buff, sizeof(buff), cc, stm);
luaL_addlstring(&b, buff, reslen, -1); luaL_addlstring(&b, buff, reslen);
} }
} }
luaL_pushresult(&b); luaL_pushresult(&b);

View File

@ -282,6 +282,7 @@ union GCObject
struct Proto p; struct Proto p;
struct UpVal uv; struct UpVal uv;
struct lua_State th; // thread struct lua_State th; // thread
struct Buffer buf;
}; };
// macros to convert a GCObject into a specific value // macros to convert a GCObject into a specific value
@ -292,6 +293,7 @@ union GCObject
#define gco2p(o) check_exp((o)->gch.tt == LUA_TPROTO, &((o)->p)) #define gco2p(o) check_exp((o)->gch.tt == LUA_TPROTO, &((o)->p))
#define gco2uv(o) check_exp((o)->gch.tt == LUA_TUPVAL, &((o)->uv)) #define gco2uv(o) check_exp((o)->gch.tt == LUA_TUPVAL, &((o)->uv))
#define gco2th(o) check_exp((o)->gch.tt == LUA_TTHREAD, &((o)->th)) #define gco2th(o) check_exp((o)->gch.tt == LUA_TTHREAD, &((o)->th))
#define gco2buf(o) check_exp((o)->gch.tt == LUA_TBUFFER, &((o)->buf))
// macro to convert any Lua object into a GCObject // macro to convert any Lua object into a GCObject
#define obj2gco(v) check_exp(iscollectable(v), cast_to(GCObject*, (v) + 0)) #define obj2gco(v) check_exp(iscollectable(v), cast_to(GCObject*, (v) + 0))

View File

@ -48,7 +48,7 @@ static int str_reverse(lua_State* L)
{ {
size_t l; size_t l;
const char* s = luaL_checklstring(L, 1, &l); const char* s = luaL_checklstring(L, 1, &l);
luaL_Buffer b; luaL_Strbuf b;
char* ptr = luaL_buffinitsize(L, &b, l); char* ptr = luaL_buffinitsize(L, &b, l);
while (l--) while (l--)
*ptr++ = s[l]; *ptr++ = s[l];
@ -60,7 +60,7 @@ static int str_lower(lua_State* L)
{ {
size_t l; size_t l;
const char* s = luaL_checklstring(L, 1, &l); const char* s = luaL_checklstring(L, 1, &l);
luaL_Buffer b; luaL_Strbuf b;
char* ptr = luaL_buffinitsize(L, &b, l); char* ptr = luaL_buffinitsize(L, &b, l);
for (size_t i = 0; i < l; i++) for (size_t i = 0; i < l; i++)
*ptr++ = tolower(uchar(s[i])); *ptr++ = tolower(uchar(s[i]));
@ -72,7 +72,7 @@ static int str_upper(lua_State* L)
{ {
size_t l; size_t l;
const char* s = luaL_checklstring(L, 1, &l); const char* s = luaL_checklstring(L, 1, &l);
luaL_Buffer b; luaL_Strbuf b;
char* ptr = luaL_buffinitsize(L, &b, l); char* ptr = luaL_buffinitsize(L, &b, l);
for (size_t i = 0; i < l; i++) for (size_t i = 0; i < l; i++)
*ptr++ = toupper(uchar(s[i])); *ptr++ = toupper(uchar(s[i]));
@ -95,7 +95,7 @@ static int str_rep(lua_State* L)
if (l > MAXSSIZE / (size_t)n) // may overflow? if (l > MAXSSIZE / (size_t)n) // may overflow?
luaL_error(L, "resulting string too large"); luaL_error(L, "resulting string too large");
luaL_Buffer b; luaL_Strbuf b;
char* ptr = luaL_buffinitsize(L, &b, l * n); char* ptr = luaL_buffinitsize(L, &b, l * n);
const char* start = ptr; const char* start = ptr;
@ -151,7 +151,7 @@ static int str_char(lua_State* L)
{ {
int n = lua_gettop(L); // number of arguments int n = lua_gettop(L); // number of arguments
luaL_Buffer b; luaL_Strbuf b;
char* ptr = luaL_buffinitsize(L, &b, n); char* ptr = luaL_buffinitsize(L, &b, n);
for (int i = 1; i <= n; i++) for (int i = 1; i <= n; i++)
@ -750,12 +750,12 @@ static int gmatch(lua_State* L)
return 1; return 1;
} }
static void add_s(MatchState* ms, luaL_Buffer* b, const char* s, const char* e) static void add_s(MatchState* ms, luaL_Strbuf* b, const char* s, const char* e)
{ {
size_t l, i; size_t l, i;
const char* news = lua_tolstring(ms->L, 3, &l); const char* news = lua_tolstring(ms->L, 3, &l);
luaL_reservebuffer(b, l, -1); luaL_prepbuffsize(b, l);
for (i = 0; i < l; i++) for (i = 0; i < l; i++)
{ {
@ -771,7 +771,7 @@ static void add_s(MatchState* ms, luaL_Buffer* b, const char* s, const char* e)
luaL_addchar(b, news[i]); luaL_addchar(b, news[i]);
} }
else if (news[i] == '0') else if (news[i] == '0')
luaL_addlstring(b, s, e - s, -1); luaL_addlstring(b, s, e - s);
else else
{ {
push_onecapture(ms, news[i] - '1', s, e); push_onecapture(ms, news[i] - '1', s, e);
@ -781,7 +781,7 @@ static void add_s(MatchState* ms, luaL_Buffer* b, const char* s, const char* e)
} }
} }
static void add_value(MatchState* ms, luaL_Buffer* b, const char* s, const char* e, int tr) static void add_value(MatchState* ms, luaL_Strbuf* b, const char* s, const char* e, int tr)
{ {
lua_State* L = ms->L; lua_State* L = ms->L;
switch (tr) switch (tr)
@ -826,7 +826,7 @@ static int str_gsub(lua_State* L)
int anchor = (*p == '^'); int anchor = (*p == '^');
int n = 0; int n = 0;
MatchState ms; MatchState ms;
luaL_Buffer b; luaL_Strbuf b;
luaL_argexpected(L, tr == LUA_TNUMBER || tr == LUA_TSTRING || tr == LUA_TFUNCTION || tr == LUA_TTABLE, 3, "string/function/table"); luaL_argexpected(L, tr == LUA_TNUMBER || tr == LUA_TSTRING || tr == LUA_TFUNCTION || tr == LUA_TTABLE, 3, "string/function/table");
luaL_buffinit(L, &b); luaL_buffinit(L, &b);
if (anchor) if (anchor)
@ -854,7 +854,7 @@ static int str_gsub(lua_State* L)
if (anchor) if (anchor)
break; break;
} }
luaL_addlstring(&b, src, ms.src_end - src, -1); luaL_addlstring(&b, src, ms.src_end - src);
luaL_pushresult(&b); luaL_pushresult(&b);
lua_pushinteger(L, n); // number of substitutions lua_pushinteger(L, n); // number of substitutions
return 2; return 2;
@ -869,12 +869,12 @@ static int str_gsub(lua_State* L)
// maximum size of each format specification (such as '%-099.99d') // maximum size of each format specification (such as '%-099.99d')
#define MAX_FORMAT 32 #define MAX_FORMAT 32
static void addquoted(lua_State* L, luaL_Buffer* b, int arg) static void addquoted(lua_State* L, luaL_Strbuf* b, int arg)
{ {
size_t l; size_t l;
const char* s = luaL_checklstring(L, arg, &l); const char* s = luaL_checklstring(L, arg, &l);
luaL_reservebuffer(b, l + 2, -1); luaL_prepbuffsize(b, l + 2);
luaL_addchar(b, '"'); luaL_addchar(b, '"');
while (l--) while (l--)
@ -891,12 +891,12 @@ static void addquoted(lua_State* L, luaL_Buffer* b, int arg)
} }
case '\r': case '\r':
{ {
luaL_addlstring(b, "\\r", 2, -1); luaL_addlstring(b, "\\r", 2);
break; break;
} }
case '\0': case '\0':
{ {
luaL_addlstring(b, "\\000", 4, -1); luaL_addlstring(b, "\\000", 4);
break; break;
} }
default: default:
@ -958,7 +958,7 @@ static int str_format(lua_State* L)
size_t sfl; size_t sfl;
const char* strfrmt = luaL_checklstring(L, arg, &sfl); const char* strfrmt = luaL_checklstring(L, arg, &sfl);
const char* strfrmt_end = strfrmt + sfl; const char* strfrmt_end = strfrmt + sfl;
luaL_Buffer b; luaL_Strbuf b;
luaL_buffinit(L, &b); luaL_buffinit(L, &b);
while (strfrmt < strfrmt_end) while (strfrmt < strfrmt_end)
{ {
@ -1029,7 +1029,7 @@ static int str_format(lua_State* L)
// no precision and string is too long to be formatted, or no format necessary to begin with // no precision and string is too long to be formatted, or no format necessary to begin with
if (form[2] == '\0' || (!strchr(form, '.') && l >= 100)) if (form[2] == '\0' || (!strchr(form, '.') && l >= 100))
{ {
luaL_addlstring(&b, s, l, -1); luaL_addlstring(&b, s, l);
continue; // skip the `luaL_addlstring' at the end continue; // skip the `luaL_addlstring' at the end
} }
else else
@ -1048,7 +1048,7 @@ static int str_format(lua_State* L)
luaL_error(L, "invalid option '%%%c' to 'format'", *(strfrmt - 1)); luaL_error(L, "invalid option '%%%c' to 'format'", *(strfrmt - 1));
} }
} }
luaL_addlstring(&b, buff, strlen(buff), -1); luaL_addlstring(&b, buff, strlen(buff));
} }
} }
luaL_pushresult(&b); luaL_pushresult(&b);
@ -1344,7 +1344,7 @@ static KOption getdetails(Header* h, size_t totalsize, const char** fmt, int* ps
** the size of a Lua integer, correcting the extra sign-extension ** the size of a Lua integer, correcting the extra sign-extension
** bytes if necessary (by default they would be zeros). ** bytes if necessary (by default they would be zeros).
*/ */
static void packint(luaL_Buffer* b, unsigned long long n, int islittle, int size, int neg) static void packint(luaL_Strbuf* b, unsigned long long n, int islittle, int size, int neg)
{ {
LUAU_ASSERT(size <= MAXINTSIZE); LUAU_ASSERT(size <= MAXINTSIZE);
char buff[MAXINTSIZE]; char buff[MAXINTSIZE];
@ -1360,7 +1360,7 @@ static void packint(luaL_Buffer* b, unsigned long long n, int islittle, int size
for (i = SZINT; i < size; i++) // correct extra bytes for (i = SZINT; i < size; i++) // correct extra bytes
buff[islittle ? i : size - 1 - i] = (char)MC; buff[islittle ? i : size - 1 - i] = (char)MC;
} }
luaL_addlstring(b, buff, size, -1); // add result to buffer luaL_addlstring(b, buff, size); // add result to buffer
} }
/* /*
@ -1384,7 +1384,7 @@ static void copywithendian(volatile char* dest, volatile const char* src, int si
static int str_pack(lua_State* L) static int str_pack(lua_State* L)
{ {
luaL_Buffer b; luaL_Strbuf b;
Header h; Header h;
const char* fmt = luaL_checkstring(L, 1); // format string const char* fmt = luaL_checkstring(L, 1); // format string
int arg = 1; // current argument to pack int arg = 1; // current argument to pack
@ -1434,7 +1434,7 @@ static int str_pack(lua_State* L)
u.n = n; u.n = n;
// move 'u' to final result, correcting endianness if needed // move 'u' to final result, correcting endianness if needed
copywithendian(buff, u.buff, size, h.islittle); copywithendian(buff, u.buff, size, h.islittle);
luaL_addlstring(&b, buff, size, -1); luaL_addlstring(&b, buff, size);
break; break;
} }
case Kchar: case Kchar:
@ -1442,8 +1442,8 @@ static int str_pack(lua_State* L)
size_t len; size_t len;
const char* s = luaL_checklstring(L, arg, &len); const char* s = luaL_checklstring(L, arg, &len);
luaL_argcheck(L, len <= (size_t)size, arg, "string longer than given size"); luaL_argcheck(L, len <= (size_t)size, arg, "string longer than given size");
luaL_addlstring(&b, s, len, -1); // add string luaL_addlstring(&b, s, len); // add string
while (len++ < (size_t)size) // pad extra space while (len++ < (size_t)size) // pad extra space
luaL_addchar(&b, LUAL_PACKPADBYTE); luaL_addchar(&b, LUAL_PACKPADBYTE);
break; break;
} }
@ -1453,7 +1453,7 @@ static int str_pack(lua_State* L)
const char* s = luaL_checklstring(L, arg, &len); const char* s = luaL_checklstring(L, arg, &len);
luaL_argcheck(L, size >= (int)sizeof(size_t) || len < ((size_t)1 << (size * NB)), arg, "string length does not fit in given size"); luaL_argcheck(L, size >= (int)sizeof(size_t) || len < ((size_t)1 << (size * NB)), arg, "string length does not fit in given size");
packint(&b, len, h.islittle, size, 0); // pack length packint(&b, len, h.islittle, size, 0); // pack length
luaL_addlstring(&b, s, len, -1); luaL_addlstring(&b, s, len);
totalsize += len; totalsize += len;
break; break;
} }
@ -1462,7 +1462,7 @@ static int str_pack(lua_State* L)
size_t len; size_t len;
const char* s = luaL_checklstring(L, arg, &len); const char* s = luaL_checklstring(L, arg, &len);
luaL_argcheck(L, strlen(s) == len, arg, "string contains zeros"); luaL_argcheck(L, strlen(s) == len, arg, "string contains zeros");
luaL_addlstring(&b, s, len, -1); luaL_addlstring(&b, s, len);
luaL_addchar(&b, '\0'); // add zero at the end luaL_addchar(&b, '\0'); // add zero at the end
totalsize += len + 1; totalsize += len + 1;
break; break;

View File

@ -217,7 +217,7 @@ static int tmove(lua_State* L)
return 1; return 1;
} }
static void addfield(lua_State* L, luaL_Buffer* b, int i) static void addfield(lua_State* L, luaL_Strbuf* b, int i)
{ {
int tt = lua_rawgeti(L, 1, i); int tt = lua_rawgeti(L, 1, i);
if (tt != LUA_TSTRING && tt != LUA_TNUMBER) if (tt != LUA_TSTRING && tt != LUA_TNUMBER)
@ -227,7 +227,7 @@ static void addfield(lua_State* L, luaL_Buffer* b, int i)
static int tconcat(lua_State* L) static int tconcat(lua_State* L)
{ {
luaL_Buffer b; luaL_Strbuf b;
size_t lsep; size_t lsep;
int i, last; int i, last;
const char* sep = luaL_optlstring(L, 2, "", &lsep); const char* sep = luaL_optlstring(L, 2, "", &lsep);
@ -238,7 +238,7 @@ static int tconcat(lua_State* L)
for (; i < last; i++) for (; i < last; i++)
{ {
addfield(L, &b, i); addfield(L, &b, i);
luaL_addlstring(&b, sep, lsep, -1); luaL_addlstring(&b, sep, lsep);
} }
if (i == last) // add last value (if interval was not empty) if (i == last) // add last value (if interval was not empty)
addfield(L, &b, i); addfield(L, &b, i);

View File

@ -28,6 +28,7 @@ const char* const luaT_typenames[] = {
"function", "function",
"userdata", "userdata",
"thread", "thread",
"buffer",
}; };
const char* const luaT_eventname[] = { const char* const luaT_eventname[] = {

View File

@ -170,12 +170,12 @@ static int utfchar(lua_State* L)
} }
else else
{ {
luaL_Buffer b; luaL_Strbuf b;
luaL_buffinit(L, &b); luaL_buffinit(L, &b);
for (int i = 1; i <= n; i++) for (int i = 1; i <= n; i++)
{ {
int l = buffutfchar(L, i, buff, &charstr); int l = buffutfchar(L, i, buff, &charstr);
luaL_addlstring(&b, charstr, l, -1); luaL_addlstring(&b, charstr, l);
} }
luaL_pushresult(&b); luaL_pushresult(&b);
} }

View File

@ -1129,6 +1129,7 @@ reentry:
case LUA_TSTRING: case LUA_TSTRING:
case LUA_TFUNCTION: case LUA_TFUNCTION:
case LUA_TTHREAD: case LUA_TTHREAD:
case LUA_TBUFFER:
pc += gcvalue(ra) == gcvalue(rb) ? LUAU_INSN_D(insn) : 1; pc += gcvalue(ra) == gcvalue(rb) ? LUAU_INSN_D(insn) : 1;
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
VM_NEXT(); VM_NEXT();
@ -1243,6 +1244,7 @@ reentry:
case LUA_TSTRING: case LUA_TSTRING:
case LUA_TFUNCTION: case LUA_TFUNCTION:
case LUA_TTHREAD: case LUA_TTHREAD:
case LUA_TBUFFER:
pc += gcvalue(ra) != gcvalue(rb) ? LUAU_INSN_D(insn) : 1; pc += gcvalue(ra) != gcvalue(rb) ? LUAU_INSN_D(insn) : 1;
LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode));
VM_NEXT(); VM_NEXT();

View File

@ -363,6 +363,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message)
if (luau_load(globalState, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) if (luau_load(globalState, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0)
{ {
Luau::CodeGen::AssemblyOptions options; Luau::CodeGen::AssemblyOptions options;
options.flags = Luau::CodeGen::CodeGen_ColdFunctions;
options.outputBinary = true; options.outputBinary = true;
options.target = kFuzzCodegenTarget; options.target = kFuzzCodegenTarget;
Luau::CodeGen::getAssembly(globalState, -1, options); Luau::CodeGen::getAssembly(globalState, -1, options);
@ -384,7 +385,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message)
if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0)
{ {
if (useCodegen) if (useCodegen)
Luau::CodeGen::compile(L, -1); Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions);
interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout; interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout;

167
stats/compiler-stats.py Normal file
View File

@ -0,0 +1,167 @@
#!/usr/bin/python3
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
import argparse
import json
from collections import Counter
import pandas as pd
## needed for 'to_markdown' method for pandas data frame
import tabulate
def getArgs():
parser = argparse.ArgumentParser(description='Analyze compiler statistics')
parser.add_argument('--bytecode-bin-factor', dest='bytecodeBinFactor',default=10,help='Bytecode bin size as a multiple of 1000 (10 by default)')
parser.add_argument('--block-bin-factor', dest='blockBinFactor',default=1,help='Block bin size as a multiple of 1000 (1 by default)')
parser.add_argument('--block-instruction-bin-factor', dest='blockInstructionBinFactor',default=1,help='Block bin size as a multiple of 1000 (1 by default)')
parser.add_argument('statsFile', help='stats.json file generated by running luau-compile')
args = parser.parse_args()
return args
def readStats(statsFile):
with open(statsFile) as f:
stats = json.load(f)
scripts = []
functionCounts = []
bytecodeLengths = []
blockPreOptCounts = []
blockPostOptCounts = []
maxBlockInstructionCounts = []
for path, fileStat in stats.items():
scripts.append(path)
functionCounts.append(fileStat['lowerStats']['totalFunctions'] - fileStat['lowerStats']['skippedFunctions'])
bytecodeLengths.append(fileStat['bytecode'])
blockPreOptCounts.append(fileStat['lowerStats']['blocksPreOpt'])
blockPostOptCounts.append(fileStat['lowerStats']['blocksPostOpt'])
maxBlockInstructionCounts.append(fileStat['lowerStats']['maxBlockInstructions'])
stats_df = pd.DataFrame({
'Script': scripts,
'FunctionCount': functionCounts,
'BytecodeLength': bytecodeLengths,
'BlockPreOptCount': blockPreOptCounts,
'BlockPostOptCount': blockPostOptCounts,
'MaxBlockInstructionCount': maxBlockInstructionCounts
})
return stats_df
def analyzeBytecodeStats(stats_df, config):
binFactor = config.bytecodeBinFactor
divisor = binFactor * 1000
totalScriptCount = len(stats_df.index)
lengthLabels = []
scriptCounts = []
scriptPercs = []
counter = Counter()
for index, row in stats_df.iterrows():
value = row['BytecodeLength']
factor = int(value / divisor)
counter[factor] += 1
for factor, scriptCount in sorted(counter.items()):
left = factor * binFactor
right = left + binFactor
lengthLabel = '{left}K-{right}K'.format(left=left, right=right)
lengthLabels.append(lengthLabel)
scriptCounts.append(scriptCount)
scriptPerc = round(scriptCount * 100 / totalScriptCount, 1)
scriptPercs.append(scriptPerc)
bcode_df = pd.DataFrame({
'BytecodeLength': lengthLabels,
'ScriptCount': scriptCounts,
'ScriptPerc': scriptPercs
})
return bcode_df
def analyzeBlockStats(stats_df, config, field):
binFactor = config.blockBinFactor
divisor = binFactor * 1000
totalScriptCount = len(stats_df.index)
blockLabels = []
scriptCounts = []
scriptPercs = []
counter = Counter()
for index, row in stats_df.iterrows():
value = row[field]
factor = int(value / divisor)
counter[factor] += 1
for factor, scriptCount in sorted(counter.items()):
left = factor * binFactor
right = left + binFactor
blockLabel = '{left}K-{right}K'.format(left=left, right=right)
blockLabels.append(blockLabel)
scriptCounts.append(scriptCount)
scriptPerc = round((scriptCount * 100) / totalScriptCount, 1)
scriptPercs.append(scriptPerc)
block_df = pd.DataFrame({
field: blockLabels,
'ScriptCount': scriptCounts,
'ScriptPerc': scriptPercs
})
return block_df
def analyzeMaxBlockInstructionStats(stats_df, config):
binFactor = config.blockInstructionBinFactor
divisor = binFactor * 1000
totalScriptCount = len(stats_df.index)
blockLabels = []
scriptCounts = []
scriptPercs = []
counter = Counter()
for index, row in stats_df.iterrows():
value = row['MaxBlockInstructionCount']
factor = int(value / divisor)
counter[factor] += 1
for factor, scriptCount in sorted(counter.items()):
left = factor * binFactor
right = left + binFactor
blockLabel = '{left}K-{right}K'.format(left=left, right=right)
blockLabels.append(blockLabel)
scriptCounts.append(scriptCount)
scriptPerc = round((scriptCount * 100) / totalScriptCount, 1)
scriptPercs.append(scriptPerc)
block_df = pd.DataFrame({
'MaxBlockInstructionCount': blockLabels,
'ScriptCount': scriptCounts,
'ScriptPerc': scriptPercs
})
return block_df
if __name__ == '__main__':
config = getArgs()
stats_df = readStats(config.statsFile)
bcode_df = analyzeBytecodeStats(stats_df, config)
print(bcode_df.to_markdown())
block_df = analyzeBlockStats(stats_df, config, 'BlockPreOptCount')
print(block_df.to_markdown())
block_df = analyzeBlockStats(stats_df, config, 'BlockPostOptCount')
print(block_df.to_markdown())
block_df = analyzeMaxBlockInstructionStats(stats_df, config)
print(block_df.to_markdown())

View File

@ -76,6 +76,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Unary")
SINGLE_COMPARE(clz(w0, w1), 0x5AC01020); SINGLE_COMPARE(clz(w0, w1), 0x5AC01020);
SINGLE_COMPARE(rbit(x0, x1), 0xDAC00020); SINGLE_COMPARE(rbit(x0, x1), 0xDAC00020);
SINGLE_COMPARE(rbit(w0, w1), 0x5AC00020); SINGLE_COMPARE(rbit(w0, w1), 0x5AC00020);
SINGLE_COMPARE(rev(w0, w1), 0x5AC00820);
SINGLE_COMPARE(rev(x0, x1), 0xDAC00C20);
} }
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary")

View File

@ -548,6 +548,10 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions")
SINGLE_COMPARE(ud2(), 0x0f, 0x0b); SINGLE_COMPARE(ud2(), 0x0f, 0x0b);
SINGLE_COMPARE(bsr(eax, edx), 0x0f, 0xbd, 0xc2); SINGLE_COMPARE(bsr(eax, edx), 0x0f, 0xbd, 0xc2);
SINGLE_COMPARE(bsf(eax, edx), 0x0f, 0xbc, 0xc2); SINGLE_COMPARE(bsf(eax, edx), 0x0f, 0xbc, 0xc2);
SINGLE_COMPARE(bswap(eax), 0x0f, 0xc8);
SINGLE_COMPARE(bswap(r12d), 0x41, 0x0f, 0xcc);
SINGLE_COMPARE(bswap(rax), 0x48, 0x0f, 0xc8);
SINGLE_COMPARE(bswap(r12), 0x49, 0x0f, 0xcc);
} }
TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelLea") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelLea")

View File

@ -83,13 +83,13 @@ ClassFixture::ClassFixture()
TypeId vector2MetaType = arena.addType(TableType{}); TypeId vector2MetaType = arena.addType(TableType{});
TypeId vector2InstanceType = arena.addType(ClassType{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); vector2InstanceType = arena.addType(ClassType{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"});
getMutable<ClassType>(vector2InstanceType)->props = { getMutable<ClassType>(vector2InstanceType)->props = {
{"X", {numberType}}, {"X", {numberType}},
{"Y", {numberType}}, {"Y", {numberType}},
}; };
TypeId vector2Type = arena.addType(ClassType{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); vector2Type = arena.addType(ClassType{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"});
getMutable<ClassType>(vector2Type)->props = { getMutable<ClassType>(vector2Type)->props = {
{"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}},
}; };

Some files were not shown because too many files have changed in this diff Show More