Merge branch 'upstream' into merge

This commit is contained in:
Hunter Goldstein 2024-11-08 11:35:18 -08:00
commit af9d9ba13e
62 changed files with 7307 additions and 2451 deletions

View File

@ -1,10 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/AutocompleteTypes.h"
#include "Luau/Location.h"
#include "Luau/Type.h"
#include <unordered_map>
#include <string>
#include <memory>
#include <optional>
@ -16,90 +16,8 @@ struct Frontend;
struct SourceModule;
struct Module;
struct TypeChecker;
using ModulePtr = std::shared_ptr<Module>;
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind
{
Property,
Binding,
Keyword,
String,
Type,
Module,
GeneratedFunction,
RequirePath,
};
enum class ParenthesesRecommendation
{
None,
CursorAfter,
CursorInside,
};
enum class TypeCorrectKind
{
None,
Correct,
CorrectFunctionResult,
};
struct AutocompleteEntry
{
AutocompleteEntryKind kind = AutocompleteEntryKind::Property;
// Nullopt if kind is Keyword
std::optional<TypeId> type = std::nullopt;
bool deprecated = false;
// Only meaningful if kind is Property.
bool wrongIndexType = false;
// Set if this suggestion matches the type expected in the context
TypeCorrectKind typeCorrect = TypeCorrectKind::None;
std::optional<const ClassType*> containingClass = std::nullopt;
std::optional<const Property*> prop = std::nullopt;
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
struct AutocompleteResult
{
AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap))
, ancestry(std::move(ancestry))
, context(context)
{
}
};
using ModuleName = std::string;
using StringCompletionCallback =
std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassType*> ctx, std::optional<std::string> contents)>;
struct FileResolver;
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback);
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau

View File

@ -0,0 +1,92 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Type.h"
#include <unordered_map>
namespace Luau
{
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind
{
Property,
Binding,
Keyword,
String,
Type,
Module,
GeneratedFunction,
RequirePath,
};
enum class ParenthesesRecommendation
{
None,
CursorAfter,
CursorInside,
};
enum class TypeCorrectKind
{
None,
Correct,
CorrectFunctionResult,
};
struct AutocompleteEntry
{
AutocompleteEntryKind kind = AutocompleteEntryKind::Property;
// Nullopt if kind is Keyword
std::optional<TypeId> type = std::nullopt;
bool deprecated = false;
// Only meaningful if kind is Property.
bool wrongIndexType = false;
// Set if this suggestion matches the type expected in the context
TypeCorrectKind typeCorrect = TypeCorrectKind::None;
std::optional<const ClassType*> containingClass = std::nullopt;
std::optional<const Property*> prop = std::nullopt;
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
struct AutocompleteResult
{
AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap))
, ancestry(std::move(ancestry))
, context(context)
{
}
};
using StringCompletionCallback =
std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassType*> ctx, std::optional<std::string> contents)>;
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau

View File

@ -5,6 +5,7 @@
#include "Luau/Constraint.h"
#include "Luau/ControlFlow.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Module.h"
#include "Luau/ModuleResolver.h"
@ -15,7 +16,6 @@
#include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h"
#include "Luau/Variant.h"
#include "Luau/Normalize.h"
#include <memory>
#include <vector>
@ -109,6 +109,9 @@ struct ConstraintGenerator
// Needed to be able to enable error-suppression preservation for immediate refinements.
NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
// Needed to register all available type functions for execution at later stages.
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// Needed to resolve modules to make 'require' import types properly.
@ -128,6 +131,7 @@ struct ConstraintGenerator
ConstraintGenerator(
ModulePtr module,
NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes,
@ -405,6 +409,7 @@ private:
TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
// make an intersect type function of these two types
TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs);
void prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program);
/** Scan the program for global definitions.
*
@ -435,6 +440,8 @@ private:
const ScopePtr& scope,
Location location
);
TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right);
};
/** Borrow a vector of pointers from a vector of owning pointers to constraints.

View File

@ -5,6 +5,7 @@
#include "Luau/Constraint.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/DenseHash.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Error.h"
#include "Luau/Location.h"
#include "Luau/Module.h"
@ -64,6 +65,7 @@ struct ConstraintSolver
NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer;
NotNull<Simplifier> simplifier;
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints;
@ -117,6 +119,7 @@ struct ConstraintSolver
explicit ConstraintSolver(
NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints,
@ -384,6 +387,10 @@ public:
**/
void reproduceConstraints(NotNull<Scope> scope, const Location& location, const Substitution& subst);
TypeId simplifyIntersection(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
TypeId simplifyIntersection(NotNull<Scope> scope, Location location, std::set<TypeId> parts);
TypeId simplifyUnion(NotNull<Scope> scope, Location location, TypeId left, TypeId right);
TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const;

View File

@ -0,0 +1,50 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeFwd.h"
#include "Luau/NotNull.h"
#include "Luau/DenseHash.h"
#include <memory>
#include <optional>
#include <vector>
namespace Luau
{
struct TypeArena;
}
// The EqSat stuff is pretty template heavy, so we go to some lengths to prevent
// the complexity from leaking outside its implementation sources.
namespace Luau::EqSatSimplification
{
struct Simplifier;
using SimplifierPtr = std::unique_ptr<Simplifier, void (*)(Simplifier*)>;
SimplifierPtr newSimplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
} // namespace Luau::EqSatSimplification
namespace Luau
{
struct EqSatSimplificationResult
{
TypeId result;
// New type function applications that were created by the reduction phase.
// We return these so that the ConstraintSolver can know to try to reduce
// them.
std::vector<TypeId> newTypeFunctions;
};
using EqSatSimplification::newSimplifier; // NOLINT: clang-tidy thinks these are unused. It is incorrect.
using Luau::EqSatSimplification::Simplifier; // NOLINT
using Luau::EqSatSimplification::SimplifierPtr;
std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simplifier, TypeId ty);
} // namespace Luau

View File

@ -0,0 +1,363 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/EGraph.h"
#include "Luau/Id.h"
#include "Luau/Language.h"
#include "Luau/Lexer.h" // For Allocator
#include "Luau/NotNull.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
struct TypeFunction;
}
namespace Luau::EqSatSimplification
{
using StringId = uint32_t;
using Id = Luau::EqSat::Id;
LUAU_EQSAT_UNIT(TNil);
LUAU_EQSAT_UNIT(TBoolean);
LUAU_EQSAT_UNIT(TNumber);
LUAU_EQSAT_UNIT(TString);
LUAU_EQSAT_UNIT(TThread);
LUAU_EQSAT_UNIT(TTopFunction);
LUAU_EQSAT_UNIT(TTopTable);
LUAU_EQSAT_UNIT(TTopClass);
LUAU_EQSAT_UNIT(TBuffer);
// Used for any type that eqsat can't do anything interesting with.
LUAU_EQSAT_ATOM(TOpaque, TypeId);
LUAU_EQSAT_ATOM(SBoolean, bool);
LUAU_EQSAT_ATOM(SString, StringId);
LUAU_EQSAT_ATOM(TFunction, TypeId);
LUAU_EQSAT_ATOM(TImportedTable, TypeId);
LUAU_EQSAT_ATOM(TClass, TypeId);
LUAU_EQSAT_UNIT(TAny);
LUAU_EQSAT_UNIT(TError);
LUAU_EQSAT_UNIT(TUnknown);
LUAU_EQSAT_UNIT(TNever);
LUAU_EQSAT_NODE_SET(Union);
LUAU_EQSAT_NODE_SET(Intersection);
LUAU_EQSAT_NODE_ARRAY(Negation, 1);
LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*);
LUAU_EQSAT_UNIT(TNoRefine);
LUAU_EQSAT_UNIT(Invalid);
// enodes are immutable, but types are cyclic. We need a way to tie the knot.
// We handle this by generating TBound nodes at points where we encounter cycles.
// Each TBound has an ordinal that we later map onto the type.
// We use a substitution rule to replace all TBound nodes with their referrent.
LUAU_EQSAT_ATOM(TBound, size_t);
// Tables are sufficiently unlike other enodes that the Language.h macros won't cut it.
struct TTable
{
explicit TTable(Id basis);
TTable(Id basis, std::vector<StringId> propNames_, std::vector<Id> propTypes_);
// All TTables extend some other table. This may be TTopTable.
//
// It will frequently be a TImportedTable, in which case we can reuse things
// like source location and documentation info.
Id getBasis() const;
EqSat::Slice<const Id> propTypes() const;
// TODO: Also support read-only table props
// TODO: Indexer type, index result type.
std::vector<StringId> propNames;
// The enode interface
EqSat::Slice<Id> mutableOperands();
EqSat::Slice<const Id> operands() const;
bool operator==(const TTable& rhs) const;
bool operator!=(const TTable& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const TTable& value) const;
};
private:
// The first element of this vector is the basis. Subsequent elements are
// property types. As we add other things like read-only properties and
// indexers, the structure of this array is likely to change.
//
// We encode our data in this way so that the operands() method can properly
// return a Slice<Id>.
std::vector<Id> storage;
};
using EType = EqSat::Language<
TNil,
TBoolean,
TNumber,
TString,
TThread,
TTopFunction,
TTopTable,
TTopClass,
TBuffer,
TOpaque,
SBoolean,
SString,
TFunction,
TTable,
TImportedTable,
TClass,
TAny,
TError,
TUnknown,
TNever,
Union,
Intersection,
Negation,
TTypeFun,
Invalid,
TNoRefine,
TBound>;
struct StringCache
{
Allocator allocator;
DenseHashMap<size_t, StringId> strings{{}};
std::vector<std::string_view> views;
StringId add(std::string_view s);
std::string_view asStringView(StringId id) const;
std::string asString(StringId id) const;
};
using EGraph = Luau::EqSat::EGraph<EType, struct Simplify>;
struct Simplify
{
using Data = bool;
template<typename T>
Data make(const EGraph&, const T&) const;
void join(Data& left, const Data& right) const;
};
struct Subst
{
Id eclass;
Id newClass;
std::string desc;
Subst(Id eclass, Id newClass, std::string desc = "");
};
struct Simplifier
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
EGraph egraph;
StringCache stringCache;
// enodes are immutable but types can be cyclic, so we need some way to
// encode the cycle. This map is used to connect TBound nodes to the right
// eclass.
//
// The cyclicIntersection rewrite rule uses this to sense when a cycle can
// be deleted from an intersection or union.
std::unordered_map<size_t, Id> mappingIdToClass;
std::vector<Subst> substs;
using RewriteRuleFn = void (Simplifier::*)(Id id);
Simplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
// Utilities
const EqSat::EClass<EType, Simplify::Data>& get(Id id) const;
Id find(Id id) const;
Id add(EType enode);
template<typename Tag>
const Tag* isTag(Id id) const;
template<typename Tag>
const Tag* isTag(const EType& enode) const;
void subst(Id from, Id to);
void subst(Id from, Id to, const std::string& ruleName);
void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
void unionClasses(std::vector<Id>& hereParts, Id there);
// Rewrite rules
void simplifyUnion(Id id);
void uninhabitedIntersection(Id id);
void intersectWithNegatedClass(Id id);
void intersectWithNoRefine(Id id);
void cyclicIntersectionOfUnion(Id id);
void cyclicUnionOfIntersection(Id id);
void expandNegation(Id id);
void intersectionOfUnion(Id id);
void intersectTableProperty(Id id);
void uninhabitedTable(Id id);
void unneededTableModification(Id id);
void builtinTypeFunctions(Id id);
void iffyTypeFunctions(Id id);
};
template<typename Tag>
struct QueryIterator
{
QueryIterator();
QueryIterator(EGraph* egraph, Id eclass);
bool operator==(const QueryIterator& other) const;
bool operator!=(const QueryIterator& other) const;
std::pair<const Tag*, size_t> operator*() const;
QueryIterator& operator++();
QueryIterator& operator++(int);
private:
EGraph* egraph = nullptr;
Id eclass;
size_t index = 0;
};
template<typename Tag>
struct Query
{
EGraph* egraph;
Id eclass;
Query(EGraph* egraph, Id eclass)
: egraph(egraph)
, eclass(eclass)
{
}
QueryIterator<Tag> begin()
{
return QueryIterator<Tag>{egraph, eclass};
}
QueryIterator<Tag> end()
{
return QueryIterator<Tag>{};
}
};
template<typename Tag>
QueryIterator<Tag>::QueryIterator()
: egraph(nullptr)
, eclass(Id{0})
, index(0)
{
}
template<typename Tag>
QueryIterator<Tag>::QueryIterator(EGraph* egraph_, Id eclass)
: egraph(egraph_)
, eclass(eclass)
, index(0)
{
const auto& ecl = (*egraph)[eclass];
static constexpr const int idx = EType::VariantTy::getTypeId<Tag>();
for (const auto& enode : ecl.nodes)
{
if (enode.index() < idx)
++index;
else
break;
}
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx)
{
egraph = nullptr;
index = 0;
}
}
template<typename Tag>
bool QueryIterator<Tag>::operator==(const QueryIterator<Tag>& rhs) const
{
if (egraph == nullptr && rhs.egraph == nullptr)
return true;
return egraph == rhs.egraph && eclass == rhs.eclass && index == rhs.index;
}
template<typename Tag>
bool QueryIterator<Tag>::operator!=(const QueryIterator<Tag>& rhs) const
{
return !(*this == rhs);
}
template<typename Tag>
std::pair<const Tag*, size_t> QueryIterator<Tag>::operator*() const
{
LUAU_ASSERT(egraph != nullptr);
EGraph::EClassT& ecl = (*egraph)[eclass];
LUAU_ASSERT(index < ecl.nodes.size());
auto& enode = ecl.nodes[index];
Tag* result = enode.template get<Tag>();
LUAU_ASSERT(result);
return {result, index};
}
// pre-increment
template<typename Tag>
QueryIterator<Tag>& QueryIterator<Tag>::operator++()
{
const auto& ecl = (*egraph)[eclass];
++index;
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId<Tag>())
{
egraph = nullptr;
index = 0;
}
return *this;
}
// post-increment
template<typename Tag>
QueryIterator<Tag>& QueryIterator<Tag>::operator++(int)
{
QueryIterator<Tag> res = *this;
++res;
return res;
}
} // namespace Luau::EqSatSimplification

View File

@ -32,7 +32,11 @@ struct ModuleInfo
bool optional = false;
};
using RequireSuggestion = std::string;
struct RequireSuggestion
{
std::string label;
std::string fullPath;
};
using RequireSuggestions = std::vector<RequireSuggestion>;
struct FileResolver

View File

@ -3,9 +3,10 @@
#include "Luau/Ast.h"
#include "Luau/Parser.h"
#include "Luau/Autocomplete.h"
#include "Luau/AutocompleteTypes.h"
#include "Luau/DenseHash.h"
#include "Luau/Module.h"
#include "Luau/Frontend.h"
#include <memory>
#include <vector>
@ -27,13 +28,23 @@ struct FragmentParseResult
std::string fragmentToParse;
AstStatBlock* root = nullptr;
std::vector<AstNode*> ancestry;
AstStat* nearestStatement = nullptr;
std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>();
};
struct FragmentTypeCheckResult
{
ModulePtr incrementalModule = nullptr;
Scope* freshScope = nullptr;
ScopePtr freshScope;
std::vector<AstNode*> ancestry;
};
struct FragmentAutocompleteResult
{
ModulePtr incrementalModule;
Scope* freshScope;
TypeArena arenaForAutocomplete;
AutocompleteResult acResults;
};
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);
@ -48,11 +59,11 @@ FragmentTypeCheckResult typecheckFragment(
std::string_view src
);
AutocompleteResult fragmentAutocomplete(
FragmentAutocompleteResult fragmentAutocomplete(
Frontend& frontend,
std::string_view src,
const ModuleName& moduleName,
Position& cursorPosition,
Position cursorPosition,
std::optional<FrontendOptions> opts,
StringCompletionCallback callback
);

View File

@ -44,6 +44,7 @@ struct ToStringOptions
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self
bool useQuestionMarks = true; // If true, use a postfix ? for options, else write them out as unions that include nil.
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections

View File

@ -31,6 +31,7 @@ namespace Luau
struct TypeArena;
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct Module;
struct TypeFunction;
struct Constraint;
@ -598,6 +599,18 @@ struct ClassType
}
};
// Data required to initialize a user-defined function and its environment
struct UserDefinedFunctionData
{
// Store a weak module reference to ensure the lifetime requirements are preserved
std::weak_ptr<Module> owner;
// References to AST elements are owned by the Module allocator which also stores this type
AstStatTypeFunction* definition = nullptr;
DenseHashMap<Name, AstStatTypeFunction*> environment{""};
};
/**
* An instance of a type function that has not yet been reduced to a more concrete
* type. The constraint solver receives a constraint to reduce each
@ -613,17 +626,20 @@ struct TypeFunctionInstanceType
std::vector<TypePackId> packArguments;
std::optional<AstName> userFuncName; // Name of the user-defined type function; only available for UDTFs
UserDefinedFunctionData userFuncData;
TypeFunctionInstanceType(
NotNull<const TypeFunction> function,
std::vector<TypeId> typeArguments,
std::vector<TypePackId> packArguments,
std::optional<AstName> userFuncName = std::nullopt
std::optional<AstName> userFuncName,
UserDefinedFunctionData userFuncData
)
: function(function)
, typeArguments(typeArguments)
, packArguments(packArguments)
, userFuncName(userFuncName)
, userFuncData(userFuncData)
{
}
@ -640,6 +656,13 @@ struct TypeFunctionInstanceType
, packArguments(packArguments)
{
}
TypeFunctionInstanceType(NotNull<const TypeFunction> function, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
: function{function}
, typeArguments(typeArguments)
, packArguments(packArguments)
{
}
};
/** Represents a pending type alias instantiation.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,27 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/AutocompleteTypes.h"
namespace Luau
{
struct Module;
struct FileResolver;
using ModulePtr = std::shared_ptr<Module>;
using ModuleName = std::string;
AutocompleteResult autocomplete_(
const ModulePtr& module,
NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena,
std::vector<AstNode*>& ancestry,
Scope* globalScope,
const ScopePtr& scopeAtPosition,
Position position,
FileResolver* fileResolver,
StringCompletionCallback callback
);
} // namespace Luau

View File

@ -33,7 +33,7 @@ LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2)
LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix)
LUAU_FASTFLAG(AutocompleteRequirePathSuggestions)
LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2);
namespace Luau
{
@ -426,7 +426,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze);
}
if (FFlag::AutocompleteRequirePathSuggestions)
if (FFlag::AutocompleteRequirePathSuggestions2)
{
TypeId requireTy = getGlobalBinding(globals, "require");
attachTag(requireTy, kRequireTagName);

View File

@ -3,6 +3,8 @@
#include "Luau/Constraint.h"
#include "Luau/VisitType.h"
LUAU_FASTFLAGVARIABLE(LuauDontRefCountTypesInTypeFunctions)
namespace Luau
{
@ -46,6 +48,21 @@ struct ReferenceCountInitializer : TypeOnceVisitor
// ClassTypes never contain free types.
return false;
}
bool visit(TypeId, const TypeFunctionInstanceType&) override
{
// We do not consider reference counted types that are inside a type
// function to be part of the reachable reference counted types.
// Otherwise, code can be constructed in just the right way such
// that two type functions both claim to mutate a free type, which
// prevents either type function from trying to generalize it, so
// we potentially get stuck.
//
// The default behavior here is `true` for "visit the child types"
// of this type, hence:
return !FFlag::LuauDontRefCountTypesInTypeFunctions;
}
};
bool isReferenceCountedType(const TypeId typ)

View File

@ -10,6 +10,7 @@
#include "Luau/Def.h"
#include "Luau/DenseHash.h"
#include "Luau/ModuleResolver.h"
#include "Luau/NotNull.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Refinement.h"
#include "Luau/Scope.h"
@ -30,11 +31,13 @@
LUAU_FASTINT(LuauCheckRecursionLimit)
LUAU_FASTFLAG(DebugLuauLogSolverToJson)
LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(DebugLuauEqSatSimplification)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAG(LuauTypestateBuiltins2)
LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues)
LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunExportedAndLocal)
LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations)
namespace Luau
@ -172,6 +175,7 @@ bool hasFreeType(TypeId ty)
ConstraintGenerator::ConstraintGenerator(
ModulePtr module,
NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes,
@ -188,6 +192,7 @@ ConstraintGenerator::ConstraintGenerator(
, rootScope(nullptr)
, dfg(dfg)
, normalizer(normalizer)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime)
, moduleResolver(moduleResolver)
, ice(ice)
@ -257,7 +262,7 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
d = follow(d);
if (d == ty)
continue;
domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result;
domainTy = simplifyUnion(scope, Location{}, domainTy, d);
}
LUAU_ASSERT(get<BlockedType>(ty));
@ -267,7 +272,15 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)
void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block)
{
// We prepopulate global data in the resumeScope to avoid writing data into the old modules scopes
prepopulateGlobalScopeForFragmentTypecheck(globalScope, resumeScope, block);
// Pre
// We need to pop the interior types,
interiorTypes.emplace_back();
visitBlockWithoutChildScope(resumeScope, block);
// Post
interiorTypes.pop_back();
fillInInferredBindings(resumeScope, block);
if (logger)
@ -282,7 +295,7 @@ void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStat
d = follow(d);
if (d == ty)
continue;
domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result;
domainTy = simplifyUnion(resumeScope, resumeScope->location, domainTy, d);
}
LUAU_ASSERT(get<BlockedType>(ty));
@ -711,7 +724,7 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
continue;
}
if (scope->parent != globalScope)
if (!FFlag::LuauUserTypeFunExportedAndLocal && scope->parent != globalScope)
{
reportError(function->location, GenericError{"Local user-defined functions are not supported yet"});
continue;
@ -740,17 +753,26 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function))
reportError(function->location, GenericError{*error});
TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{
NotNull{&builtinTypeFunctions().userFunc},
std::move(typeParams),
{},
function->name,
});
UserDefinedFunctionData udtfData;
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
udtfData.owner = module;
udtfData.definition = function;
}
TypeId typeFunctionTy = arena->addType(
TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, udtfData}
);
TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy};
// Set type bindings and definition locations for this user-defined type function
scope->privateTypeBindings[function->name.value] = std::move(typeFunction);
if (FFlag::LuauUserTypeFunExportedAndLocal && function->exported)
scope->exportedTypeBindings[function->name.value] = std::move(typeFunction);
else
scope->privateTypeBindings[function->name.value] = std::move(typeFunction);
aliasDefinitionLocations[function->name.value] = function->location;
}
else if (auto classDeclaration = stat->as<AstStatDeclareClass>())
@ -780,6 +802,55 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc
classDefinitionLocations[classDeclaration->name.value] = classDeclaration->location;
}
}
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Additional pass for user-defined type functions to fill in their environments completely
for (AstStat* stat : block->body)
{
if (auto function = stat->as<AstStatTypeFunction>())
{
// Find the type function we have already created
TypeFunctionInstanceType* mainTypeFun = nullptr;
if (auto it = scope->privateTypeBindings.find(function->name.value); it != scope->privateTypeBindings.end())
mainTypeFun = getMutable<TypeFunctionInstanceType>(it->second.type);
if (!mainTypeFun)
{
if (auto it = scope->exportedTypeBindings.find(function->name.value); it != scope->exportedTypeBindings.end())
mainTypeFun = getMutable<TypeFunctionInstanceType>(it->second.type);
}
// Fill it with all visible type functions
if (mainTypeFun)
{
UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData;
for (Scope* curr = scope.get(); curr; curr = curr->parent.get())
{
for (auto& [name, tf] : curr->privateTypeBindings)
{
if (userFuncData.environment.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = ty->userFuncData.definition;
}
for (auto& [name, tf] : curr->exportedTypeBindings)
{
if (userFuncData.environment.find(name))
continue;
if (auto ty = get<TypeFunctionInstanceType>(tf.type); ty && ty->userFuncData.definition)
userFuncData.environment[name] = ty->userFuncData.definition;
}
}
}
}
}
}
}
ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block)
@ -900,12 +971,8 @@ ControlFlow ConstraintGenerator::visitBlockWithoutChildScope_DEPRECATED(const Sc
if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function))
reportError(function->location, GenericError{*error});
TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{
NotNull{&builtinTypeFunctions().userFunc},
std::move(typeParams),
{},
function->name,
});
TypeId typeFunctionTy =
arena->addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, {}});
TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy};
@ -2807,7 +2874,7 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local
case ErrorSuppression::DoNotSuppress:
break;
case ErrorSuppression::Suppress:
ty = simplifyUnion(builtinTypes, arena, *ty, builtinTypes->errorType).result;
ty = simplifyUnion(scope, local->location, *ty, builtinTypes->errorType);
break;
case ErrorSuppression::NormalizationFailed:
reportError(local->local->annotation->location, NormalizationTooComplex{});
@ -3673,6 +3740,32 @@ TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location locati
return resultType;
}
struct FragmentTypeCheckGlobalPrepopulator : AstVisitor
{
const NotNull<Scope> globalScope;
const NotNull<Scope> currentScope;
const NotNull<const DataFlowGraph> dfg;
FragmentTypeCheckGlobalPrepopulator(NotNull<Scope> globalScope, NotNull<Scope> currentScope, NotNull<const DataFlowGraph> dfg)
: globalScope(globalScope)
, currentScope(currentScope)
, dfg(dfg)
{
}
bool visit(AstExprGlobal* global) override
{
if (auto ty = globalScope->lookup(global->name))
{
DefId def = dfg->getDef(global);
// We only want to write into the current scope the type of the global
currentScope->lvalueTypes[def] = *ty;
}
return true;
}
};
struct GlobalPrepopulator : AstVisitor
{
const NotNull<Scope> globalScope;
@ -3719,6 +3812,14 @@ struct GlobalPrepopulator : AstVisitor
}
};
void ConstraintGenerator::prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program)
{
FragmentTypeCheckGlobalPrepopulator gp{NotNull{globalScope.get()}, NotNull{resumeScope.get()}, dfg};
if (prepareModuleScope)
prepareModuleScope(module->name, resumeScope);
program->visit(&gp);
}
void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program)
{
GlobalPrepopulator gp{NotNull{globalScope.get()}, arena, dfg};
@ -3870,6 +3971,24 @@ TypeId ConstraintGenerator::createTypeFunctionInstance(
return result;
}
TypeId ConstraintGenerator::simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(UnionType{{left, right}});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId tyFun : res->newTypeFunctions)
addConstraint(scope, location, ReduceConstraint{tyFun});
return res->result;
}
else
return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result;
}
std::vector<NotNull<Constraint>> borrowConstraints(const std::vector<ConstraintPtr>& constraints)
{
std::vector<NotNull<Constraint>> result;

View File

@ -33,6 +33,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings)
LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification)
LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations)
namespace Luau
@ -320,6 +321,7 @@ struct InstantiationQueuer : TypeOnceVisitor
ConstraintSolver::ConstraintSolver(
NotNull<Normalizer> normalizer,
NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints,
@ -333,6 +335,7 @@ ConstraintSolver::ConstraintSolver(
: arena(normalizer->arena)
, builtinTypes(normalizer->builtinTypes)
, normalizer(normalizer)
, simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime)
, constraints(std::move(constraints))
, rootScope(rootScope)
@ -1802,7 +1805,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const
upperTable->props[c.propName] = rhsType;
// Food for thought: Could we block if simplification encounters a blocked type?
lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFreeUpperBound, newUpperBound).result;
lhsFree->upperBound = simplifyIntersection(constraint->scope, constraint->location, lhsFreeUpperBound, newUpperBound);
bind(constraint, c.propType, rhsType);
return true;
@ -2016,7 +2019,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull<const
}
}
TypeId res = simplifyIntersection(builtinTypes, arena, std::move(parts)).result;
TypeId res = simplifyIntersection(constraint->scope, constraint->location, std::move(parts));
unify(constraint, rhsType, res);
}
@ -2596,9 +2599,9 @@ std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTa
// if we're in an lvalue context, we need the _common_ type here.
if (context == ValueContext::LValue)
return {{}, simplifyIntersection(builtinTypes, arena, one, two).result};
return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)};
return {{}, simplifyUnion(builtinTypes, arena, one, two).result};
return {{}, simplifyUnion(constraint->scope, constraint->location, one, two)};
}
// if we're in an lvalue context, we need the _common_ type here.
else if (context == ValueContext::LValue)
@ -2630,7 +2633,7 @@ std::pair<std::vector<TypeId>, std::optional<TypeId>> ConstraintSolver::lookupTa
{
TypeId one = *begin(options);
TypeId two = *(++begin(options));
return {{}, simplifyIntersection(builtinTypes, arena, one, two).result};
return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)};
}
else
return {{}, arena->addType(IntersectionType{std::vector<TypeId>(begin(options), end(options))})};
@ -3019,6 +3022,63 @@ bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty)
return false;
}
TypeId ConstraintSolver::simplifyIntersection(NotNull<Scope> scope, Location location, TypeId left, TypeId right)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(IntersectionType{{left, right}});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId ty : res->newTypeFunctions)
pushConstraint(scope, location, ReduceConstraint{ty});
return res->result;
}
else
return ::Luau::simplifyIntersection(builtinTypes, arena, left, right).result;
}
TypeId ConstraintSolver::simplifyIntersection(NotNull<Scope> scope, Location location, std::set<TypeId> parts)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(IntersectionType{std::vector(parts.begin(), parts.end())});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId ty : res->newTypeFunctions)
pushConstraint(scope, location, ReduceConstraint{ty});
return res->result;
}
else
return ::Luau::simplifyIntersection(builtinTypes, arena, std::move(parts)).result;
}
TypeId ConstraintSolver::simplifyUnion(NotNull<Scope> scope, Location location, TypeId left, TypeId right)
{
if (FFlag::DebugLuauEqSatSimplification)
{
TypeId ty = arena->addType(UnionType{{left, right}});
std::optional<EqSatSimplificationResult> res = eqSatSimplify(simplifier, ty);
if (!res)
return ty;
for (TypeId ty : res->newTypeFunctions)
pushConstraint(scope, location, ReduceConstraint{ty});
return res->result;
}
else
return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result;
}
TypeId ConstraintSolver::errorRecoveryType() const
{
return builtinTypes->errorRecoveryType();

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,7 @@
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Common.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Parser.h"
#include "Luau/ParseOptions.h"
#include "Luau/Module.h"
@ -18,11 +19,14 @@
#include "Luau/ParseOptions.h"
#include "Luau/Module.h"
#include "AutocompleteCore.h"
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauAllowFragmentParsing);
LUAU_FASTFLAG(LuauStoreDFGOnModule2);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
namespace
{
@ -41,7 +45,6 @@ void copyModuleMap(Luau::DenseHashMap<K, V>& result, const Luau::DenseHashMap<K,
} // namespace
namespace Luau
{
@ -88,6 +91,14 @@ FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* ro
return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)};
}
/**
* Get document offsets is a function that takes a source text document as well as a start position and end position(line, column) in that
* document and attempts to get the concrete text between those points. It returns a pair of:
* - start offset that represents an index in the source `char*` corresponding to startPos
* - length, that represents how many more bytes to read to get to endPos.
* Example - your document is "foo bar baz" and getDocumentOffsets is passed (1, 4) - (1, 8). This function returns the pair {3, 7},
* which corresponds to the string " bar "
*/
std::pair<unsigned int, unsigned int> getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos)
{
unsigned int lineCount = 0;
@ -115,6 +126,13 @@ std::pair<unsigned int, unsigned int> getDocumentOffsets(const std::string_view&
foundEnd = true;
}
// We put a cursor position that extends beyond the extents of the current line
if (foundStart && !foundEnd && (lineCount > endPos.line))
{
foundEnd = true;
endOffset = docOffset - 1;
}
if (c == '\n')
{
lineCount++;
@ -125,20 +143,24 @@ std::pair<unsigned int, unsigned int> getDocumentOffsets(const std::string_view&
docOffset++;
}
if (foundStart && !foundEnd)
endOffset = src.length();
unsigned int min = std::min(startOffset, endOffset);
unsigned int len = std::max(startOffset, endOffset) - min;
return {min, len};
}
ScopePtr findClosestScope(const ModulePtr& module, const Position& cursorPos)
ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStatement)
{
LUAU_ASSERT(module->hasModuleScope());
ScopePtr closest = module->getModuleScope();
// find the scope the nearest statement belonged to.
for (auto [loc, sc] : module->scopes)
{
if (loc.begin <= cursorPos && closest->location.begin <= loc.begin)
if (loc.encloses(nearestStatement->location) && closest->location.begin <= loc.begin)
closest = sc;
}
@ -152,13 +174,27 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie
opts.allowDeclarationSyntax = false;
opts.captureComments = false;
opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)};
AstStat* enclosingStatement = result.nearestStatement;
AstStat* nearestStatement = result.nearestStatement;
const Position& endPos = cursorPos;
// If the statement starts on a previous line, grab the statement beginning
// otherwise, grab the statement end to whatever is being typed right now
const Position& startPos =
enclosingStatement->location.begin.line == cursorPos.line ? enclosingStatement->location.begin : enclosingStatement->location.end;
const Location& rootSpan = srcModule.root->location;
// Did we append vs did we insert inline
bool appended = cursorPos >= rootSpan.end;
// statement spans multiple lines
bool multiline = nearestStatement->location.begin.line != nearestStatement->location.end.line;
const Position endPos = cursorPos;
// We start by re-parsing everything (we'll refine this as we go)
Position startPos = srcModule.root->location.begin;
// If we added to the end of the sourceModule, use the end of the nearest location
if (appended && multiline)
startPos = nearestStatement->location.end;
// Statement spans one line && cursorPos is on a different line
else if (!multiline && cursorPos.line != nearestStatement->location.end.line)
startPos = nearestStatement->location.end;
else
startPos = nearestStatement->location.begin;
auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos);
@ -173,10 +209,11 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie
std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry);
std::vector<AstNode*> fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end);
fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end());
if (enclosingStatement == nullptr)
enclosingStatement = p.root;
if (nearestStatement == nullptr)
nearestStatement = p.root;
fragmentResult.root = std::move(p.root);
fragmentResult.ancestry = std::move(fabricatedAncestry);
fragmentResult.nearestStatement = nearestStatement;
return fragmentResult;
}
@ -205,7 +242,7 @@ ModulePtr copyModule(const ModulePtr& result, std::unique_ptr<Allocator> alloc)
return incrementalModule;
}
FragmentTypeCheckResult typeCheckFragmentHelper(
FragmentTypeCheckResult typecheckFragment_(
Frontend& frontend,
AstStatBlock* root,
const ModulePtr& stale,
@ -245,15 +282,18 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
/// Create a DataFlowGraph just for the surrounding context
auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler);
SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes);
/// Contraint Generator
ConstraintGenerator cg{
incrementalModule,
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{&frontend.moduleResolver},
frontend.builtinTypes,
iceHandler,
frontend.globals.globalScope,
stale->getModuleScope(),
nullptr,
nullptr,
NotNull{&updatedDfg},
@ -262,7 +302,7 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
cg.rootScope = stale->getModuleScope().get();
// Any additions to the scope must occur in a fresh scope
auto freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.push_back({root->location, freshChildOfNearestScope});
incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);
// closest Scope -> children = { ...., freshChildOfNearestScope}
// We need to trim nearestChild from the scope hierarcy
@ -274,9 +314,11 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
LUAU_ASSERT(back == freshChildOfNearestScope.get());
closestScope->children.pop_back();
/// Initialize the constraint solver and run it
ConstraintSolver cs{
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope),
borrowConstraints(cg.constraints),
@ -307,7 +349,7 @@ FragmentTypeCheckResult typeCheckFragmentHelper(
freeze(incrementalModule->internalTypes);
freeze(incrementalModule->interfaceTypes);
return {std::move(incrementalModule), freshChildOfNearestScope.get()};
return {std::move(incrementalModule), std::move(freshChildOfNearestScope)};
}
@ -327,27 +369,51 @@ FragmentTypeCheckResult typecheckFragment(
}
ModulePtr module = frontend.moduleResolver.getModule(moduleName);
const ScopePtr& closestScope = findClosestScope(module, cursorPos);
FragmentParseResult r = parseFragment(*sourceModule, src, cursorPos);
FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos);
FrontendOptions frontendOptions = opts.value_or(frontend.options);
return typeCheckFragmentHelper(frontend, r.root, module, closestScope, cursorPos, std::move(r.alloc), frontendOptions);
const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement);
FragmentTypeCheckResult result =
typecheckFragment_(frontend, parseResult.root, module, closestScope, cursorPos, std::move(parseResult.alloc), frontendOptions);
result.ancestry = std::move(parseResult.ancestry);
return result;
}
AutocompleteResult fragmentAutocomplete(
FragmentAutocompleteResult fragmentAutocomplete(
Frontend& frontend,
std::string_view src,
const ModuleName& moduleName,
Position& cursorPosition,
const FrontendOptions& opts,
Position cursorPosition,
std::optional<FrontendOptions> opts,
StringCompletionCallback callback
)
{
LUAU_ASSERT(FFlag::LuauSolverV2);
LUAU_ASSERT(FFlag::LuauAllowFragmentParsing);
LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2);
return {};
LUAU_ASSERT(FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
{
LUAU_ASSERT(!"Expected Source Module for fragment typecheck");
return {};
}
auto tcResult = typecheckFragment(frontend, moduleName, cursorPosition, opts, src);
TypeArena arenaForFragmentAutocomplete;
auto result = Luau::autocomplete_(
tcResult.incrementalModule,
frontend.builtinTypes,
&arenaForFragmentAutocomplete,
tcResult.ancestry,
frontend.globals.globalScope.get(),
tcResult.freshScope,
cursorPosition,
frontend.fileResolver,
callback
);
return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForFragmentAutocomplete), std::move(result)};
}
} // namespace Luau

View File

@ -10,6 +10,7 @@
#include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/DcrLogger.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/FileResolver.h"
#include "Luau/NonStrictTypeChecker.h"
#include "Luau/Parser.h"
@ -46,7 +47,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode)
LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionNoEvaluation)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)
LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection)
LUAU_FASTFLAG(StudioReportLuauAny2)
LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2)
@ -287,8 +287,7 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector<HotCom
std::vector<RequireCycle> getRequireCycles(
const FileResolver* resolver,
const std::unordered_map<ModuleName, std::shared_ptr<SourceNode>>& sourceNodes,
const SourceNode* start,
bool stopAtFirst = false
const SourceNode* start
)
{
std::vector<RequireCycle> result;
@ -358,9 +357,6 @@ std::vector<RequireCycle> getRequireCycles(
{
result.push_back({depLocation, std::move(cycle)});
if (stopAtFirst)
return result;
// note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start
// so it's safe to *only* clear seen vector when we find a cycle
// if we don't do it, we will not have correct reporting for some cycles
@ -884,18 +880,11 @@ void Frontend::addBuildQueueItems(
data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete);
data.recordJsonLog = FFlag::DebugLuauLogSolverToJson;
const Mode mode = sourceModule->mode.value_or(data.config.mode);
// in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely
if (cycleDetected)
{
if (FFlag::LuauMoreThoroughCycleDetection)
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), false);
else
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck);
}
data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get());
data.options = frontendOptions;
@ -1334,6 +1323,7 @@ ModulePtr check(
unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit);
Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}};
SimplifierPtr simplifier = newSimplifier(NotNull{&result->internalTypes}, builtinTypes);
TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}};
if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation)
@ -1342,6 +1332,7 @@ ModulePtr check(
ConstraintGenerator cg{
result,
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
moduleResolver,
builtinTypes,
@ -1358,6 +1349,7 @@ ModulePtr check(
ConstraintSolver cs{
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope),
borrowConstraints(cg.constraints),

View File

@ -132,7 +132,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
return dest.addType(NegationType{a.ty});
else if constexpr (std::is_same_v<T, TypeFunctionInstanceType>)
{
TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName};
TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncData};
return dest.addType(std::move(clone));
}
else

View File

@ -4,6 +4,7 @@
#include "Luau/Common.h"
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauSymbolEquality)
namespace Luau
{
@ -14,7 +15,7 @@ bool Symbol::operator==(const Symbol& rhs) const
return local == rhs.local;
else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
else if (FFlag::LuauSolverV2)
else if (FFlag::LuauSolverV2 || FFlag::LuauSymbolEquality)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else
return false;

View File

@ -870,6 +870,8 @@ struct TypeStringifier
return;
}
LUAU_ASSERT(uv.options.size() > 1);
bool optional = false;
bool hasNonNilDisjunct = false;
@ -878,7 +880,7 @@ struct TypeStringifier
{
el = follow(el);
if (isNil(el))
if (state.opts.useQuestionMarks && isNil(el))
{
optional = true;
continue;

View File

@ -51,6 +51,7 @@ LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation)
LUAU_FASTFLAG(LuauUserTypeFunFixRegister)
LUAU_FASTFLAG(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionResetState)
LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
@ -610,10 +611,29 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
NotNull<TypeFunctionContext> ctx
)
{
if (!ctx->userFuncName)
auto typeFunction = getMutable<TypeFunctionInstanceType>(instance);
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
ctx->ice->ice("all user-defined type functions must have an associated function definition");
return {std::nullopt, true, {}, {}};
if (typeFunction->userFuncData.owner.expired())
{
ctx->ice->ice("user-defined type function module has expired");
return {std::nullopt, true, {}, {}};
}
if (!typeFunction->userFuncName || !typeFunction->userFuncData.definition)
{
ctx->ice->ice("all user-defined type functions must have an associated function definition");
return {std::nullopt, true, {}, {}};
}
}
else
{
if (!ctx->userFuncName)
{
ctx->ice->ice("all user-defined type functions must have an associated function definition");
return {std::nullopt, true, {}, {}};
}
}
if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation)
@ -632,7 +652,22 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
return {std::nullopt, false, {ty}, {}};
}
AstName name = *ctx->userFuncName;
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Ensure that whole type function environment is registered
for (auto& [name, definition] : typeFunction->userFuncData.environment)
{
if (std::optional<std::string> error = ctx->typeFunctionRuntime->registerFunction(definition))
{
// Failure to register at this point means that original definition had to error out and should not have been present in the
// environment
ctx->ice->ice("user-defined type function reference cannot be registered");
return {std::nullopt, true, {}, {}};
}
}
}
AstName name = FFlag::LuauUserTypeFunExportedAndLocal ? typeFunction->userFuncData.definition->name : *ctx->userFuncName;
lua_State* global = ctx->typeFunctionRuntime->state.get();
@ -643,8 +678,44 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
lua_State* L = lua_newthread(global);
LuauTempThreadPopper popper(global);
lua_getglobal(global, name.value);
lua_xmove(global, L, 1);
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Fetch the function we want to evaluate
lua_pushlightuserdata(L, typeFunction->userFuncData.definition);
lua_gettable(L, LUA_REGISTRYINDEX);
if (!lua_isfunction(L, -1))
{
ctx->ice->ice("user-defined type function reference cannot be found in the registry");
return {std::nullopt, true, {}, {}};
}
// Build up the environment
lua_getfenv(L, -1);
lua_setreadonly(L, -1, false);
for (auto& [name, definition] : typeFunction->userFuncData.environment)
{
lua_pushlightuserdata(L, definition);
lua_gettable(L, LUA_REGISTRYINDEX);
if (!lua_isfunction(L, -1))
{
ctx->ice->ice("user-defined type function reference cannot be found in the registry");
return {std::nullopt, true, {}, {}};
}
lua_setfield(L, -2, name.c_str());
}
lua_setreadonly(L, -1, true);
lua_pop(L, 1);
}
else
{
lua_getglobal(global, name.value);
lua_xmove(global, L, 1);
}
if (FFlag::LuauUserDefinedTypeFunctionResetState)
resetTypeFunctionState(L);
@ -693,7 +764,7 @@ TypeFunctionReductionResult<TypeId> userDefinedTypeFunction(
TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get());
// At least 1 error occured while deserializing
// At least 1 error occurred while deserializing
if (runtimeBuilder->errors.size() > 0)
return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()};
@ -935,6 +1006,23 @@ std::optional<std::string> TypeFunctionRuntime::registerFunction(AstStatTypeFunc
prepareState();
lua_State* global = state.get();
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Fetch to check if function is already registered
lua_pushlightuserdata(global, function);
lua_gettable(global, LUA_REGISTRYINDEX);
if (!lua_isnil(global, -1))
{
lua_pop(global, 1);
return std::nullopt;
}
lua_pop(global, 1);
}
AstName name = function->name;
// Construct ParseResult containing the type function
@ -961,7 +1049,6 @@ std::optional<std::string> TypeFunctionRuntime::registerFunction(AstStatTypeFunc
std::string bytecode = builder.getBytecode();
lua_State* global = state.get();
// Separate sandboxed thread for individual execution and private globals
lua_State* L = lua_newthread(global);
@ -989,9 +1076,19 @@ std::optional<std::string> TypeFunctionRuntime::registerFunction(AstStatTypeFunc
return format("Could not find '%s' type function in the global scope", name.value);
}
// Store resulting function in the global environment
lua_xmove(L, global, 1);
lua_setglobal(global, name.value);
if (FFlag::LuauUserTypeFunExportedAndLocal)
{
// Store resulting function in the registry
lua_pushlightuserdata(global, function);
lua_xmove(L, global, 1);
lua_settable(global, LUA_REGISTRYINDEX);
}
else
{
// Store resulting function in the global environment
lua_xmove(L, global, 1);
lua_setglobal(global, name.value);
}
return std::nullopt;
}

View File

@ -0,0 +1,48 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Location.h"
#include "Luau/DenseHash.h"
#include "Luau/Common.h"
#include <vector>
namespace Luau
{
class Allocator
{
public:
Allocator();
Allocator(Allocator&&);
Allocator& operator=(Allocator&&) = delete;
~Allocator();
void* allocate(size_t size);
template<typename T, typename... Args>
T* alloc(Args&&... args)
{
static_assert(std::is_trivially_destructible<T>::value, "Objects allocated with this allocator will never have their destructors run!");
T* t = static_cast<T*>(allocate(sizeof(T)));
new (t) T(std::forward<Args>(args)...);
return t;
}
private:
struct Page
{
Page* next;
char data[8192];
};
Page* root;
size_t offset;
};
}

View File

@ -316,16 +316,18 @@ public:
enum QuoteStyle
{
Quoted,
QuotedSimple,
QuotedRaw,
Unquoted
};
AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle = Quoted);
AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle);
void visit(AstVisitor* visitor) override;
bool isQuoted() const;
AstArray<char> value;
QuoteStyle quoteStyle = Quoted;
QuoteStyle quoteStyle;
};
class AstExprLocal : public AstExpr
@ -876,13 +878,14 @@ class AstStatTypeFunction : public AstStat
public:
LUAU_RTTI(AstStatTypeFunction);
AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body);
AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body, bool exported);
void visit(AstVisitor* visitor) override;
AstName name;
Location nameLocation;
AstExprFunction* body;
bool exported;
};
class AstStatDeclareGlobal : public AstStat

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Allocator.h"
#include "Luau/Ast.h"
#include "Luau/Location.h"
#include "Luau/DenseHash.h"
@ -11,40 +12,6 @@
namespace Luau
{
class Allocator
{
public:
Allocator();
Allocator(Allocator&&);
Allocator& operator=(Allocator&&) = delete;
~Allocator();
void* allocate(size_t size);
template<typename T, typename... Args>
T* alloc(Args&&... args)
{
static_assert(std::is_trivially_destructible<T>::value, "Objects allocated with this allocator will never have their destructors run!");
T* t = static_cast<T*>(allocate(sizeof(T)));
new (t) T(std::forward<Args>(args)...);
return t;
}
private:
struct Page
{
Page* next;
char data[8192];
};
Page* root;
size_t offset;
};
struct Lexeme
{
enum Type

66
Ast/src/Allocator.cpp Normal file
View File

@ -0,0 +1,66 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Allocator.h"
namespace Luau
{
Allocator::Allocator()
: root(static_cast<Page*>(operator new(sizeof(Page))))
, offset(0)
{
root->next = nullptr;
}
Allocator::Allocator(Allocator&& rhs)
: root(rhs.root)
, offset(rhs.offset)
{
rhs.root = nullptr;
rhs.offset = 0;
}
Allocator::~Allocator()
{
Page* page = root;
while (page)
{
Page* next = page->next;
operator delete(page);
page = next;
}
}
void* Allocator::allocate(size_t size)
{
constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double);
if (root)
{
uintptr_t data = reinterpret_cast<uintptr_t>(root->data);
uintptr_t result = (data + offset + align - 1) & ~(align - 1);
if (result + size <= data + sizeof(root->data))
{
offset = result - data + size;
return reinterpret_cast<void*>(result);
}
}
// allocate new page
size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data);
void* pageData = operator new(offsetof(Page, data) + pageSize);
Page* page = static_cast<Page*>(pageData);
page->next = root;
root = page;
offset = size;
return page->data;
}
}

View File

@ -92,6 +92,11 @@ void AstExprConstantString::visit(AstVisitor* visitor)
visitor->visit(this);
}
bool AstExprConstantString::isQuoted() const
{
return quoteStyle == QuoteStyle::QuotedSimple || quoteStyle == QuoteStyle::QuotedRaw;
}
AstExprLocal::AstExprLocal(const Location& location, AstLocal* local, bool upvalue)
: AstExpr(ClassIndex(), location)
, local(local)
@ -760,11 +765,18 @@ void AstStatTypeAlias::visit(AstVisitor* visitor)
}
}
AstStatTypeFunction::AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body)
AstStatTypeFunction::AstStatTypeFunction(
const Location& location,
const AstName& name,
const Location& nameLocation,
AstExprFunction* body,
bool exported
)
: AstStat(ClassIndex(), location)
, name(name)
, nameLocation(nameLocation)
, body(body)
, exported(exported)
{
}

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Lexer.h"
#include "Luau/Allocator.h"
#include "Luau/Common.h"
#include "Luau/Confusables.h"
#include "Luau/StringUtils.h"
@ -10,64 +11,6 @@
namespace Luau
{
Allocator::Allocator()
: root(static_cast<Page*>(operator new(sizeof(Page))))
, offset(0)
{
root->next = nullptr;
}
Allocator::Allocator(Allocator&& rhs)
: root(rhs.root)
, offset(rhs.offset)
{
rhs.root = nullptr;
rhs.offset = 0;
}
Allocator::~Allocator()
{
Page* page = root;
while (page)
{
Page* next = page->next;
operator delete(page);
page = next;
}
}
void* Allocator::allocate(size_t size)
{
constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double);
if (root)
{
uintptr_t data = reinterpret_cast<uintptr_t>(root->data);
uintptr_t result = (data + offset + align - 1) & ~(align - 1);
if (result + size <= data + sizeof(root->data))
{
offset = result - data + size;
return reinterpret_cast<void*>(result);
}
}
// allocate new page
size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data);
void* pageData = operator new(offsetof(Page, data) + pageSize);
Page* page = static_cast<Page*>(pageData);
page->next = root;
root = page;
offset = size;
return page->data;
}
Lexeme::Lexeme(const Location& location, Type type)
: type(type)
, location(location)

View File

@ -21,6 +21,7 @@ LUAU_FASTFLAGVARIABLE(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauNativeAttribute)
LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunParseExport)
LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing)
LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck)
@ -943,8 +944,11 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported)
Lexeme matchFn = lexer.current();
nextLexeme();
if (exported)
report(start, "Type function cannot be exported");
if (!FFlag::LuauUserDefinedTypeFunParseExport)
{
if (exported)
report(start, "Type function cannot be exported");
}
// parse the name of the type function
std::optional<Name> fnName = parseNameOpt("type function name");
@ -962,7 +966,7 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported)
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
return allocator.alloc<AstStatTypeFunction>(Location(start, body->location), fnName->name, fnName->location, body);
return allocator.alloc<AstStatTypeFunction>(Location(start, body->location), fnName->name, fnName->location, body, exported);
}
AstDeclaredClassProp Parser::parseDeclaredClassMethod()
@ -3012,8 +3016,23 @@ std::optional<AstArray<char>> Parser::parseCharArray()
AstExpr* Parser::parseString()
{
Location location = lexer.current().location;
AstExprConstantString::QuoteStyle style;
switch (lexer.current().type)
{
case Lexeme::QuotedString:
case Lexeme::InterpStringSimple:
style = AstExprConstantString::QuotedSimple;
break;
case Lexeme::RawString:
style = AstExprConstantString::QuotedRaw;
break;
default:
LUAU_ASSERT(false && "Invalid string type");
}
if (std::optional<AstArray<char>> value = parseCharArray())
return allocator.alloc<AstExprConstantString>(location, *value);
return allocator.alloc<AstExprConstantString>(location, *value, style);
else
return reportExprError(location, {}, "String literal contains malformed escape sequence");
}

View File

@ -1,4 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Config.h"
#include "Luau/ModuleResolver.h"
#include "Luau/TypeInfer.h"
#include "Luau/BuiltinDefinitions.h"
@ -224,7 +225,14 @@ struct CliConfigResolver : Luau::ConfigResolver
if (std::optional<std::string> contents = readFile(configPath))
{
std::optional<std::string> error = Luau::parseConfig(*contents, result);
Luau::ConfigOptions::AliasOptions aliasOpts;
aliasOpts.configLocation = configPath;
aliasOpts.overwriteAliases = true;
Luau::ConfigOptions opts;
opts.aliasOptions = std::move(aliasOpts);
std::optional<std::string> error = Luau::parseConfig(*contents, result, opts);
if (error)
configErrors.push_back({configPath, *error});
}

View File

@ -181,6 +181,16 @@ std::string resolvePath(std::string_view path, std::string_view baseFilePath)
return resolvedPath;
}
bool hasFileExtension(std::string_view name, const std::vector<std::string>& extensions)
{
for (const std::string& extension : extensions)
{
if (name.size() >= extension.size() && name.substr(name.size() - extension.size()) == extension)
return true;
}
return false;
}
std::optional<std::string> readFile(const std::string& name)
{
#ifdef _WIN32

View File

@ -15,6 +15,8 @@ std::string resolvePath(std::string_view relativePath, std::string_view baseFile
std::optional<std::string> readFile(const std::string& name);
std::optional<std::string> readStdin();
bool hasFileExtension(std::string_view name, const std::vector<std::string>& extensions);
bool isAbsolutePath(std::string_view path);
bool isFile(const std::string& path);
bool isDirectory(const std::string& path);

View File

@ -3,6 +3,7 @@
#include "FileUtils.h"
#include "Luau/Common.h"
#include "Luau/Config.h"
#include <algorithm>
#include <array>
@ -83,6 +84,9 @@ RequireResolver::ModuleStatus RequireResolver::findModuleImpl()
absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix
}
if (hasFileExtension(absolutePath, {".luau", ".lua"}) && isFile(absolutePath))
luaL_argerrorL(L, 1, "error requiring module: consider removing the file extension");
return ModuleStatus::NotFound;
}
@ -235,14 +239,15 @@ std::optional<std::string> RequireResolver::getAlias(std::string alias)
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
}
);
while (!config.aliases.count(alias) && !isConfigFullyResolved)
while (!config.aliases.contains(alias) && !isConfigFullyResolved)
{
parseNextConfig();
}
if (!config.aliases.count(alias) && isConfigFullyResolved)
if (!config.aliases.contains(alias) && isConfigFullyResolved)
return std::nullopt; // could not find alias
return resolvePath(config.aliases[alias], joinPaths(lastSearchedDir, Luau::kConfigName));
const Luau::Config::AliasInfo& aliasInfo = config.aliases[alias];
return resolvePath(aliasInfo.value, aliasInfo.configLocation);
}
void RequireResolver::parseNextConfig()
@ -275,9 +280,16 @@ void RequireResolver::parseConfigInDirectory(const std::string& directory)
{
std::string configPath = joinPaths(directory, Luau::kConfigName);
Luau::ConfigOptions::AliasOptions aliasOpts;
aliasOpts.configLocation = configPath;
aliasOpts.overwriteAliases = false;
Luau::ConfigOptions opts;
opts.aliasOptions = std::move(aliasOpts);
if (std::optional<std::string> contents = readFile(configPath))
{
std::optional<std::string> error = Luau::parseConfig(*contents, config);
std::optional<std::string> error = Luau::parseConfig(*contents, config, opts);
if (error)
luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str());
}

View File

@ -19,7 +19,7 @@ class Variant
static_assert(std::disjunction_v<std::is_reference<Ts>...> == false, "variant does not allow references as an alternative type");
static_assert(std::disjunction_v<std::is_array<Ts>...> == false, "variant does not allow arrays as an alternative type");
private:
public:
template<typename T>
static constexpr int getTypeId()
{
@ -35,6 +35,7 @@ private:
return -1;
}
private:
template<typename T, typename... Tail>
struct First
{

View File

@ -1,12 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/LinterConfig.h"
#include "Luau/ParseOptions.h"
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <string_view>
#include <vector>
namespace Luau
@ -19,6 +21,10 @@ constexpr const char* kConfigName = ".luaurc";
struct Config
{
Config();
Config(const Config& other) noexcept;
Config& operator=(const Config& other) noexcept;
Config(Config&& other) noexcept = default;
Config& operator=(Config&& other) noexcept = default;
Mode mode = Mode::Nonstrict;
@ -32,7 +38,19 @@ struct Config
std::vector<std::string> globals;
std::unordered_map<std::string, std::string> aliases;
struct AliasInfo
{
std::string value;
std::string_view configLocation;
};
DenseHashMap<std::string, AliasInfo> aliases{""};
void setAlias(std::string alias, const std::string& value, const std::string configLocation);
private:
// Prevents making unnecessary copies of the same config location string.
DenseHashMap<std::string, std::unique_ptr<std::string>> configLocationCache{""};
};
struct ConfigResolver
@ -60,6 +78,18 @@ std::optional<std::string> parseLintRuleString(
bool isValidAlias(const std::string& alias);
std::optional<std::string> parseConfig(const std::string& contents, Config& config, bool compat = false);
struct ConfigOptions
{
bool compat = false;
struct AliasOptions
{
std::string configLocation;
bool overwriteAliases;
};
std::optional<AliasOptions> aliasOptions = std::nullopt;
};
std::optional<std::string> parseConfig(const std::string& contents, Config& config, const ConfigOptions& options = ConfigOptions{});
} // namespace Luau

View File

@ -4,7 +4,8 @@
#include "Luau/Lexer.h"
#include "Luau/StringUtils.h"
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <string>
namespace Luau
{
@ -16,6 +17,50 @@ Config::Config()
enabledLint.setDefaults();
}
Config::Config(const Config& other) noexcept
: mode(other.mode)
, parseOptions(other.parseOptions)
, enabledLint(other.enabledLint)
, fatalLint(other.fatalLint)
, lintErrors(other.lintErrors)
, typeErrors(other.typeErrors)
, globals(other.globals)
{
for (const auto& [alias, aliasInfo] : other.aliases)
{
std::string configLocation = std::string(aliasInfo.configLocation);
if (!configLocationCache.contains(configLocation))
configLocationCache[configLocation] = std::make_unique<std::string>(configLocation);
AliasInfo newAliasInfo;
newAliasInfo.value = aliasInfo.value;
newAliasInfo.configLocation = *configLocationCache[configLocation];
aliases[alias] = std::move(newAliasInfo);
}
}
Config& Config::operator=(const Config& other) noexcept
{
if (this != &other)
{
Config copy(other);
std::swap(*this, copy);
}
return *this;
}
void Config::setAlias(std::string alias, const std::string& value, const std::string configLocation)
{
AliasInfo& info = aliases[alias];
info.value = value;
if (!configLocationCache.contains(configLocation))
configLocationCache[configLocation] = std::make_unique<std::string>(configLocation);
info.configLocation = *configLocationCache[configLocation];
}
static Error parseBoolean(bool& result, const std::string& value)
{
if (value == "true")
@ -136,7 +181,12 @@ bool isValidAlias(const std::string& alias)
return true;
}
Error parseAlias(std::unordered_map<std::string, std::string>& aliases, std::string aliasKey, const std::string& aliasValue)
static Error parseAlias(
Config& config,
std::string aliasKey,
const std::string& aliasValue,
const std::optional<ConfigOptions::AliasOptions>& aliasOptions
)
{
if (!isValidAlias(aliasKey))
return Error{"Invalid alias " + aliasKey};
@ -150,8 +200,12 @@ Error parseAlias(std::unordered_map<std::string, std::string>& aliases, std::str
return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c;
}
);
if (!aliases.count(aliasKey))
aliases[std::move(aliasKey)] = aliasValue;
if (!aliasOptions)
return Error("Cannot parse aliases without alias options");
if (aliasOptions->overwriteAliases || !config.aliases.contains(aliasKey))
config.setAlias(std::move(aliasKey), aliasValue, aliasOptions->configLocation);
return std::nullopt;
}
@ -285,16 +339,16 @@ static Error parseJson(const std::string& contents, Action action)
return {};
}
Error parseConfig(const std::string& contents, Config& config, bool compat)
Error parseConfig(const std::string& contents, Config& config, const ConfigOptions& options)
{
return parseJson(
contents,
[&](const std::vector<std::string>& keys, const std::string& value) -> Error
{
if (keys.size() == 1 && keys[0] == "languageMode")
return parseModeString(config.mode, value, compat);
return parseModeString(config.mode, value, options.compat);
else if (keys.size() == 2 && keys[0] == "lint")
return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, compat);
return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, options.compat);
else if (keys.size() == 1 && keys[0] == "lintErrors")
return parseBoolean(config.lintErrors, value);
else if (keys.size() == 1 && keys[0] == "typeErrors")
@ -305,9 +359,9 @@ Error parseConfig(const std::string& contents, Config& config, bool compat)
return std::nullopt;
}
else if (keys.size() == 2 && keys[0] == "aliases")
return parseAlias(config.aliases, keys[1], value);
else if (compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode")
return parseModeString(config.mode, value, compat);
return parseAlias(config, keys[1], value, options.aliasOptions);
else if (options.compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode")
return parseModeString(config.mode, value, options.compat);
else
{
std::vector<std::string_view> keysv(keys.begin(), keys.end());

View File

@ -23,6 +23,13 @@ struct Analysis final
using D = typename N::Data;
Analysis() = default;
Analysis(N a)
: analysis(std::move(a))
{
}
template<typename T>
static D fnMake(const N& analysis, const EGraph<L, N>& egraph, const L& enode)
{
@ -59,6 +66,15 @@ struct EClass final
template<typename L, typename N>
struct EGraph final
{
using EClassT = EClass<L, typename N::Data>;
EGraph() = default;
explicit EGraph(N analysis)
: analysis(std::move(analysis))
{
}
Id find(Id id) const
{
return unionfind.find(id);
@ -85,33 +101,59 @@ struct EGraph final
return id;
}
void merge(Id id1, Id id2)
// Returns true if the two IDs were not previously merged.
bool merge(Id id1, Id id2)
{
id1 = find(id1);
id2 = find(id2);
if (id1 == id2)
return;
return false;
unionfind.merge(id1, id2);
const Id mergedId = unionfind.merge(id1, id2);
EClass<L, typename N::Data>& eclass1 = get(id1);
EClass<L, typename N::Data> eclass2 = std::move(get(id2));
// Ensure that id1 is the Id that we keep, and id2 is the id that we drop.
if (mergedId == id2)
std::swap(id1, id2);
EClassT& eclass1 = get(id1);
EClassT eclass2 = std::move(get(id2));
classes.erase(id2);
worklist.reserve(worklist.size() + eclass2.parents.size());
for (auto [enode, id] : eclass2.parents)
worklist.push_back({std::move(enode), id});
eclass1.nodes.insert(eclass1.nodes.end(), eclass2.nodes.begin(), eclass2.nodes.end());
eclass1.parents.insert(eclass1.parents.end(), eclass2.parents.begin(), eclass2.parents.end());
std::sort(
eclass1.nodes.begin(),
eclass1.nodes.end(),
[](const L& left, const L& right)
{
return left.index() < right.index();
}
);
worklist.reserve(worklist.size() + eclass1.parents.size());
for (const auto& [eclass, id] : eclass1.parents)
worklist.push_back(id);
analysis.join(eclass1.data, eclass2.data);
return true;
}
void rebuild()
{
std::unordered_set<Id> seen;
while (!worklist.empty())
{
auto [enode, id] = worklist.back();
Id id = worklist.back();
worklist.pop_back();
repair(get(find(id)));
const bool isFresh = seen.insert(id).second;
if (!isFresh)
continue;
repair(find(id));
}
}
@ -120,16 +162,21 @@ struct EGraph final
return classes.size();
}
EClass<L, typename N::Data>& operator[](Id id)
EClassT& operator[](Id id)
{
return get(find(id));
}
const EClass<L, typename N::Data>& operator[](Id id) const
const EClassT& operator[](Id id) const
{
return const_cast<EGraph*>(this)->get(find(id));
}
const std::unordered_map<Id, EClassT>& getAllClasses() const
{
return classes;
}
private:
Analysis<L, N> analysis;
@ -139,19 +186,19 @@ private:
/// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same
/// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the
/// e-class 𝑀[find(𝑎)].
std::unordered_map<Id, EClass<L, typename N::Data>> classes;
std::unordered_map<Id, EClassT> classes;
/// The hashcons 𝐻 is a map from e-nodes to e-class ids.
std::unordered_map<L, Id, typename L::Hash> hashcons;
std::vector<std::pair<L, Id>> worklist;
std::vector<Id> worklist;
private:
void canonicalize(L& enode)
{
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
for (Id& id : enode.operands())
for (Id& id : enode.mutableOperands())
id = find(id);
}
@ -171,7 +218,7 @@ private:
classes.insert_or_assign(
id,
EClass<L, typename N::Data>{
EClassT{
id,
{enode},
analysis.make(*this, enode),
@ -182,7 +229,7 @@ private:
for (Id operand : enode.operands())
get(operand).parents.push_back({enode, id});
worklist.emplace_back(enode, id);
worklist.emplace_back(id);
hashcons.insert_or_assign(enode, id);
return id;
@ -190,12 +237,13 @@ private:
// Looks up for an eclass from a given non-canonicalized `id`.
// For a canonicalized eclass, use `get(find(id))` or `egraph[id]`.
EClass<L, typename N::Data>& get(Id id)
EClassT& get(Id id)
{
LUAU_ASSERT(classes.count(id));
return classes.at(id);
}
void repair(EClass<L, typename N::Data>& eclass)
void repair(Id id)
{
// In the egg paper, the `repair` function makes use of two loops over the `eclass.parents`
// by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id.
@ -204,26 +252,54 @@ private:
// Here, we unify the two loops. I think it's equivalent?
// After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent.
std::unordered_map<L, Id, typename L::Hash> map;
for (auto& [enode, id] : eclass.parents)
std::unordered_map<L, Id, typename L::Hash> newParents;
// The eclass can be deallocated if it is merged into another eclass, so
// we take what we need from it and avoid retaining a pointer.
std::vector<std::pair<L, Id>> parents = get(id).parents;
for (auto& pair : parents)
{
L& enode = pair.first;
Id id = pair.second;
// By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id.
hashcons.erase(enode);
canonicalize(enode);
hashcons.insert_or_assign(enode, find(id));
if (auto it = map.find(enode); it != map.end())
if (auto it = newParents.find(enode); it != newParents.end())
merge(id, it->second);
map.insert_or_assign(enode, find(id));
newParents.insert_or_assign(enode, find(id));
}
eclass.parents.clear();
for (auto it = map.begin(); it != map.end();)
// We reacquire the pointer because the prior loop potentially merges
// the eclass into another, which might move it around in memory.
EClassT* eclass = &get(find(id));
eclass->parents.clear();
for (const auto& [node, id] : newParents)
eclass->parents.emplace_back(std::move(node), std::move(id));
std::unordered_set<L, typename L::Hash> newNodes;
for (L node : eclass->nodes)
{
auto node = map.extract(it++);
eclass.parents.emplace_back(std::move(node.key()), node.mapped());
canonicalize(node);
newNodes.insert(std::move(node));
}
eclass->nodes.assign(newNodes.begin(), newNodes.end());
// FIXME: Extract into sortByTag()
std::sort(
eclass->nodes.begin(),
eclass->nodes.end(),
[](const L& left, const L& right)
{
return left.index() < right.index();
}
);
}
};

View File

@ -9,15 +9,17 @@ namespace Luau::EqSat
struct Id final
{
explicit Id(size_t id);
explicit Id(uint32_t id);
explicit operator size_t() const;
explicit operator uint32_t() const;
bool operator==(Id rhs) const;
bool operator!=(Id rhs) const;
bool operator<(Id rhs) const;
private:
size_t id;
uint32_t id;
};
} // namespace Luau::EqSat

View File

@ -6,9 +6,19 @@
#include "Luau/Slice.h"
#include "Luau/Variant.h"
#include <algorithm>
#include <array>
#include <type_traits>
#include <unordered_set>
#include <utility>
#include <vector>
#define LUAU_EQSAT_UNIT(name) \
struct name : ::Luau::EqSat::Unit<name> \
{ \
static constexpr const char* tag = #name; \
using Unit::Unit; \
}
#define LUAU_EQSAT_ATOM(name, t) \
struct name : public ::Luau::EqSat::Atom<name, t> \
@ -31,21 +41,57 @@
using NodeVector::NodeVector; \
}
#define LUAU_EQSAT_FIELD(name) \
struct name : public ::Luau::EqSat::Field<name> \
{ \
}
#define LUAU_EQSAT_NODE_FIELDS(name, ...) \
struct name : public ::Luau::EqSat::NodeFields<name, __VA_ARGS__> \
#define LUAU_EQSAT_NODE_SET(name) \
struct name : public ::Luau::EqSat::NodeSet<name, std::vector<::Luau::EqSat::Id>> \
{ \
static constexpr const char* tag = #name; \
using NodeFields::NodeFields; \
using NodeSet::NodeSet; \
}
#define LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(name, t) \
struct name : public ::Luau::EqSat::NodeAtomAndVector<name, t, std::vector<::Luau::EqSat::Id>> \
{ \
static constexpr const char* tag = #name; \
using NodeAtomAndVector::NodeAtomAndVector; \
}
namespace Luau::EqSat
{
template<typename Phantom>
struct Unit
{
Slice<Id> mutableOperands()
{
return {};
}
Slice<const Id> operands() const
{
return {};
}
bool operator==(const Unit& rhs) const
{
return true;
}
bool operator!=(const Unit& rhs) const
{
return false;
}
struct Hash
{
size_t operator()(const Unit& value) const
{
// chosen by fair dice roll.
// guaranteed to be random.
return 4;
}
};
};
template<typename Phantom, typename T>
struct Atom
{
@ -60,7 +106,7 @@ struct Atom
}
public:
Slice<Id> operands()
Slice<Id> mutableOperands()
{
return {};
}
@ -92,6 +138,62 @@ private:
T _value;
};
template<typename Phantom, typename X, typename T>
struct NodeAtomAndVector
{
template<typename... Args>
NodeAtomAndVector(const X& value, Args&&... args)
: _value(value)
, vector{std::forward<Args>(args)...}
{
}
Id operator[](size_t i) const
{
return vector[i];
}
public:
const X& value() const
{
return _value;
}
Slice<Id> mutableOperands()
{
return Slice{vector.data(), vector.size()};
}
Slice<const Id> operands() const
{
return Slice{vector.data(), vector.size()};
}
bool operator==(const NodeAtomAndVector& rhs) const
{
return _value == rhs._value && vector == rhs.vector;
}
bool operator!=(const NodeAtomAndVector& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const NodeAtomAndVector& value) const
{
size_t result = languageHash(value._value);
hashCombine(result, languageHash(value.vector));
return result;
}
};
private:
X _value;
T vector;
};
template<typename Phantom, typename T>
struct NodeVector
{
@ -107,7 +209,7 @@ struct NodeVector
}
public:
Slice<Id> operands()
Slice<Id> mutableOperands()
{
return Slice{vector.data(), vector.size()};
}
@ -139,90 +241,61 @@ private:
T vector;
};
/// Empty base class just for static_asserts.
struct FieldBase
template<typename Phantom, typename T>
struct NodeSet
{
FieldBase() = delete;
FieldBase(FieldBase&&) = delete;
FieldBase& operator=(FieldBase&&) = delete;
FieldBase(const FieldBase&) = delete;
FieldBase& operator=(const FieldBase&) = delete;
};
template<typename Phantom>
struct Field : FieldBase
{
};
template<typename Phantom, typename... Fields>
struct NodeFields
{
static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value);
template<typename T>
static constexpr int getIndex()
template<typename... Args>
NodeSet(Args&&... args)
: vector{std::forward<Args>(args)...}
{
constexpr int N = sizeof...(Fields);
constexpr bool is[N] = {std::is_same_v<std::decay_t<T>, Fields>...};
std::sort(begin(vector), end(vector));
auto it = std::unique(begin(vector), end(vector));
vector.erase(it, end(vector));
}
for (int i = 0; i < N; ++i)
if (is[i])
return i;
return -1;
Id operator[](size_t i) const
{
return vector[i];
}
public:
template<typename... Args>
NodeFields(Args&&... args)
: array{std::forward<Args>(args)...}
Slice<Id> mutableOperands()
{
}
Slice<Id> operands()
{
return Slice{array};
return Slice{vector.data(), vector.size()};
}
Slice<const Id> operands() const
{
return Slice{array.data(), array.size()};
return Slice{vector.data(), vector.size()};
}
template<typename T>
Id field() const
bool operator==(const NodeSet& rhs) const
{
static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Fields>...>);
return array[getIndex<T>()];
return vector == rhs.vector;
}
bool operator==(const NodeFields& rhs) const
{
return array == rhs.array;
}
bool operator!=(const NodeFields& rhs) const
bool operator!=(const NodeSet& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const NodeFields& value) const
size_t operator()(const NodeSet& value) const
{
return languageHash(value.array);
return languageHash(value.vector);
}
};
private:
std::array<Id, sizeof...(Fields)> array;
protected:
T vector;
};
template<typename... Ts>
struct Language final
{
using VariantTy = Luau::Variant<Ts...>;
template<typename T>
using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>;
@ -237,14 +310,14 @@ struct Language final
return v.index();
}
/// You should never call this function with the intention of mutating the `Id`.
/// Reading is ok, but you should also never assume that these `Id`s are stable.
Slice<Id> operands() noexcept
/// This should only be used in canonicalization!
/// Always prefer operands()
Slice<Id> mutableOperands() noexcept
{
return visit(
[](auto&& v) -> Slice<Id>
{
return v.operands();
return v.mutableOperands();
},
v
);
@ -306,7 +379,7 @@ public:
};
private:
Variant<Ts...> v;
VariantTy v;
};
} // namespace Luau::EqSat

View File

@ -3,6 +3,7 @@
#include <cstddef>
#include <functional>
#include <unordered_set>
#include <vector>
namespace Luau::EqSat

View File

@ -14,7 +14,9 @@ struct UnionFind final
Id makeSet();
Id find(Id id) const;
Id find(Id id);
void merge(Id a, Id b);
// Merge aSet with bSet and return the canonicalized Id into the merged set.
Id merge(Id aSet, Id bSet);
private:
std::vector<Id> parents;

View File

@ -1,15 +1,16 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Id.h"
#include "Luau/Common.h"
namespace Luau::EqSat
{
Id::Id(size_t id)
Id::Id(uint32_t id)
: id(id)
{
}
Id::operator size_t() const
Id::operator uint32_t() const
{
return id;
}
@ -24,9 +25,14 @@ bool Id::operator!=(Id rhs) const
return id != rhs.id;
}
bool Id::operator<(Id rhs) const
{
return id < rhs.id;
}
} // namespace Luau::EqSat
size_t std::hash<Luau::EqSat::Id>::operator()(Luau::EqSat::Id id) const
{
return std::hash<size_t>()(size_t(id));
return std::hash<uint32_t>()(uint32_t(id));
}

View File

@ -8,7 +8,9 @@ namespace Luau::EqSat
Id UnionFind::makeSet()
{
Id id{parents.size()};
LUAU_ASSERT(parents.size() < std::numeric_limits<uint32_t>::max());
Id id{uint32_t(parents.size())};
parents.push_back(id);
ranks.push_back(0);
@ -25,42 +27,44 @@ Id UnionFind::find(Id id)
Id set = canonicalize(id);
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
while (id != parents[size_t(id)])
while (id != parents[uint32_t(id)])
{
// Note: we don't update the ranks here since a rank
// represents the upper bound on the maximum depth of a tree
Id parent = parents[size_t(id)];
parents[size_t(id)] = set;
Id parent = parents[uint32_t(id)];
parents[uint32_t(id)] = set;
id = parent;
}
return set;
}
void UnionFind::merge(Id a, Id b)
Id UnionFind::merge(Id a, Id b)
{
Id aSet = find(a);
Id bSet = find(b);
if (aSet == bSet)
return;
return aSet;
// Ensure that the rank of set A is greater than the rank of set B
if (ranks[size_t(aSet)] < ranks[size_t(bSet)])
if (ranks[uint32_t(aSet)] > ranks[uint32_t(bSet)])
std::swap(aSet, bSet);
parents[size_t(bSet)] = aSet;
parents[uint32_t(bSet)] = aSet;
if (ranks[size_t(aSet)] == ranks[size_t(bSet)])
ranks[size_t(aSet)]++;
if (ranks[uint32_t(aSet)] == ranks[uint32_t(bSet)])
ranks[uint32_t(aSet)]++;
return aSet;
}
Id UnionFind::canonicalize(Id id) const
{
LUAU_ASSERT(size_t(id) < parents.size());
LUAU_ASSERT(uint32_t(id) < parents.size());
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
while (id != parents[size_t(id)])
id = parents[size_t(id)];
while (id != parents[uint32_t(id)])
id = parents[uint32_t(id)];
return id;
}

View File

@ -14,6 +14,7 @@ endif()
# Luau.Ast Sources
target_sources(Luau.Ast PRIVATE
Ast/include/Luau/Allocator.h
Ast/include/Luau/Ast.h
Ast/include/Luau/Confusables.h
Ast/include/Luau/Lexer.h
@ -24,6 +25,7 @@ target_sources(Luau.Ast PRIVATE
Ast/include/Luau/StringUtils.h
Ast/include/Luau/TimeTrace.h
Ast/src/Allocator.cpp
Ast/src/Ast.cpp
Ast/src/Confusables.cpp
Ast/src/Lexer.cpp
@ -168,6 +170,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/AstJsonEncoder.h
Analysis/include/Luau/AstQuery.h
Analysis/include/Luau/Autocomplete.h
Analysis/include/Luau/AutocompleteTypes.h
Analysis/include/Luau/BuiltinDefinitions.h
Analysis/include/Luau/Cancellation.h
Analysis/include/Luau/Clone.h
@ -181,6 +184,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Differ.h
Analysis/include/Luau/Documentation.h
Analysis/include/Luau/Error.h
Analysis/include/Luau/EqSatSimplification.h
Analysis/include/Luau/FileResolver.h
Analysis/include/Luau/FragmentAutocomplete.h
Analysis/include/Luau/Frontend.h
@ -245,6 +249,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/AstJsonEncoder.cpp
Analysis/src/AstQuery.cpp
Analysis/src/Autocomplete.cpp
Analysis/src/AutocompleteCore.cpp
Analysis/src/BuiltinDefinitions.cpp
Analysis/src/Clone.cpp
Analysis/src/Constraint.cpp
@ -256,6 +261,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Differ.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp
Analysis/src/EqSatSimplification.cpp
Analysis/src/FragmentAutocomplete.cpp
Analysis/src/Frontend.cpp
Analysis/src/Generalization.cpp
@ -417,7 +423,7 @@ endif()
if(TARGET Luau.UnitTest)
# Luau.UnitTest Sources
target_sources(Luau.UnitTest PRIVATE
tests/AnyTypeSummary.test.cpp
tests/AnyTypeSummary.test.cpp
tests/AssemblyBuilderA64.test.cpp
tests/AssemblyBuilderX64.test.cpp
tests/AstJsonEncoder.test.cpp
@ -444,6 +450,7 @@ if(TARGET Luau.UnitTest)
tests/EqSat.language.test.cpp
tests/EqSat.propositional.test.cpp
tests/EqSat.slice.test.cpp
tests/EqSatSimplification.test.cpp
tests/Error.test.cpp
tests/Fixture.cpp
tests/Fixture.h

View File

@ -39,7 +39,7 @@ const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Ri
"$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n"
"$URL: www.lua.org $\n";
const char* luau_ident = "$Luau: Copyright (C) 2019-2023 Roblox Corporation $\n"
const char* luau_ident = "$Luau: Copyright (C) 2019-2024 Roblox Corporation $\n"
"$URL: luau.org $\n";
#define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base))

View File

@ -67,7 +67,7 @@ TEST_CASE("encode_constants")
charString.data = const_cast<char*>("a\x1d\0\\\"b");
charString.size = 6;
AstExprConstantString needsEscaping{Location(), charString};
AstExprConstantString needsEscaping{Location(), charString, AstExprConstantString::QuotedSimple};
CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil));
CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b));
@ -83,7 +83,7 @@ TEST_CASE("basic_escaping")
{
std::string s = "hello \"world\"";
AstArray<char> theString{s.data(), s.size()};
AstExprConstantString str{Location(), theString};
AstExprConstantString str{Location(), theString, AstExprConstantString::QuotedSimple};
std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})";
CHECK_EQ(expected, toJson(&str));

View File

@ -151,40 +151,6 @@ struct ACBuiltinsFixture : ACFixtureImpl<BuiltinsFixture>
{
};
#define LUAU_CHECK_HAS_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \
if (!count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)
#define LUAU_CHECK_HAS_NO_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \
if (count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)
TEST_SUITE_BEGIN("AutocompleteTest");
TEST_CASE_FIXTURE(ACFixture, "empty_program")

View File

@ -58,7 +58,11 @@ TEST_CASE("report_a_syntax_error")
TEST_CASE("noinfer_is_still_allowed")
{
Config config;
auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, true);
ConfigOptions opts;
opts.compat = true;
auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, opts);
REQUIRE(!err);
CHECK_EQ(int(Luau::Mode::NoCheck), int(config.mode));
@ -147,6 +151,10 @@ TEST_CASE("extra_globals")
TEST_CASE("lint_rules_compat")
{
Config config;
ConfigOptions opts;
opts.compat = true;
auto err = parseConfig(
R"(
{"lint": {
@ -156,7 +164,7 @@ TEST_CASE("lint_rules_compat")
}}
)",
config,
true
opts
);
REQUIRE(!err);

View File

@ -10,6 +10,7 @@ namespace Luau
ConstraintGeneratorFixture::ConstraintGeneratorFixture()
: Fixture()
, mainModule(new Module)
, simplifier(newSimplifier(NotNull{&arena}, builtinTypes))
, forceTheFlag{FFlag::LuauSolverV2, true}
{
mainModule->name = "MainModule";
@ -25,6 +26,7 @@ void ConstraintGeneratorFixture::generateConstraints(const std::string& code)
cg = std::make_unique<ConstraintGenerator>(
mainModule,
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull(&moduleResolver),
builtinTypes,
@ -44,8 +46,19 @@ void ConstraintGeneratorFixture::solve(const std::string& code)
{
generateConstraints(code);
ConstraintSolver cs{
NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, NotNull{dfg.get()}, {}
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{rootScope},
constraints,
"MainModule",
NotNull(&moduleResolver),
{},
&logger,
NotNull{dfg.get()},
{}
};
cs.run();
}

View File

@ -4,8 +4,9 @@
#include "Luau/ConstraintGenerator.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/DcrLogger.h"
#include "Luau/TypeArena.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Module.h"
#include "Luau/TypeArena.h"
#include "Fixture.h"
#include "ScopedFlags.h"
@ -20,6 +21,7 @@ struct ConstraintGeneratorFixture : Fixture
DcrLogger logger;
UnifierSharedState sharedState{&ice};
Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}};
SimplifierPtr simplifier;
TypeCheckLimits limits;
TypeFunctionRuntime typeFunctionRuntime{NotNull{&ice}, NotNull{&limits}};

View File

@ -11,9 +11,7 @@ LUAU_EQSAT_ATOM(I32, int);
LUAU_EQSAT_ATOM(Bool, bool);
LUAU_EQSAT_ATOM(Str, std::string);
LUAU_EQSAT_FIELD(Left);
LUAU_EQSAT_FIELD(Right);
LUAU_EQSAT_NODE_FIELDS(Add, Left, Right);
LUAU_EQSAT_NODE_ARRAY(Add, 2);
using namespace Luau;
@ -117,8 +115,8 @@ TEST_CASE("node_field")
Add add{left, right};
EqSat::Id left2 = add.field<Left>();
EqSat::Id right2 = add.field<Right>();
EqSat::Id left2 = add.operands()[0];
EqSat::Id right2 = add.operands()[1];
CHECK(left == left2);
CHECK(left != right2);
@ -135,10 +133,10 @@ TEST_CASE("language_operands")
const Add* add = v2.get<Add>();
REQUIRE(add);
EqSat::Slice<EqSat::Id> actual = v2.operands();
EqSat::Slice<const EqSat::Id> actual = v2.operands();
CHECK(actual.size() == 2);
CHECK(actual[0] == add->field<Left>());
CHECK(actual[1] == add->field<Right>());
CHECK(actual[0] == add->operands()[0]);
CHECK(actual[1] == add->operands()[1]);
}
TEST_SUITE_END();

View File

@ -0,0 +1,728 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Fixture.h"
#include "Luau/EqSatSimplification.h"
using namespace Luau;
struct ESFixture : Fixture
{
ScopedFastFlag newSolverOnly{FFlag::LuauSolverV2, true};
TypeArena arena_;
const NotNull<TypeArena> arena{&arena_};
SimplifierPtr simplifier;
TypeId parentClass;
TypeId childClass;
TypeId anotherChild;
TypeId unrelatedClass;
TypeId genericT = arena_.addType(GenericType{"T"});
TypeId genericU = arena_.addType(GenericType{"U"});
TypeId numberToString = arena_.addType(FunctionType{
arena_.addTypePack({builtinTypes->numberType}),
arena_.addTypePack({builtinTypes->stringType})
});
TypeId stringToNumber = arena_.addType(FunctionType{
arena_.addTypePack({builtinTypes->stringType}),
arena_.addTypePack({builtinTypes->numberType})
});
ESFixture()
: simplifier(newSimplifier(arena, builtinTypes))
{
createSomeClasses(&frontend);
ScopePtr moduleScope = frontend.globals.globalScope;
parentClass = moduleScope->linearSearchForBinding("Parent")->typeId;
childClass = moduleScope->linearSearchForBinding("Child")->typeId;
anotherChild = moduleScope->linearSearchForBinding("AnotherChild")->typeId;
unrelatedClass = moduleScope->linearSearchForBinding("Unrelated")->typeId;
}
std::optional<std::string> simplifyStr(TypeId ty)
{
auto res = eqSatSimplify(NotNull{simplifier.get()}, ty);
LUAU_ASSERT(res);
return toString(res->result);
}
TypeId tbl(TableType::Props props)
{
return arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, TableState::Sealed});
}
};
TEST_SUITE_BEGIN("EqSatSimplification");
TEST_CASE_FIXTURE(ESFixture, "primitive")
{
CHECK("number" == simplifyStr(builtinTypes->numberType));
}
TEST_CASE_FIXTURE(ESFixture, "number | number")
{
TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->numberType}});
CHECK("number" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "number | string")
{
CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = number | t1")
{
TypeId ty = arena->freshType(nullptr);
asMutable(ty)->ty.emplace<UnionType>(std::vector<TypeId>{builtinTypes->numberType, ty});
CHECK("number" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "number | string | number")
{
TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType, builtinTypes->numberType}});
CHECK("number | string" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "string | (number | string) | number")
{
TypeId u1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}});
TypeId u2 = arena->addType(UnionType{{builtinTypes->stringType, u1, builtinTypes->numberType}});
CHECK("number | string" == simplifyStr(u2));
}
TEST_CASE_FIXTURE(ESFixture, "string | any")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->anyType}})));
}
TEST_CASE_FIXTURE(ESFixture, "any | string")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "any | never")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string | unknown")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->unknownType}})));
}
TEST_CASE_FIXTURE(ESFixture, "unknown | string")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "unknown | never")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string | never")
{
CHECK("string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string | never | number")
{
CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType, builtinTypes->numberType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & string")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & number")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->numberType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & unknown")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->unknownType}})));
}
TEST_CASE_FIXTURE(ESFixture, "never & string")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->neverType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & (unknown | never)")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->stringType,
arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "true | false")
{
CHECK("boolean" == simplifyStr(arena->addType(UnionType{{builtinTypes->trueType, builtinTypes->falseType}})));
}
/*
* Intuitively, if we have a type like
*
* x where x = A & B & (C | D | x)
*
* We know that x is certainly not larger than A & B.
* We also know that the union (C | D | x) can be rewritten `(C | D | (A & B & (C | D | x)))
* This tells us that the union part is not smaller than A & B.
* We can therefore discard the union entirely and simplify this type to A & B
*/
TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (number | t1)")
{
TypeId intersectionTy = arena->addType(BlockedType{});
TypeId unionTy = arena->addType(UnionType{{builtinTypes->numberType, intersectionTy}});
asMutable(intersectionTy)->ty.emplace<IntersectionType>(std::vector<TypeId>{builtinTypes->stringType, unionTy});
CHECK("string" == simplifyStr(intersectionTy));
}
TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (unknown | t1)")
{
TypeId intersectionTy = arena->addType(BlockedType{});
TypeId unionTy = arena->addType(UnionType{{builtinTypes->unknownType, intersectionTy}});
asMutable(intersectionTy)->ty.emplace<IntersectionType>(std::vector<TypeId>{builtinTypes->stringType, unionTy});
CHECK("string" == simplifyStr(intersectionTy));
}
TEST_CASE_FIXTURE(ESFixture, "error | unknown")
{
CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->errorType, builtinTypes->unknownType}})));
}
TEST_CASE_FIXTURE(ESFixture, "\"hello\" | string")
{
CHECK("string" == simplifyStr(arena->addType(UnionType{{
arena->addType(SingletonType{StringSingleton{"hello"}}), builtinTypes->stringType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "\"hello\" | \"world\" | \"hello\"")
{
CHECK("\"hello\" | \"world\"" == simplifyStr(arena->addType(UnionType{{
arena->addType(SingletonType{StringSingleton{"hello"}}),
arena->addType(SingletonType{StringSingleton{"world"}}),
arena->addType(SingletonType{StringSingleton{"hello"}}),
}})));
}
TEST_CASE_FIXTURE(ESFixture, "nil | boolean | number | string | thread | function | table | class | buffer")
{
CHECK("unknown" == simplifyStr(arena->addType(UnionType{{
builtinTypes->nilType,
builtinTypes->booleanType,
builtinTypes->numberType,
builtinTypes->stringType,
builtinTypes->threadType,
builtinTypes->functionType,
builtinTypes->tableType,
builtinTypes->classType,
builtinTypes->bufferType,
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Parent & number")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
parentClass, builtinTypes->numberType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child & Parent")
{
CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{
childClass, parentClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child & Unrelated")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
childClass, unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child | Parent")
{
CHECK("Parent" == simplifyStr(arena->addType(UnionType{{
childClass, parentClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "class | Child")
{
CHECK("class" == simplifyStr(arena->addType(UnionType{{
builtinTypes->classType, childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Parent | class | Child")
{
CHECK("class" == simplifyStr(arena->addType(UnionType{{
parentClass, builtinTypes->classType, childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "Parent | Unrelated")
{
CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{
parentClass, unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "never | Parent | Unrelated")
{
CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{
builtinTypes->neverType, parentClass, unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "never | Parent | (number & string) | Unrelated")
{
CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{
builtinTypes->neverType, parentClass,
arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}),
unrelatedClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "T & U")
{
CHECK("T & U" == simplifyStr(arena->addType(IntersectionType{{
genericT, genericU
}})));
}
TEST_CASE_FIXTURE(ESFixture, "boolean & true")
{
CHECK("true" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->booleanType, builtinTypes->trueType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "boolean & (true | number | string | thread | function | table | class | buffer)")
{
TypeId truthy = arena->addType(UnionType{{
builtinTypes->trueType,
builtinTypes->numberType,
builtinTypes->stringType,
builtinTypes->threadType,
builtinTypes->functionType,
builtinTypes->tableType,
builtinTypes->classType,
builtinTypes->bufferType,
}});
CHECK("true" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->booleanType, truthy
}})));
}
TEST_CASE_FIXTURE(ESFixture, "boolean & ~(false?)")
{
CHECK("true" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->booleanType, builtinTypes->truthyType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "false & ~(false?)")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->falseType, builtinTypes->truthyType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (number) -> string")
{
CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, numberToString}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (number) -> string")
{
CHECK("(number) -> string" == simplifyStr(arena->addType(UnionType{{numberToString, numberToString}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & function")
{
CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->functionType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & boolean")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->booleanType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & string")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & ~function")
{
TypeId notFunction = arena->addType(NegationType{builtinTypes->functionType});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, notFunction}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string | function")
{
CHECK("function" == simplifyStr(arena->addType(UnionType{{numberToString, builtinTypes->functionType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (string) -> number")
{
CHECK("((number) -> string) & ((string) -> number)" == simplifyStr(arena->addType(IntersectionType{{numberToString, stringToNumber}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (string) -> number")
{
CHECK("((number) -> string) | ((string) -> number)" == simplifyStr(arena->addType(UnionType{{numberToString, stringToNumber}})));
}
TEST_CASE_FIXTURE(ESFixture, "add<number, number>")
{
CHECK("number" == simplifyStr(arena->addType(
TypeFunctionInstanceType{builtinTypeFunctions().addFunc, {
builtinTypes->numberType, builtinTypes->numberType
}}
)));
}
TEST_CASE_FIXTURE(ESFixture, "union<number, number>")
{
CHECK("number" == simplifyStr(arena->addType(
TypeFunctionInstanceType{builtinTypeFunctions().unionFunc, {
builtinTypes->numberType, builtinTypes->numberType
}}
)));
}
TEST_CASE_FIXTURE(ESFixture, "never & ~string")
{
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->neverType,
arena->addType(NegationType{builtinTypes->stringType})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "blocked & never")
{
const TypeId blocked = arena->addType(BlockedType{});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{blocked, builtinTypes->neverType}})));
}
TEST_CASE_FIXTURE(ESFixture, "blocked & ~number & function")
{
const TypeId blocked = arena->addType(BlockedType{});
const TypeId notNumber = arena->addType(NegationType{builtinTypes->numberType});
const TypeId ty = arena->addType(IntersectionType{{blocked, notNumber, builtinTypes->functionType}});
std::string expected = toString(blocked) + " & function";
CHECK(expected == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "(number | boolean | string | nil | table) & (false | nil)")
{
const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->stringType, builtinTypes->nilType, builtinTypes->tableType}});
CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(number | boolean | nil) & (false | nil)")
{
const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->nilType}});
CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(boolean | nil) & (false | nil)")
{
const TypeId t1 = arena->addType(UnionType{{builtinTypes->booleanType, builtinTypes->nilType}});
CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}})));
}
// (('a & false) | ('a & nil)) | number
// Child & ~Parent
// ~Parent & Child
// ~Child & Parent
// Parent & ~Child
// ~Child & ~Parent
// ~Parent & ~Child
TEST_CASE_FIXTURE(ESFixture, "free & string & number")
{
Scope scope{builtinTypes->anyTypePack};
const TypeId freeTy = arena->addType(FreeType{&scope});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{freeTy, builtinTypes->numberType, builtinTypes->stringType}})));
}
TEST_CASE_FIXTURE(ESFixture, "(blocked & number) | (blocked & number)")
{
const TypeId blocked = arena->addType(BlockedType{});
const TypeId u = arena->addType(IntersectionType{{blocked, builtinTypes->numberType}});
const TypeId ty = arena->addType(UnionType{{u, u}});
const std::string blockedStr = toString(blocked);
CHECK(blockedStr + " & number" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "{} & unknown")
{
CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{
tbl({}),
builtinTypes->unknownType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{} & table")
{
CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{
tbl({}),
builtinTypes->tableType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{} & ~(false?)")
{
CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{
tbl({}),
builtinTypes->truthyType
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: number}")
{
const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}});
const TypeId hasX = tbl({{"x", builtinTypes->numberType}});
const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}});
auto res = eqSatSimplify(NotNull{simplifier.get()}, ty);
CHECK("{ x: number }" == toString(res->result));
// Also assert that we don't allocate a fresh TableType in this case.
CHECK(follow(res->result) == hasX);
}
TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: ~(false?)}")
{
const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}});
const TypeId hasX = tbl({{"x", builtinTypes->truthyType}});
const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}});
auto res = eqSatSimplify(NotNull{simplifier.get()}, ty);
CHECK("{ x: number }" == toString(res->result));
}
TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) }")
{
// {x: number?}?
const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}});
// {x: ~(false?)}
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
const TypeId ty = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}});
CHECK("{ x: number }" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "never | (({ x: number? }?) & { x: ~(false?) })")
{
// {x: number?}?
const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}});
// {x: ~(false?)}
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
// ({x: number?}?) & {x: ~(false?)}
const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}});
const TypeId ty = arena->addType(UnionType{{builtinTypes->neverType, intersectionTy}});
CHECK("{ x: number }" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "({ x: number? }?) & { x: ~(false?) } & ~(false?)")
{
// {x: number?}?
const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}});
// {x: ~(false?)}
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
// ({x: number?}?) & {x: ~(false?)} & ~(false?)
const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}});
CHECK("{ x: number }" == simplifyStr(intersectionTy));
}
#if 0
// TODO
TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) } & ~(false?)) | number")
{
// ({ x: number? }?) & { x: ~(false?) } & ~(false?)
const TypeId xWithOptionalNumber = tbl({{"x", builtinTypes->optionalNumberType}});
const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}});
const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}});
const TypeId ty = arena->addType(UnionType{{intersectionTy, builtinTypes->numberType}});
CHECK("{ x: number } | number" == simplifyStr(ty));
}
#endif
TEST_CASE_FIXTURE(ESFixture, "number & no-refine")
{
CHECK("number" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->noRefineType}})));
}
TEST_CASE_FIXTURE(ESFixture, "{ x: number } & ~boolean")
{
const TypeId tblTy = tbl(TableType::Props{{"x", builtinTypes->numberType}});
const TypeId ty = arena->addType(IntersectionType{{
tblTy,
arena->addType(NegationType{builtinTypes->booleanType})
}});
CHECK("{ x: number }" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "(nil & string)?")
{
const TypeId nilAndString = arena->addType(IntersectionType{{builtinTypes->nilType, builtinTypes->stringType}});
const TypeId ty = arena->addType(UnionType{{nilAndString, builtinTypes->nilType}});
CHECK("nil" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "string & \"hi\"")
{
const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}});
CHECK("\"hi\"" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, hi}})));
}
TEST_CASE_FIXTURE(ESFixture, "string & (\"hi\" | \"bye\")")
{
const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}});
const TypeId bye = arena->addType(SingletonType{StringSingleton{"bye"}});
CHECK("\"bye\" | \"hi\"" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->stringType,
arena->addType(UnionType{{hi, bye}})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & ~Child")
{
const TypeId ty = arena->addType(IntersectionType{{
arena->addType(UnionType{{childClass, unrelatedClass}}),
arena->addType(NegationType{childClass})
}});
CHECK("Unrelated" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "string & ~Child")
{
CHECK("string" == simplifyStr(arena->addType(IntersectionType{{
builtinTypes->stringType,
arena->addType(NegationType{childClass})
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & Child")
{
CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{
arena->addType(UnionType{{childClass, unrelatedClass}}),
childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "(Child | AnotherChild) & ~Child")
{
CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{
arena->addType(UnionType{{childClass, anotherChild}}),
childClass
}})));
}
TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: never }")
{
const TypeId ty = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->neverType}});
CHECK("never" == simplifyStr(ty));
}
TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: number? } & { x: string }")
{
const TypeId leftTable = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->optionalNumberType}});
const TypeId rightTable = tbl({{"x", builtinTypes->stringType}});
CHECK("never" == simplifyStr(arena->addType(IntersectionType{{leftTable, rightTable}})));
}
TEST_CASE_FIXTURE(ESFixture, "Child & add<Child | AnotherChild | string, Parent>")
{
const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}});
const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{
builtinTypeFunctions().addFunc,
{u, parentClass},
{}
});
const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}});
CHECK("Child & add<AnotherChild | Child | string, Parent>" == simplifyStr(intersection));
}
TEST_CASE_FIXTURE(ESFixture, "Child & intersect<Child | AnotherChild | string, Parent>")
{
const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}});
const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{
builtinTypeFunctions().intersectFunc,
{u, parentClass},
{}
});
const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}});
CHECK("Child" == simplifyStr(intersection));
}
// {someKey: ~any}
//
// Maybe something we could do here is to try to reduce the key, get the
// class->node mapping, and skip the extraction process if the class corresponds
// to TNever.
// t1 where t1 = add<union<number, t1>, number>
TEST_SUITE_END();

View File

@ -293,3 +293,37 @@ using DifferFixtureWithBuiltins = DifferFixtureGeneric<BuiltinsFixture>;
} while (false)
#define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result)
#define LUAU_CHECK_HAS_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \
if (!count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)
#define LUAU_CHECK_HAS_NO_KEY(map, key) \
do \
{ \
auto&& _m = (map); \
auto&& _k = (key); \
const size_t count = _m.count(_k); \
CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \
if (count) \
{ \
MESSAGE("Keys: (count " << _m.size() << ")"); \
for (const auto& [k, v] : _m) \
{ \
MESSAGE("\tkey: " << k); \
} \
} \
} while (false)

View File

@ -4,19 +4,37 @@
#include "Fixture.h"
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Autocomplete.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Common.h"
#include "Luau/Frontend.h"
#include "Luau/AutocompleteTypes.h"
using namespace Luau;
LUAU_FASTFLAG(LuauAllowFragmentParsing);
LUAU_FASTFLAG(LuauStoreDFGOnModule2);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
static std::optional<AutocompleteEntryMap> nullCallback(std::string tag, std::optional<const ClassType*> ptr, std::optional<std::string> contents)
{
return std::nullopt;
}
struct FragmentAutocompleteFixture : Fixture
{
ScopedFastFlag sffs[3] = {{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}, {FFlag::LuauStoreDFGOnModule2, true}};
ScopedFastFlag sffs[4] = {
{FFlag::LuauAllowFragmentParsing, true},
{FFlag::LuauSolverV2, true},
{FFlag::LuauStoreDFGOnModule2, true},
{FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete, true}
};
FragmentAutocompleteFixture()
{
addGlobalBinding(frontend.globals, "table", Binding{builtinTypes->anyType});
addGlobalBinding(frontend.globals, "math", Binding{builtinTypes->anyType});
}
FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos)
{
ParseResult p = tryParse(source); // We don't care about parsing incomplete asts
@ -26,7 +44,6 @@ struct FragmentAutocompleteFixture : Fixture
CheckResult checkBase(const std::string& document)
{
ScopedFastFlag sff{FFlag::LuauSolverV2, true};
FrontendOptions opts;
opts.retainFullTypeGraphs = true;
return this->frontend.check("MainModule", opts);
@ -48,6 +65,16 @@ struct FragmentAutocompleteFixture : Fixture
options.runLintChecks = false;
return Luau::typecheckFragment(frontend, "MainModule", cursorPos, options, document);
}
FragmentAutocompleteResult autocompleteFragment(const std::string& document, Position cursorPos)
{
FrontendOptions options;
options.retainFullTypeGraphs = true;
// Don't strictly need this in the new solver
options.forAutocomplete = true;
options.runLintChecks = false;
return Luau::fragmentAutocomplete(frontend, document, "MainModule", cursorPos, options, nullCallback);
}
};
TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests");
@ -172,6 +199,13 @@ TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteParserTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer")
{
check("local a =");
auto fragment = parseFragment("local a =", Position(0, 10));
CHECK_EQ("local a =", fragment.fragmentToParse);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null")
{
auto res = check(R"(
@ -278,6 +312,33 @@ local y = 5
CHECK_EQ("y", std::string(rhs->name.value));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope")
{
check(R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)");
auto fragment = parseFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{6, 0}
);
CHECK_EQ("function abc()\n local myInnerLocal = 1\n\n end\n", fragment.fragmentToParse);
}
TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests");
@ -302,7 +363,7 @@ local z = x + y
Position{3, 15}
);
auto opt = linearSearchForBinding(fragment.freshScope, "z");
auto opt = linearSearchForBinding(fragment.freshScope.get(), "z");
REQUIRE(opt);
CHECK_EQ("number", toString(*opt));
}
@ -326,9 +387,222 @@ local y = 5
Position{2, 11}
);
auto correct = linearSearchForBinding(fragment.freshScope, "z");
auto correct = linearSearchForBinding(fragment.freshScope.get(), "z");
REQUIRE(correct);
CHECK_EQ("number", toString(*correct));
}
TEST_SUITE_END();
TEST_SUITE_BEGIN("FragmentAutocompleteTests");
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_simple_property_access")
{
auto res = check(
R"(
local tbl = { abc = 1234}
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
local tbl = { abc = 1234}
tbl.
)",
Position{2, 5}
);
LUAU_ASSERT(fragment.freshScope);
CHECK_EQ(1, fragment.acResults.entryMap.size());
CHECK(fragment.acResults.entryMap.count("abc"));
CHECK_EQ(AutocompleteContext::Property, fragment.acResults.context);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_nested_property_access")
{
auto res = check(
R"(
local tbl = { abc = { def = 1234, egh = false } }
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
local tbl = { abc = { def = 1234, egh = false } }
tbl.abc.
)",
Position{2, 8}
);
LUAU_ASSERT(fragment.freshScope);
CHECK_EQ(2, fragment.acResults.entryMap.size());
CHECK(fragment.acResults.entryMap.count("def"));
CHECK(fragment.acResults.entryMap.count("egh"));
CHECK_EQ(fragment.acResults.context, AutocompleteContext::Property);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "inline_autocomplete_picks_the_right_scope")
{
auto res = check(
R"(
type Table = { a: number, b: number }
do
type Table = { x: string, y: string }
end
)"
);
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
type Table = { a: number, b: number }
do
type Table = { x: string, y: string }
local a : T
end
)",
Position{4, 15}
);
LUAU_ASSERT(fragment.freshScope);
REQUIRE(fragment.acResults.entryMap.count("Table"));
REQUIRE(fragment.acResults.entryMap["Table"].type);
const TableType* tv = get<TableType>(follow(*fragment.acResults.entryMap["Table"].type));
REQUIRE(tv);
CHECK(tv->props.count("x"));
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nested_recursive_function")
{
auto res = check(R"(
function foo()
end
)");
LUAU_REQUIRE_NO_ERRORS(res);
auto fragment = autocompleteFragment(
R"(
function foo()
end
)",
Position{2, 0}
);
CHECK(fragment.acResults.entryMap.count("foo"));
CHECK_EQ(AutocompleteContext::Statement, fragment.acResults.context);
}
// Start compatibility tests!
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "empty_program")
{
check("");
auto frag = autocompleteFragment(" ", Position{0, 1});
auto ac = frag.acResults;
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Statement);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer")
{
check("local a =");
auto frag = autocompleteFragment("local a =", Position{0, 9});
auto ac = frag.acResults;
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Expression);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "leave_numbers_alone")
{
check("local a = 3.");
auto frag = autocompleteFragment("local a = 3.", Position{0, 12});
auto ac = frag.acResults;
CHECK(ac.entryMap.empty());
CHECK_EQ(ac.context, AutocompleteContext::Unknown);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "user_defined_globals")
{
check("local myLocal = 4; ");
auto frag = autocompleteFragment("local myLocal = 4; ", Position{0, 18});
auto ac = frag.acResults;
CHECK(ac.entryMap.count("myLocal"));
CHECK(ac.entryMap.count("table"));
CHECK(ac.entryMap.count("math"));
CHECK_EQ(ac.context, AutocompleteContext::Statement);
}
TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "dont_suggest_local_before_its_definition")
{
check(R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)");
// autocomplete after abc but before myInnerLocal
auto fragment = autocompleteFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{3, 0}
);
auto ac = fragment.acResults;
CHECK(ac.entryMap.count("myLocal"));
LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal");
// autocomplete after my inner local
fragment = autocompleteFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{4, 0}
);
ac = fragment.acResults;
CHECK(ac.entryMap.count("myLocal"));
CHECK(ac.entryMap.count("myInnerLocal"));
fragment = autocompleteFragment(
R"(
local myLocal = 4
function abc()
local myInnerLocal = 1
end
)",
Position{6, 0}
);
ac = fragment.acResults;
CHECK(ac.entryMap.count("myLocal"));
LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal");
}
TEST_SUITE_END();

View File

@ -18,6 +18,7 @@ LUAU_FASTINT(LuauParseErrorLimit)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport)
namespace
{
@ -2377,10 +2378,15 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms")
TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions")
{
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag sff2{FFlag::LuauUserDefinedTypeFunParseExport, true};
AstStat* stat = parse(R"(
type function foo()
return
return types.number
end
export type function bar()
return types.string
end
)");
@ -2417,7 +2423,6 @@ TEST_CASE_FIXTURE(Fixture, "invalid_user_defined_type_functions")
{
ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
matchParseError("export type function foo() end", "Type function cannot be exported");
matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'");
matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'");
}

View File

@ -424,6 +424,13 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireUnprefixedPath")
assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"});
}
TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithExtension")
{
std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency.luau";
runProtectedRequire(path);
assertOutputContainsAll({"false", "error requiring module: consider removing the file extension"});
}
TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias")
{
std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/alias_requirer";

View File

@ -964,6 +964,7 @@ TEST_CASE_FIXTURE(Fixture, "correct_stringification_user_defined_type_functions"
std::vector<TypeId>{builtinTypes->numberType}, // Type Function Arguments
{},
{AstName{"woohoo"}}, // Type Function Name
{},
};
Type tv{tftt};

View File

@ -16,6 +16,8 @@ LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite)
LUAU_FASTFLAG(LuauUserTypeFunFixMetatable)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionResetState)
LUAU_FASTFLAG(LuauUserTypeFunNonstrict)
LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal)
LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport)
TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests");
@ -1298,4 +1300,92 @@ local a: foo<> = "a"
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "implicit_export")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true};
ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true};
ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true};
fileResolver.source["game/A"] = R"(
type function concat(a, b)
return types.singleton(a:value() .. b:value())
end
export type Concat<T, U> = concat<T, U>
local a: concat<'first', 'second'>
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")");
CheckResult bResult = check(R"(
local Test = require(game.A);
local b: Test.Concat<'third', 'fourth'>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
CHECK(toString(requireType("b")) == R"("thirdfourth")");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "local_scope")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true};
ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true};
ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true};
CheckResult result = check(R"(
type function foo()
return "hi"
end
local function test()
type function bar()
return types.singleton(foo())
end
return ("" :: any) :: bar<>
end
local a = test()
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK(toString(requireType("a")) == R"("hi")");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "explicit_export")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true};
ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true};
ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true};
ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true};
ScopedFastFlag luauUserDefinedTypeFunParseExport{FFlag::LuauUserDefinedTypeFunParseExport, true};
fileResolver.source["game/A"] = R"(
export type function concat(a, b)
return types.singleton(a:value() .. b:value())
end
local a: concat<'first', 'second'>
return {}
)";
CheckResult aResult = frontend.check("game/A");
LUAU_REQUIRE_NO_ERRORS(aResult);
CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")");
CheckResult bResult = check(R"(
local Test = require(game.A);
local b: Test.concat<'third', 'fourth'>
)");
LUAU_REQUIRE_NO_ERRORS(bResult);
CHECK(toString(requireType("b")) == R"("thirdfourth")");
}
TEST_SUITE_END();

View File

@ -20,6 +20,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack)
LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions)
TEST_SUITE_BEGIN("TypeInferFunctions");
@ -681,6 +682,11 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function")
TEST_CASE_FIXTURE(Fixture, "higher_order_function_2")
{
// CLI-114134: this code *probably* wants the egraph in order
// to work properly. The new solver either falls over or
// forces so many constraints as to be unreliable.
DOES_NOT_PASS_NEW_SOLVER_GUARD();
CheckResult result = check(R"(
function bottomupmerge(comp, a, b, left, mid, right)
local i, j = left, mid
@ -743,6 +749,11 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3")
TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4")
{
// CLI-114134: this code *probably* wants the egraph in order
// to work properly. The new solver either falls over or
// forces so many constraints as to be unreliable.
DOES_NOT_PASS_NEW_SOLVER_GUARD();
CheckResult result = check(R"(
function bottomupmerge(comp, a, b, left, mid, right)
local i, j = left, mid
@ -2554,8 +2565,17 @@ end
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type")
{
if (!FFlag::LuauSolverV2)
return;
ScopedFastFlag sffs[] = {
{FFlag::LuauSolverV2, true},
{FFlag::LuauDontRefCountTypesInTypeFunctions, true}
};
// CLI-114134: This test:
// a) Has a kind of weird result (suggesting `number | false` is not great);
// b) Is force solving some constraints.
// We end up with a weird recursive type that, if you roughly look at it, is
// clearly `number`. Hopefully the egraph will be able to unfold this.
CheckResult result = check(R"(
function fib(n)
return n < 2 and 1 or fib(n-1) + fib(n-2)
@ -2565,9 +2585,7 @@ end
LUAU_REQUIRE_ERRORS(result);
auto err = get<ExplicitFunctionAnnotationRecommended>(result.errors.back());
LUAU_ASSERT(err);
CHECK("number" == toString(err->recommendedReturn));
REQUIRE(1 == err->recommendedArgs.size());
CHECK("number" == toString(err->recommendedArgs[0].second));
CHECK("false | number" == toString(err->recommendedReturn));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type")
@ -2862,6 +2880,8 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun")
TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types")
{
ScopedFastFlag _{FFlag::LuauDontRefCountTypesInTypeFunctions, true};
CheckResult result = check(R"(
function foo(player)
local success,result = player:thing()
@ -2889,7 +2909,7 @@ TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types")
auto tm2 = get<TypePackMismatch>(result.errors[1]);
REQUIRE(tm2);
CHECK(toString(tm2->wantedTp) == "string");
CHECK(toString(tm2->givenTp) == "buffer | class | function | number | string | table | thread | true");
CHECK(toString(tm2->givenTp) == "(buffer | class | function | number | string | table | thread | true) & unknown");
}
else
{

View File

@ -24,6 +24,7 @@ LUAU_FASTINT(LuauNormalizeCacheLimit);
LUAU_FASTINT(LuauRecursionLimit);
LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues)
LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions)
using namespace Luau;
@ -1730,4 +1731,36 @@ TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue")
)"));
}
TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function")
{
ScopedFastFlag sffs[] = {
{FFlag::LuauSolverV2, true},
{FFlag::LuauDontRefCountTypesInTypeFunctions, true}
};
LUAU_CHECK_NO_ERRORS(check(R"(
--!strict
local function foo(a : string?)
local b = a or ""
return b:upper()
end
)"));
}
TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type")
{
ScopedFastFlag sffs[] = {
{FFlag::LuauSolverV2, true},
{FFlag::LuauDontRefCountTypesInTypeFunctions, true}
};
LUAU_CHECK_NO_ERRORS(check(R"(
--!strict
local function wtf(name: string?)
local message
message = "invalid alternate fiber: " .. (name or "UNNAMED alternate")
end
)"));
}
TEST_SUITE_END();