Sync to upstream/release/651 (#1513)

### What's New?

* Fragment Autocomplete: a new API allows for type checking a small
fragment of code against an existing file, significantly speeding up
autocomplete performance in large files.

### New Solver

* E-Graphs have landed: this is an ongoing approach to make the new type
solver simplify types in a more consistent and principled manner, based
on similar work (see: https://egraphs-good.github.io/).
* Adds support for exporting / local user type functions (previously
they were always exported).
* Fixes a set of bugs in which the new solver will fail to complete
inference for simple expressions with just literals and operators.

### General Updates
* Requiring a path with a ".lua" or ".luau" extension will now have a
bespoke error suggesting to remove said extension.
* Fixes a bug in which whether two `Luau::Symbol`s are equal depends on
whether the new solver is enabled.

---

Internal Contributors:

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: David Cope <dcope@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
This commit is contained in:
Hunter Goldstein 2024-11-08 13:41:45 -08:00 committed by GitHub
parent 26b2307a8b
commit a36a3c41cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
62 changed files with 7318 additions and 2459 deletions

View File

@ -1,10 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/AutocompleteTypes.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include <unordered_map>
#include <string> #include <string>
#include <memory> #include <memory>
#include <optional> #include <optional>
@ -16,90 +16,8 @@ struct Frontend;
struct SourceModule; struct SourceModule;
struct Module; struct Module;
struct TypeChecker; struct TypeChecker;
struct FileResolver;
using ModulePtr = std::shared_ptr<Module>;
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind
{
Property,
Binding,
Keyword,
String,
Type,
Module,
GeneratedFunction,
RequirePath,
};
enum class ParenthesesRecommendation
{
None,
CursorAfter,
CursorInside,
};
enum class TypeCorrectKind
{
None,
Correct,
CorrectFunctionResult,
};
struct AutocompleteEntry
{
AutocompleteEntryKind kind = AutocompleteEntryKind::Property;
// Nullopt if kind is Keyword
std::optional<TypeId> type = std::nullopt;
bool deprecated = false;
// Only meaningful if kind is Property.
bool wrongIndexType = false;
// Set if this suggestion matches the type expected in the context
TypeCorrectKind typeCorrect = TypeCorrectKind::None;
std::optional<const ClassType*> containingClass = std::nullopt;
std::optional<const Property*> prop = std::nullopt;
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
struct AutocompleteResult
{
AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap))
, ancestry(std::move(ancestry))
, context(context)
{
}
};
using ModuleName = std::string;
using StringCompletionCallback =
std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassType*> ctx, std::optional<std::string> contents)>;
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback);
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau } // namespace Luau

View File

@ -0,0 +1,92 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Type.h"
#include <unordered_map>
namespace Luau
{
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind
{
Property,
Binding,
Keyword,
String,
Type,
Module,
GeneratedFunction,
RequirePath,
};
enum class ParenthesesRecommendation
{
None,
CursorAfter,
CursorInside,
};
enum class TypeCorrectKind
{
None,
Correct,
CorrectFunctionResult,
};
struct AutocompleteEntry
{
AutocompleteEntryKind kind = AutocompleteEntryKind::Property;
// Nullopt if kind is Keyword
std::optional<TypeId> type = std::nullopt;
bool deprecated = false;
// Only meaningful if kind is Property.
bool wrongIndexType = false;
// Set if this suggestion matches the type expected in the context
TypeCorrectKind typeCorrect = TypeCorrectKind::None;
std::optional<const ClassType*> containingClass = std::nullopt;
std::optional<const Property*> prop = std::nullopt;
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
struct AutocompleteResult
{
AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap))
, ancestry(std::move(ancestry))
, context(context)
{
}
};
using StringCompletionCallback =
std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassType*> ctx, std::optional<std::string> contents)>;
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau

View File

@ -5,6 +5,7 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/ControlFlow.h" #include "Luau/ControlFlow.h"
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/InsertionOrderedMap.h" #include "Luau/InsertionOrderedMap.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
@ -15,7 +16,6 @@
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/Normalize.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -109,6 +109,9 @@ struct ConstraintGenerator
// Needed to be able to enable error-suppression preservation for immediate refinements. // Needed to be able to enable error-suppression preservation for immediate refinements.
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
// Needed to register all available type functions for execution at later stages. // Needed to register all available type functions for execution at later stages.
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// Needed to resolve modules to make 'require' import types properly. // Needed to resolve modules to make 'require' import types properly.
@ -128,6 +131,7 @@ struct ConstraintGenerator
ConstraintGenerator( ConstraintGenerator(
ModulePtr module, ModulePtr module,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
@ -405,6 +409,7 @@ private:
TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
// make an intersect type function of these two types // make an intersect type function of these two types
TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
void prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program);
/** Scan the program for global definitions. /** Scan the program for global definitions.
* *
@ -435,6 +440,8 @@ private:
const ScopePtr& scope, const ScopePtr& scope,
Location location Location location
); );
TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right);
}; };
/** Borrow a vector of pointers from a vector of owning pointers to constraints. /** Borrow a vector of pointers from a vector of owning pointers to constraints.

View File

@ -5,6 +5,7 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Module.h" #include "Luau/Module.h"
@ -64,6 +65,7 @@ struct ConstraintSolver
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter iceReporter; InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer; NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// The entire set of constraints that the solver is trying to resolve. // The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints; std::vector<NotNull<Constraint>> constraints;
@ -117,6 +119,7 @@ struct ConstraintSolver
explicit ConstraintSolver( explicit ConstraintSolver(
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope, NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints, std::vector<NotNull<Constraint>> constraints,
@ -384,6 +387,10 @@ public:
**/ **/
void reproduceConstraints(NotNull<Scope> scope, const Location& location, const Substitution& subst); void reproduceConstraints(NotNull<Scope> scope, const Location& location, const Substitution& subst);
TypeId simplifyIntersection(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
TypeId simplifyIntersection(NotNull<Scope> scope, Location location, std::set<TypeId> parts);
TypeId simplifyUnion(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
TypeId errorRecoveryType() const; TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const; TypePackId errorRecoveryTypePack() const;

View File

@ -0,0 +1,50 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeFwd.h"
#include "Luau/NotNull.h"
#include "Luau/DenseHash.h"
#include <memory>
#include <optional>
#include <vector>
namespace Luau
{
struct TypeArena;
}
// The EqSat stuff is pretty template heavy, so we go to some lengths to prevent
// the complexity from leaking outside its implementation sources.
namespace Luau::EqSatSimplification
{
struct Simplifier;
using SimplifierPtr = std::unique_ptr<Simplifier, void (*)(Simplifier*)>;
SimplifierPtr newSimplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
} // namespace Luau::EqSatSimplification
namespace Luau
{
struct EqSatSimplificationResult
{
TypeId result;
// New type function applications that were created by the reduction phase.
// We return these so that the ConstraintSolver can know to try to reduce
// them.
std::vector<TypeId> newTypeFunctions;
};
using EqSatSimplification::newSimplifier; // NOLINT: clang-tidy thinks these are unused. It is incorrect.
using Luau::EqSatSimplification::Simplifier; // NOLINT
using Luau::EqSatSimplification::SimplifierPtr;
std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simplifier, TypeId ty);
} // namespace Luau

View File

@ -0,0 +1,363 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/EGraph.h"
#include "Luau/Id.h"
#include "Luau/Language.h"
#include "Luau/Lexer.h" // For Allocator
#include "Luau/NotNull.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
struct TypeFunction;
}
namespace Luau::EqSatSimplification
{
using StringId = uint32_t;
using Id = Luau::EqSat::Id;
LUAU_EQSAT_UNIT(TNil);
LUAU_EQSAT_UNIT(TBoolean);
LUAU_EQSAT_UNIT(TNumber);
LUAU_EQSAT_UNIT(TString);
LUAU_EQSAT_UNIT(TThread);
LUAU_EQSAT_UNIT(TTopFunction);
LUAU_EQSAT_UNIT(TTopTable);
LUAU_EQSAT_UNIT(TTopClass);
LUAU_EQSAT_UNIT(TBuffer);
// Used for any type that eqsat can't do anything interesting with.
LUAU_EQSAT_ATOM(TOpaque, TypeId);
LUAU_EQSAT_ATOM(SBoolean, bool);
LUAU_EQSAT_ATOM(SString, StringId);
LUAU_EQSAT_ATOM(TFunction, TypeId);
LUAU_EQSAT_ATOM(TImportedTable, TypeId);
LUAU_EQSAT_ATOM(TClass, TypeId);
LUAU_EQSAT_UNIT(TAny);
LUAU_EQSAT_UNIT(TError);
LUAU_EQSAT_UNIT(TUnknown);
LUAU_EQSAT_UNIT(TNever);
LUAU_EQSAT_NODE_SET(Union);
LUAU_EQSAT_NODE_SET(Intersection);
LUAU_EQSAT_NODE_ARRAY(Negation, 1);
LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*);
LUAU_EQSAT_UNIT(TNoRefine);
LUAU_EQSAT_UNIT(Invalid);
// enodes are immutable, but types are cyclic. We need a way to tie the knot.
// We handle this by generating TBound nodes at points where we encounter cycles.
// Each TBound has an ordinal that we later map onto the type.
// We use a substitution rule to replace all TBound nodes with their referrent.
LUAU_EQSAT_ATOM(TBound, size_t);
// Tables are sufficiently unlike other enodes that the Language.h macros won't cut it.
struct TTable
{
explicit TTable(Id basis);
TTable(Id basis, std::vector<StringId> propNames_, std::vector<Id> propTypes_);
// All TTables extend some other table. This may be TTopTable.
//
// It will frequently be a TImportedTable, in which case we can reuse things
// like source location and documentation info.
Id getBasis() const;
EqSat::Slice<const Id> propTypes() const;
// TODO: Also support read-only table props
// TODO: Indexer type, index result type.
std::vector<StringId> propNames;
// The enode interface
EqSat::Slice<Id> mutableOperands();
EqSat::Slice<const Id> operands() const;
bool operator==(const TTable& rhs) const;
bool operator!=(const TTable& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const TTable& value) const;
};
private:
// The first element of this vector is the basis. Subsequent elements are
// property types. As we add other things like read-only properties and
// indexers, the structure of this array is likely to change.
//
// We encode our data in this way so that the operands() method can properly
// return a Slice<Id>.
std::vector<Id> storage;
};
using EType = EqSat::Language<
TNil,
TBoolean,
TNumber,
TString,
TThread,
TTopFunction,
TTopTable,
TTopClass,
TBuffer,
TOpaque,
SBoolean,
SString,
TFunction,
TTable,
TImportedTable,
TClass,
TAny,
TError,
TUnknown,
TNever,
Union,
Intersection,
Negation,
TTypeFun,
Invalid,
TNoRefine,
TBound>;
struct StringCache
{
Allocator allocator;
DenseHashMap<size_t, StringId> strings{{}};
std::vector<std::string_view> views;
StringId add(std::string_view s);
std::string_view asStringView(StringId id) const;
std::string asString(StringId id) const;
};
using EGraph = Luau::EqSat::EGraph<EType, struct Simplify>;
struct Simplify
{
using Data = bool;
template<typename T>
Data make(const EGraph&, const T&) const;
void join(Data& left, const Data& right) const;
};
struct Subst
{
Id eclass;
Id newClass;
std::string desc;
Subst(Id eclass, Id newClass, std::string desc = "");
};
struct Simplifier
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
EGraph egraph;
StringCache stringCache;
// enodes are immutable but types can be cyclic, so we need some way to
// encode the cycle. This map is used to connect TBound nodes to the right
// eclass.
//
// The cyclicIntersection rewrite rule uses this to sense when a cycle can
// be deleted from an intersection or union.
std::unordered_map<size_t, Id> mappingIdToClass;
std::vector<Subst> substs;
using RewriteRuleFn = void (Simplifier::*)(Id id);
Simplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
// Utilities
const EqSat::EClass<EType, Simplify::Data>& get(Id id) const;
Id find(Id id) const;
Id add(EType enode);
template<typename Tag>
const Tag* isTag(Id id) const;
template<typename Tag>
const Tag* isTag(const EType& enode) const;
void subst(Id from, Id to);
void subst(Id from, Id to, const std::string& ruleName);
void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void unionClasses(std::vector<Id>& hereParts, Id there);
// Rewrite rules
void simplifyUnion(Id id);
void uninhabitedIntersection(Id id);
void intersectWithNegatedClass(Id id);
void intersectWithNoRefine(Id id);
void cyclicIntersectionOfUnion(Id id);
void cyclicUnionOfIntersection(Id id);
void expandNegation(Id id);
void intersectionOfUnion(Id id);
void intersectTableProperty(Id id);
void uninhabitedTable(Id id);
void unneededTableModification(Id id);
void builtinTypeFunctions(Id id);
void iffyTypeFunctions(Id id);
};
template<typename Tag>
struct QueryIterator
{
QueryIterator();
QueryIterator(EGraph* egraph, Id eclass);
bool operator==(const QueryIterator& other) const;
bool operator!=(const QueryIterator& other) const;
std::pair<const Tag*, size_t> operator*() const;
QueryIterator& operator++();
QueryIterator& operator++(int);
private:
EGraph* egraph = nullptr;
Id eclass;
size_t index = 0;
};
template<typename Tag>
struct Query
{
EGraph* egraph;
Id eclass;
Query(EGraph* egraph, Id eclass)
: egraph(egraph)
, eclass(eclass)
{
}
QueryIterator<Tag> begin()
{
return QueryIterator<Tag>{egraph, eclass};
}
QueryIterator<Tag> end()
{
return QueryIterator<Tag>{};
}
};
template<typename Tag>
QueryIterator<Tag>::QueryIterator()
: egraph(nullptr)
, eclass(Id{0})
, index(0)
{
}
template<typename Tag>
QueryIterator<Tag>::QueryIterator(EGraph* egraph_, Id eclass)
: egraph(egraph_)
, eclass(eclass)
, index(0)
{
const auto& ecl = (*egraph)[eclass];
static constexpr const int idx = EType::VariantTy::getTypeId<Tag>();
for (const auto& enode : ecl.nodes)
{
if (enode.index() < idx)
++index;
else
break;
}
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx)
{
egraph = nullptr;
index = 0;
}
}
template<typename Tag>
bool QueryIterator<Tag>::operator==(const QueryIterator<Tag>& rhs) const
{
if (egraph == nullptr && rhs.egraph == nullptr)
return true;
return egraph == rhs.egraph && eclass == rhs.eclass && index == rhs.index;
}
template<typename Tag>
bool QueryIterator<Tag>::operator!=(const QueryIterator<Tag>& rhs) const
{
return !(*this == rhs);
}
template<typename Tag>
std::pair<const Tag*, size_t> QueryIterator<Tag>::operator*() const
{
LUAU_ASSERT(egraph != nullptr);
EGraph::EClassT& ecl = (*egraph)[eclass];
LUAU_ASSERT(index < ecl.nodes.size());
auto& enode = ecl.nodes[index];
Tag* result = enode.template get<Tag>();
LUAU_ASSERT(result);
return {result, index};
}
// pre-increment
template<typename Tag>
QueryIterator<Tag>& QueryIterator<Tag>::operator++()
{
const auto& ecl = (*egraph)[eclass];
++index;
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId<Tag>())
{
egraph = nullptr;
index = 0;
}
return *this;
}
// post-increment
template<typename Tag>
QueryIterator<Tag>& QueryIterator<Tag>::operator++(int)
{
QueryIterator<Tag> res = *this;
++res;
return res;
}
} // namespace Luau::EqSatSimplification

View File

@ -32,7 +32,11 @@ struct ModuleInfo
bool optional = false; bool optional = false;
}; };
using RequireSuggestion = std::string; struct RequireSuggestion
{
std::string label;
std::string fullPath;
};
using RequireSuggestions = std::vector<RequireSuggestion>; using RequireSuggestions = std::vector<RequireSuggestion>;
struct FileResolver struct FileResolver

View File

@ -3,9 +3,10 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
#include "Luau/Autocomplete.h" #include "Luau/AutocompleteTypes.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/Frontend.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -27,13 +28,23 @@ struct FragmentParseResult
std::string fragmentToParse; std::string fragmentToParse;
AstStatBlock* root = nullptr; AstStatBlock* root = nullptr;
std::vector<AstNode*> ancestry; std::vector<AstNode*> ancestry;
AstStat* nearestStatement = nullptr;
std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>(); std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>();
}; };
struct FragmentTypeCheckResult struct FragmentTypeCheckResult
{ {
ModulePtr incrementalModule = nullptr; ModulePtr incrementalModule = nullptr;
Scope* freshScope = nullptr; ScopePtr freshScope;
std::vector<AstNode*> ancestry;
};
struct FragmentAutocompleteResult
{
ModulePtr incrementalModule;
Scope* freshScope;
TypeArena arenaForAutocomplete;
AutocompleteResult acResults;
}; };
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
@ -48,11 +59,11 @@ FragmentTypeCheckResult typecheckFragment(
std::string_view src std::string_view src
); );
AutocompleteResult fragmentAutocomplete( FragmentAutocompleteResult fragmentAutocomplete(
Frontend& frontend, Frontend& frontend,
std::string_view src, std::string_view src,
const ModuleName& moduleName, const ModuleName& moduleName,
Position& cursorPosition, Position cursorPosition,
std::optional<FrontendOptions> opts, std::optional<FrontendOptions> opts,
StringCompletionCallback callback StringCompletionCallback callback
); );

View File

@ -44,6 +44,7 @@ struct ToStringOptions
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self
bool useQuestionMarks = true; // If true, use a postfix ? for options, else write them out as unions that include nil.
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections

View File

@ -31,6 +31,7 @@ namespace Luau
struct TypeArena; struct TypeArena;
struct Scope; struct Scope;
using ScopePtr = std::shared_ptr<Scope>; using ScopePtr = std::shared_ptr<Scope>;
struct Module;
struct TypeFunction; struct TypeFunction;
struct Constraint; struct Constraint;
@ -598,6 +599,18 @@ struct ClassType
} }
}; };
// Data required to initialize a user-defined function and its environment
struct UserDefinedFunctionData
{
// Store a weak module reference to ensure the lifetime requirements are preserved
std::weak_ptr<Module> owner;
// References to AST elements are owned by the Module allocator which also stores this type
AstStatTypeFunction* definition = nullptr;
DenseHashMap<Name, AstStatTypeFunction*> environment{""};
};
/** /**
* An instance of a type function that has not yet been reduced to a more concrete * An instance of a type function that has not yet been reduced to a more concrete
* type. The constraint solver receives a constraint to reduce each * type. The constraint solver receives a constraint to reduce each
@ -613,17 +626,20 @@ struct TypeFunctionInstanceType
std::vector<TypePackId> packArguments; std::vector<TypePackId> packArguments;
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
UserDefinedFunctionData userFuncData;
TypeFunctionInstanceType( TypeFunctionInstanceType(
NotNull<const TypeFunction> function, NotNull<const TypeFunction> function,
std::vector<TypeId> typeArguments, std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments, std::vector<TypePackId> packArguments,
std::optional<AstName> userFuncName = std::nullopt std::optional<AstName> userFuncName,
UserDefinedFunctionData userFuncData
) )
: function(function) : function(function)
, typeArguments(typeArguments) , typeArguments(typeArguments)
, packArguments(packArguments) , packArguments(packArguments)
, userFuncName(userFuncName) , userFuncName(userFuncName)
, userFuncData(userFuncData)
{ {
} }
@ -640,6 +656,13 @@ struct TypeFunctionInstanceType
, packArguments(packArguments) , packArguments(packArguments)
{ {
} }
TypeFunctionInstanceType(NotNull<const TypeFunction> function, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
: function{function}
, typeArguments(typeArguments)
, packArguments(packArguments)
{
}
}; };
/** Represents a pending type alias instantiation. /** Represents a pending type alias instantiation.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,27 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/AutocompleteTypes.h"
namespace Luau
{
struct Module;
struct FileResolver;
using ModulePtr = std::shared_ptr<Module>;
using ModuleName = std::string;
AutocompleteResult autocomplete_(
const ModulePtr& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
std::vector<AstNode*>& ancestry,
Scope* globalScope,
const ScopePtr& scopeAtPosition,
Position position,
FileResolver* fileResolver,
StringCompletionCallback callback
);
} // namespace Luau

View File

@ -33,7 +33,7 @@ LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2) LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2)
LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix) LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix)
LUAU_FASTFLAG(AutocompleteRequirePathSuggestions) LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2);
namespace Luau namespace Luau
{ {
@ -426,7 +426,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze); attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze);
} }
if (FFlag::AutocompleteRequirePathSuggestions) if (FFlag::AutocompleteRequirePathSuggestions2)
{ {
TypeId requireTy = getGlobalBinding(globals, "require"); TypeId requireTy = getGlobalBinding(globals, "require");
attachTag(requireTy, kRequireTagName); attachTag(requireTy, kRequireTagName);

View File

@ -3,6 +3,8 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_FASTFLAGVARIABLE(LuauDontRefCountTypesInTypeFunctions)
namespace Luau namespace Luau
{ {
@ -46,6 +48,21 @@ struct ReferenceCountInitializer : TypeOnceVisitor
// ClassTypes never contain free types. // ClassTypes never contain free types.
return false; return false;
} }
bool visit(TypeId, const TypeFunctionInstanceType&) override
{
// We do not consider reference counted types that are inside a type
// function to be part of the reachable reference counted types.
// Otherwise, code can be constructed in just the right way such
// that two type functions both claim to mutate a free type, which
// prevents either type function from trying to generalize it, so
// we potentially get stuck.
//
// The default behavior here is `true` for "visit the child types"
// of this type, hence:
return !FFlag::LuauDontRefCountTypesInTypeFunctions;
}
}; };
bool isReferenceCountedType(const TypeId typ) bool isReferenceCountedType(const TypeId typ)

View File

@ -10,6 +10,7 @@
#include "Luau/Def.h" #include "Luau/Def.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/NotNull.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/Refinement.h" #include "Luau/Refinement.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
@ -30,11 +31,13 @@
LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTINT(LuauCheckRecursionLimit)
LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauLogSolverToJson)
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(DebugLuauEqSatSimplification)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAG(LuauTypestateBuiltins2)
LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues)
LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses) LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunExportedAndLocal)
LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations)
namespace Luau namespace Luau
@ -172,6 +175,7 @@ bool hasFreeType(TypeId ty)
ConstraintGenerator::ConstraintGenerator( ConstraintGenerator::ConstraintGenerator(
ModulePtr module, ModulePtr module,
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
@ -188,6 +192,7 @@ ConstraintGenerator::ConstraintGenerator(
, rootScope(nullptr) , rootScope(nullptr)
, dfg(dfg) , dfg(dfg)
, normalizer(normalizer) , normalizer(normalizer)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime) , typeFunctionRuntime(typeFunctionRuntime)
, moduleResolver(moduleResolver) , moduleResolver(moduleResolver)
, ice(ice) , ice(ice)
@ -257,7 +262,7 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
d = follow(d); d = follow(d);
if (d == ty) if (d == ty)
continue; continue;
domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; domainTy = simplifyUnion(scope, Location{}, domainTy, d);
} }
LUAU_ASSERT(get<BlockedType>(ty)); LUAU_ASSERT(get<BlockedType>(ty));
@ -267,7 +272,15 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block) void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block)
{ {
// We prepopulate global data in the resumeScope to avoid writing data into the old modules scopes
prepopulateGlobalScopeForFragmentTypecheck(globalScope, resumeScope, block);
// Pre
// We need to pop the interior types,
interiorTypes.emplace_back();
visitBlockWithoutChildScope(resumeScope, block); visitBlockWithoutChildScope(resumeScope, block);
// Post
interiorTypes.pop_back();
fillInInferredBindings(resumeScope, block); fillInInferredBindings(resumeScope, block);
if (logger) if (logger)
@ -282,7 +295,7 @@ void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStat
d = follow(d); d = follow(d);
if (d == ty) if (d == ty)
continue; continue;
domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; domainTy = simplifyUnion(resumeScope, resumeScope->location, domainTy, d);
} }
LUAU_ASSERT(get<BlockedType>(ty)); LUAU_ASSERT(get<BlockedType>(ty));
@ -711,7 +724,7 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
continue; continue;
} }
if (scope->parent != globalScope) if (!FFlag::LuauUserTypeFunExportedAndLocal && scope->parent != globalScope)
{ {
reportError(function->location, GenericError{"Local user-defined functions are not supported yet"}); reportError(function->location, GenericError{"Local user-defined functions are not supported yet"});
continue; continue;
@ -740,17 +753,26 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function)) if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function))
reportError(function->location, GenericError{*error}); reportError(function->location, GenericError{*error});
TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ UserDefinedFunctionData udtfData;
NotNull{&builtinTypeFunctions().userFunc},
std::move(typeParams), if (FFlag::LuauUserTypeFunExportedAndLocal)
{}, {
function->name, udtfData.owner = module;
}); udtfData.definition = function;
}
TypeId typeFunctionTy = arena->addType(
TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, udtfData}
);
TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy};
// Set type bindings and definition locations for this user-defined type function // Set type bindings and definition locations for this user-defined type function
scope->privateTypeBindings[function->name.value] = std::move(typeFunction); if (FFlag::LuauUserTypeFunExportedAndLocal && function->exported)
scope->exportedTypeBindings[function->name.value] = std::move(typeFunction);
else
scope->privateTypeBindings[function->name.value] = std::move(typeFunction);
aliasDefinitionLocations[function->name.value] = function->location; aliasDefinitionLocations[function->name.value] = function->location;
} }
else if (auto classDeclaration = stat->as<AstStatDeclareClass>()) else if (auto classDeclaration = stat->as<AstStatDeclareClass>())
@ -780,6 +802,55 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
classDefinitionLocations[classDeclaration->name.value] = classDeclaration->location; classDefinitionLocations[classDeclaration->name.value] = classDeclaration->location;
} }
} }
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Additional pass for user-defined type functions to fill in their environments completely
for (AstStat* stat : block->body)
{
if (auto function = stat->as<AstStatTypeFunction>())
{
// Find the type function we have already created
TypeFunctionInstanceType* mainTypeFun = nullptr;
if (auto it = scope->privateTypeBindings.find(function->name.value); it != scope->privateTypeBindings.end())
mainTypeFun = getMutable<TypeFunctionInstanceType>(it->second.type);
if (!mainTypeFun)
{
if (auto it = scope->exportedTypeBindings.find(function->name.value); it != scope->exportedTypeBindings.end())
mainTypeFun = getMutable<TypeFunctionInstanceType>(it->second.type);
}
// Fill it with all visible type functions
if (mainTypeFun)
{
UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData;
for (Scope* curr = scope.get(); curr; curr = curr->parent.get())
{
for (auto& [name, tf] : curr->privateTypeBindings)
{
if (userFuncData.environment.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = ty->userFuncData.definition;
}
for (auto& [name, tf] : curr->exportedTypeBindings)
{
if (userFuncData.environment.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = ty->userFuncData.definition;
}
}
}
}
}
}
} }
ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block)
@ -900,12 +971,8 @@ ControlFlow ConstraintGenerator::visitBlockWithoutChildScope_DEPRECATED(const Sc
if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function)) if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function))
reportError(function->location, GenericError{*error}); reportError(function->location, GenericError{*error});
TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ TypeId typeFunctionTy =
NotNull{&builtinTypeFunctions().userFunc}, arena->addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, {}});
std::move(typeParams),
{},
function->name,
});
TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy};
@ -2807,7 +2874,7 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local
case ErrorSuppression::DoNotSuppress: case ErrorSuppression::DoNotSuppress:
break; break;
case ErrorSuppression::Suppress: case ErrorSuppression::Suppress:
ty = simplifyUnion(builtinTypes, arena, *ty, builtinTypes->errorType).result; ty = simplifyUnion(scope, local->location, *ty, builtinTypes->errorType);
break; break;
case ErrorSuppression::NormalizationFailed: case ErrorSuppression::NormalizationFailed:
reportError(local->local->annotation->location, NormalizationTooComplex{}); reportError(local->local->annotation->location, NormalizationTooComplex{});
@ -3673,6 +3740,32 @@ TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location locati
return resultType; return resultType;
} }
struct FragmentTypeCheckGlobalPrepopulator : AstVisitor
{
const NotNull<Scope> globalScope;
const NotNull<Scope> currentScope;
const NotNull<const DataFlowGraph> dfg;
FragmentTypeCheckGlobalPrepopulator(NotNull<Scope> globalScope, NotNull<Scope> currentScope, NotNull<const DataFlowGraph> dfg)
: globalScope(globalScope)
, currentScope(currentScope)
, dfg(dfg)
{
}
bool visit(AstExprGlobal* global) override
{
if (auto ty = globalScope->lookup(global->name))
{
DefId def = dfg->getDef(global);
// We only want to write into the current scope the type of the global
currentScope->lvalueTypes[def] = *ty;
}
return true;
}
};
struct GlobalPrepopulator : AstVisitor struct GlobalPrepopulator : AstVisitor
{ {
const NotNull<Scope> globalScope; const NotNull<Scope> globalScope;
@ -3719,6 +3812,14 @@ struct GlobalPrepopulator : AstVisitor
} }
}; };
void ConstraintGenerator::prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program)
{
FragmentTypeCheckGlobalPrepopulator gp{NotNull{globalScope.get()}, NotNull{resumeScope.get()}, dfg};
if (prepareModuleScope)
prepareModuleScope(module->name, resumeScope);
program->visit(&gp);
}
void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program)
{ {
GlobalPrepopulator gp{NotNull{globalScope.get()}, arena, dfg}; GlobalPrepopulator gp{NotNull{globalScope.get()}, arena, dfg};
@ -3870,6 +3971,24 @@ TypeId ConstraintGenerator::createTypeFunctionInstance(
return result; return result;
} }
TypeId ConstraintGenerator::simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(UnionType{{left, right}});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId tyFun : res->newTypeFunctions)
addConstraint(scope, location, ReduceConstraint{tyFun});
return res->result;
}
else
return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result;
}
std::vector<NotNull<Constraint>> borrowConstraints(const std::vector<ConstraintPtr>& constraints) std::vector<NotNull<Constraint>> borrowConstraints(const std::vector<ConstraintPtr>& constraints)
{ {
std::vector<NotNull<Constraint>> result; std::vector<NotNull<Constraint>> result;

View File

@ -33,6 +33,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings)
LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack) LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification)
LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations)
namespace Luau namespace Luau
@ -320,6 +321,7 @@ struct InstantiationQueuer : TypeOnceVisitor
ConstraintSolver::ConstraintSolver( ConstraintSolver::ConstraintSolver(
NotNull<Normalizer> normalizer, NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope, NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints, std::vector<NotNull<Constraint>> constraints,
@ -333,6 +335,7 @@ ConstraintSolver::ConstraintSolver(
: arena(normalizer->arena) : arena(normalizer->arena)
, builtinTypes(normalizer->builtinTypes) , builtinTypes(normalizer->builtinTypes)
, normalizer(normalizer) , normalizer(normalizer)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime) , typeFunctionRuntime(typeFunctionRuntime)
, constraints(std::move(constraints)) , constraints(std::move(constraints))
, rootScope(rootScope) , rootScope(rootScope)
@ -1802,7 +1805,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const
upperTable->props[c.propName] = rhsType; upperTable->props[c.propName] = rhsType;
// Food for thought: Could we block if simplification encounters a blocked type? // Food for thought: Could we block if simplification encounters a blocked type?
lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFreeUpperBound, newUpperBound).result; lhsFree->upperBound = simplifyIntersection(constraint->scope, constraint->location, lhsFreeUpperBound, newUpperBound);
bind(constraint, c.propType, rhsType); bind(constraint, c.propType, rhsType);
return true; return true;
@ -2016,7 +2019,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull<const
} }
} }
TypeId res = simplifyIntersection(builtinTypes, arena, std::move(parts)).result; TypeId res = simplifyIntersection(constraint->scope, constraint->location, std::move(parts));
unify(constraint, rhsType, res); unify(constraint, rhsType, res);
} }
@ -2596,9 +2599,9 @@ std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTa
// if we're in an lvalue context, we need the _common_ type here. // if we're in an lvalue context, we need the _common_ type here.
if (context == ValueContext::LValue) if (context == ValueContext::LValue)
return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)};
return {{}, simplifyUnion(builtinTypes, arena, one, two).result}; return {{}, simplifyUnion(constraint->scope, constraint->location, one, two)};
} }
// if we're in an lvalue context, we need the _common_ type here. // if we're in an lvalue context, we need the _common_ type here.
else if (context == ValueContext::LValue) else if (context == ValueContext::LValue)
@ -2630,7 +2633,7 @@ std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTa
{ {
TypeId one = *begin(options); TypeId one = *begin(options);
TypeId two = *(++begin(options)); TypeId two = *(++begin(options));
return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)};
} }
else else
return {{}, arena->addType(IntersectionType{std::vector<TypeId>(begin(options), end(options))})}; return {{}, arena->addType(IntersectionType{std::vector<TypeId>(begin(options), end(options))})};
@ -3019,6 +3022,63 @@ bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty)
return false; return false;
} }
TypeId ConstraintSolver::simplifyIntersection(NotNull<Scope> scope, Location location, TypeId left, TypeId right)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(IntersectionType{{left, right}});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId ty : res->newTypeFunctions)
pushConstraint(scope, location, ReduceConstraint{ty});
return res->result;
}
else
return ::Luau::simplifyIntersection(builtinTypes, arena, left, right).result;
}
TypeId ConstraintSolver::simplifyIntersection(NotNull<Scope> scope, Location location, std::set<TypeId> parts)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(IntersectionType{std::vector(parts.begin(), parts.end())});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId ty : res->newTypeFunctions)
pushConstraint(scope, location, ReduceConstraint{ty});
return res->result;
}
else
return ::Luau::simplifyIntersection(builtinTypes, arena, std::move(parts)).result;
}
TypeId ConstraintSolver::simplifyUnion(NotNull<Scope> scope, Location location, TypeId left, TypeId right)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(UnionType{{left, right}});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId ty : res->newTypeFunctions)
pushConstraint(scope, location, ReduceConstraint{ty});
return res->result;
}
else
return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result;
}
TypeId ConstraintSolver::errorRecoveryType() const TypeId ConstraintSolver::errorRecoveryType() const
{ {
return builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType();

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,7 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
#include "Luau/ParseOptions.h" #include "Luau/ParseOptions.h"
#include "Luau/Module.h" #include "Luau/Module.h"
@ -18,11 +19,14 @@
#include "Luau/ParseOptions.h" #include "Luau/ParseOptions.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "AutocompleteCore.h"
LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauAllowFragmentParsing);
LUAU_FASTFLAG(LuauStoreDFGOnModule2); LUAU_FASTFLAG(LuauStoreDFGOnModule2);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
namespace namespace
{ {
@ -41,7 +45,6 @@ void copyModuleMap(Luau::DenseHashMap<K, V>& result, const Luau::DenseHashMap<K,
} // namespace } // namespace
namespace Luau namespace Luau
{ {
@ -88,14 +91,22 @@ FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* ro
return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)}; return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)};
} }
std::pair<unsigned int, unsigned int> getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos) /**
* Get document offsets is a function that takes a source text document as well as a start position and end position(line, column) in that
* document and attempts to get the concrete text between those points. It returns a pair of:
* - start offset that represents an index in the source `char*` corresponding to startPos
* - length, that represents how many more bytes to read to get to endPos.
* Example - your document is "foo bar baz" and getDocumentOffsets is passed (1, 4) - (1, 8). This function returns the pair {3, 7},
* which corresponds to the string " bar "
*/
std::pair<size_t, size_t> getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos)
{ {
unsigned int lineCount = 0; size_t lineCount = 0;
unsigned int colCount = 0; size_t colCount = 0;
unsigned int docOffset = 0; size_t docOffset = 0;
unsigned int startOffset = 0; size_t startOffset = 0;
unsigned int endOffset = 0; size_t endOffset = 0;
bool foundStart = false; bool foundStart = false;
bool foundEnd = false; bool foundEnd = false;
for (char c : src) for (char c : src)
@ -115,6 +126,13 @@ std::pair<unsigned int, unsigned int> getDocumentOffsets(const std::string_view&
foundEnd = true; foundEnd = true;
} }
// We put a cursor position that extends beyond the extents of the current line
if (foundStart && !foundEnd && (lineCount > endPos.line))
{
foundEnd = true;
endOffset = docOffset - 1;
}
if (c == '\n') if (c == '\n')
{ {
lineCount++; lineCount++;
@ -125,20 +143,24 @@ std::pair<unsigned int, unsigned int> getDocumentOffsets(const std::string_view&
docOffset++; docOffset++;
} }
if (foundStart && !foundEnd)
endOffset = src.length();
unsigned int min = std::min(startOffset, endOffset); size_t min = std::min(startOffset, endOffset);
unsigned int len = std::max(startOffset, endOffset) - min; size_t len = std::max(startOffset, endOffset) - min;
return {min, len}; return {min, len};
} }
ScopePtr findClosestScope(const ModulePtr& module, const Position& cursorPos) ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStatement)
{ {
LUAU_ASSERT(module->hasModuleScope()); LUAU_ASSERT(module->hasModuleScope());
ScopePtr closest = module->getModuleScope(); ScopePtr closest = module->getModuleScope();
// find the scope the nearest statement belonged to.
for (auto [loc, sc] : module->scopes) for (auto [loc, sc] : module->scopes)
{ {
if (loc.begin <= cursorPos && closest->location.begin <= loc.begin) if (loc.encloses(nearestStatement->location) && closest->location.begin <= loc.begin)
closest = sc; closest = sc;
} }
@ -152,13 +174,27 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie
opts.allowDeclarationSyntax = false; opts.allowDeclarationSyntax = false;
opts.captureComments = false; opts.captureComments = false;
opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)}; opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)};
AstStat* enclosingStatement = result.nearestStatement; AstStat* nearestStatement = result.nearestStatement;
const Position& endPos = cursorPos; const Location& rootSpan = srcModule.root->location;
// If the statement starts on a previous line, grab the statement beginning // Did we append vs did we insert inline
// otherwise, grab the statement end to whatever is being typed right now bool appended = cursorPos >= rootSpan.end;
const Position& startPos = // statement spans multiple lines
enclosingStatement->location.begin.line == cursorPos.line ? enclosingStatement->location.begin : enclosingStatement->location.end; bool multiline = nearestStatement->location.begin.line != nearestStatement->location.end.line;
const Position endPos = cursorPos;
// We start by re-parsing everything (we'll refine this as we go)
Position startPos = srcModule.root->location.begin;
// If we added to the end of the sourceModule, use the end of the nearest location
if (appended && multiline)
startPos = nearestStatement->location.end;
// Statement spans one line && cursorPos is on a different line
else if (!multiline && cursorPos.line != nearestStatement->location.end.line)
startPos = nearestStatement->location.end;
else
startPos = nearestStatement->location.begin;
auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos); auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos);
@ -173,10 +209,11 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie
std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry); std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry);
std::vector<AstNode*> fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end); std::vector<AstNode*> fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end);
fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end()); fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end());
if (enclosingStatement == nullptr) if (nearestStatement == nullptr)
enclosingStatement = p.root; nearestStatement = p.root;
fragmentResult.root = std::move(p.root); fragmentResult.root = std::move(p.root);
fragmentResult.ancestry = std::move(fabricatedAncestry); fragmentResult.ancestry = std::move(fabricatedAncestry);
fragmentResult.nearestStatement = nearestStatement;
return fragmentResult; return fragmentResult;
} }
@ -205,7 +242,7 @@ ModulePtr copyModule(const ModulePtr& result, std::unique_ptr<Allocator> alloc)
return incrementalModule; return incrementalModule;
} }
FragmentTypeCheckResult typeCheckFragmentHelper( FragmentTypeCheckResult typecheckFragment_(
Frontend& frontend, Frontend& frontend,
AstStatBlock* root, AstStatBlock* root,
const ModulePtr& stale, const ModulePtr& stale,
@ -245,15 +282,18 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
/// Create a DataFlowGraph just for the surrounding context /// Create a DataFlowGraph just for the surrounding context
auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler); auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler);
SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes);
/// Contraint Generator /// Contraint Generator
ConstraintGenerator cg{ ConstraintGenerator cg{
incrementalModule, incrementalModule,
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
NotNull{&frontend.moduleResolver}, NotNull{&frontend.moduleResolver},
frontend.builtinTypes, frontend.builtinTypes,
iceHandler, iceHandler,
frontend.globals.globalScope, stale->getModuleScope(),
nullptr, nullptr,
nullptr, nullptr,
NotNull{&updatedDfg}, NotNull{&updatedDfg},
@ -262,7 +302,7 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
cg.rootScope = stale->getModuleScope().get(); cg.rootScope = stale->getModuleScope().get();
// Any additions to the scope must occur in a fresh scope // Any additions to the scope must occur in a fresh scope
auto freshChildOfNearestScope = std::make_shared<Scope>(closestScope); auto freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.push_back({root->location, freshChildOfNearestScope}); incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);
// closest Scope -> children = { ...., freshChildOfNearestScope} // closest Scope -> children = { ...., freshChildOfNearestScope}
// We need to trim nearestChild from the scope hierarcy // We need to trim nearestChild from the scope hierarcy
@ -274,9 +314,11 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
LUAU_ASSERT(back == freshChildOfNearestScope.get()); LUAU_ASSERT(back == freshChildOfNearestScope.get());
closestScope->children.pop_back(); closestScope->children.pop_back();
/// Initialize the constraint solver and run it /// Initialize the constraint solver and run it
ConstraintSolver cs{ ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope), NotNull(cg.rootScope),
borrowConstraints(cg.constraints), borrowConstraints(cg.constraints),
@ -307,7 +349,7 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
freeze(incrementalModule->internalTypes); freeze(incrementalModule->internalTypes);
freeze(incrementalModule->interfaceTypes); freeze(incrementalModule->interfaceTypes);
return {std::move(incrementalModule), freshChildOfNearestScope.get()}; return {std::move(incrementalModule), std::move(freshChildOfNearestScope)};
} }
@ -327,27 +369,51 @@ FragmentTypeCheckResult typecheckFragment(
} }
ModulePtr module = frontend.moduleResolver.getModule(moduleName); ModulePtr module = frontend.moduleResolver.getModule(moduleName);
const ScopePtr& closestScope = findClosestScope(module, cursorPos); FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos);
FragmentParseResult r = parseFragment(*sourceModule, src, cursorPos);
FrontendOptions frontendOptions = opts.value_or(frontend.options); FrontendOptions frontendOptions = opts.value_or(frontend.options);
return typeCheckFragmentHelper(frontend, r.root, module, closestScope, cursorPos, std::move(r.alloc), frontendOptions); const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement);
FragmentTypeCheckResult result =
typecheckFragment_(frontend, parseResult.root, module, closestScope, cursorPos, std::move(parseResult.alloc), frontendOptions);
result.ancestry = std::move(parseResult.ancestry);
return result;
} }
AutocompleteResult fragmentAutocomplete(
FragmentAutocompleteResult fragmentAutocomplete(
Frontend& frontend, Frontend& frontend,
std::string_view src, std::string_view src,
const ModuleName& moduleName, const ModuleName& moduleName,
Position& cursorPosition, Position cursorPosition,
const FrontendOptions& opts, std::optional<FrontendOptions> opts,
StringCompletionCallback callback StringCompletionCallback callback
) )
{ {
LUAU_ASSERT(FFlag::LuauSolverV2); LUAU_ASSERT(FFlag::LuauSolverV2);
LUAU_ASSERT(FFlag::LuauAllowFragmentParsing); LUAU_ASSERT(FFlag::LuauAllowFragmentParsing);
LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2); LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2);
return {}; LUAU_ASSERT(FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
{
LUAU_ASSERT(!"Expected Source Module for fragment typecheck");
return {};
}
auto tcResult = typecheckFragment(frontend, moduleName, cursorPosition, opts, src);
TypeArena arenaForFragmentAutocomplete;
auto result = Luau::autocomplete_(
tcResult.incrementalModule,
frontend.builtinTypes,
&arenaForFragmentAutocomplete,
tcResult.ancestry,
frontend.globals.globalScope.get(),
tcResult.freshScope,
cursorPosition,
frontend.fileResolver,
callback
);
return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForFragmentAutocomplete), std::move(result)};
} }
} // namespace Luau } // namespace Luau

View File

@ -10,6 +10,7 @@
#include "Luau/ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/FileResolver.h" #include "Luau/FileResolver.h"
#include "Luau/NonStrictTypeChecker.h" #include "Luau/NonStrictTypeChecker.h"
#include "Luau/Parser.h" #include "Luau/Parser.h"
@ -46,7 +47,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode)
LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode) LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionNoEvaluation) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionNoEvaluation)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)
LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection)
LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAG(StudioReportLuauAny2)
LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2) LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2)
@ -287,8 +287,7 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector<HotCom
std::vector<RequireCycle> getRequireCycles( std::vector<RequireCycle> getRequireCycles(
const FileResolver* resolver, const FileResolver* resolver,
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes, const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes,
const SourceNode* start, const SourceNode* start
bool stopAtFirst = false
) )
{ {
std::vector<RequireCycle> result; std::vector<RequireCycle> result;
@ -358,9 +357,6 @@ std::vector<RequireCycle> getRequireCycles(
{ {
result.push_back({depLocation, std::move(cycle)}); result.push_back({depLocation, std::move(cycle)});
if (stopAtFirst)
return result;
// note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start // note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start
// so it's safe to *only* clear seen vector when we find a cycle // so it's safe to *only* clear seen vector when we find a cycle
// if we don't do it, we will not have correct reporting for some cycles // if we don't do it, we will not have correct reporting for some cycles
@ -884,18 +880,11 @@ void Frontend::addBuildQueueItems(
data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete); data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete);
data.recordJsonLog = FFlag::DebugLuauLogSolverToJson; data.recordJsonLog = FFlag::DebugLuauLogSolverToJson;
const Mode mode = sourceModule->mode.value_or(data.config.mode);
// in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely // all correct programs must be acyclic so this code triggers rarely
if (cycleDetected) if (cycleDetected)
{ data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get());
if (FFlag::LuauMoreThoroughCycleDetection)
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), false);
else
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck);
}
data.options = frontendOptions; data.options = frontendOptions;
@ -1334,6 +1323,7 @@ ModulePtr check(
unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit);
Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}};
SimplifierPtr simplifier = newSimplifier(NotNull{&result->internalTypes}, builtinTypes);
TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}}; TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}};
if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation)
@ -1342,6 +1332,7 @@ ModulePtr check(
ConstraintGenerator cg{ ConstraintGenerator cg{
result, result,
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
moduleResolver, moduleResolver,
builtinTypes, builtinTypes,
@ -1358,6 +1349,7 @@ ModulePtr check(
ConstraintSolver cs{ ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope), NotNull(cg.rootScope),
borrowConstraints(cg.constraints), borrowConstraints(cg.constraints),

View File

@ -132,7 +132,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
return dest.addType(NegationType{a.ty}); return dest.addType(NegationType{a.ty});
else if constexpr (std::is_same_v<T, TypeFunctionInstanceType>) else if constexpr (std::is_same_v<T, TypeFunctionInstanceType>)
{ {
TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName}; TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncData};
return dest.addType(std::move(clone)); return dest.addType(std::move(clone));
} }
else else

View File

@ -4,6 +4,7 @@
#include "Luau/Common.h" #include "Luau/Common.h"
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauSymbolEquality)
namespace Luau namespace Luau
{ {
@ -14,7 +15,7 @@ bool Symbol::operator==(const Symbol& rhs) const
return local == rhs.local; return local == rhs.local;
else if (global.value) else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
else if (FFlag::LuauSolverV2) else if (FFlag::LuauSolverV2 || FFlag::LuauSymbolEquality)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else else
return false; return false;

View File

@ -870,6 +870,8 @@ struct TypeStringifier
return; return;
} }
LUAU_ASSERT(uv.options.size() > 1);
bool optional = false; bool optional = false;
bool hasNonNilDisjunct = false; bool hasNonNilDisjunct = false;
@ -878,7 +880,7 @@ struct TypeStringifier
{ {
el = follow(el); el = follow(el);
if (isNil(el)) if (state.opts.useQuestionMarks && isNil(el))
{ {
optional = true; optional = true;
continue; continue;

View File

@ -51,6 +51,7 @@ LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation)
LUAU_FASTFLAG(LuauUserTypeFunFixRegister) LUAU_FASTFLAG(LuauUserTypeFunFixRegister)
LUAU_FASTFLAG(LuauRemoveNotAnyHack) LUAU_FASTFLAG(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionResetState) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionResetState)
LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
@ -610,10 +611,29 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
NotNull<TypeFunctionContext> ctx NotNull<TypeFunctionContext> ctx
) )
{ {
if (!ctx->userFuncName) auto typeFunction = getMutable<TypeFunctionInstanceType>(instance);
if (FFlag::LuauUserTypeFunExportedAndLocal)
{ {
ctx->ice->ice("all user-defined type functions must have an associated function definition"); if (typeFunction->userFuncData.owner.expired())
return {std::nullopt, true, {}, {}}; {
ctx->ice->ice("user-defined type function module has expired");
return {std::nullopt, true, {}, {}};
}
if (!typeFunction->userFuncName || !typeFunction->userFuncData.definition)
{
ctx->ice->ice("all user-defined type functions must have an associated function definition");
return {std::nullopt, true, {}, {}};
}
}
else
{
if (!ctx->userFuncName)
{
ctx->ice->ice("all user-defined type functions must have an associated function definition");
return {std::nullopt, true, {}, {}};
}
} }
if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation)
@ -632,7 +652,22 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
return {std::nullopt, false, {ty}, {}}; return {std::nullopt, false, {ty}, {}};
} }
AstName name = *ctx->userFuncName; if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Ensure that whole type function environment is registered
for (auto& [name, definition] : typeFunction->userFuncData.environment)
{
if (std::optional<std::string> error = ctx->typeFunctionRuntime->registerFunction(definition))
{
// Failure to register at this point means that original definition had to error out and should not have been present in the
// environment
ctx->ice->ice("user-defined type function reference cannot be registered");
return {std::nullopt, true, {}, {}};
}
}
}
AstName name = FFlag::LuauUserTypeFunExportedAndLocal ? typeFunction->userFuncData.definition->name : *ctx->userFuncName;
lua_State* global = ctx->typeFunctionRuntime->state.get(); lua_State* global = ctx->typeFunctionRuntime->state.get();
@ -643,8 +678,44 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
lua_State* L = lua_newthread(global); lua_State* L = lua_newthread(global);
LuauTempThreadPopper popper(global); LuauTempThreadPopper popper(global);
lua_getglobal(global, name.value); if (FFlag::LuauUserTypeFunExportedAndLocal)
lua_xmove(global, L, 1); {
// Fetch the function we want to evaluate
lua_pushlightuserdata(L, typeFunction->userFuncData.definition);
lua_gettable(L, LUA_REGISTRYINDEX);
if (!lua_isfunction(L, -1))
{
ctx->ice->ice("user-defined type function reference cannot be found in the registry");
return {std::nullopt, true, {}, {}};
}
// Build up the environment
lua_getfenv(L, -1);
lua_setreadonly(L, -1, false);
for (auto& [name, definition] : typeFunction->userFuncData.environment)
{
lua_pushlightuserdata(L, definition);
lua_gettable(L, LUA_REGISTRYINDEX);
if (!lua_isfunction(L, -1))
{
ctx->ice->ice("user-defined type function reference cannot be found in the registry");
return {std::nullopt, true, {}, {}};
}
lua_setfield(L, -2, name.c_str());
}
lua_setreadonly(L, -1, true);
lua_pop(L, 1);
}
else
{
lua_getglobal(global, name.value);
lua_xmove(global, L, 1);
}
if (FFlag::LuauUserDefinedTypeFunctionResetState) if (FFlag::LuauUserDefinedTypeFunctionResetState)
resetTypeFunctionState(L); resetTypeFunctionState(L);
@ -693,7 +764,7 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get()); TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get());
// At least 1 error occured while deserializing // At least 1 error occurred while deserializing
if (runtimeBuilder->errors.size() > 0) if (runtimeBuilder->errors.size() > 0)
return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()};
@ -935,6 +1006,23 @@ std::optional<std::string> TypeFunctionRuntime::registerFunction(AstStatTypeFunc
prepareState(); prepareState();
lua_State* global = state.get();
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Fetch to check if function is already registered
lua_pushlightuserdata(global, function);
lua_gettable(global, LUA_REGISTRYINDEX);
if (!lua_isnil(global, -1))
{
lua_pop(global, 1);
return std::nullopt;
}
lua_pop(global, 1);
}
AstName name = function->name; AstName name = function->name;
// Construct ParseResult containing the type function // Construct ParseResult containing the type function
@ -961,7 +1049,6 @@ std::optional<std::string> TypeFunctionRuntime::registerFunction(AstStatTypeFunc
std::string bytecode = builder.getBytecode(); std::string bytecode = builder.getBytecode();
lua_State* global = state.get();
// Separate sandboxed thread for individual execution and private globals // Separate sandboxed thread for individual execution and private globals
lua_State* L = lua_newthread(global); lua_State* L = lua_newthread(global);
@ -989,9 +1076,19 @@ std::optional<std::string> TypeFunctionRuntime::registerFunction(AstStatTypeFunc
return format("Could not find '%s' type function in the global scope", name.value); return format("Could not find '%s' type function in the global scope", name.value);
} }
// Store resulting function in the global environment if (FFlag::LuauUserTypeFunExportedAndLocal)
lua_xmove(L, global, 1); {
lua_setglobal(global, name.value); // Store resulting function in the registry
lua_pushlightuserdata(global, function);
lua_xmove(L, global, 1);
lua_settable(global, LUA_REGISTRYINDEX);
}
else
{
// Store resulting function in the global environment
lua_xmove(L, global, 1);
lua_setglobal(global, name.value);
}
return std::nullopt; return std::nullopt;
} }

View File

@ -0,0 +1,48 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Location.h"
#include "Luau/DenseHash.h"
#include "Luau/Common.h"
#include <vector>
namespace Luau
{
class Allocator
{
public:
Allocator();
Allocator(Allocator&&);
Allocator& operator=(Allocator&&) = delete;
~Allocator();
void* allocate(size_t size);
template<typename T, typename... Args>
T* alloc(Args&&... args)
{
static_assert(std::is_trivially_destructible<T>::value, "Objects allocated with this allocator will never have their destructors run!");
T* t = static_cast<T*>(allocate(sizeof(T)));
new (t) T(std::forward<Args>(args)...);
return t;
}
private:
struct Page
{
Page* next;
char data[8192];
};
Page* root;
size_t offset;
};
}

View File

@ -316,16 +316,18 @@ public:
enum QuoteStyle enum QuoteStyle
{ {
Quoted, QuotedSimple,
QuotedRaw,
Unquoted Unquoted
}; };
AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle = Quoted); AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool isQuoted() const;
AstArray<char> value; AstArray<char> value;
QuoteStyle quoteStyle = Quoted; QuoteStyle quoteStyle;
}; };
class AstExprLocal : public AstExpr class AstExprLocal : public AstExpr
@ -876,13 +878,14 @@ class AstStatTypeFunction : public AstStat
public: public:
LUAU_RTTI(AstStatTypeFunction); LUAU_RTTI(AstStatTypeFunction);
AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body); AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body, bool exported);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstName name; AstName name;
Location nameLocation; Location nameLocation;
AstExprFunction* body; AstExprFunction* body;
bool exported;
}; };
class AstStatDeclareGlobal : public AstStat class AstStatDeclareGlobal : public AstStat

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Allocator.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
@ -11,40 +12,6 @@
namespace Luau namespace Luau
{ {
class Allocator
{
public:
Allocator();
Allocator(Allocator&&);
Allocator& operator=(Allocator&&) = delete;
~Allocator();
void* allocate(size_t size);
template<typename T, typename... Args>
T* alloc(Args&&... args)
{
static_assert(std::is_trivially_destructible<T>::value, "Objects allocated with this allocator will never have their destructors run!");
T* t = static_cast<T*>(allocate(sizeof(T)));
new (t) T(std::forward<Args>(args)...);
return t;
}
private:
struct Page
{
Page* next;
char data[8192];
};
Page* root;
size_t offset;
};
struct Lexeme struct Lexeme
{ {
enum Type enum Type

66
Ast/src/Allocator.cpp Normal file
View File

@ -0,0 +1,66 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Allocator.h"
namespace Luau
{
Allocator::Allocator()
: root(static_cast<Page*>(operator new(sizeof(Page))))
, offset(0)
{
root->next = nullptr;
}
Allocator::Allocator(Allocator&& rhs)
: root(rhs.root)
, offset(rhs.offset)
{
rhs.root = nullptr;
rhs.offset = 0;
}
Allocator::~Allocator()
{
Page* page = root;
while (page)
{
Page* next = page->next;
operator delete(page);
page = next;
}
}
void* Allocator::allocate(size_t size)
{
constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double);
if (root)
{
uintptr_t data = reinterpret_cast<uintptr_t>(root->data);
uintptr_t result = (data + offset + align - 1) & ~(align - 1);
if (result + size <= data + sizeof(root->data))
{
offset = result - data + size;
return reinterpret_cast<void*>(result);
}
}
// allocate new page
size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data);
void* pageData = operator new(offsetof(Page, data) + pageSize);
Page* page = static_cast<Page*>(pageData);
page->next = root;
root = page;
offset = size;
return page->data;
}
}

View File

@ -92,6 +92,11 @@ void AstExprConstantString::visit(AstVisitor* visitor)
visitor->visit(this); visitor->visit(this);
} }
bool AstExprConstantString::isQuoted() const
{
return quoteStyle == QuoteStyle::QuotedSimple || quoteStyle == QuoteStyle::QuotedRaw;
}
AstExprLocal::AstExprLocal(const Location& location, AstLocal* local, bool upvalue) AstExprLocal::AstExprLocal(const Location& location, AstLocal* local, bool upvalue)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, local(local) , local(local)
@ -760,11 +765,18 @@ void AstStatTypeAlias::visit(AstVisitor* visitor)
} }
} }
AstStatTypeFunction::AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body) AstStatTypeFunction::AstStatTypeFunction(
const Location& location,
const AstName& name,
const Location& nameLocation,
AstExprFunction* body,
bool exported
)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, name(name) , name(name)
, nameLocation(nameLocation) , nameLocation(nameLocation)
, body(body) , body(body)
, exported(exported)
{ {
} }

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Lexer.h" #include "Luau/Lexer.h"
#include "Luau/Allocator.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Confusables.h" #include "Luau/Confusables.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
@ -10,64 +11,6 @@
namespace Luau namespace Luau
{ {
Allocator::Allocator()
: root(static_cast<Page*>(operator new(sizeof(Page))))
, offset(0)
{
root->next = nullptr;
}
Allocator::Allocator(Allocator&& rhs)
: root(rhs.root)
, offset(rhs.offset)
{
rhs.root = nullptr;
rhs.offset = 0;
}
Allocator::~Allocator()
{
Page* page = root;
while (page)
{
Page* next = page->next;
operator delete(page);
page = next;
}
}
void* Allocator::allocate(size_t size)
{
constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double);
if (root)
{
uintptr_t data = reinterpret_cast<uintptr_t>(root->data);
uintptr_t result = (data + offset + align - 1) & ~(align - 1);
if (result + size <= data + sizeof(root->data))
{
offset = result - data + size;
return reinterpret_cast<void*>(result);
}
}
// allocate new page
size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data);
void* pageData = operator new(offsetof(Page, data) + pageSize);
Page* page = static_cast<Page*>(pageData);
page->next = root;
root = page;
offset = size;
return page->data;
}
Lexeme::Lexeme(const Location& location, Type type) Lexeme::Lexeme(const Location& location, Type type)
: type(type) : type(type)
, location(location) , location(location)

View File

@ -21,6 +21,7 @@ LUAU_FASTFLAGVARIABLE(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauNativeAttribute) LUAU_FASTFLAGVARIABLE(LuauNativeAttribute)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr) LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunParseExport)
LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing) LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing)
LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck) LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck)
@ -943,8 +944,11 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported)
Lexeme matchFn = lexer.current(); Lexeme matchFn = lexer.current();
nextLexeme(); nextLexeme();
if (exported) if (!FFlag::LuauUserDefinedTypeFunParseExport)
report(start, "Type function cannot be exported"); {
if (exported)
report(start, "Type function cannot be exported");
}
// parse the name of the type function // parse the name of the type function
std::optional<Name> fnName = parseNameOpt("type function name"); std::optional<Name> fnName = parseNameOpt("type function name");
@ -962,7 +966,7 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported)
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
return allocator.alloc<AstStatTypeFunction>(Location(start, body->location), fnName->name, fnName->location, body); return allocator.alloc<AstStatTypeFunction>(Location(start, body->location), fnName->name, fnName->location, body, exported);
} }
AstDeclaredClassProp Parser::parseDeclaredClassMethod() AstDeclaredClassProp Parser::parseDeclaredClassMethod()
@ -3012,8 +3016,23 @@ std::optional<AstArray<char>> Parser::parseCharArray()
AstExpr* Parser::parseString() AstExpr* Parser::parseString()
{ {
Location location = lexer.current().location; Location location = lexer.current().location;
AstExprConstantString::QuoteStyle style;
switch (lexer.current().type)
{
case Lexeme::QuotedString:
case Lexeme::InterpStringSimple:
style = AstExprConstantString::QuotedSimple;
break;
case Lexeme::RawString:
style = AstExprConstantString::QuotedRaw;
break;
default:
LUAU_ASSERT(false && "Invalid string type");
}
if (std::optional<AstArray<char>> value = parseCharArray()) if (std::optional<AstArray<char>> value = parseCharArray())
return allocator.alloc<AstExprConstantString>(location, *value); return allocator.alloc<AstExprConstantString>(location, *value, style);
else else
return reportExprError(location, {}, "String literal contains malformed escape sequence"); return reportExprError(location, {}, "String literal contains malformed escape sequence");
} }

View File

@ -1,4 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Config.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
@ -224,7 +225,14 @@ struct CliConfigResolver : Luau::ConfigResolver
if (std::optional<std::string> contents = readFile(configPath)) if (std::optional<std::string> contents = readFile(configPath))
{ {
std::optional<std::string> error = Luau::parseConfig(*contents, result); Luau::ConfigOptions::AliasOptions aliasOpts;
aliasOpts.configLocation = configPath;
aliasOpts.overwriteAliases = true;
Luau::ConfigOptions opts;
opts.aliasOptions = std::move(aliasOpts);
std::optional<std::string> error = Luau::parseConfig(*contents, result, opts);
if (error) if (error)
configErrors.push_back({configPath, *error}); configErrors.push_back({configPath, *error});
} }

View File

@ -181,6 +181,16 @@ std::string resolvePath(std::string_view path, std::string_view baseFilePath)
return resolvedPath; return resolvedPath;
} }
bool hasFileExtension(std::string_view name, const std::vector<std::string>& extensions)
{
for (const std::string& extension : extensions)
{
if (name.size() >= extension.size() && name.substr(name.size() - extension.size()) == extension)
return true;
}
return false;
}
std::optional<std::string> readFile(const std::string& name) std::optional<std::string> readFile(const std::string& name)
{ {
#ifdef _WIN32 #ifdef _WIN32

View File

@ -15,6 +15,8 @@ std::string resolvePath(std::string_view relativePath, std::string_view baseFile
std::optional<std::string> readFile(const std::string& name); std::optional<std::string> readFile(const std::string& name);
std::optional<std::string> readStdin(); std::optional<std::string> readStdin();
bool hasFileExtension(std::string_view name, const std::vector<std::string>& extensions);
bool isAbsolutePath(std::string_view path); bool isAbsolutePath(std::string_view path);
bool isFile(const std::string& path); bool isFile(const std::string& path);
bool isDirectory(const std::string& path); bool isDirectory(const std::string& path);

View File

@ -3,6 +3,7 @@
#include "FileUtils.h" #include "FileUtils.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Config.h"
#include <algorithm> #include <algorithm>
#include <array> #include <array>
@ -83,6 +84,9 @@ RequireResolver::ModuleStatus RequireResolver::findModuleImpl()
absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix
} }
if (hasFileExtension(absolutePath, {".luau", ".lua"}) && isFile(absolutePath))
luaL_argerrorL(L, 1, "error requiring module: consider removing the file extension");
return ModuleStatus::NotFound; return ModuleStatus::NotFound;
} }
@ -235,14 +239,15 @@ std::optional<std::string> RequireResolver::getAlias(std::string alias)
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
} }
); );
while (!config.aliases.count(alias) && !isConfigFullyResolved) while (!config.aliases.contains(alias) && !isConfigFullyResolved)
{ {
parseNextConfig(); parseNextConfig();
} }
if (!config.aliases.count(alias) && isConfigFullyResolved) if (!config.aliases.contains(alias) && isConfigFullyResolved)
return std::nullopt; // could not find alias return std::nullopt; // could not find alias
return resolvePath(config.aliases[alias], joinPaths(lastSearchedDir, Luau::kConfigName)); const Luau::Config::AliasInfo& aliasInfo = config.aliases[alias];
return resolvePath(aliasInfo.value, aliasInfo.configLocation);
} }
void RequireResolver::parseNextConfig() void RequireResolver::parseNextConfig()
@ -275,9 +280,16 @@ void RequireResolver::parseConfigInDirectory(const std::string& directory)
{ {
std::string configPath = joinPaths(directory, Luau::kConfigName); std::string configPath = joinPaths(directory, Luau::kConfigName);
Luau::ConfigOptions::AliasOptions aliasOpts;
aliasOpts.configLocation = configPath;
aliasOpts.overwriteAliases = false;
Luau::ConfigOptions opts;
opts.aliasOptions = std::move(aliasOpts);
if (std::optional<std::string> contents = readFile(configPath)) if (std::optional<std::string> contents = readFile(configPath))
{ {
std::optional<std::string> error = Luau::parseConfig(*contents, config); std::optional<std::string> error = Luau::parseConfig(*contents, config, opts);
if (error) if (error)
luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str()); luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str());
} }

View File

@ -19,7 +19,7 @@ class Variant
static_assert(std::disjunction_v<std::is_reference<Ts>...> == false, "variant does not allow references as an alternative type"); static_assert(std::disjunction_v<std::is_reference<Ts>...> == false, "variant does not allow references as an alternative type");
static_assert(std::disjunction_v<std::is_array<Ts>...> == false, "variant does not allow arrays as an alternative type"); static_assert(std::disjunction_v<std::is_array<Ts>...> == false, "variant does not allow arrays as an alternative type");
private: public:
template<typename T> template<typename T>
static constexpr int getTypeId() static constexpr int getTypeId()
{ {
@ -35,6 +35,7 @@ private:
return -1; return -1;
} }
private:
template<typename T, typename... Tail> template<typename T, typename... Tail>
struct First struct First
{ {

View File

@ -1,12 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/DenseHash.h"
#include "Luau/LinterConfig.h" #include "Luau/LinterConfig.h"
#include "Luau/ParseOptions.h" #include "Luau/ParseOptions.h"
#include <memory>
#include <optional> #include <optional>
#include <string> #include <string>
#include <unordered_map> #include <string_view>
#include <vector> #include <vector>
namespace Luau namespace Luau
@ -19,6 +21,10 @@ constexpr const char* kConfigName = ".luaurc";
struct Config struct Config
{ {
Config(); Config();
Config(const Config& other) noexcept;
Config& operator=(const Config& other) noexcept;
Config(Config&& other) noexcept = default;
Config& operator=(Config&& other) noexcept = default;
Mode mode = Mode::Nonstrict; Mode mode = Mode::Nonstrict;
@ -32,7 +38,19 @@ struct Config
std::vector<std::string> globals; std::vector<std::string> globals;
std::unordered_map<std::string, std::string> aliases; struct AliasInfo
{
std::string value;
std::string_view configLocation;
};
DenseHashMap<std::string, AliasInfo> aliases{""};
void setAlias(std::string alias, const std::string& value, const std::string configLocation);
private:
// Prevents making unnecessary copies of the same config location string.
DenseHashMap<std::string, std::unique_ptr<std::string>> configLocationCache{""};
}; };
struct ConfigResolver struct ConfigResolver
@ -60,6 +78,18 @@ std::optional<std::string> parseLintRuleString(
bool isValidAlias(const std::string& alias); bool isValidAlias(const std::string& alias);
std::optional<std::string> parseConfig(const std::string& contents, Config& config, bool compat = false); struct ConfigOptions
{
bool compat = false;
struct AliasOptions
{
std::string configLocation;
bool overwriteAliases;
};
std::optional<AliasOptions> aliasOptions = std::nullopt;
};
std::optional<std::string> parseConfig(const std::string& contents, Config& config, const ConfigOptions& options = ConfigOptions{});
} // namespace Luau } // namespace Luau

View File

@ -4,7 +4,8 @@
#include "Luau/Lexer.h" #include "Luau/Lexer.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <memory>
#include <string>
namespace Luau namespace Luau
{ {
@ -16,6 +17,50 @@ Config::Config()
enabledLint.setDefaults(); enabledLint.setDefaults();
} }
Config::Config(const Config& other) noexcept
: mode(other.mode)
, parseOptions(other.parseOptions)
, enabledLint(other.enabledLint)
, fatalLint(other.fatalLint)
, lintErrors(other.lintErrors)
, typeErrors(other.typeErrors)
, globals(other.globals)
{
for (const auto& [alias, aliasInfo] : other.aliases)
{
std::string configLocation = std::string(aliasInfo.configLocation);
if (!configLocationCache.contains(configLocation))
configLocationCache[configLocation] = std::make_unique<std::string>(configLocation);
AliasInfo newAliasInfo;
newAliasInfo.value = aliasInfo.value;
newAliasInfo.configLocation = *configLocationCache[configLocation];
aliases[alias] = std::move(newAliasInfo);
}
}
Config& Config::operator=(const Config& other) noexcept
{
if (this != &other)
{
Config copy(other);
std::swap(*this, copy);
}
return *this;
}
void Config::setAlias(std::string alias, const std::string& value, const std::string configLocation)
{
AliasInfo& info = aliases[alias];
info.value = value;
if (!configLocationCache.contains(configLocation))
configLocationCache[configLocation] = std::make_unique<std::string>(configLocation);
info.configLocation = *configLocationCache[configLocation];
}
static Error parseBoolean(bool& result, const std::string& value) static Error parseBoolean(bool& result, const std::string& value)
{ {
if (value == "true") if (value == "true")
@ -136,7 +181,12 @@ bool isValidAlias(const std::string& alias)
return true; return true;
} }
Error parseAlias(std::unordered_map<std::string, std::string>& aliases, std::string aliasKey, const std::string& aliasValue) static Error parseAlias(
Config& config,
std::string aliasKey,
const std::string& aliasValue,
const std::optional<ConfigOptions::AliasOptions>& aliasOptions
)
{ {
if (!isValidAlias(aliasKey)) if (!isValidAlias(aliasKey))
return Error{"Invalid alias " + aliasKey}; return Error{"Invalid alias " + aliasKey};
@ -150,8 +200,12 @@ Error parseAlias(std::unordered_map<std::string, std::string>& aliases, std::str
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
} }
); );
if (!aliases.count(aliasKey))
aliases[std::move(aliasKey)] = aliasValue; if (!aliasOptions)
return Error("Cannot parse aliases without alias options");
if (aliasOptions->overwriteAliases || !config.aliases.contains(aliasKey))
config.setAlias(std::move(aliasKey), aliasValue, aliasOptions->configLocation);
return std::nullopt; return std::nullopt;
} }
@ -285,16 +339,16 @@ static Error parseJson(const std::string& contents, Action action)
return {}; return {};
} }
Error parseConfig(const std::string& contents, Config& config, bool compat) Error parseConfig(const std::string& contents, Config& config, const ConfigOptions& options)
{ {
return parseJson( return parseJson(
contents, contents,
[&](const std::vector<std::string>& keys, const std::string& value) -> Error [&](const std::vector<std::string>& keys, const std::string& value) -> Error
{ {
if (keys.size() == 1 && keys[0] == "languageMode") if (keys.size() == 1 && keys[0] == "languageMode")
return parseModeString(config.mode, value, compat); return parseModeString(config.mode, value, options.compat);
else if (keys.size() == 2 && keys[0] == "lint") else if (keys.size() == 2 && keys[0] == "lint")
return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, compat); return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, options.compat);
else if (keys.size() == 1 && keys[0] == "lintErrors") else if (keys.size() == 1 && keys[0] == "lintErrors")
return parseBoolean(config.lintErrors, value); return parseBoolean(config.lintErrors, value);
else if (keys.size() == 1 && keys[0] == "typeErrors") else if (keys.size() == 1 && keys[0] == "typeErrors")
@ -305,9 +359,9 @@ Error parseConfig(const std::string& contents, Config& config, bool compat)
return std::nullopt; return std::nullopt;
} }
else if (keys.size() == 2 && keys[0] == "aliases") else if (keys.size() == 2 && keys[0] == "aliases")
return parseAlias(config.aliases, keys[1], value); return parseAlias(config, keys[1], value, options.aliasOptions);
else if (compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode") else if (options.compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode")
return parseModeString(config.mode, value, compat); return parseModeString(config.mode, value, options.compat);
else else
{ {
std::vector<std::string_view> keysv(keys.begin(), keys.end()); std::vector<std::string_view> keysv(keys.begin(), keys.end());

View File

@ -23,6 +23,13 @@ struct Analysis final
using D = typename N::Data; using D = typename N::Data;
Analysis() = default;
Analysis(N a)
: analysis(std::move(a))
{
}
template<typename T> template<typename T>
static D fnMake(const N& analysis, const EGraph<L, N>& egraph, const L& enode) static D fnMake(const N& analysis, const EGraph<L, N>& egraph, const L& enode)
{ {
@ -59,6 +66,15 @@ struct EClass final
template<typename L, typename N> template<typename L, typename N>
struct EGraph final struct EGraph final
{ {
using EClassT = EClass<L, typename N::Data>;
EGraph() = default;
explicit EGraph(N analysis)
: analysis(std::move(analysis))
{
}
Id find(Id id) const Id find(Id id) const
{ {
return unionfind.find(id); return unionfind.find(id);
@ -85,33 +101,59 @@ struct EGraph final
return id; return id;
} }
void merge(Id id1, Id id2) // Returns true if the two IDs were not previously merged.
bool merge(Id id1, Id id2)
{ {
id1 = find(id1); id1 = find(id1);
id2 = find(id2); id2 = find(id2);
if (id1 == id2) if (id1 == id2)
return; return false;
unionfind.merge(id1, id2); const Id mergedId = unionfind.merge(id1, id2);
EClass<L, typename N::Data>& eclass1 = get(id1); // Ensure that id1 is the Id that we keep, and id2 is the id that we drop.
EClass<L, typename N::Data> eclass2 = std::move(get(id2)); if (mergedId == id2)
std::swap(id1, id2);
EClassT& eclass1 = get(id1);
EClassT eclass2 = std::move(get(id2));
classes.erase(id2); classes.erase(id2);
worklist.reserve(worklist.size() + eclass2.parents.size()); eclass1.nodes.insert(eclass1.nodes.end(), eclass2.nodes.begin(), eclass2.nodes.end());
for (auto [enode, id] : eclass2.parents) eclass1.parents.insert(eclass1.parents.end(), eclass2.parents.begin(), eclass2.parents.end());
worklist.push_back({std::move(enode), id});
std::sort(
eclass1.nodes.begin(),
eclass1.nodes.end(),
[](const L& left, const L& right)
{
return left.index() < right.index();
}
);
worklist.reserve(worklist.size() + eclass1.parents.size());
for (const auto& [eclass, id] : eclass1.parents)
worklist.push_back(id);
analysis.join(eclass1.data, eclass2.data); analysis.join(eclass1.data, eclass2.data);
return true;
} }
void rebuild() void rebuild()
{ {
std::unordered_set<Id> seen;
while (!worklist.empty()) while (!worklist.empty())
{ {
auto [enode, id] = worklist.back(); Id id = worklist.back();
worklist.pop_back(); worklist.pop_back();
repair(get(find(id)));
const bool isFresh = seen.insert(id).second;
if (!isFresh)
continue;
repair(find(id));
} }
} }
@ -120,16 +162,21 @@ struct EGraph final
return classes.size(); return classes.size();
} }
EClass<L, typename N::Data>& operator[](Id id) EClassT& operator[](Id id)
{ {
return get(find(id)); return get(find(id));
} }
const EClass<L, typename N::Data>& operator[](Id id) const const EClassT& operator[](Id id) const
{ {
return const_cast<EGraph*>(this)->get(find(id)); return const_cast<EGraph*>(this)->get(find(id));
} }
const std::unordered_map<Id, EClassT>& getAllClasses() const
{
return classes;
}
private: private:
Analysis<L, N> analysis; Analysis<L, N> analysis;
@ -139,19 +186,19 @@ private:
/// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same /// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same
/// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the /// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the
/// e-class 𝑀[find(𝑎)]. /// e-class 𝑀[find(𝑎)].
std::unordered_map<Id, EClass<L, typename N::Data>> classes; std::unordered_map<Id, EClassT> classes;
/// The hashcons 𝐻 is a map from e-nodes to e-class ids. /// The hashcons 𝐻 is a map from e-nodes to e-class ids.
std::unordered_map<L, Id, typename L::Hash> hashcons; std::unordered_map<L, Id, typename L::Hash> hashcons;
std::vector<std::pair<L, Id>> worklist; std::vector<Id> worklist;
private: private:
void canonicalize(L& enode) void canonicalize(L& enode)
{ {
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
for (Id& id : enode.operands()) for (Id& id : enode.mutableOperands())
id = find(id); id = find(id);
} }
@ -171,7 +218,7 @@ private:
classes.insert_or_assign( classes.insert_or_assign(
id, id,
EClass<L, typename N::Data>{ EClassT{
id, id,
{enode}, {enode},
analysis.make(*this, enode), analysis.make(*this, enode),
@ -182,7 +229,7 @@ private:
for (Id operand : enode.operands()) for (Id operand : enode.operands())
get(operand).parents.push_back({enode, id}); get(operand).parents.push_back({enode, id});
worklist.emplace_back(enode, id); worklist.emplace_back(id);
hashcons.insert_or_assign(enode, id); hashcons.insert_or_assign(enode, id);
return id; return id;
@ -190,12 +237,13 @@ private:
// Looks up for an eclass from a given non-canonicalized `id`. // Looks up for an eclass from a given non-canonicalized `id`.
// For a canonicalized eclass, use `get(find(id))` or `egraph[id]`. // For a canonicalized eclass, use `get(find(id))` or `egraph[id]`.
EClass<L, typename N::Data>& get(Id id) EClassT& get(Id id)
{ {
LUAU_ASSERT(classes.count(id));
return classes.at(id); return classes.at(id);
} }
void repair(EClass<L, typename N::Data>& eclass) void repair(Id id)
{ {
// In the egg paper, the `repair` function makes use of two loops over the `eclass.parents` // In the egg paper, the `repair` function makes use of two loops over the `eclass.parents`
// by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id. // by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id.
@ -204,26 +252,54 @@ private:
// Here, we unify the two loops. I think it's equivalent? // Here, we unify the two loops. I think it's equivalent?
// After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent. // After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent.
std::unordered_map<L, Id, typename L::Hash> map; std::unordered_map<L, Id, typename L::Hash> newParents;
for (auto& [enode, id] : eclass.parents)
// The eclass can be deallocated if it is merged into another eclass, so
// we take what we need from it and avoid retaining a pointer.
std::vector<std::pair<L, Id>> parents = get(id).parents;
for (auto& pair : parents)
{ {
L& enode = pair.first;
Id id = pair.second;
// By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id.
hashcons.erase(enode); hashcons.erase(enode);
canonicalize(enode); canonicalize(enode);
hashcons.insert_or_assign(enode, find(id)); hashcons.insert_or_assign(enode, find(id));
if (auto it = map.find(enode); it != map.end()) if (auto it = newParents.find(enode); it != newParents.end())
merge(id, it->second); merge(id, it->second);
map.insert_or_assign(enode, find(id)); newParents.insert_or_assign(enode, find(id));
} }
eclass.parents.clear(); // We reacquire the pointer because the prior loop potentially merges
for (auto it = map.begin(); it != map.end();) // the eclass into another, which might move it around in memory.
EClassT* eclass = &get(find(id));
eclass->parents.clear();
for (const auto& [node, id] : newParents)
eclass->parents.emplace_back(std::move(node), std::move(id));
std::unordered_set<L, typename L::Hash> newNodes;
for (L node : eclass->nodes)
{ {
auto node = map.extract(it++); canonicalize(node);
eclass.parents.emplace_back(std::move(node.key()), node.mapped()); newNodes.insert(std::move(node));
} }
eclass->nodes.assign(newNodes.begin(), newNodes.end());
// FIXME: Extract into sortByTag()
std::sort(
eclass->nodes.begin(),
eclass->nodes.end(),
[](const L& left, const L& right)
{
return left.index() < right.index();
}
);
} }
}; };

View File

@ -2,6 +2,7 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include <cstdint>
#include <functional> #include <functional>
namespace Luau::EqSat namespace Luau::EqSat
@ -9,15 +10,17 @@ namespace Luau::EqSat
struct Id final struct Id final
{ {
explicit Id(size_t id); explicit Id(uint32_t id);
explicit operator size_t() const; explicit operator uint32_t() const;
bool operator==(Id rhs) const; bool operator==(Id rhs) const;
bool operator!=(Id rhs) const; bool operator!=(Id rhs) const;
bool operator<(Id rhs) const;
private: private:
size_t id; uint32_t id;
}; };
} // namespace Luau::EqSat } // namespace Luau::EqSat

View File

@ -6,9 +6,19 @@
#include "Luau/Slice.h" #include "Luau/Slice.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include <algorithm>
#include <array> #include <array>
#include <type_traits> #include <type_traits>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector>
#define LUAU_EQSAT_UNIT(name) \
struct name : ::Luau::EqSat::Unit<name> \
{ \
static constexpr const char* tag = #name; \
using Unit::Unit; \
}
#define LUAU_EQSAT_ATOM(name, t) \ #define LUAU_EQSAT_ATOM(name, t) \
struct name : public ::Luau::EqSat::Atom<name, t> \ struct name : public ::Luau::EqSat::Atom<name, t> \
@ -31,21 +41,57 @@
using NodeVector::NodeVector; \ using NodeVector::NodeVector; \
} }
#define LUAU_EQSAT_FIELD(name) \ #define LUAU_EQSAT_NODE_SET(name) \
struct name : public ::Luau::EqSat::Field<name> \ struct name : public ::Luau::EqSat::NodeSet<name, std::vector<::Luau::EqSat::Id>> \
{ \
}
#define LUAU_EQSAT_NODE_FIELDS(name, ...) \
struct name : public ::Luau::EqSat::NodeFields<name, __VA_ARGS__> \
{ \ { \
static constexpr const char* tag = #name; \ static constexpr const char* tag = #name; \
using NodeFields::NodeFields; \ using NodeSet::NodeSet; \
}
#define LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(name, t) \
struct name : public ::Luau::EqSat::NodeAtomAndVector<name, t, std::vector<::Luau::EqSat::Id>> \
{ \
static constexpr const char* tag = #name; \
using NodeAtomAndVector::NodeAtomAndVector; \
} }
namespace Luau::EqSat namespace Luau::EqSat
{ {
template<typename Phantom>
struct Unit
{
Slice<Id> mutableOperands()
{
return {};
}
Slice<const Id> operands() const
{
return {};
}
bool operator==(const Unit& rhs) const
{
return true;
}
bool operator!=(const Unit& rhs) const
{
return false;
}
struct Hash
{
size_t operator()(const Unit& value) const
{
// chosen by fair dice roll.
// guaranteed to be random.
return 4;
}
};
};
template<typename Phantom, typename T> template<typename Phantom, typename T>
struct Atom struct Atom
{ {
@ -60,7 +106,7 @@ struct Atom
} }
public: public:
Slice<Id> operands() Slice<Id> mutableOperands()
{ {
return {}; return {};
} }
@ -92,6 +138,62 @@ private:
T _value; T _value;
}; };
template<typename Phantom, typename X, typename T>
struct NodeAtomAndVector
{
template<typename... Args>
NodeAtomAndVector(const X& value, Args&&... args)
: _value(value)
, vector{std::forward<Args>(args)...}
{
}
Id operator[](size_t i) const
{
return vector[i];
}
public:
const X& value() const
{
return _value;
}
Slice<Id> mutableOperands()
{
return Slice{vector.data(), vector.size()};
}
Slice<const Id> operands() const
{
return Slice{vector.data(), vector.size()};
}
bool operator==(const NodeAtomAndVector& rhs) const
{
return _value == rhs._value && vector == rhs.vector;
}
bool operator!=(const NodeAtomAndVector& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const NodeAtomAndVector& value) const
{
size_t result = languageHash(value._value);
hashCombine(result, languageHash(value.vector));
return result;
}
};
private:
X _value;
T vector;
};
template<typename Phantom, typename T> template<typename Phantom, typename T>
struct NodeVector struct NodeVector
{ {
@ -107,7 +209,7 @@ struct NodeVector
} }
public: public:
Slice<Id> operands() Slice<Id> mutableOperands()
{ {
return Slice{vector.data(), vector.size()}; return Slice{vector.data(), vector.size()};
} }
@ -139,90 +241,61 @@ private:
T vector; T vector;
}; };
/// Empty base class just for static_asserts. template<typename Phantom, typename T>
struct FieldBase struct NodeSet
{ {
FieldBase() = delete; template<typename... Args>
NodeSet(Args&&... args)
FieldBase(FieldBase&&) = delete; : vector{std::forward<Args>(args)...}
FieldBase& operator=(FieldBase&&) = delete;
FieldBase(const FieldBase&) = delete;
FieldBase& operator=(const FieldBase&) = delete;
};
template<typename Phantom>
struct Field : FieldBase
{
};
template<typename Phantom, typename... Fields>
struct NodeFields
{
static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value);
template<typename T>
static constexpr int getIndex()
{ {
constexpr int N = sizeof...(Fields); std::sort(begin(vector), end(vector));
constexpr bool is[N] = {std::is_same_v<std::decay_t<T>, Fields>...}; auto it = std::unique(begin(vector), end(vector));
vector.erase(it, end(vector));
}
for (int i = 0; i < N; ++i) Id operator[](size_t i) const
if (is[i]) {
return i; return vector[i];
return -1;
} }
public: public:
template<typename... Args> Slice<Id> mutableOperands()
NodeFields(Args&&... args)
: array{std::forward<Args>(args)...}
{ {
} return Slice{vector.data(), vector.size()};
Slice<Id> operands()
{
return Slice{array};
} }
Slice<const Id> operands() const Slice<const Id> operands() const
{ {
return Slice{array.data(), array.size()}; return Slice{vector.data(), vector.size()};
} }
template<typename T> bool operator==(const NodeSet& rhs) const
Id field() const
{ {
static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Fields>...>); return vector == rhs.vector;
return array[getIndex<T>()];
} }
bool operator==(const NodeFields& rhs) const bool operator!=(const NodeSet& rhs) const
{
return array == rhs.array;
}
bool operator!=(const NodeFields& rhs) const
{ {
return !(*this == rhs); return !(*this == rhs);
} }
struct Hash struct Hash
{ {
size_t operator()(const NodeFields& value) const size_t operator()(const NodeSet& value) const
{ {
return languageHash(value.array); return languageHash(value.vector);
} }
}; };
private: protected:
std::array<Id, sizeof...(Fields)> array; T vector;
}; };
template<typename... Ts> template<typename... Ts>
struct Language final struct Language final
{ {
using VariantTy = Luau::Variant<Ts...>;
template<typename T> template<typename T>
using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>; using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>;
@ -237,14 +310,14 @@ struct Language final
return v.index(); return v.index();
} }
/// You should never call this function with the intention of mutating the `Id`. /// This should only be used in canonicalization!
/// Reading is ok, but you should also never assume that these `Id`s are stable. /// Always prefer operands()
Slice<Id> operands() noexcept Slice<Id> mutableOperands() noexcept
{ {
return visit( return visit(
[](auto&& v) -> Slice<Id> [](auto&& v) -> Slice<Id>
{ {
return v.operands(); return v.mutableOperands();
}, },
v v
); );
@ -306,7 +379,7 @@ public:
}; };
private: private:
Variant<Ts...> v; VariantTy v;
}; };
} // namespace Luau::EqSat } // namespace Luau::EqSat

View File

@ -3,6 +3,7 @@
#include <cstddef> #include <cstddef>
#include <functional> #include <functional>
#include <unordered_set>
#include <vector> #include <vector>
namespace Luau::EqSat namespace Luau::EqSat

View File

@ -14,7 +14,9 @@ struct UnionFind final
Id makeSet(); Id makeSet();
Id find(Id id) const; Id find(Id id) const;
Id find(Id id); Id find(Id id);
void merge(Id a, Id b);
// Merge aSet with bSet and return the canonicalized Id into the merged set.
Id merge(Id aSet, Id bSet);
private: private:
std::vector<Id> parents; std::vector<Id> parents;

View File

@ -1,15 +1,16 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Id.h" #include "Luau/Id.h"
#include "Luau/Common.h"
namespace Luau::EqSat namespace Luau::EqSat
{ {
Id::Id(size_t id) Id::Id(uint32_t id)
: id(id) : id(id)
{ {
} }
Id::operator size_t() const Id::operator uint32_t() const
{ {
return id; return id;
} }
@ -24,9 +25,14 @@ bool Id::operator!=(Id rhs) const
return id != rhs.id; return id != rhs.id;
} }
bool Id::operator<(Id rhs) const
{
return id < rhs.id;
}
} // namespace Luau::EqSat } // namespace Luau::EqSat
size_t std::hash<Luau::EqSat::Id>::operator()(Luau::EqSat::Id id) const size_t std::hash<Luau::EqSat::Id>::operator()(Luau::EqSat::Id id) const
{ {
return std::hash<size_t>()(size_t(id)); return std::hash<uint32_t>()(uint32_t(id));
} }

View File

@ -3,12 +3,16 @@
#include "Luau/Common.h" #include "Luau/Common.h"
#include <limits>
namespace Luau::EqSat namespace Luau::EqSat
{ {
Id UnionFind::makeSet() Id UnionFind::makeSet()
{ {
Id id{parents.size()}; LUAU_ASSERT(parents.size() < std::numeric_limits<uint32_t>::max());
Id id{uint32_t(parents.size())};
parents.push_back(id); parents.push_back(id);
ranks.push_back(0); ranks.push_back(0);
@ -25,42 +29,44 @@ Id UnionFind::find(Id id)
Id set = canonicalize(id); Id set = canonicalize(id);
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
while (id != parents[size_t(id)]) while (id != parents[uint32_t(id)])
{ {
// Note: we don't update the ranks here since a rank // Note: we don't update the ranks here since a rank
// represents the upper bound on the maximum depth of a tree // represents the upper bound on the maximum depth of a tree
Id parent = parents[size_t(id)]; Id parent = parents[uint32_t(id)];
parents[size_t(id)] = set; parents[uint32_t(id)] = set;
id = parent; id = parent;
} }
return set; return set;
} }
void UnionFind::merge(Id a, Id b) Id UnionFind::merge(Id a, Id b)
{ {
Id aSet = find(a); Id aSet = find(a);
Id bSet = find(b); Id bSet = find(b);
if (aSet == bSet) if (aSet == bSet)
return; return aSet;
// Ensure that the rank of set A is greater than the rank of set B // Ensure that the rank of set A is greater than the rank of set B
if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) if (ranks[uint32_t(aSet)] > ranks[uint32_t(bSet)])
std::swap(aSet, bSet); std::swap(aSet, bSet);
parents[size_t(bSet)] = aSet; parents[uint32_t(bSet)] = aSet;
if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) if (ranks[uint32_t(aSet)] == ranks[uint32_t(bSet)])
ranks[size_t(aSet)]++; ranks[uint32_t(aSet)]++;
return aSet;
} }
Id UnionFind::canonicalize(Id id) const Id UnionFind::canonicalize(Id id) const
{ {
LUAU_ASSERT(size_t(id) < parents.size()); LUAU_ASSERT(uint32_t(id) < parents.size());
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
while (id != parents[size_t(id)]) while (id != parents[uint32_t(id)])
id = parents[size_t(id)]; id = parents[uint32_t(id)];
return id; return id;
} }

View File

@ -14,6 +14,7 @@ endif()
# Luau.Ast Sources # Luau.Ast Sources
target_sources(Luau.Ast PRIVATE target_sources(Luau.Ast PRIVATE
Ast/include/Luau/Allocator.h
Ast/include/Luau/Ast.h Ast/include/Luau/Ast.h
Ast/include/Luau/Confusables.h Ast/include/Luau/Confusables.h
Ast/include/Luau/Lexer.h Ast/include/Luau/Lexer.h
@ -24,6 +25,7 @@ target_sources(Luau.Ast PRIVATE
Ast/include/Luau/StringUtils.h Ast/include/Luau/StringUtils.h
Ast/include/Luau/TimeTrace.h Ast/include/Luau/TimeTrace.h
Ast/src/Allocator.cpp
Ast/src/Ast.cpp Ast/src/Ast.cpp
Ast/src/Confusables.cpp Ast/src/Confusables.cpp
Ast/src/Lexer.cpp Ast/src/Lexer.cpp
@ -168,6 +170,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstJsonEncoder.h
Analysis/include/Luau/AstQuery.h Analysis/include/Luau/AstQuery.h
Analysis/include/Luau/Autocomplete.h Analysis/include/Luau/Autocomplete.h
Analysis/include/Luau/AutocompleteTypes.h
Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/BuiltinDefinitions.h
Analysis/include/Luau/Cancellation.h Analysis/include/Luau/Cancellation.h
Analysis/include/Luau/Clone.h Analysis/include/Luau/Clone.h
@ -181,6 +184,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Differ.h Analysis/include/Luau/Differ.h
Analysis/include/Luau/Documentation.h Analysis/include/Luau/Documentation.h
Analysis/include/Luau/Error.h Analysis/include/Luau/Error.h
Analysis/include/Luau/EqSatSimplification.h
Analysis/include/Luau/FileResolver.h Analysis/include/Luau/FileResolver.h
Analysis/include/Luau/FragmentAutocomplete.h Analysis/include/Luau/FragmentAutocomplete.h
Analysis/include/Luau/Frontend.h Analysis/include/Luau/Frontend.h
@ -245,6 +249,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/AstJsonEncoder.cpp Analysis/src/AstJsonEncoder.cpp
Analysis/src/AstQuery.cpp Analysis/src/AstQuery.cpp
Analysis/src/Autocomplete.cpp Analysis/src/Autocomplete.cpp
Analysis/src/AutocompleteCore.cpp
Analysis/src/BuiltinDefinitions.cpp Analysis/src/BuiltinDefinitions.cpp
Analysis/src/Clone.cpp Analysis/src/Clone.cpp
Analysis/src/Constraint.cpp Analysis/src/Constraint.cpp
@ -256,6 +261,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Differ.cpp Analysis/src/Differ.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp Analysis/src/Error.cpp
Analysis/src/EqSatSimplification.cpp
Analysis/src/FragmentAutocomplete.cpp Analysis/src/FragmentAutocomplete.cpp
Analysis/src/Frontend.cpp Analysis/src/Frontend.cpp
Analysis/src/Generalization.cpp Analysis/src/Generalization.cpp
@ -417,7 +423,7 @@ endif()
if(TARGET Luau.UnitTest) if(TARGET Luau.UnitTest)
# Luau.UnitTest Sources # Luau.UnitTest Sources
target_sources(Luau.UnitTest PRIVATE target_sources(Luau.UnitTest PRIVATE
tests/AnyTypeSummary.test.cpp tests/AnyTypeSummary.test.cpp
tests/AssemblyBuilderA64.test.cpp tests/AssemblyBuilderA64.test.cpp
tests/AssemblyBuilderX64.test.cpp tests/AssemblyBuilderX64.test.cpp
tests/AstJsonEncoder.test.cpp tests/AstJsonEncoder.test.cpp
@ -444,6 +450,7 @@ if(TARGET Luau.UnitTest)
tests/EqSat.language.test.cpp tests/EqSat.language.test.cpp
tests/EqSat.propositional.test.cpp tests/EqSat.propositional.test.cpp
tests/EqSat.slice.test.cpp tests/EqSat.slice.test.cpp
tests/EqSatSimplification.test.cpp
tests/Error.test.cpp tests/Error.test.cpp
tests/Fixture.cpp tests/Fixture.cpp
tests/Fixture.h tests/Fixture.h

View File

@ -39,7 +39,7 @@ const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Ri
"$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n"
"$URL: www.lua.org $\n"; "$URL: www.lua.org $\n";
const char* luau_ident = "$Luau: Copyright (C) 2019-2023 Roblox Corporation $\n" const char* luau_ident = "$Luau: Copyright (C) 2019-2024 Roblox Corporation $\n"
"$URL: luau.org $\n"; "$URL: luau.org $\n";
#define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base)) #define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base))

View File

@ -67,7 +67,7 @@ TEST_CASE("encode_constants")
charString.data = const_cast<char*>("a\x1d\0\\\"b"); charString.data = const_cast<char*>("a\x1d\0\\\"b");
charString.size = 6; charString.size = 6;
AstExprConstantString needsEscaping{Location(), charString}; AstExprConstantString needsEscaping{Location(), charString, AstExprConstantString::QuotedSimple};
CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil)); CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil));
CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b)); CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b));
@ -83,7 +83,7 @@ TEST_CASE("basic_escaping")
{ {
std::string s = "hello \"world\""; std::string s = "hello \"world\"";
AstArray<char> theString{s.data(), s.size()}; AstArray<char> theString{s.data(), s.size()};
AstExprConstantString str{Location(), theString}; AstExprConstantString str{Location(), theString, AstExprConstantString::QuotedSimple};
std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})"; std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})";
CHECK_EQ(expected, toJson(&str)); CHECK_EQ(expected, toJson(&str));

View File

@ -151,40 +151,6 @@ struct ACBuiltinsFixture : ACFixtureImpl<BuiltinsFixture>
{ {
}; };
#define LUAU_CHECK_HAS_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \
if (!count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)
#define LUAU_CHECK_HAS_NO_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \
if (count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)
TEST_SUITE_BEGIN("AutocompleteTest"); TEST_SUITE_BEGIN("AutocompleteTest");
TEST_CASE_FIXTURE(ACFixture, "empty_program") TEST_CASE_FIXTURE(ACFixture, "empty_program")

View File

@ -58,7 +58,11 @@ TEST_CASE("report_a_syntax_error")
TEST_CASE("noinfer_is_still_allowed") TEST_CASE("noinfer_is_still_allowed")
{ {
Config config; Config config;
auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, true);
ConfigOptions opts;
opts.compat = true;
auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, opts);
REQUIRE(!err); REQUIRE(!err);
CHECK_EQ(int(Luau::Mode::NoCheck), int(config.mode)); CHECK_EQ(int(Luau::Mode::NoCheck), int(config.mode));
@ -147,6 +151,10 @@ TEST_CASE("extra_globals")
TEST_CASE("lint_rules_compat") TEST_CASE("lint_rules_compat")
{ {
Config config; Config config;
ConfigOptions opts;
opts.compat = true;
auto err = parseConfig( auto err = parseConfig(
R"( R"(
{"lint": { {"lint": {
@ -156,7 +164,7 @@ TEST_CASE("lint_rules_compat")
}} }}
)", )",
config, config,
true opts
); );
REQUIRE(!err); REQUIRE(!err);

View File

@ -10,6 +10,7 @@ namespace Luau
ConstraintGeneratorFixture::ConstraintGeneratorFixture() ConstraintGeneratorFixture::ConstraintGeneratorFixture()
: Fixture() : Fixture()
, mainModule(new Module) , mainModule(new Module)
, simplifier(newSimplifier(NotNull{&arena}, builtinTypes))
, forceTheFlag{FFlag::LuauSolverV2, true} , forceTheFlag{FFlag::LuauSolverV2, true}
{ {
mainModule->name = "MainModule"; mainModule->name = "MainModule";
@ -25,6 +26,7 @@ void ConstraintGeneratorFixture::generateConstraints(const std::string& code)
cg = std::make_unique<ConstraintGenerator>( cg = std::make_unique<ConstraintGenerator>(
mainModule, mainModule,
NotNull{&normalizer}, NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
NotNull(&moduleResolver), NotNull(&moduleResolver),
builtinTypes, builtinTypes,
@ -44,8 +46,19 @@ void ConstraintGeneratorFixture::solve(const std::string& code)
{ {
generateConstraints(code); generateConstraints(code);
ConstraintSolver cs{ ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, NotNull{dfg.get()}, {} NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{rootScope},
constraints,
"MainModule",
NotNull(&moduleResolver),
{},
&logger,
NotNull{dfg.get()},
{}
}; };
cs.run(); cs.run();
} }

View File

@ -4,8 +4,9 @@
#include "Luau/ConstraintGenerator.h" #include "Luau/ConstraintGenerator.h"
#include "Luau/ConstraintSolver.h" #include "Luau/ConstraintSolver.h"
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/TypeArena.h" #include "Luau/EqSatSimplification.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/TypeArena.h"
#include "Fixture.h" #include "Fixture.h"
#include "ScopedFlags.h" #include "ScopedFlags.h"
@ -20,6 +21,7 @@ struct ConstraintGeneratorFixture : Fixture
DcrLogger logger; DcrLogger logger;
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
SimplifierPtr simplifier;
TypeCheckLimits limits; TypeCheckLimits limits;
TypeFunctionRuntime typeFunctionRuntime{NotNull{&ice}, NotNull{&limits}}; TypeFunctionRuntime typeFunctionRuntime{NotNull{&ice}, NotNull{&limits}};

View File

@ -11,9 +11,7 @@ LUAU_EQSAT_ATOM(I32, int);
LUAU_EQSAT_ATOM(Bool, bool); LUAU_EQSAT_ATOM(Bool, bool);
LUAU_EQSAT_ATOM(Str, std::string); LUAU_EQSAT_ATOM(Str, std::string);
LUAU_EQSAT_FIELD(Left); LUAU_EQSAT_NODE_ARRAY(Add, 2);
LUAU_EQSAT_FIELD(Right);
LUAU_EQSAT_NODE_FIELDS(Add, Left, Right);
using namespace Luau; using namespace Luau;
@ -117,8 +115,8 @@ TEST_CASE("node_field")
Add add{left, right}; Add add{left, right};
EqSat::Id left2 = add.field<Left>(); EqSat::Id left2 = add.operands()[0];
EqSat::Id right2 = add.field<Right>(); EqSat::Id right2 = add.operands()[1];
CHECK(left == left2); CHECK(left == left2);
CHECK(left != right2); CHECK(left != right2);
@ -135,10 +133,10 @@ TEST_CASE("language_operands")
const Add* add = v2.get<Add>(); const Add* add = v2.get<Add>();
REQUIRE(add); REQUIRE(add);
EqSat::Slice<EqSat::Id> actual = v2.operands(); EqSat::Slice<const EqSat::Id> actual = v2.operands();
CHECK(actual.size() == 2); CHECK(actual.size() == 2);
CHECK(actual[0] == add->field<Left>()); CHECK(actual[0] == add->operands()[0]);
CHECK(actual[1] == add->field<Right>()); CHECK(actual[1] == add->operands()[1]);
} }
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -0,0 +1,728 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Fixture.h"
#include "Luau/EqSatSimplification.h"
using namespace Luau;
struct ESFixture : Fixture
{
ScopedFastFlag newSolverOnly{FFlag::LuauSolverV2, true};
TypeArena arena_;
const NotNull<TypeArena> arena{&arena_};
SimplifierPtr simplifier;
TypeId parentClass;
TypeId childClass;
TypeId anotherChild;
TypeId unrelatedClass;
TypeId genericT = arena_.addType(GenericType{"T"});
TypeId genericU = arena_.addType(GenericType{"U"});
TypeId numberToString = arena_.addType(FunctionType{
arena_.addTypePack({builtinTypes->numberType}),
arena_.addTypePack({builtinTypes->stringType})
});
TypeId stringToNumber = arena_.addType(FunctionType{
arena_.addTypePack({builtinTypes->stringType}),
arena_.addTypePack({builtinTypes->numberType})
});
ESFixture()
: simplifier(newSimplifier(arena, builtinTypes))
{
createSomeClasses(&frontend);
ScopePtr moduleScope = frontend.globals.globalScope;
parentClass = moduleScope->linearSearchForBinding("Parent")->typeId;
childClass = moduleScope->linearSearchForBinding("Child")->typeId;
anotherChild = moduleScope->linearSearchForBinding("AnotherChild")->typeId;
unrelatedClass = moduleScope->linearSearchForBinding("Unrelated")->typeId;
}
std::optional<std::string> simplifyStr(TypeId ty)
{
auto res = eqSatSimplify(NotNull{simplifier.get()}, ty);
LUAU_ASSERT(res);
return toString(res->result);
}
TypeId tbl(TableType::Props props)
{
return arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, TableState::Sealed});
}
};
TEST_SUITE_BEGIN("EqSatSimplification");
TEST_CASE_FIXTURE(ESFixture, "primitive")
{
CHECK("number" == simplifyStr(builtinTypes->numberType));
}
TEST_CASE_FIXTURE(ESFixture, "number | number")
{
TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->numberType}});
CHECK("number" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "number | string")
{
CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = number | t1")
{
TypeId ty = arena->freshType(nullptr);
asMutable(ty)->ty.emplace<UnionType>(std::vector<TypeId>{builtinTypes->numberType, ty});
CHECK("number" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "number | string | number")
{
TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType, builtinTypes->numberType}});
CHECK("number | string" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "string | (number | string) | number")
{
TypeId u1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}});
TypeId u2 = arena->addType(UnionType{{builtinTypes->stringType, u1, builtinTypes->numberType}});
CHECK("number | string" == simplifyStr(u2));
}
TEST_CASE_FIXTURE(ESFixture, "string | any")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->anyType}})));
}
TEST_CASE_FIXTURE(ESFixture, "any | string")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "any | never")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string | unknown")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->unknownType}})));
}
TEST_CASE_FIXTURE(ESFixture, "unknown | string")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "unknown | never")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string | never")
{
CHECK("string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string | never | number")
{
CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType, builtinTypes->numberType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & string")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & number")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->numberType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & unknown")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->unknownType}})));
}
TEST_CASE_FIXTURE(ESFixture, "never & string")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->neverType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & (unknown | never)")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->stringType,
arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "true | false")
{
CHECK("boolean" == simplifyStr(arena->addType(UnionType{{builtinTypes->trueType, builtinTypes->falseType}})));
}
/*
* Intuitively, if we have a type like
*
* x where x = A & B & (C | D | x)
*
* We know that x is certainly not larger than A & B.
* We also know that the union (C | D | x) can be rewritten `(C | D | (A & B & (C | D | x)))
* This tells us that the union part is not smaller than A & B.
* We can therefore discard the union entirely and simplify this type to A & B
*/
TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (number | t1)")
{
TypeId intersectionTy = arena->addType(BlockedType{});
TypeId unionTy = arena->addType(UnionType{{builtinTypes->numberType, intersectionTy}});
asMutable(intersectionTy)->ty.emplace<IntersectionType>(std::vector<TypeId>{builtinTypes->stringType, unionTy});
CHECK("string" == simplifyStr(intersectionTy));
}
TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (unknown | t1)")
{
TypeId intersectionTy = arena->addType(BlockedType{});
TypeId unionTy = arena->addType(UnionType{{builtinTypes->unknownType, intersectionTy}});
asMutable(intersectionTy)->ty.emplace<IntersectionType>(std::vector<TypeId>{builtinTypes->stringType, unionTy});
CHECK("string" == simplifyStr(intersectionTy));
}
TEST_CASE_FIXTURE(ESFixture, "error | unknown")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->errorType, builtinTypes->unknownType}})));
}
TEST_CASE_FIXTURE(ESFixture, "\"hello\" | string")
{
CHECK("string" == simplifyStr(arena->addType(UnionType{{
arena->addType(SingletonType{StringSingleton{"hello"}}), builtinTypes->stringType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "\"hello\" | \"world\" | \"hello\"")
{
CHECK("\"hello\" | \"world\"" == simplifyStr(arena->addType(UnionType{{
arena->addType(SingletonType{StringSingleton{"hello"}}),
arena->addType(SingletonType{StringSingleton{"world"}}),
arena->addType(SingletonType{StringSingleton{"hello"}}),
}})));
}
TEST_CASE_FIXTURE(ESFixture, "nil | boolean | number | string | thread | function | table | class | buffer")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{
builtinTypes->nilType,
builtinTypes->booleanType,
builtinTypes->numberType,
builtinTypes->stringType,
builtinTypes->threadType,
builtinTypes->functionType,
builtinTypes->tableType,
builtinTypes->classType,
builtinTypes->bufferType,
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Parent & number")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
parentClass, builtinTypes->numberType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child & Parent")
{
CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{
childClass, parentClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child & Unrelated")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
childClass, unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child | Parent")
{
CHECK("Parent" == simplifyStr(arena->addType(UnionType{{
childClass, parentClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "class | Child")
{
CHECK("class" == simplifyStr(arena->addType(UnionType{{
builtinTypes->classType, childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Parent | class | Child")
{
CHECK("class" == simplifyStr(arena->addType(UnionType{{
parentClass, builtinTypes->classType, childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Parent | Unrelated")
{
CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{
parentClass, unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "never | Parent | Unrelated")
{
CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{
builtinTypes->neverType, parentClass, unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "never | Parent | (number & string) | Unrelated")
{
CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{
builtinTypes->neverType, parentClass,
arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}),
unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "T & U")
{
CHECK("T & U" == simplifyStr(arena->addType(IntersectionType{{
genericT, genericU
}})));
}
TEST_CASE_FIXTURE(ESFixture, "boolean & true")
{
CHECK("true" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->booleanType, builtinTypes->trueType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "boolean & (true | number | string | thread | function | table | class | buffer)")
{
TypeId truthy = arena->addType(UnionType{{
builtinTypes->trueType,
builtinTypes->numberType,
builtinTypes->stringType,
builtinTypes->threadType,
builtinTypes->functionType,
builtinTypes->tableType,
builtinTypes->classType,
builtinTypes->bufferType,
}});
CHECK("true" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->booleanType, truthy
}})));
}
TEST_CASE_FIXTURE(ESFixture, "boolean & ~(false?)")
{
CHECK("true" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->booleanType, builtinTypes->truthyType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "false & ~(false?)")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->falseType, builtinTypes->truthyType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (number) -> string")
{
CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, numberToString}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (number) -> string")
{
CHECK("(number) -> string" == simplifyStr(arena->addType(UnionType{{numberToString, numberToString}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & function")
{
CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->functionType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & boolean")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->booleanType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & string")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & ~function")
{
TypeId notFunction = arena->addType(NegationType{builtinTypes->functionType});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, notFunction}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string | function")
{
CHECK("function" == simplifyStr(arena->addType(UnionType{{numberToString, builtinTypes->functionType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (string) -> number")
{
CHECK("((number) -> string) & ((string) -> number)" == simplifyStr(arena->addType(IntersectionType{{numberToString, stringToNumber}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (string) -> number")
{
CHECK("((number) -> string) | ((string) -> number)" == simplifyStr(arena->addType(UnionType{{numberToString, stringToNumber}})));
}
TEST_CASE_FIXTURE(ESFixture, "add<number, number>")
{
CHECK("number" == simplifyStr(arena->addType(
TypeFunctionInstanceType{builtinTypeFunctions().addFunc, {
builtinTypes->numberType, builtinTypes->numberType
}}
)));
}
TEST_CASE_FIXTURE(ESFixture, "union<number, number>")
{
CHECK("number" == simplifyStr(arena->addType(
TypeFunctionInstanceType{builtinTypeFunctions().unionFunc, {
builtinTypes->numberType, builtinTypes->numberType
}}
)));
}
TEST_CASE_FIXTURE(ESFixture, "never & ~string")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->neverType,
arena->addType(NegationType{builtinTypes->stringType})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "blocked & never")
{
const TypeId blocked = arena->addType(BlockedType{});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{blocked, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "blocked & ~number & function")
{
const TypeId blocked = arena->addType(BlockedType{});
const TypeId notNumber = arena->addType(NegationType{builtinTypes->numberType});
const TypeId ty = arena->addType(IntersectionType{{blocked, notNumber, builtinTypes->functionType}});
std::string expected = toString(blocked) + " & function";
CHECK(expected == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "(number | boolean | string | nil | table) & (false | nil)")
{
const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->stringType, builtinTypes->nilType, builtinTypes->tableType}});
CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number | boolean | nil) & (false | nil)")
{
const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->nilType}});
CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(boolean | nil) & (false | nil)")
{
const TypeId t1 = arena->addType(UnionType{{builtinTypes->booleanType, builtinTypes->nilType}});
CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}})));
}
// (('a & false) | ('a & nil)) | number
// Child & ~Parent
// ~Parent & Child
// ~Child & Parent
// Parent & ~Child
// ~Child & ~Parent
// ~Parent & ~Child
TEST_CASE_FIXTURE(ESFixture, "free & string & number")
{
Scope scope{builtinTypes->anyTypePack};
const TypeId freeTy = arena->addType(FreeType{&scope});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{freeTy, builtinTypes->numberType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(blocked & number) | (blocked & number)")
{
const TypeId blocked = arena->addType(BlockedType{});
const TypeId u = arena->addType(IntersectionType{{blocked, builtinTypes->numberType}});
const TypeId ty = arena->addType(UnionType{{u, u}});
const std::string blockedStr = toString(blocked);
CHECK(blockedStr + " & number" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "{} & unknown")
{
CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{
tbl({}),
builtinTypes->unknownType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{} & table")
{
CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{
tbl({}),
builtinTypes->tableType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{} & ~(false?)")
{
CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{
tbl({}),
builtinTypes->truthyType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: number}")
{
const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}});
const TypeId hasX = tbl({{"x", builtinTypes->numberType}});
const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}});
auto res = eqSatSimplify(NotNull{simplifier.get()}, ty);
CHECK("{ x: number }" == toString(res->result));
// Also assert that we don't allocate a fresh TableType in this case.
CHECK(follow(res->result) == hasX);
}
TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: ~(false?)}")
{
const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}});
const TypeId hasX = tbl({{"x", builtinTypes->truthyType}});
const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}});
auto res = eqSatSimplify(NotNull{simplifier.get()}, ty);
CHECK("{ x: number }" == toString(res->result));
}
TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) }")
{
// {x: number?}?
const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}});
// {x: ~(false?)}
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
const TypeId ty = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}});
CHECK("{ x: number }" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "never | (({ x: number? }?) & { x: ~(false?) })")
{
// {x: number?}?
const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}});
// {x: ~(false?)}
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
// ({x: number?}?) & {x: ~(false?)}
const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}});
const TypeId ty = arena->addType(UnionType{{builtinTypes->neverType, intersectionTy}});
CHECK("{ x: number }" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "({ x: number? }?) & { x: ~(false?) } & ~(false?)")
{
// {x: number?}?
const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}});
// {x: ~(false?)}
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
// ({x: number?}?) & {x: ~(false?)} & ~(false?)
const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}});
CHECK("{ x: number }" == simplifyStr(intersectionTy));
}
#if 0
// TODO
TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) } & ~(false?)) | number")
{
// ({ x: number? }?) & { x: ~(false?) } & ~(false?)
const TypeId xWithOptionalNumber = tbl({{"x", builtinTypes->optionalNumberType}});
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}});
const TypeId ty = arena->addType(UnionType{{intersectionTy, builtinTypes->numberType}});
CHECK("{ x: number } | number" == simplifyStr(ty));
}
#endif
TEST_CASE_FIXTURE(ESFixture, "number & no-refine")
{
CHECK("number" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->noRefineType}})));
}
TEST_CASE_FIXTURE(ESFixture, "{ x: number } & ~boolean")
{
const TypeId tblTy = tbl(TableType::Props{{"x", builtinTypes->numberType}});
const TypeId ty = arena->addType(IntersectionType{{
tblTy,
arena->addType(NegationType{builtinTypes->booleanType})
}});
CHECK("{ x: number }" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "(nil & string)?")
{
const TypeId nilAndString = arena->addType(IntersectionType{{builtinTypes->nilType, builtinTypes->stringType}});
const TypeId ty = arena->addType(UnionType{{nilAndString, builtinTypes->nilType}});
CHECK("nil" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "string & \"hi\"")
{
const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}});
CHECK("\"hi\"" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, hi}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & (\"hi\" | \"bye\")")
{
const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}});
const TypeId bye = arena->addType(SingletonType{StringSingleton{"bye"}});
CHECK("\"bye\" | \"hi\"" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->stringType,
arena->addType(UnionType{{hi, bye}})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & ~Child")
{
const TypeId ty = arena->addType(IntersectionType{{
arena->addType(UnionType{{childClass, unrelatedClass}}),
arena->addType(NegationType{childClass})
}});
CHECK("Unrelated" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "string & ~Child")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->stringType,
arena->addType(NegationType{childClass})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & Child")
{
CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{
arena->addType(UnionType{{childClass, unrelatedClass}}),
childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(Child | AnotherChild) & ~Child")
{
CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{
arena->addType(UnionType{{childClass, anotherChild}}),
childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: never }")
{
const TypeId ty = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->neverType}});
CHECK("never" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: number? } & { x: string }")
{
const TypeId leftTable = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->optionalNumberType}});
const TypeId rightTable = tbl({{"x", builtinTypes->stringType}});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{leftTable, rightTable}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child & add<Child | AnotherChild | string, Parent>")
{
const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}});
const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{
builtinTypeFunctions().addFunc,
{u, parentClass},
{}
});
const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}});
CHECK("Child & add<AnotherChild | Child | string, Parent>" == simplifyStr(intersection));
}
TEST_CASE_FIXTURE(ESFixture, "Child & intersect<Child | AnotherChild | string, Parent>")
{
const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}});
const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{
builtinTypeFunctions().intersectFunc,
{u, parentClass},
{}
});
const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}});
CHECK("Child" == simplifyStr(intersection));
}
// {someKey: ~any}
//
// Maybe something we could do here is to try to reduce the key, get the
// class->node mapping, and skip the extraction process if the class corresponds
// to TNever.
// t1 where t1 = add<union<number, t1>, number>
TEST_SUITE_END();

View File

@ -293,3 +293,37 @@ using DifferFixtureWithBuiltins = DifferFixtureGeneric<BuiltinsFixture>;
} while (false) } while (false)
#define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result) #define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result)
#define LUAU_CHECK_HAS_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \
if (!count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)
#define LUAU_CHECK_HAS_NO_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \
if (count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)

View File

@ -4,19 +4,37 @@
#include "Fixture.h" #include "Fixture.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Autocomplete.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/AutocompleteTypes.h"
using namespace Luau; using namespace Luau;
LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauAllowFragmentParsing);
LUAU_FASTFLAG(LuauStoreDFGOnModule2); LUAU_FASTFLAG(LuauStoreDFGOnModule2);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
static std::optional<AutocompleteEntryMap> nullCallback(std::string tag, std::optional<const ClassType*> ptr, std::optional<std::string> contents)
{
return std::nullopt;
}
struct FragmentAutocompleteFixture : Fixture struct FragmentAutocompleteFixture : Fixture
{ {
ScopedFastFlag sffs[3] = {{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}, {FFlag::LuauStoreDFGOnModule2, true}}; ScopedFastFlag sffs[4] = {
{FFlag::LuauAllowFragmentParsing, true},
{FFlag::LuauSolverV2, true},
{FFlag::LuauStoreDFGOnModule2, true},
{FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete, true}
};
FragmentAutocompleteFixture()
{
addGlobalBinding(frontend.globals, "table", Binding{builtinTypes->anyType});
addGlobalBinding(frontend.globals, "math", Binding{builtinTypes->anyType});
}
FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos) FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos)
{ {
ParseResult p = tryParse(source); // We don't care about parsing incomplete asts ParseResult p = tryParse(source); // We don't care about parsing incomplete asts
@ -26,7 +44,6 @@ struct FragmentAutocompleteFixture : Fixture
CheckResult checkBase(const std::string& document) CheckResult checkBase(const std::string& document)
{ {
ScopedFastFlag sff{FFlag::LuauSolverV2, true};
FrontendOptions opts; FrontendOptions opts;
opts.retainFullTypeGraphs = true; opts.retainFullTypeGraphs = true;
return this->frontend.check("MainModule", opts); return this->frontend.check("MainModule", opts);
@ -48,6 +65,16 @@ struct FragmentAutocompleteFixture : Fixture
options.runLintChecks = false; options.runLintChecks = false;
return Luau::typecheckFragment(frontend, "MainModule", cursorPos, options, document); return Luau::typecheckFragment(frontend, "MainModule", cursorPos, options, document);
} }
FragmentAutocompleteResult autocompleteFragment(const std::string& document, Position cursorPos)
{
FrontendOptions options;
options.retainFullTypeGraphs = true;
// Don't strictly need this in the new solver
options.forAutocomplete = true;
options.runLintChecks = false;
return Luau::fragmentAutocomplete(frontend, document, "MainModule", cursorPos, options, nullCallback);
}
}; };
TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests");
@ -172,6 +199,13 @@ TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteParserTests"); TEST_SUITE_BEGIN("FragmentAutocompleteParserTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer")
{
check("local a =");
auto fragment = parseFragment("local a =", Position(0, 10));
CHECK_EQ("local a =", fragment.fragmentToParse);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null")
{ {
auto res = check(R"( auto res = check(R"(
@ -278,6 +312,33 @@ local y = 5
CHECK_EQ("y", std::string(rhs->name.value)); CHECK_EQ("y", std::string(rhs->name.value));
} }
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope")
{
check(R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)");
auto fragment = parseFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{6, 0}
);
CHECK_EQ("function abc()\n local myInnerLocal = 1\n\n end\n", fragment.fragmentToParse);
}
TEST_SUITE_END(); TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests");
@ -302,7 +363,7 @@ local z = x + y
Position{3, 15} Position{3, 15}
); );
auto opt = linearSearchForBinding(fragment.freshScope, "z"); auto opt = linearSearchForBinding(fragment.freshScope.get(), "z");
REQUIRE(opt); REQUIRE(opt);
CHECK_EQ("number", toString(*opt)); CHECK_EQ("number", toString(*opt));
} }
@ -326,9 +387,222 @@ local y = 5
Position{2, 11} Position{2, 11}
); );
auto correct = linearSearchForBinding(fragment.freshScope, "z"); auto correct = linearSearchForBinding(fragment.freshScope.get(), "z");
REQUIRE(correct); REQUIRE(correct);
CHECK_EQ("number", toString(*correct)); CHECK_EQ("number", toString(*correct));
} }
TEST_SUITE_END(); TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_simple_property_access")
{
auto res = check(
R"(
local tbl = { abc = 1234}
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
local tbl = { abc = 1234}
tbl.
)",
Position{2, 5}
);
LUAU_ASSERT(fragment.freshScope);
CHECK_EQ(1, fragment.acResults.entryMap.size());
CHECK(fragment.acResults.entryMap.count("abc"));
CHECK_EQ(AutocompleteContext::Property, fragment.acResults.context);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_nested_property_access")
{
auto res = check(
R"(
local tbl = { abc = { def = 1234, egh = false } }
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
local tbl = { abc = { def = 1234, egh = false } }
tbl.abc.
)",
Position{2, 8}
);
LUAU_ASSERT(fragment.freshScope);
CHECK_EQ(2, fragment.acResults.entryMap.size());
CHECK(fragment.acResults.entryMap.count("def"));
CHECK(fragment.acResults.entryMap.count("egh"));
CHECK_EQ(fragment.acResults.context, AutocompleteContext::Property);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "inline_autocomplete_picks_the_right_scope")
{
auto res = check(
R"(
type Table = { a: number, b: number }
do
type Table = { x: string, y: string }
end
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
type Table = { a: number, b: number }
do
type Table = { x: string, y: string }
local a : T
end
)",
Position{4, 15}
);
LUAU_ASSERT(fragment.freshScope);
REQUIRE(fragment.acResults.entryMap.count("Table"));
REQUIRE(fragment.acResults.entryMap["Table"].type);
const TableType* tv = get<TableType>(follow(*fragment.acResults.entryMap["Table"].type));
REQUIRE(tv);
CHECK(tv->props.count("x"));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nested_recursive_function")
{
auto res = check(R"(
function foo()
end
)");
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
function foo()
end
)",
Position{2, 0}
);
CHECK(fragment.acResults.entryMap.count("foo"));
CHECK_EQ(AutocompleteContext::Statement, fragment.acResults.context);
}
// Start compatibility tests!
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "empty_program")
{
check("");
auto frag = autocompleteFragment(" ", Position{0, 1});
auto ac = frag.acResults;
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Statement);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer")
{
check("local a =");
auto frag = autocompleteFragment("local a =", Position{0, 9});
auto ac = frag.acResults;
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Expression);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "leave_numbers_alone")
{
check("local a = 3.");
auto frag = autocompleteFragment("local a = 3.", Position{0, 12});
auto ac = frag.acResults;
CHECK(ac.entryMap.empty());
CHECK_EQ(ac.context, AutocompleteContext::Unknown);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "user_defined_globals")
{
check("local myLocal = 4; ");
auto frag = autocompleteFragment("local myLocal = 4; ", Position{0, 18});
auto ac = frag.acResults;
CHECK(ac.entryMap.count("myLocal"));
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Statement);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "dont_suggest_local_before_its_definition")
{
check(R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)");
// autocomplete after abc but before myInnerLocal
auto fragment = autocompleteFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{3, 0}
);
auto ac = fragment.acResults;
CHECK(ac.entryMap.count("myLocal"));
LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal");
// autocomplete after my inner local
fragment = autocompleteFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{4, 0}
);
ac = fragment.acResults;
CHECK(ac.entryMap.count("myLocal"));
CHECK(ac.entryMap.count("myInnerLocal"));
fragment = autocompleteFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{6, 0}
);
ac = fragment.acResults;
CHECK(ac.entryMap.count("myLocal"));
LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal");
}
TEST_SUITE_END();

View File

@ -18,6 +18,7 @@ LUAU_FASTINT(LuauParseErrorLimit)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr) LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport)
namespace namespace
{ {
@ -2377,10 +2378,15 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms")
TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions")
{ {
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag sff2{FFlag::LuauUserDefinedTypeFunParseExport, true};
AstStat* stat = parse(R"( AstStat* stat = parse(R"(
type function foo() type function foo()
return return types.number
end
export type function bar()
return types.string
end end
)"); )");
@ -2417,7 +2423,6 @@ TEST_CASE_FIXTURE(Fixture, "invalid_user_defined_type_functions")
{ {
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
matchParseError("export type function foo() end", "Type function cannot be exported");
matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'"); matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'");
matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'"); matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'");
} }

View File

@ -424,6 +424,13 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireUnprefixedPath")
assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"}); assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"});
} }
TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithExtension")
{
std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency.luau";
runProtectedRequire(path);
assertOutputContainsAll({"false", "error requiring module: consider removing the file extension"});
}
TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias")
{ {
std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/alias_requirer"; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/alias_requirer";

View File

@ -964,6 +964,7 @@ TEST_CASE_FIXTURE(Fixture, "correct_stringification_user_defined_type_functions"
std::vector<TypeId>{builtinTypes->numberType}, // Type Function Arguments std::vector<TypeId>{builtinTypes->numberType}, // Type Function Arguments
{}, {},
{AstName{"woohoo"}}, // Type Function Name {AstName{"woohoo"}}, // Type Function Name
{},
}; };
Type tv{tftt}; Type tv{tftt};

View File

@ -16,6 +16,8 @@ LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite)
LUAU_FASTFLAG(LuauUserTypeFunFixMetatable) LUAU_FASTFLAG(LuauUserTypeFunFixMetatable)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionResetState) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionResetState)
LUAU_FASTFLAG(LuauUserTypeFunNonstrict) LUAU_FASTFLAG(LuauUserTypeFunNonstrict)
LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal)
LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport)
TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests");
@ -1298,4 +1300,92 @@ local a: foo<> = "a"
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "implicit_export")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true};
ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true};
ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true};
fileResolver.source["game/A"] = R"(
type function concat(a, b)
return types.singleton(a:value() .. b:value())
end
export type Concat<T, U> = concat<T, U>
local a: concat<'first', 'second'>
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")");
CheckResult bResult = check(R"(
local Test = require(game.A);
local b: Test.Concat<'third', 'fourth'>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
CHECK(toString(requireType("b")) == R"("thirdfourth")");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "local_scope")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true};
ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true};
ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true};
CheckResult result = check(R"(
type function foo()
return "hi"
end
local function test()
type function bar()
return types.singleton(foo())
end
return ("" :: any) :: bar<>
end
local a = test()
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK(toString(requireType("a")) == R"("hi")");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "explicit_export")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true};
ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true};
ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true};
ScopedFastFlag luauUserDefinedTypeFunParseExport{FFlag::LuauUserDefinedTypeFunParseExport, true};
fileResolver.source["game/A"] = R"(
export type function concat(a, b)
return types.singleton(a:value() .. b:value())
end
local a: concat<'first', 'second'>
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")");
CheckResult bResult = check(R"(
local Test = require(game.A);
local b: Test.concat<'third', 'fourth'>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
CHECK(toString(requireType("b")) == R"("thirdfourth")");
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -20,6 +20,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack)
LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions)
TEST_SUITE_BEGIN("TypeInferFunctions"); TEST_SUITE_BEGIN("TypeInferFunctions");
@ -681,6 +682,11 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function")
TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") TEST_CASE_FIXTURE(Fixture, "higher_order_function_2")
{ {
// CLI-114134: this code *probably* wants the egraph in order
// to work properly. The new solver either falls over or
// forces so many constraints as to be unreliable.
DOES_NOT_PASS_NEW_SOLVER_GUARD();
CheckResult result = check(R"( CheckResult result = check(R"(
function bottomupmerge(comp, a, b, left, mid, right) function bottomupmerge(comp, a, b, left, mid, right)
local i, j = left, mid local i, j = left, mid
@ -743,6 +749,11 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3")
TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4")
{ {
// CLI-114134: this code *probably* wants the egraph in order
// to work properly. The new solver either falls over or
// forces so many constraints as to be unreliable.
DOES_NOT_PASS_NEW_SOLVER_GUARD();
CheckResult result = check(R"( CheckResult result = check(R"(
function bottomupmerge(comp, a, b, left, mid, right) function bottomupmerge(comp, a, b, left, mid, right)
local i, j = left, mid local i, j = left, mid
@ -2554,8 +2565,17 @@ end
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type") TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type")
{ {
if (!FFlag::LuauSolverV2) ScopedFastFlag sffs[] = {
return; {FFlag::LuauSolverV2, true},
{FFlag::LuauDontRefCountTypesInTypeFunctions, true}
};
// CLI-114134: This test:
// a) Has a kind of weird result (suggesting `number | false` is not great);
// b) Is force solving some constraints.
// We end up with a weird recursive type that, if you roughly look at it, is
// clearly `number`. Hopefully the egraph will be able to unfold this.
CheckResult result = check(R"( CheckResult result = check(R"(
function fib(n) function fib(n)
return n < 2 and 1 or fib(n-1) + fib(n-2) return n < 2 and 1 or fib(n-1) + fib(n-2)
@ -2565,9 +2585,7 @@ end
LUAU_REQUIRE_ERRORS(result); LUAU_REQUIRE_ERRORS(result);
auto err = get<ExplicitFunctionAnnotationRecommended>(result.errors.back()); auto err = get<ExplicitFunctionAnnotationRecommended>(result.errors.back());
LUAU_ASSERT(err); LUAU_ASSERT(err);
CHECK("number" == toString(err->recommendedReturn)); CHECK("false | number" == toString(err->recommendedReturn));
REQUIRE(1 == err->recommendedArgs.size());
CHECK("number" == toString(err->recommendedArgs[0].second));
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type") TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type")
@ -2862,6 +2880,8 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun")
TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types")
{ {
ScopedFastFlag _{FFlag::LuauDontRefCountTypesInTypeFunctions, true};
CheckResult result = check(R"( CheckResult result = check(R"(
function foo(player) function foo(player)
local success,result = player:thing() local success,result = player:thing()
@ -2889,7 +2909,7 @@ TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types")
auto tm2 = get<TypePackMismatch>(result.errors[1]); auto tm2 = get<TypePackMismatch>(result.errors[1]);
REQUIRE(tm2); REQUIRE(tm2);
CHECK(toString(tm2->wantedTp) == "string"); CHECK(toString(tm2->wantedTp) == "string");
CHECK(toString(tm2->givenTp) == "buffer | class | function | number | string | table | thread | true"); CHECK(toString(tm2->givenTp) == "(buffer | class | function | number | string | table | thread | true) & unknown");
} }
else else
{ {

View File

@ -24,6 +24,7 @@ LUAU_FASTINT(LuauNormalizeCacheLimit);
LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauRecursionLimit);
LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues)
LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions)
using namespace Luau; using namespace Luau;
@ -1730,4 +1731,36 @@ TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue")
)")); )"));
} }
TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function")
{
ScopedFastFlag sffs[] = {
{FFlag::LuauSolverV2, true},
{FFlag::LuauDontRefCountTypesInTypeFunctions, true}
};
LUAU_CHECK_NO_ERRORS(check(R"(
--!strict
local function foo(a : string?)
local b = a or ""
return b:upper()
end
)"));
}
TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type")
{
ScopedFastFlag sffs[] = {
{FFlag::LuauSolverV2, true},
{FFlag::LuauDontRefCountTypesInTypeFunctions, true}
};
LUAU_CHECK_NO_ERRORS(check(R"(
--!strict
local function wtf(name: string?)
local message
message = "invalid alternate fiber: " .. (name or "UNNAMED alternate")
end
)"));
}
TEST_SUITE_END(); TEST_SUITE_END();