diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index def7ab27..3ab54954 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -112,6 +112,7 @@ struct FunctionCheckConstraint TypePackId argsPack; class AstExprCall* callSite = nullptr; + NotNull> astTypes; NotNull> astExpectedTypes; }; diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 7d762fc0..4eb6bcc7 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -277,8 +277,12 @@ private: * * To determine which scope is appropriate, we also accept rootTy, which is * to be the type that contains blockedTy. + * + * A constraint is required and will validate that blockedTy is owned by this + * constraint. This prevents one constraint from interfering with another's + * blocked types. */ - void bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, Location location); + void bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, NotNull constraint); /** * Marks a constraint as being blocked on a type or type pack. The constraint diff --git a/Analysis/include/Luau/Instantiation2.h b/Analysis/include/Luau/Instantiation2.h index 9dfbb613..1ddaaf2b 100644 --- a/Analysis/include/Luau/Instantiation2.h +++ b/Analysis/include/Luau/Instantiation2.h @@ -61,6 +61,7 @@ struct Instantiation2 : Substitution { } + bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; TypeId clean(TypeId ty) override; diff --git a/Analysis/include/Luau/Set.h b/Analysis/include/Luau/Set.h index 033bf840..2fea2e6a 100644 --- a/Analysis/include/Luau/Set.h +++ b/Analysis/include/Luau/Set.h @@ -53,6 +53,17 @@ public: insert(*it); } + void erase(T&& element) + { + bool& entry = mapping[element]; + + if (entry) + { + entry = false; + entryCount--; + } + } + void erase(const T& element) { bool& entry = mapping[element]; diff --git a/Analysis/include/Luau/TableLiteralInference.h b/Analysis/include/Luau/TableLiteralInference.h new file mode 100644 index 00000000..1a6d51ea --- /dev/null +++ b/Analysis/include/Luau/TableLiteralInference.h @@ -0,0 +1,28 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/NotNull.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ + +struct TypeArena; +struct BuiltinTypes; +struct Unifier2; +class AstExpr; + +TypeId matchLiteralType( + NotNull> astTypes, + NotNull> astExpectedTypes, + NotNull builtinTypes, + NotNull arena, + NotNull unifier, + TypeId expectedType, + TypeId exprType, + const AstExpr* expr +); + +} diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 25b71e03..8e82ee8f 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -146,6 +146,10 @@ struct BlockedType BlockedType(); int index; + Constraint* getOwner() const; + void setOwner(Constraint* newOwner); + +private: // The constraint that is intended to unblock this type. Other constraints // should block on this constraint if present. Constraint* owner = nullptr; @@ -419,6 +423,9 @@ struct Property TypeId type() const; void setType(TypeId ty); + // Sets the write type of this property to the read type. + void makeShared(); + bool isShared() const; bool isReadOnly() const; bool isWriteOnly() const; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 11721f40..f3df9ae7 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -82,6 +82,8 @@ struct BlockedTypePack BlockedTypePack(); size_t index; + struct Constraint* owner = nullptr; + static size_t nextIndex; }; diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index f894d88d..308f6c78 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -8,16 +8,18 @@ #include "Luau/ControlFlow.h" #include "Luau/DcrLogger.h" #include "Luau/DenseHash.h" +#include "Luau/InsertionOrderedMap.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" #include "Luau/Refinement.h" #include "Luau/Scope.h" -#include "Luau/TypeUtils.h" +#include "Luau/Simplify.h" +#include "Luau/TableLiteralInference.h" #include "Luau/Type.h" #include "Luau/TypeFamily.h" -#include "Luau/Simplify.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier2.h" #include "Luau/VisitType.h" -#include "Luau/InsertionOrderedMap.h" #include #include @@ -230,7 +232,9 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) Checkpoint end = checkpoint(this); - NotNull genConstraint = addConstraint(scope, block->location, GeneralizationConstraint{arena->addType(BlockedType{}), moduleFnTy, std::move(interiorTypes.back())}); + TypeId result = arena->addType(BlockedType{}); + NotNull genConstraint = addConstraint(scope, block->location, GeneralizationConstraint{result, moduleFnTy, std::move(interiorTypes.back())}); + getMutable(result)->setOwner(genConstraint); forEachConstraint(start, end, this, [genConstraint](const ConstraintPtr& c) { genConstraint->dependencies.push_back(NotNull{c.get()}); }); @@ -422,7 +426,11 @@ void ConstraintGenerator::computeRefinement(const ScopePtr& scope, Location loca TypeId nextDiscriminantTy = arena->addType(TableType{}); NotNull table{getMutable(nextDiscriminantTy)}; - table->props[*key->propName] = Property::readonly(discriminantTy); + // When we fully support read-write properties (i.e. when we allow properties with + // completely disparate read and write types), then the following property can be + // set to read-only since refinements only tell us about what we read. This cannot + // be allowed yet though because it causes read and write types to diverge. + table->props[*key->propName] = Property::rw(discriminantTy); table->scope = scope.get(); table->state = TableState::Sealed; @@ -935,7 +943,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocalFuncti } }); - addConstraint(scope, std::move(c)); + getMutable(functionType)->setOwner(addConstraint(scope, std::move(c))); module->astTypes[function->func] = functionType; } else @@ -1045,7 +1053,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f } }); - addConstraint(scope, std::move(c)); + if (auto blocked = getMutable(generalizedType)) + blocked->setOwner(addConstraint(scope, std::move(c))); } return ControlFlow::None; @@ -1155,7 +1164,13 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* ass } TypePackId resultPack = checkPack(scope, assign->values).tp; - addConstraint(scope, assign->location, UnpackConstraint{arena->addTypePack(std::move(assignees)), resultPack, /*resultIsLValue*/ true}); + auto c = addConstraint(scope, assign->location, UnpackConstraint{arena->addTypePack(assignees), resultPack, /*resultIsLValue*/ true}); + for (TypeId assignee : assignees) + { + auto blocked = getMutable(assignee); + LUAU_ASSERT(blocked); + blocked->setOwner(c); + } return ControlFlow::None; } @@ -1665,7 +1680,10 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* unpackedTypes.emplace_back(mt); TypePackId mtPack = arena->addTypePack(std::move(unpackedTypes)); - addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail}); + auto c = addConstraint(scope, call->location, UnpackConstraint{mtPack, *argTail}); + getMutable(mt)->setOwner(c); + if (auto b = getMutable(target); b && b->getOwner() == nullptr) + b->setOwner(c); } LUAU_ASSERT(target); @@ -1723,7 +1741,7 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* */ NotNull checkConstraint = - addConstraint(scope, call->func->location, FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astExpectedTypes}}); + addConstraint(scope, call->func->location, FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}}); forEachConstraint(funcBeginCheckpoint, funcEndCheckpoint, this, [checkConstraint](const ConstraintPtr& constraint) { checkConstraint->dependencies.emplace_back(constraint.get()); @@ -1739,6 +1757,8 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* &module->astOverloadResolvedTypes, }); + getMutable(rets)->owner = callConstraint.get(); + callConstraint->dependencies.push_back(checkConstraint); forEachConstraint(argBeginCheckpoint, argEndCheckpoint, this, [checkConstraint, callConstraint](const ConstraintPtr& constraint) { @@ -1913,7 +1933,8 @@ Inference ConstraintGenerator::checkIndexName(const ScopePtr& scope, const Refin scope->rvalueRefinements[key->def] = result; } - addConstraint(scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); + auto c = addConstraint(scope, indexee->location, HasPropConstraint{result, obj, std::move(index), ValueContext::RValue, inConditional(typeContext)}); + getMutable(result)->setOwner(c); if (key) return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; @@ -1949,7 +1970,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIndexExpr* in scope->rvalueRefinements[key->def] = result; } - addConstraint(scope, indexExpr->expr->location, HasIndexerConstraint{result, obj, indexType}); + auto c = addConstraint(scope, indexExpr->expr->location, HasIndexerConstraint{result, obj, indexType}); + getMutable(result)->setOwner(c); if (key) return Inference{result, refinementArena.proposition(key, builtinTypes->truthyType)}; @@ -1968,6 +1990,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* fun TypeId generalizedTy = arena->addType(BlockedType{}); NotNull gc = addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature, std::move(interiorTypes.back())}); + getMutable(generalizedTy)->setOwner(gc); interiorTypes.pop_back(); Constraint* previous = nullptr; @@ -2411,7 +2434,7 @@ std::optional ConstraintGenerator::checkLValue(const ScopePtr& scope, As { Constraint* owner = nullptr; if (auto blocked = get(*ty)) - owner = blocked->owner; + owner = blocked->getOwner(); auto unpackC = addConstraint(scope, local->location, UnpackConstraint{arena->addTypePack({*ty}), arena->addTypePack({assignedTy}), @@ -2419,6 +2442,8 @@ std::optional ConstraintGenerator::checkLValue(const ScopePtr& scope, As if (owner) unpackC->dependencies.push_back(NotNull{owner}); + else if (auto blocked = getMutable(*ty)) + blocked->setOwner(unpackC); recordInferredBinding(local->local, *ty); } @@ -2477,7 +2502,8 @@ TypeId ConstraintGenerator::updateProperty(const ScopePtr& scope, AstExpr* expr, TypeId resultType = arena->addType(BlockedType{}); TypeId subjectType = check(scope, indexExpr->expr).ty; TypeId indexType = check(scope, indexExpr->index).ty; - addConstraint(scope, expr->location, SetIndexerConstraint{resultType, subjectType, indexType, assignedTy}); + auto c = addConstraint(scope, expr->location, SetIndexerConstraint{resultType, subjectType, indexType, assignedTy}); + getMutable(resultType)->setOwner(c); module->astTypes[expr] = assignedTy; @@ -2548,7 +2574,7 @@ TypeId ConstraintGenerator::updateProperty(const ScopePtr& scope, AstExpr* expr, TypeId updatedType = arena->addType(BlockedType{}); auto setC = addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), assignedTy}); - getMutable(updatedType)->owner = setC.get(); + getMutable(updatedType)->setOwner(setC); TypeId prevSegmentTy = updatedType; for (size_t i = 0; i < segments.size(); ++i) @@ -2557,6 +2583,7 @@ TypeId ConstraintGenerator::updateProperty(const ScopePtr& scope, AstExpr* expr, module->astTypes[exprs[i]] = segmentTy; ValueContext ctx = i == segments.size() - 1 ? ValueContext::LValue : ValueContext::RValue; auto hasC = addConstraint(scope, expr->location, HasPropConstraint{segmentTy, prevSegmentTy, segments[i], ctx, inConditional(typeContext)}); + getMutable(segmentTy)->setOwner(hasC); setC->dependencies.push_back(hasC); prevSegmentTy = segmentTy; } @@ -2582,8 +2609,6 @@ TypeId ConstraintGenerator::updateProperty(const ScopePtr& scope, AstExpr* expr, Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) { - const bool expectedTypeIsFree = expectedType && get(follow(*expectedType)); - TypeId ty = arena->addType(TableType{}); TableType* ttv = getMutable(ty); LUAU_ASSERT(ttv); @@ -2601,98 +2626,21 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, indexValueLowerBound.insert(follow(currentResultType)); }; - std::optional annotatedKeyType; - std::optional annotatedIndexResultType; - - if (expectedType) - { - if (const TableType* ttv = get(follow(*expectedType))) - { - if (ttv->indexer) - { - annotatedKeyType.emplace(follow(ttv->indexer->indexType)); - annotatedIndexResultType.emplace(ttv->indexer->indexResultType); - } - } - } - - bool isIndexedResultType = false; - std::optional pinnedIndexResultType; - + TypeIds valuesLowerBound; for (const AstExprTable::Item& item : expr->items) { - std::optional expectedValueType; - if (item.kind == AstExprTable::Item::Kind::General || item.kind == AstExprTable::Item::Kind::List) - isIndexedResultType = true; + // Expected types are threaded through table literals separately via the + // function matchLiteralType. - if (item.key && expectedType && !expectedTypeIsFree) - { - if (auto stringKey = item.key->as()) - { - ErrorVec errorVec; - std::optional propTy = - findTablePropertyRespectingMeta(builtinTypes, errorVec, follow(*expectedType), stringKey->value.data, item.value->location); - if (propTy) - expectedValueType = propTy; - else - { - expectedValueType = arena->addType(BlockedType{}); - addConstraint(scope, item.value->location, - HasPropConstraint{ - *expectedValueType, *expectedType, stringKey->value.data, ValueContext::RValue, /*inConditional*/ inConditional(typeContext), /*suppressSimplification*/ true}); - } - } - } - - // We'll resolve the expected index result type here with the following priority: - // 1. Record table types - in which key, value pairs must be handled on a k,v pair basis. - // In this case, the above if-statement will populate expectedValueType - // 2. Someone places an annotation on a General or List table - // Trust the annotation and have the solver inform them if they get it wrong - // 3. Someone omits the annotation on a general or List table - // Use the type of the first indexResultType as the expected type - std::optional checkExpectedIndexResultType; - if (expectedValueType) - { - checkExpectedIndexResultType = expectedValueType; - } - else if (annotatedIndexResultType) - { - checkExpectedIndexResultType = annotatedIndexResultType; - } - else if (pinnedIndexResultType) - { - checkExpectedIndexResultType = pinnedIndexResultType; - } - - TypeId itemTy = check(scope, item.value, checkExpectedIndexResultType).ty; - - // we should preserve error-suppressingness from the expected value type if we have one - if (expectedValueType) - { - switch (shouldSuppressErrors(normalizer, *expectedValueType)) - { - case ErrorSuppression::DoNotSuppress: - break; - case ErrorSuppression::Suppress: - itemTy = simplifyUnion(builtinTypes, arena, itemTy, builtinTypes->errorType).result; - break; - case ErrorSuppression::NormalizationFailed: - reportError(item.value->location, NormalizationTooComplex{}); - break; - } - } - - if (isIndexedResultType && !pinnedIndexResultType) - pinnedIndexResultType = itemTy; + TypeId itemTy = check(scope, item.value).ty; if (item.key) { // Even though we don't need to use the type of the item's key if // it's a string constant, we still want to check it to populate // astTypes. - TypeId keyTy = check(scope, item.key, annotatedKeyType).ty; + TypeId keyTy = check(scope, item.key).ty; if (AstExprConstantString* key = item.key->as()) { @@ -2729,6 +2677,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, ttv->indexer = TableIndexer{indexKey, indexValue}; } + if (expectedType) + { + Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice}; + matchLiteralType(NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr); + } + return Inference{ty}; } @@ -2786,13 +2740,15 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu // This check ensures that expectedType is precisely optional and not any (since any is also an optional type) if (expectedType && isOptional(*expectedType) && !get(*expectedType)) { - auto ut = get(*expectedType); - for (auto u : ut) + if (auto ut = get(*expectedType)) { - if (get(u) && !isNil(u)) + for (auto u : ut) { - expectedFunction = get(u); - break; + if (get(u) && !isNil(u)) + { + expectedFunction = get(u); + break; + } } } } @@ -3307,7 +3263,8 @@ Inference ConstraintGenerator::flattenPack(const ScopePtr& scope, Location locat TypeId typeResult = arena->addType(BlockedType{}); TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); - addConstraint(scope, location, UnpackConstraint{resultPack, tp}); + auto c = addConstraint(scope, location, UnpackConstraint{resultPack, tp}); + getMutable(typeResult)->setOwner(c); return Inference{typeResult, refinement}; } @@ -3426,7 +3383,8 @@ void ConstraintGenerator::fillInInferredBindings(const ScopePtr& globalScope, As else { TypeId ty = arena->addType(BlockedType{}); - addConstraint(globalScope, Location{}, SetOpConstraint{SetOpConstraint::Union, ty, std::move(tys)}); + auto c = addConstraint(globalScope, Location{}, SetOpConstraint{SetOpConstraint::Union, ty, std::move(tys)}); + getMutable(ty)->setOwner(c); scope->bindings[symbol] = Binding{ty, location}; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index db7e84c5..702b299d 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -13,6 +13,7 @@ #include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Simplify.h" +#include "Luau/TableLiteralInference.h" #include "Luau/TimeTrace.h" #include "Luau/ToString.h" #include "Luau/Type.h" @@ -60,6 +61,24 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const dumpBindings(child, opts); } +// used only in asserts +[[maybe_unused]] static bool canMutate(TypeId ty, NotNull constraint) +{ + if (auto blocked = get(ty)) + return blocked->getOwner() == constraint; + + return true; +} + +// used only in asserts +[[maybe_unused]] static bool canMutate(TypePackId tp, NotNull constraint) +{ + if (auto blocked = get(tp)) + return blocked->owner == nullptr || blocked->owner == constraint; + + return true; +} + static std::pair, std::vector> saturateArguments(TypeArena* arena, NotNull builtinTypes, const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments) { @@ -327,7 +346,7 @@ void ConstraintSolver::run() if (FFlag::DebugLuauLogSolver) { - printf("Starting solver\n"); + printf("Starting solver for module %s (%s)\n", moduleResolver->getHumanReadableModuleName(currentModuleName).c_str(), currentModuleName.c_str()); dump(this, opts); printf("Bindings:\n"); dumpBindings(rootScope, opts); @@ -587,7 +606,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull(generalizedType)) - asMutable(generalizedType)->ty.emplace(generalized->result); + bindBlockedType(generalizedType, generalized->result, c.sourceType, constraint); else unify(constraint, generalizedType, generalized->result); @@ -624,18 +643,19 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull instantiated = instantiate(builtinTypes, NotNull{arena}, NotNull{&limits}, constraint->scope, c.superType); LUAU_ASSERT(get(c.subType)); + LUAU_ASSERT(canMutate(c.subType, constraint)); if (!instantiated.has_value()) { reportError(UnificationTooComplex{}, constraint->location); - asMutable(c.subType)->ty.emplace(errorRecoveryType()); + bindBlockedType(c.subType, errorRecoveryType(), c.superType, constraint); unblock(c.subType, constraint->location); return true; } - asMutable(c.subType)->ty.emplace(*instantiated); + bindBlockedType(c.subType, *instantiated, c.superType, constraint); InstantiationQueuer queuer{constraint->scope, constraint->location, this}; queuer.traverse(c.subType); @@ -1096,18 +1116,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull subst = instantiation.substitute(result); - if (!subst) { reportError(CodeTooComplex{}, constraint->location); result = builtinTypes->errorTypePack; } else - { result = *subst; - } if (c.result != result) asMutable(c.result)->ty.emplace(result); @@ -1185,7 +1201,6 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull expectedArgs = flatten(ftv->argTypes).first; const std::vector argPackHead = flatten(argsPack).first; @@ -1219,11 +1234,16 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullis() || expr->is() || expr->is() || - expr->is() || expr->is()) + expr->is()) { Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; u2.unify(actualArgTy, expectedArgTy); } + else if (expr->is()) + { + Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; + (void) matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr); + } } return true; @@ -1267,6 +1287,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(resultType)); + LUAU_ASSERT(canMutate(resultType, constraint)); if (isBlocked(subjectType) || get(subjectType) || get(subjectType)) return block(subjectType, constraint); @@ -1280,7 +1301,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullanyType), c.subjectType, constraint->location); + bindBlockedType(resultType, result.value_or(builtinTypes->anyType), c.subjectType, constraint); unblock(resultType, constraint->location); return true; } @@ -1394,7 +1415,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNulllocation); + bindBlockedType(a, b, subjectType, constraint); }; if (existingPropType) @@ -1412,15 +1433,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNulltable); if (get(subjectType)) - { - /* - * This should never occur because lookupTableProp() will add bounds to - * any free types it encounters. There will always be an - * existingPropType if the subject is free. - */ - LUAU_ASSERT(false); return false; - } else if (auto ttv = getMutable(subjectType)) { if (ttv->state == TableState::Free) @@ -1453,6 +1466,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNull(resultType)); + LUAU_ASSERT(canMutate(resultType, constraint)); if (auto ft = get(subjectType)) { @@ -1476,8 +1490,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNullindexType); - LUAU_ASSERT(get(resultType)); - bindBlockedType(resultType, indexer->indexResultType, subjectType, constraint->location); + bindBlockedType(resultType, indexer->indexResultType, subjectType, constraint); return true; } else if (tt->state == TableState::Unsealed) @@ -1498,12 +1511,12 @@ bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNullindexer) { unify(constraint, indexType, indexer->indexType); - asMutable(resultType)->ty.emplace(indexer->indexResultType); + bindBlockedType(resultType, indexer->indexResultType, subjectType, constraint); return true; } else if (isString(indexType)) { - asMutable(resultType)->ty.emplace(builtinTypes->unknownType); + bindBlockedType(resultType, builtinTypes->unknownType, subjectType, constraint); return true; } } @@ -1525,6 +1538,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNulladdType(BlockedType{}); + getMutable(r)->setOwner(const_cast(constraint.get())); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r); // FIXME: It's too late to stop and block now I think? We should @@ -1537,9 +1551,9 @@ bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNullty.emplace(builtinTypes->errorType); + bindBlockedType(resultType, builtinTypes->errorType, subjectType, constraint); else if (1 == results.size()) - asMutable(resultType)->ty.emplace(*results.begin()); + bindBlockedType(resultType, *results.begin(), subjectType, constraint); else asMutable(resultType)->ty.emplace(std::vector(results.begin(), results.end())); @@ -1556,6 +1570,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNulladdType(BlockedType{}); + getMutable(r)->setOwner(const_cast(constraint.get())); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r); // We should have found all the blocked types ahead of time (see BlockedTypeFinder below) @@ -1576,7 +1591,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNullerrorType, subjectType, constraint->location); + bindBlockedType(resultType, builtinTypes->errorType, subjectType, constraint); return true; } @@ -1714,6 +1729,7 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullty.emplace(sourcePack); unblock(resultPack, constraint->location); @@ -1727,6 +1743,8 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull(resultTy); c.resultIsLValue && lt) { lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, srcTy).result; @@ -1750,7 +1768,7 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullty.emplace(f); } else - asMutable(resultTy)->ty.emplace(srcTy); + bindBlockedType(resultTy, srcTy, srcTy, constraint); } else { @@ -1796,6 +1814,7 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull(resultTy); c.resultIsLValue && lt) { lt->domain = simplifyUnion(builtinTypes, arena, lt->domain, builtinTypes->nilType).result; @@ -1838,6 +1857,8 @@ bool ConstraintSolver::tryDispatch(const SetOpConstraint& c, NotNullty.emplace(res); return true; @@ -1849,7 +1870,6 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNulllocation, TypeFamilyContext{NotNull{this}, constraint->scope, constraint}, force); - for (TypeId r : result.reducedTypes) unblock(r, constraint->location); @@ -2366,11 +2386,11 @@ bool ConstraintSolver::unify(NotNull constraint, TID subTy, TI return unify(constraint->scope, constraint->location, subTy, superTy); } -void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, Location location) +void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, NotNull constraint) { resultTy = follow(resultTy); - LUAU_ASSERT(get(blockedTy)); + LUAU_ASSERT(get(blockedTy) && canMutate(blockedTy, constraint)); if (blockedTy == resultTy) { @@ -2381,7 +2401,7 @@ void ConstraintSolver::bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId else if (auto tt = get(rootTy); tt && tt->state == TableState::Free) freeScope = tt->scope; else - iceReporter.ice("bindBlockedType couldn't find an appropriate scope for a fresh type!", location); + iceReporter.ice("bindBlockedType couldn't find an appropriate scope for a fresh type!", constraint->location); LUAU_ASSERT(freeScope); @@ -2418,7 +2438,7 @@ void ConstraintSolver::block(NotNull target, NotNullpushBlock(constraint, target); if (FFlag::DebugLuauLogSolver) - printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str()); + printf("%s depends on constraint %s\n", toString(*constraint, opts).c_str(), toString(*target, opts).c_str()); } } @@ -2431,7 +2451,7 @@ bool ConstraintSolver::block(TypeId target, NotNull constraint logger->pushBlock(constraint, target); if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + printf("%s depends on TypeId %s\n", toString(*constraint, opts).c_str(), toString(target, opts).c_str()); } return false; @@ -2446,7 +2466,7 @@ bool ConstraintSolver::block(TypePackId target, NotNull constr logger->pushBlock(constraint, target); if (FFlag::DebugLuauLogSolver) - printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + printf("%s depends on TypePackId %s\n", toString(*constraint, opts).c_str(), toString(target, opts).c_str()); } return false; diff --git a/Analysis/src/Instantiation2.cpp b/Analysis/src/Instantiation2.cpp index 31f27f8e..a9284ff9 100644 --- a/Analysis/src/Instantiation2.cpp +++ b/Analysis/src/Instantiation2.cpp @@ -4,6 +4,13 @@ namespace Luau { +bool Instantiation2::ignoreChildren(TypeId ty) +{ + if (get(ty)) + return true; + return false; +} + bool Instantiation2::isDirty(TypeId ty) { return get(ty) && genericSubstitutions.contains(ty); diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index d285bf26..18a0d6f7 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -148,7 +148,7 @@ struct NonStrictTypeChecker NotNull builtinTypes; const NotNull ice; - TypeArena arena; + NotNull arena; Module* module; Normalizer normalizer; Subtyping subtyping; @@ -159,13 +159,14 @@ struct NonStrictTypeChecker const NotNull limits; - NonStrictTypeChecker(NotNull builtinTypes, const NotNull ice, NotNull unifierState, - NotNull dfg, NotNull limits, Module* module) + NonStrictTypeChecker(NotNull arena, NotNull builtinTypes, const NotNull ice, + NotNull unifierState, NotNull dfg, NotNull limits, Module* module) : builtinTypes(builtinTypes) , ice(ice) + , arena(arena) , module(module) - , normalizer{&arena, builtinTypes, unifierState, /* cache inhabitance */ true} - , subtyping{builtinTypes, NotNull{&arena}, NotNull(&normalizer), ice, NotNull{module->getModuleScope().get()}} + , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} + , subtyping{builtinTypes, arena, NotNull(&normalizer), ice, NotNull{module->getModuleScope().get()}} , dfg(dfg) , limits(limits) { @@ -187,8 +188,8 @@ struct NonStrictTypeChecker return *fst; else if (auto ftp = get(pack)) { - TypeId result = arena.addType(FreeType{ftp->scope}); - TypePackId freeTail = arena.addTypePack(FreeTypePack{ftp->scope}); + TypeId result = arena->addType(FreeType{ftp->scope}); + TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope}); TypePack& resultPack = asMutable(pack)->ty.emplace(); resultPack.head.assign(1, result); @@ -210,9 +211,8 @@ struct NonStrictTypeChecker if (noTypeFamilyErrors.find(instance)) return instance; - ErrorVec errors = reduceFamilies( - instance, location, TypeFamilyContext{NotNull{&arena}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true) - .errors; + ErrorVec errors = + reduceFamilies(instance, location, TypeFamilyContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true).errors; if (errors.empty()) noTypeFamilyErrors.insert(instance); @@ -303,7 +303,7 @@ struct NonStrictTypeChecker ctx.remove(dfg->getDef(local)); } else - ctx = NonStrictContext::disjunction(builtinTypes, NotNull{&arena}, visit(stat), ctx); + ctx = NonStrictContext::disjunction(builtinTypes, arena, visit(stat), ctx); } return ctx; } @@ -317,9 +317,9 @@ struct NonStrictTypeChecker { NonStrictContext thenBody = visit(ifStatement->thenbody); NonStrictContext elseBody = visit(ifStatement->elsebody); - branchContext = NonStrictContext::conjunction(builtinTypes, NotNull{&arena}, thenBody, elseBody); + branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody); } - return NonStrictContext::disjunction(builtinTypes, NotNull{&arena}, condB, branchContext); + return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext); } NonStrictContext visit(AstStatWhile* whileStatement) @@ -641,8 +641,7 @@ struct NonStrictTypeChecker NonStrictContext condB = visit(ifElse->condition); NonStrictContext thenB = visit(ifElse->trueExpr); NonStrictContext elseB = visit(ifElse->falseExpr); - return NonStrictContext::disjunction( - builtinTypes, NotNull{&arena}, condB, NonStrictContext::conjunction(builtinTypes, NotNull{&arena}, thenB, elseB)); + return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB)); } NonStrictContext visit(AstExprInterpString* interpString) @@ -710,7 +709,7 @@ private: { TypeId& cachedResult = cachedNegations[baseType]; if (!cachedResult) - cachedResult = arena.addType(NegationType{baseType}); + cachedResult = arena->addType(NegationType{baseType}); return cachedResult; }; }; @@ -718,8 +717,7 @@ private: void checkNonStrict(NotNull builtinTypes, NotNull ice, NotNull unifierState, NotNull dfg, NotNull limits, const SourceModule& sourceModule, Module* module) { - // TODO: unimplemented - NonStrictTypeChecker typeChecker{builtinTypes, ice, unifierState, dfg, limits, module}; + NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, ice, unifierState, dfg, limits, module}; typeChecker.visit(sourceModule.root); unfreeze(module->interfaceTypes); copyErrors(module->errors, module->interfaceTypes, builtinTypes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 015507cb..096ade6f 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -2968,6 +2968,12 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, Set(t)) + { + // if we're intersecting with `~never`, this is equivalent to intersecting with `unknown` + // this is a noop since an intersection with `unknown` is trivial. + return true; + } else if (auto nt = get(t)) return intersectNormalWithTy(here, nt->ty, seenSetTypes); else @@ -2991,6 +2997,20 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, Set(ty)) + { + for (auto& [_, prop] : tableTy->props) + prop.makeShared(); + } + else if (auto metatableTy = get(ty)) + { + makeTableShared(metatableTy->metatable); + makeTableShared(metatableTy->table); + } +} + // -------- Convert back from a normalized type to a type TypeId Normalizer::typeFromNormal(const NormalizedType& norm) { @@ -3085,7 +3105,18 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) if (!get(norm.buffers)) result.push_back(builtinTypes->bufferType); - result.insert(result.end(), norm.tables.begin(), norm.tables.end()); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + result.reserve(result.size() + norm.tables.size()); + for (auto table : norm.tables) + { + makeTableShared(table); + result.push_back(table); + } + } + else + result.insert(result.end(), norm.tables.begin(), norm.tables.end()); + for (auto& [tyvar, intersect] : norm.tyvars) { if (get(intersect->tops)) diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index a21f5092..8b15e919 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -377,8 +377,12 @@ Relation relate(TypeId left, TypeId right, SimplifierSeenSet& seen) { std::vector opts; for (TypeId part : ut) - if (relate(left, part, seen) == Relation::Subset) + { + Relation r = relate(left, part, seen); + + if (r == Relation::Subset || r == Relation::Coincident) return Relation::Subset; + } return Relation::Intersects; } diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp new file mode 100644 index 00000000..b23f614a --- /dev/null +++ b/Analysis/src/TableLiteralInference.cpp @@ -0,0 +1,416 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Ast.h" +#include "Luau/Normalize.h" +#include "Luau/Simplify.h" +#include "Luau/Type.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" +#include "Luau/Unifier2.h" + +namespace Luau +{ + +static bool isLiteral(const AstExpr* expr) +{ + return ( + expr->is() || + expr->is() || + expr->is() || + expr->is() || + expr->is() || + expr->is() + ); +} + +// A fast approximation of subTy <: superTy +static bool fastIsSubtype(TypeId subTy, TypeId superTy) +{ + Relation r = relate(superTy, subTy); + return r == Relation::Coincident || r == Relation::Superset; +} + +static bool isRecord(const AstExprTable::Item& item) +{ + if (item.kind == AstExprTable::Item::Record) + return true; + else if (item.kind == AstExprTable::Item::General && item.key->is()) + return true; + else + return false; +} + +static std::optional extractMatchingTableType(std::vector& tables, TypeId exprType, NotNull builtinTypes) +{ + if (tables.empty()) + return std::nullopt; + + const TableType* exprTable = get(follow(exprType)); + if (!exprTable) + return std::nullopt; + + size_t tableCount = 0; + std::optional firstTable; + + for (TypeId ty: tables) + { + ty = follow(ty); + if (auto tt = get(ty)) + { + // If the expected table has a key whose type is a string or boolean + // singleton and the corresponding exprType property does not match, + // then skip this table. + + if (!firstTable) + firstTable = ty; + ++tableCount; + + for (const auto& [name, expectedProp]: tt->props) + { + if (!expectedProp.readTy) + continue; + + const TypeId expectedType = follow(*expectedProp.readTy); + + auto st = get(expectedType); + if (!st) + continue; + + auto it = exprTable->props.find(name); + if (it == exprTable->props.end()) + continue; + + const auto& [_name, exprProp] = *it; + + if (!exprProp.readTy) + continue; + + const TypeId propType = follow(*exprProp.readTy); + + const FreeType* ft = get(propType); + + if (ft && get(ft->lowerBound)) + { + if (fastIsSubtype(builtinTypes->booleanType, ft->upperBound) && + fastIsSubtype(expectedType, builtinTypes->booleanType)) + { + return ty; + } + + if (fastIsSubtype(builtinTypes->stringType, ft->upperBound) && + fastIsSubtype(expectedType, ft->lowerBound)) + { + return ty; + } + } + } + } + } + + if (tableCount == 1) + { + LUAU_ASSERT(firstTable); + return firstTable; + } + + return std::nullopt; +} + +TypeId matchLiteralType( + NotNull> astTypes, + NotNull> astExpectedTypes, + NotNull builtinTypes, + NotNull arena, + NotNull unifier, + TypeId expectedType, + TypeId exprType, + const AstExpr* expr +) +{ + /* + * Table types that arise from literal table expressions have some + * properties that make this algorithm much simpler. + * + * Most importantly, the parts of the type that arise directly from the + * table expression are guaranteed to be acyclic. This means we can do all + * kinds of naive depth first traversal shenanigans and not worry about + * nasty details like aliasing or reentrancy. + * + * We are therefore completely free to mutate these portions of the + * TableType however we choose! We'll take advantage of this property to do + * things like replace explicit named properties with indexers as required + * by the expected type. + */ + if (!isLiteral(expr)) + return exprType; + + expectedType = follow(expectedType); + exprType = follow(exprType); + + if (get(expectedType) || get(expectedType)) + { + // "Narrowing" to unknown or any is not going to do anything useful. + return exprType; + } + + if (expr->is()) + { + auto ft = get(exprType); + if (ft && + get(ft->lowerBound) && + fastIsSubtype(builtinTypes->stringType, ft->upperBound) && + fastIsSubtype(ft->lowerBound, builtinTypes->stringType) + ) + { + // if the upper bound is a subtype of the expected type, we can push the expected type in + Relation upperBoundRelation = relate(ft->upperBound, expectedType); + if (upperBoundRelation == Relation::Subset || upperBoundRelation == Relation::Coincident) + { + asMutable(exprType)->ty.emplace(expectedType); + return exprType; + } + + // likewise, if the lower bound is a subtype, we can force the expected type in + // if this is the case and the previous relation failed, it means that the primitive type + // constraint was going to have to select the lower bound for this type anyway. + Relation lowerBoundRelation = relate(ft->lowerBound, expectedType); + if (lowerBoundRelation == Relation::Subset || lowerBoundRelation == Relation::Coincident) + { + asMutable(exprType)->ty.emplace(expectedType); + return exprType; + } + } + } + else if (expr->is()) + { + auto ft = get(exprType); + if (ft && + get(ft->lowerBound) && + fastIsSubtype(builtinTypes->booleanType, ft->upperBound) && + fastIsSubtype(ft->lowerBound, builtinTypes->booleanType) + ) + { + // if the upper bound is a subtype of the expected type, we can push the expected type in + Relation upperBoundRelation = relate(ft->upperBound, expectedType); + if (upperBoundRelation == Relation::Subset || upperBoundRelation == Relation::Coincident) + { + asMutable(exprType)->ty.emplace(expectedType); + return exprType; + } + + // likewise, if the lower bound is a subtype, we can force the expected type in + // if this is the case and the previous relation failed, it means that the primitive type + // constraint was going to have to select the lower bound for this type anyway. + Relation lowerBoundRelation = relate(ft->lowerBound, expectedType); + if (lowerBoundRelation == Relation::Subset || lowerBoundRelation == Relation::Coincident) + { + asMutable(exprType)->ty.emplace(expectedType); + return exprType; + } + } + } + + if (expr->is() || expr->is() || expr->is() || expr->is()) + { + if (auto ft = get(exprType); ft && fastIsSubtype(ft->upperBound, expectedType)) + { + asMutable(exprType)->ty.emplace(expectedType); + return exprType; + } + + Relation r = relate(exprType, expectedType); + if (r == Relation::Coincident || r == Relation::Subset) + return expectedType; + + return exprType; + } + + // TODO: lambdas + + if (auto exprTable = expr->as()) + { + TableType* tableTy = getMutable(exprType); + LUAU_ASSERT(tableTy); + + const TableType* expectedTableTy = get(expectedType); + + if (!expectedTableTy) + { + if (auto utv = get(expectedType)) + { + std::vector parts{begin(utv), end(utv)}; + + std::optional tt = extractMatchingTableType(parts, exprType, builtinTypes); + + if (tt) + { + TypeId res = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *tt, exprType, expr); + + parts.push_back(res); + return arena->addType(UnionType{std::move(parts)}); + } + } + + return exprType; + } + + for (const AstExprTable::Item& item: exprTable->items) + { + if (isRecord(item)) + { + const AstArray& s = item.key->as()->value; + std::string keyStr{s.data, s.data + s.size}; + auto it = tableTy->props.find(keyStr); + LUAU_ASSERT(it != tableTy->props.end()); + + Property& prop = it->second; + + // Table literals always initially result in shared read-write types + LUAU_ASSERT(prop.isShared()); + TypeId propTy = *prop.readTy; + + auto it2 = expectedTableTy->props.find(keyStr); + + if (it2 == expectedTableTy->props.end()) + { + // expectedType may instead have an indexer. This is + // kind of interesting because it means we clip the prop + // from the exprType and fold it into the indexer. + if (expectedTableTy->indexer && isString(expectedTableTy->indexer->indexType)) + { + (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; + (*astExpectedTypes)[item.value] = expectedTableTy->indexer->indexResultType; + + TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, expectedTableTy->indexer->indexResultType, propTy, item.value); + + if (tableTy->indexer) + unifier->unify(matchedType, tableTy->indexer->indexResultType); + else + tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType}; + + tableTy->props.erase(keyStr); + } + + // If it's just an extra property and the expected type + // has no indexer, there's no work to do here. + + continue; + } + + LUAU_ASSERT(it2 != expectedTableTy->props.end()); + + const Property& expectedProp = it2->second; + + std::optional expectedReadTy = expectedProp.readTy; + std::optional expectedWriteTy = expectedProp.writeTy; + + TypeId matchedType = nullptr; + + // Important optimization: If we traverse into the read and + // write types separately even when they are shared, we go + // quadratic in a hurry. + if (expectedProp.isShared()) + { + matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value); + prop.readTy = matchedType; + prop.writeTy = matchedType; + } + else if (expectedReadTy) + { + matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value); + prop.readTy = matchedType; + prop.writeTy.reset(); + } + else if (expectedWriteTy) + { + matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedWriteTy, propTy, item.value); + prop.readTy.reset(); + prop.writeTy = matchedType; + } + else + { + // Also important: It is presently the case that all + // table properties are either read-only, or have the + // same read and write types. + LUAU_ASSERT(!"Should be unreachable"); + } + + LUAU_ASSERT(prop.readTy || prop.writeTy); + + LUAU_ASSERT(matchedType); + + (*astExpectedTypes)[item.value] = matchedType; + } + else if (item.kind == AstExprTable::Item::List) + { + LUAU_ASSERT(tableTy->indexer); + + if (expectedTableTy->indexer) + { + const TypeId* propTy = astTypes->find(item.value); + LUAU_ASSERT(propTy); + + unifier->unify(expectedTableTy->indexer->indexType, builtinTypes->numberType); + TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, expectedTableTy->indexer->indexResultType, *propTy, item.value); + + tableTy->indexer->indexResultType = matchedType; + } + } + else if (item.kind == AstExprTable::Item::General) + { + LUAU_ASSERT(!"TODO"); + } + else + LUAU_ASSERT(!"Unexpected"); + } + + // Keys that the expectedType says we should have, but that aren't + // specified by the AST fragment. + // + // If any such keys are options, then we'll add them to the expression + // type. + // + // We use std::optional here because the empty string is a + // perfectly reasonable value to insert into the set. We'll use + // std::nullopt as our sentinel value. + Set> missingKeys{{}}; + for (const auto& [name, _] : expectedTableTy->props) + missingKeys.insert(name); + + for (const AstExprTable::Item& item: exprTable->items) + { + if (item.key) + { + if (const auto str = item.key->as()) + { + missingKeys.erase(std::string(str->value.data, str->value.size)); + } + } + } + + for (const auto& key: missingKeys) + { + LUAU_ASSERT(key.has_value()); + + auto it = expectedTableTy->props.find(*key); + LUAU_ASSERT(it != expectedTableTy->props.end()); + + const Property& expectedProp = it->second; + + Property exprProp; + + if (expectedProp.readTy && isOptional(*expectedProp.readTy)) + exprProp.readTy = *expectedProp.readTy; + if (expectedProp.writeTy && isOptional(*expectedProp.writeTy)) + exprProp.writeTy = *expectedProp.writeTy; + + // If the property isn't actually optional, do nothing. + if (exprProp.readTy || exprProp.writeTy) + tableTy->props[*key] = std::move(exprProp); + } + } + + return exprType; +} + +} diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index d8044c16..718e9e8f 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -542,6 +542,19 @@ BlockedType::BlockedType() { } +Constraint* BlockedType::getOwner() const { + return owner; +} + +void BlockedType::setOwner(Constraint* newOwner) { + LUAU_ASSERT(owner == nullptr); + + if (owner != nullptr) + return; + + owner = newOwner; +} + PendingExpansionType::PendingExpansionType( std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) : prefix(prefix) @@ -686,6 +699,12 @@ void Property::setType(TypeId ty) writeTy = ty; } +void Property::makeShared() +{ + if (writeTy) + writeTy = readTy; +} + bool Property::isShared() const { return readTy && writeTy && readTy == writeTy; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 61475a57..bd46e29c 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -240,7 +240,6 @@ struct TypeChecker2 const NotNull ice; const SourceModule* sourceModule; Module* module; - TypeArena testArena; TypeContext typeContext = TypeContext::Default; std::vector> stack; @@ -260,8 +259,8 @@ struct TypeChecker2 , ice(unifierState->iceHandler) , sourceModule(sourceModule) , module(module) - , normalizer{&testArena, builtinTypes, unifierState, /* cacheInhabitance */ true} - , _subtyping{builtinTypes, NotNull{&testArena}, NotNull{&normalizer}, NotNull{unifierState->iceHandler}, + , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} + , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{unifierState->iceHandler}, NotNull{module->getModuleScope().get()}} , subtyping(&_subtyping) { @@ -443,7 +442,7 @@ struct TypeChecker2 seenTypeFamilyInstances.insert(instance); ErrorVec errors = reduceFamilies( - instance, location, TypeFamilyContext{NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true) + instance, location, TypeFamilyContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true) .errors; if (!isErrorSuppressing(location, instance)) reportErrors(std::move(errors)); @@ -651,7 +650,7 @@ struct TypeChecker2 Scope* scope = findInnermostScope(ret->location); TypePackId expectedRetType = scope->returnType; - TypeArena* arena = &testArena; + TypeArena* arena = &module->internalTypes; TypePackId actualRetType = reconstructPack(ret->list, *arena); testIsSubtype(actualRetType, expectedRetType, ret->location); @@ -778,7 +777,7 @@ struct TypeChecker2 return; NotNull scope = stack.back(); - TypeArena& arena = testArena; + TypeArena& arena = module->internalTypes; std::vector variableTypes; for (AstLocal* var : forInStatement->vars) @@ -1198,38 +1197,46 @@ struct TypeChecker2 void visit(AstExprConstantNil* expr) { +#if defined(LUAU_ENABLE_ASSERT) TypeId actualType = lookupType(expr); TypeId expectedType = builtinTypes->nilType; SubtypingResult r = subtyping->isSubtype(actualType, expectedType); LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, actualType)); +#endif } void visit(AstExprConstantBool* expr) { - TypeId actualType = lookupType(expr); - TypeId expectedType = builtinTypes->booleanType; +#if defined(LUAU_ENABLE_ASSERT) + const TypeId bestType = expr->value ? builtinTypes->trueType : builtinTypes->falseType; + const TypeId inferredType = lookupType(expr); - SubtypingResult r = subtyping->isSubtype(actualType, expectedType); - LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, actualType)); + const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); + LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); +#endif } void visit(AstExprConstantNumber* expr) { - TypeId actualType = lookupType(expr); - TypeId expectedType = builtinTypes->numberType; +#if defined(LUAU_ENABLE_ASSERT) + const TypeId bestType = builtinTypes->numberType; + const TypeId inferredType = lookupType(expr); - SubtypingResult r = subtyping->isSubtype(actualType, expectedType); - LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, actualType)); + const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); + LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); +#endif } void visit(AstExprConstantString* expr) { - TypeId actualType = lookupType(expr); - TypeId expectedType = builtinTypes->stringType; +#if defined(LUAU_ENABLE_ASSERT) + const TypeId bestType = module->internalTypes.addType(SingletonType{StringSingleton{std::string{expr->value.data, expr->value.size}}}); + const TypeId inferredType = lookupType(expr); - SubtypingResult r = subtyping->isSubtype(actualType, expectedType); - LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, actualType)); + const SubtypingResult r = subtyping->isSubtype(bestType, inferredType); + LUAU_ASSERT(r.isSubtype || isErrorSuppressing(expr->location, inferredType)); +#endif } void visit(AstExprLocal* expr) @@ -1333,7 +1340,7 @@ struct TypeChecker2 OverloadResolver resolver{ builtinTypes, - NotNull{&testArena}, + NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{stack.back()}, ice, @@ -1703,10 +1710,10 @@ struct TypeChecker2 return; } - TypePackId expectedArgs = testArena.addTypePack({operandType}); - TypePackId expectedRet = testArena.addTypePack({resultType}); + TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); + TypePackId expectedRet = module->internalTypes.addTypePack({resultType}); - TypeId expectedFunction = testArena.addType(FunctionType{expectedArgs, expectedRet}); + TypeId expectedFunction = module->internalTypes.addType(FunctionType{expectedArgs, expectedRet}); bool success = testIsSubtype(*mm, expectedFunction, expr->location); if (!success) @@ -1758,8 +1765,8 @@ struct TypeChecker2 bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe; bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or; - TypeId leftType = lookupType(expr->left); - TypeId rightType = lookupType(expr->right); + TypeId leftType = follow(lookupType(expr->left)); + TypeId rightType = follow(lookupType(expr->right)); TypeId expectedResult = follow(lookupType(expr)); if (get(expectedResult)) @@ -1770,7 +1777,7 @@ struct TypeChecker2 if (expr->op == AstExprBinary::Op::Or) { - leftType = stripNil(builtinTypes, testArena, leftType); + leftType = stripNil(builtinTypes, module->internalTypes, leftType); } const NormalizedType* normLeft = normalizer.normalize(leftType); @@ -1874,25 +1881,25 @@ struct TypeChecker2 // swapped argument ordering. if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt) { - expectedArgs = testArena.addTypePack({rightType, leftType}); + expectedArgs = module->internalTypes.addTypePack({rightType, leftType}); } else { - expectedArgs = testArena.addTypePack({leftType, rightType}); + expectedArgs = module->internalTypes.addTypePack({leftType, rightType}); } TypePackId expectedRets; if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe || expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) { - expectedRets = testArena.addTypePack({builtinTypes->booleanType}); + expectedRets = module->internalTypes.addTypePack({builtinTypes->booleanType}); } else { - expectedRets = testArena.addTypePack({testArena.freshType(scope, TypeLevel{})}); + expectedRets = module->internalTypes.addTypePack({module->internalTypes.freshType(scope, TypeLevel{})}); } - TypeId expectedTy = testArena.addType(FunctionType(expectedArgs, expectedRets)); + TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets)); testIsSubtype(follow(*mm), expectedTy, expr->location); @@ -2097,8 +2104,8 @@ struct TypeChecker2 return *fst; else if (auto ftp = get(pack)) { - TypeId result = testArena.addType(FreeType{ftp->scope}); - TypePackId freeTail = testArena.addTypePack(FreeTypePack{ftp->scope}); + TypeId result = module->internalTypes.addType(FreeType{ftp->scope}); + TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); TypePack& resultPack = asMutable(pack)->ty.emplace(); resultPack.head.assign(1, result); @@ -2623,7 +2630,7 @@ struct TypeChecker2 { std::vector parts; parts.insert(parts.end(), norm->functions.parts.begin(), norm->functions.parts.end()); - fetch(testArena.addType(IntersectionType{std::move(parts)})); + fetch(module->internalTypes.addType(IntersectionType{std::move(parts)})); } } for (const auto& [tyvar, intersect] : norm->tyvars) @@ -2631,7 +2638,7 @@ struct TypeChecker2 if (get(intersect->tops)) { TypeId ty = normalizer.typeFromNormal(*intersect); - fetch(testArena.addType(IntersectionType{{tyvar, ty}})); + fetch(module->internalTypes.addType(IntersectionType{{tyvar, ty}})); } else fetch(tyvar); @@ -2739,7 +2746,7 @@ struct TypeChecker2 return true; if (cls->indexer) { - TypeId inhabitatedTestType = testArena.addType(IntersectionType{{cls->indexer->indexType, astIndexExprType}}); + TypeId inhabitatedTestType = module->internalTypes.addType(IntersectionType{{cls->indexer->indexType, astIndexExprType}}); return normalizer.isInhabited(inhabitatedTestType); } return false; diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index e5df313e..a61fb88a 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -25,6 +25,8 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); +LUAU_FASTFLAG(DebugLuauLogSolver); + namespace Luau { @@ -152,6 +154,13 @@ struct FamilyReducer template void replace(T subject, T replacement) { + if (FFlag::DebugLuauLogSolver) + printf("%s -> %s\n", toString(subject, {true}).c_str(), toString(replacement, {true}).c_str()); + + // TODO: This should be an ICE (CLI-100942) + if (subject->owningArena != ctx.arena.get()) + ctx.ice->ice("Attempting to modify a type family instance from another arena", location); + asMutable(subject)->ty.template emplace>(replacement); if constexpr (std::is_same_v) @@ -171,6 +180,9 @@ struct FamilyReducer if (reduction.uninhabited || force) { + if (FFlag::DebugLuauLogSolver) + printf("%s is uninhabited\n", toString(subject, {true}).c_str()); + if constexpr (std::is_same_v) result.errors.push_back(TypeError{location, UninhabitedTypeFamily{subject}}); else if constexpr (std::is_same_v) @@ -178,6 +190,9 @@ struct FamilyReducer } else if (!reduction.uninhabited && !force) { + if (FFlag::DebugLuauLogSolver) + printf("%s is irreducible; blocked on %zu types, %zu packs\n", toString(subject, {true}).c_str(), reduction.blockedTypes.size(), reduction.blockedPacks.size()); + for (TypeId b : reduction.blockedTypes) result.blockedTypes.insert(b); @@ -201,11 +216,17 @@ struct FamilyReducer if (skip == SkipTestResult::Irreducible) { + if (FFlag::DebugLuauLogSolver) + printf("%s is irreducible due to a dependency on %s\n" , toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + irreducible.insert(subject); return false; } else if (skip == SkipTestResult::Defer) { + if (FFlag::DebugLuauLogSolver) + printf("Deferring %s until %s is solved\n" , toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + if constexpr (std::is_same_v) queuedTys.push_back(subject); else if constexpr (std::is_same_v) @@ -221,11 +242,17 @@ struct FamilyReducer if (skip == SkipTestResult::Irreducible) { + if (FFlag::DebugLuauLogSolver) + printf("%s is irreducible due to a dependency on %s\n" , toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + irreducible.insert(subject); return false; } else if (skip == SkipTestResult::Defer) { + if (FFlag::DebugLuauLogSolver) + printf("Deferring %s until %s is solved\n" , toString(subject, {true}).c_str(), toString(p, {true}).c_str()); + if constexpr (std::is_same_v) queuedTys.push_back(subject); else if constexpr (std::is_same_v) @@ -246,12 +273,20 @@ struct FamilyReducer if (irreducible.contains(subject)) return; + if (FFlag::DebugLuauLogSolver) + printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); + if (const TypeFamilyInstanceType* tfit = get(subject)) { SkipTestResult testCyclic = testForSkippability(subject); if (!testParameters(subject, tfit) && testCyclic != SkipTestResult::CyclicTypeFamily) + { + if (FFlag::DebugLuauLogSolver) + printf("Irreducible due to irreducible/pending and a non-cyclic family\n"); + return; + } TypeFamilyReductionResult result = tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); handleFamilyReduction(subject, result); @@ -266,6 +301,9 @@ struct FamilyReducer if (irreducible.contains(subject)) return; + if (FFlag::DebugLuauLogSolver) + printf("Trying to reduce %s\n", toString(subject, {true}).c_str()); + if (const TypeFamilyInstanceTypePack* tfit = get(subject)) { if (!testParameters(subject, tfit)) @@ -346,7 +384,7 @@ FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location locati bool isPending(TypeId ty, ConstraintSolver* solver) { - return is(ty) || is(ty) || is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); + return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); } TypeFamilyReductionResult notFamilyFn( diff --git a/CMakeLists.txt b/CMakeLists.txt index 0189fcd0..9e99e931 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,6 +23,7 @@ endif() project(Luau LANGUAGES CXX C) add_library(Luau.Common INTERFACE) +add_library(Luau.CLI.lib STATIC) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) add_library(Luau.Config STATIC) @@ -65,9 +66,12 @@ include(Sources.cmake) target_include_directories(Luau.Common INTERFACE Common/include) +target_compile_features(Luau.CLI.lib PUBLIC cxx_std_17) +target_link_libraries(Luau.CLI.lib PRIVATE Luau.Common) + target_compile_features(Luau.Ast PUBLIC cxx_std_17) target_include_directories(Luau.Ast PUBLIC Ast/include) -target_link_libraries(Luau.Ast PUBLIC Luau.Common) +target_link_libraries(Luau.Ast PUBLIC Luau.Common Luau.CLI.lib) target_compile_features(Luau.Compiler PUBLIC cxx_std_17) target_include_directories(Luau.Compiler PUBLIC Compiler/include) @@ -137,6 +141,7 @@ endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) +target_compile_options(Luau.CLI.lib PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS}) @@ -193,7 +198,7 @@ if(LUAU_BUILD_CLI) target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include) - target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM isocline) + target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM Luau.CLI.lib isocline) if(UNIX) find_library(LIBPTHREAD pthread) @@ -203,17 +208,17 @@ if(LUAU_BUILD_CLI) endif() endif() - target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis Luau.CLI.lib) - target_link_libraries(Luau.Ast.CLI PRIVATE Luau.Ast Luau.Analysis) + target_link_libraries(Luau.Ast.CLI PRIVATE Luau.Ast Luau.Analysis Luau.CLI.lib) target_compile_features(Luau.Reduce.CLI PRIVATE cxx_std_17) target_include_directories(Luau.Reduce.CLI PUBLIC Reduce/include) - target_link_libraries(Luau.Reduce.CLI PRIVATE Luau.Common Luau.Ast Luau.Analysis) + target_link_libraries(Luau.Reduce.CLI PRIVATE Luau.Common Luau.Ast Luau.Analysis Luau.CLI.lib) - target_link_libraries(Luau.Compile.CLI PRIVATE Luau.Compiler Luau.VM Luau.CodeGen) + target_link_libraries(Luau.Compile.CLI PRIVATE Luau.Compiler Luau.VM Luau.CodeGen Luau.CLI.lib) - target_link_libraries(Luau.Bytecode.CLI PRIVATE Luau.Compiler Luau.VM Luau.CodeGen) + target_link_libraries(Luau.Bytecode.CLI PRIVATE Luau.Compiler Luau.VM Luau.CodeGen Luau.CLI.lib) endif() if(LUAU_BUILD_TESTS) @@ -230,7 +235,7 @@ if(LUAU_BUILD_TESTS) target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.CLI.Test PRIVATE extern CLI) - target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM isocline) + target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.Config Luau.CodeGen Luau.VM Luau.CLI.lib isocline) if(UNIX) find_library(LIBPTHREAD pthread) if (LIBPTHREAD) diff --git a/CodeGen/include/Luau/IrVisitUseDef.h b/CodeGen/include/Luau/IrVisitUseDef.h index 3d67d13c..09167ef3 100644 --- a/CodeGen/include/Luau/IrVisitUseDef.h +++ b/CodeGen/include/Luau/IrVisitUseDef.h @@ -4,7 +4,7 @@ #include "Luau/Common.h" #include "Luau/IrData.h" -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores3) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) namespace Luau { @@ -188,7 +188,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i visitor.def(inst.b); break; case IrCmd::FALLBACK_FORGPREP: - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) { // This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use visitor.useRange(vmRegOp(inst.b), 3); @@ -216,7 +216,8 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i // After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG Rn, tag, block instructions are generated case IrCmd::CHECK_TAG: - visitor.maybeUse(inst.a); + if (!FFlag::LuauCodegenRemoveDeadStores4) + visitor.maybeUse(inst.a); break; default: diff --git a/CodeGen/include/Luau/NativeProtoExecData.h b/CodeGen/include/Luau/NativeProtoExecData.h new file mode 100644 index 00000000..0033c276 --- /dev/null +++ b/CodeGen/include/Luau/NativeProtoExecData.h @@ -0,0 +1,49 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +// The NativeProtoExecData is constant metadata associated with a NativeProto. +// We generally refer to the NativeProtoExecData via a pointer to the instruction +// offsets array because this makes the logic in the entry gate simpler. + +class NativeModule; + +struct NativeProtoExecDataHeader +{ + // The NativeModule that owns this NativeProto. This is initialized + // when the NativeProto is bound to the NativeModule via assignToModule(). + NativeModule* nativeModule = nullptr; + + // The number of bytecode instructions in the proto. This is the number of + // elements in the instruction offsets array following this header. + uint32_t bytecodeInstructionCount = 0; + + // The size of the native code for this NativeProto, in bytes. + size_t nativeCodeSize = 0; +}; + +// Make sure that the instruction offsets array following the header will be +// correctly aligned: +static_assert(sizeof(NativeProtoExecDataHeader) % sizeof(uint32_t) == 0); + +struct NativeProtoExecDataDeleter +{ + void operator()(const uint32_t* instructionOffsets) const noexcept; +}; + +using NativeProtoExecDataPtr = std::unique_ptr; + +[[nodiscard]] NativeProtoExecDataPtr createNativeProtoExecData(uint32_t bytecodeInstructionCount); + +[[nodiscard]] NativeProtoExecDataHeader& getNativeProtoExecDataHeader(uint32_t* instructionOffsets) noexcept; +[[nodiscard]] const NativeProtoExecDataHeader& getNativeProtoExecDataHeader(const uint32_t* instructionOffsets) noexcept; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/SharedCodeAllocator.h b/CodeGen/include/Luau/SharedCodeAllocator.h new file mode 100644 index 00000000..577532e8 --- /dev/null +++ b/CodeGen/include/Luau/SharedCodeAllocator.h @@ -0,0 +1,210 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/NativeProtoExecData.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +// SharedCodeAllocator is a native executable code allocator that provides +// shared ownership of the native code. Code is allocated on a per-module +// basis. Each module is uniquely identifiable via an id, which may be a hash +// or other unique value. Each module may contain multiple natively compiled +// functions (protos). +// +// The module is the unit of shared ownership (i.e., it is where the reference +// count is maintained). + +using ModuleId = std::array; + +class NativeProto; +class NativeModule; +class NativeModuleRef; +class SharedCodeAllocator; + +// A NativeProto represents a single natively-compiled function. A NativeProto +// should be constructed for each function as it is compiled. When compilation +// of all of the functions in a module is complete, the set of NativeProtos +// representing those functions should be passed to the NativeModule constructor. +class NativeProto +{ +public: + NativeProto(uint32_t bytecodeId, NativeProtoExecDataPtr nativeExecData); + + NativeProto(const NativeProto&) = delete; + NativeProto(NativeProto&&) noexcept = default; + NativeProto& operator=(const NativeProto&) = delete; + NativeProto& operator=(NativeProto&&) noexcept = default; + + // This should be called to initialize the NativeProto state prior to + // passing the NativeProto to the NativeModule constructor. + void setEntryOffset(uint32_t entryOffset) noexcept; + + // This will be called by the NativeModule constructor to bind this + // NativeProto to the NativeModule. + void assignToModule(NativeModule* nativeModule) noexcept; + + // Gets the bytecode id for the Proto that was compiled into this NativeProto + [[nodiscard]] uint32_t getBytecodeId() const noexcept; + + // Gets the address of the entry point for this function + [[nodiscard]] const uint8_t* getEntryAddress() const noexcept; + + // Gets the native exec data for this function + [[nodiscard]] const NativeProtoExecDataHeader& getNativeExecDataHeader() const noexcept; + + // The NativeProto stores an array that maps bytecode instruction indices to + // native code offsets relative to the native entry point. When compilation + // and code allocation is complete, we store a pointer to this data in the + // Luau VM Proto object for this function. When we do this, we must acquire + // a reference to the NativeModule that owns this NativeProto. The + // getOwning-version of this function acquires that reference and gets the + // instruction offsets pointer. When the Proto object is destroyed, this + // pointer must be passed to releaseOwningPointerToInstructionOffsets to + // release the reference. + // + // (This structure is designed to make it much more difficult to "forget" + // to acquire a reference.) + [[nodiscard]] const uint32_t* getNonOwningPointerToInstructionOffsets() const noexcept; + [[nodiscard]] const uint32_t* getOwningPointerToInstructionOffsets() const noexcept; + + static void releaseOwningPointerToInstructionOffsets(const uint32_t* ownedInstructionOffsets) noexcept; + +private: + uint32_t bytecodeId = 0; + + // We store the native code offset until assignToModule() is called, after + // which point we store the actual address. + const uint8_t* entryOffsetOrAddress = nullptr; + + NativeProtoExecDataPtr nativeExecData = {}; +}; + +// A NativeModule represents a single natively-compiled module (script). It is +// the unit of shared ownership and is thus where the reference count is +// maintained. It owns a set of NativeProtos, with associated native exec data, +// and the allocated native data and code. +class NativeModule +{ +public: + NativeModule( + SharedCodeAllocator* allocator, const ModuleId& moduleId, const uint8_t* moduleBaseAddress, std::vector nativeProtos) noexcept; + + NativeModule(const NativeModule&) = delete; + NativeModule(NativeModule&&) = delete; + NativeModule& operator=(const NativeModule&) = delete; + NativeModule& operator=(NativeModule&&) = delete; + + // The NativeModule must not be destroyed if there are any outstanding + // references. It should thus only be destroyed by a call to release() + // that releases the last reference. + ~NativeModule() noexcept; + + size_t addRef() const noexcept; + size_t release() const noexcept; + [[nodiscard]] size_t getRefcount() const noexcept; + + // Gets the base address of the executable native code for the module. + [[nodiscard]] const uint8_t* getModuleBaseAddress() const noexcept; + + // Attempts to find the NativeProto with the given bytecode id. If no + // NativeProto for that bytecode id exists, a null pointer is returned. + [[nodiscard]] const NativeProto* tryGetNativeProto(uint32_t bytecodeId) const noexcept; + +private: + mutable std::atomic refcount = 0; + + SharedCodeAllocator* allocator = nullptr; + ModuleId moduleId = {}; + const uint8_t* moduleBaseAddress = nullptr; + + std::vector nativeProtos = {}; +}; + +// A NativeModuleRef is an owning reference to a NativeModule. (Note: We do +// not use shared_ptr, to avoid complex state management in the Luau GC Proto +// object.) +class NativeModuleRef +{ +public: + NativeModuleRef() noexcept = default; + NativeModuleRef(NativeModule* nativeModule) noexcept; + + NativeModuleRef(const NativeModuleRef& other) noexcept; + NativeModuleRef(NativeModuleRef&& other) noexcept; + NativeModuleRef& operator=(NativeModuleRef other) noexcept; + + ~NativeModuleRef() noexcept; + + void reset() noexcept; + void swap(NativeModuleRef& other) noexcept; + + [[nodiscard]] bool empty() const noexcept; + explicit operator bool() const noexcept; + + [[nodiscard]] const NativeModule* get() const noexcept; + [[nodiscard]] const NativeModule* operator->() const noexcept; + [[nodiscard]] const NativeModule& operator*() const noexcept; + +private: + const NativeModule* nativeModule = nullptr; +}; + +class SharedCodeAllocator +{ +public: + SharedCodeAllocator() = default; + + SharedCodeAllocator(const SharedCodeAllocator&) = delete; + SharedCodeAllocator(SharedCodeAllocator&&) = delete; + SharedCodeAllocator& operator=(const SharedCodeAllocator&) = delete; + SharedCodeAllocator& operator=(SharedCodeAllocator&&) = delete; + + ~SharedCodeAllocator() noexcept; + + // If we have a NativeModule for the given ModuleId, an owning reference to + // it is returned. Otherwise, an empty NativeModuleRef is returned. + [[nodiscard]] NativeModuleRef tryGetNativeModule(const ModuleId& moduleId) const noexcept; + + // If we have a NativeModule for the given ModuleId, an owning reference to + // it is returned. Otherwise, a new NativeModule is created for that ModuleId + // using the provided NativeProtos, data, and code (space is allocated for the + // data and code such that it can be executed). + NativeModuleRef getOrInsertNativeModule( + const ModuleId& moduleId, std::vector nativeProtos, const std::vector& data, const std::vector& code); + + // If a NativeModule exists for the given ModuleId and that NativeModule + // is no longer referenced, the NativeModule is destroyed. This should + // usually only be called by NativeModule::release() when the reference + // count becomes zero + void eraseNativeModuleIfUnreferenced(const ModuleId& moduleId); + +private: + struct ModuleIdHash + { + [[nodiscard]] size_t operator()(const ModuleId& moduleId) const noexcept; + }; + + [[nodiscard]] NativeModuleRef tryGetNativeModuleWithLockHeld(const ModuleId& moduleId) const noexcept; + + mutable std::mutex mutex; + + // Will be removed when backend allocator is integrated + const uint8_t* baseAddress = reinterpret_cast(0x0f00'0000); + + std::unordered_map, ModuleIdHash, std::equal_to<>> nativeModules; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 14f5b5b5..07004ae3 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -69,7 +69,7 @@ static const Instruction kCodeEntryInsn = LOP_NATIVECALL; static void* gPerfLogContext = nullptr; static PerfLogFn gPerfLogFn = nullptr; -struct NativeProto +struct OldNativeProto { Proto* p; void* execdata; @@ -116,7 +116,7 @@ ExtraExecData* getExtraExecData(Proto* proto, void* execdata) return reinterpret_cast(reinterpret_cast(execdata) + size); } -static NativeProto createNativeProto(Proto* proto, const IrBuilder& ir) +static OldNativeProto createOldNativeProto(Proto* proto, const IrBuilder& ir) { if (FFlag::LuauCodegenHeapSizeReport) { @@ -186,7 +186,7 @@ static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) } template -static std::optional createNativeFunction( +static std::optional createNativeFunction( AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result) { IrBuilder ir; @@ -204,7 +204,7 @@ static std::optional createNativeFunction( if (!lowerFunction(ir, build, helpers, proto, {}, /* stats */ nullptr, result)) return std::nullopt; - return createNativeProto(proto, ir); + return createOldNativeProto(proto, ir); } static NativeState* getNativeState(lua_State* L) @@ -455,7 +455,7 @@ CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags, Comp X64::assembleHelpers(build, helpers); #endif - std::vector results; + std::vector results; results.reserve(protos.size()); uint32_t totalIrInstCount = 0; @@ -468,7 +468,7 @@ CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags, Comp // If multiple compilations fail, we only use the failure from the first unsuccessful compilation. CodeGenCompilationResult temp = CodeGenCompilationResult::Success; - if (std::optional np = createNativeFunction(build, helpers, p, totalIrInstCount, temp)) + if (std::optional np = createNativeFunction(build, helpers, p, totalIrInstCount, temp)) results.push_back(*np); // second compilation failure onwards, this condition fails and codeGenCompilationResult is not assigned. else if (codeGenCompilationResult == CodeGenCompilationResult::Success) @@ -478,7 +478,7 @@ CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags, Comp // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module if (!build.finalize()) { - for (NativeProto result : results) + for (OldNativeProto result : results) destroyExecData(result.execdata); return CodeGenCompilationResult::CodeGenAssemblerFinalizationFailure; @@ -497,7 +497,7 @@ CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags, Comp if (!data->codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), int(build.code.size() * sizeof(build.code[0])), nativeData, sizeNativeData, codeStart)) { - for (NativeProto result : results) + for (OldNativeProto result : results) destroyExecData(result.execdata); return CodeGenCompilationResult::AllocationFailed; @@ -538,7 +538,7 @@ CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags, Comp } } - for (const NativeProto& result : results) + for (const OldNativeProto& result : results) { // the memory is now managed by VM and will be freed via onDestroyFunction result.p->execdata = result.execdata; @@ -548,7 +548,7 @@ CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags, Comp if (stats != nullptr) { - for (const NativeProto& result : results) + for (const OldNativeProto& result : results) { stats->bytecodeSizeBytes += result.p->sizecode * sizeof(Instruction); diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index 4f769178..e1a2b2a9 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -27,7 +27,7 @@ LUAU_FASTFLAG(DebugCodegenSkipNumbering) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores3) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) namespace Luau { @@ -312,7 +312,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& } } - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) markDeadStoresInBlockChains(ir); } diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index fdce733f..8b27f40d 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -12,7 +12,7 @@ #include "lstate.h" -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores3) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) namespace Luau { @@ -30,7 +30,7 @@ static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vmovsd(luauRegValue(ra), xmm0); - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) build.mov(luauRegTag(ra), LUA_TNUMBER); if (nresults > 1) @@ -38,7 +38,7 @@ static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra + 1), xmm0); - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) build.mov(luauRegTag(ra + 1), LUA_TNUMBER); } } @@ -53,14 +53,14 @@ static void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra), xmm1); - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) build.mov(luauRegTag(ra), LUA_TNUMBER); if (nresults > 1) { build.vmovsd(luauRegValue(ra + 1), xmm0); - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) build.mov(luauRegTag(ra + 1), LUA_TNUMBER); } } @@ -91,7 +91,7 @@ static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vmovsd(luauRegValue(ra), tmp0.reg); - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) build.mov(luauRegTag(ra), LUA_TNUMBER); } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 5f101b9f..6d4bb350 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -15,7 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauCodeGenVectorA64, false) LUAU_FASTFLAGVARIABLE(LuauCodeGenOptVecA64, false) LUAU_FASTFLAG(LuauCodegenVectorTag2) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores3) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) namespace Luau { @@ -204,7 +204,7 @@ static bool emitBuiltin( { case LBF_MATH_FREXP: { - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) { CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); @@ -238,7 +238,7 @@ static bool emitBuiltin( } case LBF_MATH_MODF: { - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) { CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); @@ -278,7 +278,7 @@ static bool emitBuiltin( build.fcsel(d0, d1, d0, getConditionFP(IrCondition::Less)); build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) { RegisterA64 temp = regs.allocTemp(KindA64::w); build.mov(temp, LUA_TNUMBER); @@ -1512,20 +1512,35 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) Label fresh; // used when guard aborts execution or jumps to a VM exit Label& fail = getTargetLabel(inst.c, fresh); - // To support DebugLuauAbortingChecks, CHECK_TAG with VmReg has to be handled - RegisterA64 tag = inst.a.kind == IrOpKind::VmReg ? regs.allocTemp(KindA64::w) : regOp(inst.a); - - if (inst.a.kind == IrOpKind::VmReg) - build.ldr(tag, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); - - if (tagOp(inst.b) == 0) + if (FFlag::LuauCodegenRemoveDeadStores4) { - build.cbnz(tag, fail); + if (tagOp(inst.b) == 0) + { + build.cbnz(regOp(inst.a), fail); + } + else + { + build.cmp(regOp(inst.a), tagOp(inst.b)); + build.b(ConditionA64::NotEqual, fail); + } } else { - build.cmp(tag, tagOp(inst.b)); - build.b(ConditionA64::NotEqual, fail); + // To support DebugLuauAbortingChecks, CHECK_TAG with VmReg has to be handled + RegisterA64 tag = inst.a.kind == IrOpKind::VmReg ? regs.allocTemp(KindA64::w) : regOp(inst.a); + + if (inst.a.kind == IrOpKind::VmReg) + build.ldr(tag, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); + + if (tagOp(inst.b) == 0) + { + build.cbnz(tag, fail); + } + else + { + build.cmp(tag, tagOp(inst.b)); + build.b(ConditionA64::NotEqual, fail); + } } finalizeTargetLabel(inst.c, fresh); diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index 30778fd9..b88dca81 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -8,7 +8,7 @@ #include -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores3) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results @@ -48,7 +48,7 @@ static BuiltinImplResult translateBuiltinNumberToNumber( builtinCheckDouble(build, build.vmReg(arg), pcpos); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); - if (!FFlag::LuauCodegenRemoveDeadStores3) + if (!FFlag::LuauCodegenRemoveDeadStores4) { if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); @@ -112,7 +112,7 @@ static BuiltinImplResult translateBuiltinNumberTo2Number( build.inst( IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(nresults == 1 ? 1 : 2)); - if (!FFlag::LuauCodegenRemoveDeadStores3) + if (!FFlag::LuauCodegenRemoveDeadStores4) { if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 5d55c877..b3471573 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -13,7 +13,6 @@ #include "ltm.h" LUAU_FASTFLAGVARIABLE(LuauCodegenVectorTag2, false) -LUAU_FASTFLAGVARIABLE(LuauCodegenVectorTag, false) LUAU_FASTFLAGVARIABLE(LuauCodegenLoadTVTag, false) namespace Luau diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index 93a452ff..3975e25a 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -119,7 +119,7 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) break; // These instructions read VmReg only after optimizeMemoryOperandsX64 - case IrCmd::CHECK_TAG: + case IrCmd::CHECK_TAG: // TODO: remove with FFlagLuauCodegenRemoveDeadStores4 case IrCmd::CHECK_TRUTHY: case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: diff --git a/CodeGen/src/NativeProtoExecData.cpp b/CodeGen/src/NativeProtoExecData.cpp new file mode 100644 index 00000000..1b2ca379 --- /dev/null +++ b/CodeGen/src/NativeProtoExecData.cpp @@ -0,0 +1,45 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/NativeProtoExecData.h" + +#include "Luau/Common.h" + +#include + +namespace Luau +{ +namespace CodeGen +{ + +[[nodiscard]] static size_t computeNativeExecDataSize(uint32_t bytecodeInstructionCount) noexcept +{ + return sizeof(NativeProtoExecDataHeader) + (bytecodeInstructionCount * sizeof(uint32_t)); +} + +void NativeProtoExecDataDeleter::operator()(const uint32_t* instructionOffsets) const noexcept +{ + const NativeProtoExecDataHeader* header = &getNativeProtoExecDataHeader(instructionOffsets); + header->~NativeProtoExecDataHeader(); + delete[] reinterpret_cast(header); +} + +[[nodiscard]] NativeProtoExecDataPtr createNativeProtoExecData(uint32_t bytecodeInstructionCount) +{ + std::unique_ptr bytes = std::make_unique(computeNativeExecDataSize(bytecodeInstructionCount)); + new (static_cast(bytes.get())) NativeProtoExecDataHeader{}; + return NativeProtoExecDataPtr{reinterpret_cast(bytes.release() + sizeof(NativeProtoExecDataHeader))}; +} + +[[nodiscard]] NativeProtoExecDataHeader& getNativeProtoExecDataHeader(uint32_t* instructionOffsets) noexcept +{ + return *reinterpret_cast(reinterpret_cast(instructionOffsets) - sizeof(NativeProtoExecDataHeader)); +} + +[[nodiscard]] const NativeProtoExecDataHeader& getNativeProtoExecDataHeader(const uint32_t* instructionOffsets) noexcept +{ + + return *reinterpret_cast( + reinterpret_cast(instructionOffsets) - sizeof(NativeProtoExecDataHeader)); +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 946fac41..8bffb3e0 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -19,8 +19,8 @@ LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) LUAU_FASTFLAG(LuauCodegenVectorTag2) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCodeGenCoverForgprepEffect, false) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) LUAU_FASTFLAG(LuauCodegenLoadTVTag) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores3) namespace Luau { @@ -610,7 +610,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (state.tryGetTag(source) == value) { - if (FFlag::DebugLuauAbortingChecks) + if (FFlag::DebugLuauAbortingChecks && !FFlag::LuauCodegenRemoveDeadStores4) replace(function, block, index, {IrCmd::CHECK_TAG, inst.a, inst.b, build.undef()}); else kill(function, inst); @@ -1075,7 +1075,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::FASTCALL: { - if (FFlag::LuauCodegenRemoveDeadStores3) + if (FFlag::LuauCodegenRemoveDeadStores4) { LuauBuiltinFunction bfid = LuauBuiltinFunction(function.uintOp(inst.a)); int firstReturnReg = vmRegOp(inst.b); diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp index 25b7fed4..0b6550fd 100644 --- a/CodeGen/src/OptimizeDeadStore.cpp +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -9,8 +9,9 @@ #include "lobject.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores3, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores4, false) LUAU_FASTFLAG(LuauCodegenVectorTag2) +LUAU_FASTFLAG(LuauCodegenLoadTVTag) // TODO: optimization can be improved by knowing which registers are live in at each VM exit @@ -336,7 +337,7 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, // If the argument is a vector, it's not a GC object // Note that for known boolean/number/GCO, we already optimize into STORE_SPLIT_TVALUE form - // TODO: this can be removed if TAG_VECTOR+STORE_TVALUE is replaced with STORE_SPLIT_TVALUE + // TODO (CLI-101027): similar code is used in constant propagation optimization and should be shared in utilities if (IrInst* arg = function.asInstOp(inst.b)) { if (FFlag::LuauCodegenVectorTag2) @@ -350,6 +351,9 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, arg->cmd == IrCmd::UNM_VEC) regInfo.maybeGco = false; } + + if (FFlag::LuauCodegenLoadTVTag && arg->cmd == IrCmd::LOAD_TVALUE && arg->c.kind != IrOpKind::None) + regInfo.maybeGco = isGCO(function.tagOp(arg->c)); } state.hasGcoToClear |= regInfo.maybeGco; @@ -377,9 +381,6 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, // Guard checks can jump to a block which might be using some or all the values we stored case IrCmd::CHECK_TAG: - // After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG might use a VM register - visitVmRegDefsUses(state, function, inst); - state.checkLiveIns(inst.c); break; case IrCmd::TRY_NUM_TO_INDEX: diff --git a/CodeGen/src/SharedCodeAllocator.cpp b/CodeGen/src/SharedCodeAllocator.cpp new file mode 100644 index 00000000..b26bd245 --- /dev/null +++ b/CodeGen/src/SharedCodeAllocator.cpp @@ -0,0 +1,293 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/SharedCodeAllocator.h" + +#include +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + + + +NativeProto::NativeProto(uint32_t bytecodeId, NativeProtoExecDataPtr nativeExecData) + : bytecodeId{bytecodeId} + , nativeExecData{std::move(nativeExecData)} +{ +} + +void NativeProto::setEntryOffset(uint32_t entryOffset) noexcept +{ + entryOffsetOrAddress = reinterpret_cast(static_cast(entryOffset)); +} + +void NativeProto::assignToModule(NativeModule* nativeModule) noexcept +{ + getNativeProtoExecDataHeader(nativeExecData.get()).nativeModule = nativeModule; + + entryOffsetOrAddress = nativeModule->getModuleBaseAddress() + reinterpret_cast(entryOffsetOrAddress); +} + +[[nodiscard]] uint32_t NativeProto::getBytecodeId() const noexcept +{ + return bytecodeId; +} + +[[nodiscard]] const uint8_t* NativeProto::getEntryAddress() const noexcept +{ + return entryOffsetOrAddress; +} + +[[nodiscard]] const NativeProtoExecDataHeader& NativeProto::getNativeExecDataHeader() const noexcept +{ + return getNativeProtoExecDataHeader(nativeExecData.get()); +} + +[[nodiscard]] const uint32_t* NativeProto::getNonOwningPointerToInstructionOffsets() const noexcept +{ + return nativeExecData.get(); +} + +[[nodiscard]] const uint32_t* NativeProto::getOwningPointerToInstructionOffsets() const noexcept +{ + getNativeProtoExecDataHeader(nativeExecData.get()).nativeModule->addRef(); + return nativeExecData.get(); +} + +void NativeProto::releaseOwningPointerToInstructionOffsets(const uint32_t* ownedInstructionOffsets) noexcept +{ + getNativeProtoExecDataHeader(ownedInstructionOffsets).nativeModule->release(); +} + + +struct NativeProtoBytecodeIdEqual +{ + [[nodiscard]] bool operator()(const NativeProto& left, const NativeProto& right) const noexcept + { + return left.getBytecodeId() == right.getBytecodeId(); + } +}; + +struct NativeProtoBytecodeIdLess +{ + [[nodiscard]] bool operator()(const NativeProto& left, const NativeProto& right) const noexcept + { + return left.getBytecodeId() < right.getBytecodeId(); + } + + [[nodiscard]] bool operator()(const NativeProto& left, uint32_t right) const noexcept + { + return left.getBytecodeId() < right; + } + + [[nodiscard]] bool operator()(uint32_t left, const NativeProto& right) const noexcept + { + return left < right.getBytecodeId(); + } +}; + +NativeModule::NativeModule( + SharedCodeAllocator* allocator, const ModuleId& moduleId, const uint8_t* moduleBaseAddress, std::vector nativeProtos) noexcept + : allocator{allocator} + , moduleId{moduleId} + , moduleBaseAddress{moduleBaseAddress} + , nativeProtos{std::move(nativeProtos)} +{ + LUAU_ASSERT(allocator != nullptr); + LUAU_ASSERT(moduleBaseAddress != nullptr); + + // Bind all of the NativeProtos to this module: + for (NativeProto& nativeProto : this->nativeProtos) + { + nativeProto.assignToModule(this); + } + + std::sort(this->nativeProtos.begin(), this->nativeProtos.end(), NativeProtoBytecodeIdLess{}); + + // We should not have two NativeProtos for the same bytecode id: + LUAU_ASSERT(std::adjacent_find(this->nativeProtos.begin(), this->nativeProtos.end(), NativeProtoBytecodeIdEqual{}) == this->nativeProtos.end()); +} + +NativeModule::~NativeModule() noexcept +{ + LUAU_ASSERT(refcount == 0); +} + +size_t NativeModule::addRef() const noexcept +{ + return refcount.fetch_add(1) + 1; +} + +size_t NativeModule::release() const noexcept +{ + size_t newRefcount = refcount.fetch_sub(1) - 1; + if (newRefcount != 0) + return newRefcount; + + allocator->eraseNativeModuleIfUnreferenced(moduleId); + + // NOTE: *this may have been destroyed by the prior call, and must not be + // accessed after this point. + return 0; +} + +[[nodiscard]] size_t NativeModule::getRefcount() const noexcept +{ + return refcount; +} + +[[nodiscard]] const uint8_t* NativeModule::getModuleBaseAddress() const noexcept +{ + return moduleBaseAddress; +} + +[[nodiscard]] const NativeProto* NativeModule::tryGetNativeProto(uint32_t bytecodeId) const noexcept +{ + const auto range = std::equal_range(nativeProtos.begin(), nativeProtos.end(), bytecodeId, NativeProtoBytecodeIdLess{}); + if (range.first == range.second) + return nullptr; + + LUAU_ASSERT(std::next(range.first) == range.second); + + return &*range.first; +} + + +NativeModuleRef::NativeModuleRef(NativeModule* nativeModule) noexcept + : nativeModule{nativeModule} +{ + if (nativeModule != nullptr) + nativeModule->addRef(); +} + +NativeModuleRef::NativeModuleRef(const NativeModuleRef& other) noexcept + : nativeModule{other.nativeModule} +{ + if (nativeModule != nullptr) + nativeModule->addRef(); +} + +NativeModuleRef::NativeModuleRef(NativeModuleRef&& other) noexcept + : nativeModule{std::exchange(other.nativeModule, nullptr)} +{ +} + +NativeModuleRef& NativeModuleRef::operator=(NativeModuleRef other) noexcept +{ + swap(other); + + return *this; +} + +NativeModuleRef::~NativeModuleRef() noexcept +{ + reset(); +} + +void NativeModuleRef::reset() noexcept +{ + if (nativeModule == nullptr) + return; + + nativeModule->release(); + nativeModule = nullptr; +} + +void NativeModuleRef::swap(NativeModuleRef& other) noexcept +{ + std::swap(nativeModule, other.nativeModule); +} + +[[nodiscard]] bool NativeModuleRef::empty() const noexcept +{ + return nativeModule == nullptr; +} + +NativeModuleRef::operator bool() const noexcept +{ + return nativeModule != nullptr; +} + +[[nodiscard]] const NativeModule* NativeModuleRef::get() const noexcept +{ + return nativeModule; +} + +[[nodiscard]] const NativeModule* NativeModuleRef::operator->() const noexcept +{ + return nativeModule; +} + +[[nodiscard]] const NativeModule& NativeModuleRef::operator*() const noexcept +{ + return *nativeModule; +} + + +SharedCodeAllocator::~SharedCodeAllocator() noexcept +{ + // The allocator should not be destroyed until all outstanding references + // have been released and all allocated modules have been destroyed. + LUAU_ASSERT(nativeModules.empty()); +} + +[[nodiscard]] NativeModuleRef SharedCodeAllocator::tryGetNativeModule(const ModuleId& moduleId) const noexcept +{ + std::unique_lock lock{mutex}; + + return tryGetNativeModuleWithLockHeld(moduleId); +} + +NativeModuleRef SharedCodeAllocator::getOrInsertNativeModule( + const ModuleId& moduleId, std::vector nativeProtos, const std::vector& data, const std::vector& code) +{ + std::unique_lock lock{mutex}; + + if (NativeModuleRef existingModule = tryGetNativeModuleWithLockHeld(moduleId)) + return existingModule; + + // We simulate allocation until the backend allocator is integrated + + std::unique_ptr& nativeModule = nativeModules[moduleId]; + nativeModule = std::make_unique(this, moduleId, baseAddress, std::move(nativeProtos)); + + baseAddress += data.size() + code.size(); + + return NativeModuleRef{nativeModule.get()}; +} + +void SharedCodeAllocator::eraseNativeModuleIfUnreferenced(const ModuleId& moduleId) +{ + std::unique_lock lock{mutex}; + + const auto it = nativeModules.find(moduleId); + if (it == nativeModules.end()) + return; + + // It is possible that someone acquired a reference to the module between + // the time that we called this function and the time that we acquired the + // lock. If so, that's okay. + if (it->second->getRefcount() != 0) + return; + + nativeModules.erase(it); +} + +[[nodiscard]] NativeModuleRef SharedCodeAllocator::tryGetNativeModuleWithLockHeld(const ModuleId& moduleId) const noexcept +{ + const auto it = nativeModules.find(moduleId); + if (it == nativeModules.end()) + return NativeModuleRef{}; + + return NativeModuleRef{it->second.get()}; +} + +[[nodiscard]] size_t SharedCodeAllocator::ModuleIdHash::operator()(const ModuleId& moduleId) const noexcept +{ + return std::hash{}(std::string_view{reinterpret_cast(moduleId.data()), moduleId.size()}); +} + +} // namespace CodeGen +} // namespace Luau diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 852ac80a..3c9bc074 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -7,6 +7,9 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauCompileNoJumpLineRetarget, false) +LUAU_FASTFLAG(LuauCompileRepeatUntilSkippedLocals) + namespace Luau { @@ -969,7 +972,11 @@ void BytecodeBuilder::foldJumps() if (LUAU_INSN_OP(jumpInsn) == LOP_JUMP && LUAU_INSN_OP(targetInsn) == LOP_RETURN) { insns[jumpLabel] = targetInsn; - lines[jumpLabel] = lines[targetLabel]; + + if (!FFlag::LuauCompileNoJumpLineRetarget) + { + lines[jumpLabel] = lines[targetLabel]; + } } else if (int16_t(offset) == offset) { @@ -2171,13 +2178,23 @@ std::string BytecodeBuilder::dumpCurrentFunction(std::vector& dumpinstoffs) { const DebugLocal& l = debugLocals[i]; - LUAU_ASSERT(l.startpc < l.endpc); - LUAU_ASSERT(l.startpc < lines.size()); - LUAU_ASSERT(l.endpc <= lines.size()); // endpc is exclusive in the debug info, but it's more intuitive to print inclusive data + if (FFlag::LuauCompileRepeatUntilSkippedLocals && l.startpc == l.endpc) + { + LUAU_ASSERT(l.startpc < lines.size()); - // it would be nice to emit name as well but it requires reverse lookup through stringtable - formatAppend(result, "local %d: reg %d, start pc %d line %d, end pc %d line %d\n", int(i), l.reg, l.startpc, lines[l.startpc], - l.endpc - 1, lines[l.endpc - 1]); + // it would be nice to emit name as well but it requires reverse lookup through stringtable + formatAppend(result, "local %d: reg %d, start pc %d line %d, no live range\n", int(i), l.reg, l.startpc, lines[l.startpc]); + } + else + { + LUAU_ASSERT(l.startpc < l.endpc); + LUAU_ASSERT(l.startpc < lines.size()); + LUAU_ASSERT(l.endpc <= lines.size()); // endpc is exclusive in the debug info, but it's more intuitive to print inclusive data + + // it would be nice to emit name as well but it requires reverse lookup through stringtable + formatAppend(result, "local %d: reg %d, start pc %d line %d, end pc %d line %d\n", int(i), l.reg, l.startpc, lines[l.startpc], + l.endpc - 1, lines[l.endpc - 1]); + } } } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 6d859aa2..9d562874 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,6 +26,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) +LUAU_FASTFLAGVARIABLE(LuauCompileRepeatUntilSkippedLocals, false) + namespace Luau { @@ -2674,6 +2676,7 @@ struct Compiler RegScope rs(this); bool continueValidated = false; + size_t conditionLocals = 0; for (size_t i = 0; i < body->body.size; ++i) { @@ -2691,9 +2694,25 @@ struct Compiler { validateContinueUntil(loops.back().continueUsed, stat->condition, body, i + 1); continueValidated = true; + + if (FFlag::LuauCompileRepeatUntilSkippedLocals) + conditionLocals = localStack.size(); } } + // if continue was used, some locals might not have had their initialization completed + // the lifetime of these locals has to end before the condition is executed + // because referencing skipped locals is not possible from the condition, this earlier closure doesn't affect upvalues + if (FFlag::LuauCompileRepeatUntilSkippedLocals && continueValidated) + { + // if continueValidated is set, it means we have visited at least one body node and size > 0 + setDebugLineEnd(body->body.data[body->body.size - 1]); + + closeLocals(conditionLocals); + + popLocals(conditionLocals); + } + size_t contLabel = bytecode.emitLabel(); size_t endLabel; diff --git a/Sources.cmake b/Sources.cmake index 5db4b83b..1ae40f3d 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -86,12 +86,14 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/IrUtils.h CodeGen/include/Luau/IrVisitUseDef.h CodeGen/include/Luau/Label.h + CodeGen/include/Luau/NativeProtoExecData.h CodeGen/include/Luau/OperandX64.h CodeGen/include/Luau/OptimizeConstProp.h CodeGen/include/Luau/OptimizeDeadStore.h CodeGen/include/Luau/OptimizeFinalX64.h CodeGen/include/Luau/RegisterA64.h CodeGen/include/Luau/RegisterX64.h + CodeGen/include/Luau/SharedCodeAllocator.h CodeGen/include/Luau/UnwindBuilder.h CodeGen/include/Luau/UnwindBuilderDwarf2.h CodeGen/include/Luau/UnwindBuilderWin.h @@ -124,6 +126,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/IrUtils.cpp CodeGen/src/IrValueLocationTracking.cpp CodeGen/src/lcodegen.cpp + CodeGen/src/NativeProtoExecData.cpp CodeGen/src/NativeState.cpp CodeGen/src/OptimizeConstProp.cpp CodeGen/src/OptimizeDeadStore.cpp @@ -132,6 +135,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/UnwindBuilderWin.cpp CodeGen/src/BytecodeAnalysis.cpp CodeGen/src/BytecodeSummary.cpp + CodeGen/src/SharedCodeAllocator.cpp CodeGen/src/BitUtils.h CodeGen/src/ByteUtils.h @@ -200,6 +204,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Substitution.h Analysis/include/Luau/Subtyping.h Analysis/include/Luau/Symbol.h + Analysis/include/Luau/TableLiteralInference.h Analysis/include/Luau/ToDot.h Analysis/include/Luau/TopoSortStatements.h Analysis/include/Luau/ToString.h @@ -264,6 +269,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Substitution.cpp Analysis/src/Subtyping.cpp Analysis/src/Symbol.cpp + Analysis/src/TableLiteralInference.cpp Analysis/src/ToDot.cpp Analysis/src/TopoSortStatements.cpp Analysis/src/ToString.cpp @@ -350,15 +356,25 @@ target_sources(isocline PRIVATE extern/isocline/src/isocline.c ) + +if (TARGET Luau.Repl.CLI OR TARGET Luau.Analyze.CLI OR + TARGET Luau.Ast.CLI OR TARGET Luau.CLI.Test OR + TARGET Luau.Reduce.CLI OR TARGET Luau.Compile.CLI OR + TARGET Luau.Bytecode.CLI) + # Common sources shared between all CLI apps. + target_sources(Luau.CLI.lib PRIVATE + CLI/FileUtils.cpp + CLI/Flags.cpp + CLI/Flags.h + CLI/FileUtils.h + ) +endif() + if(TARGET Luau.Repl.CLI) # Luau.Repl.CLI Sources target_sources(Luau.Repl.CLI PRIVATE CLI/Coverage.h CLI/Coverage.cpp - CLI/FileUtils.h - CLI/FileUtils.cpp - CLI/Flags.h - CLI/Flags.cpp CLI/Profiler.h CLI/Profiler.cpp CLI/Repl.cpp @@ -369,10 +385,6 @@ endif() if(TARGET Luau.Analyze.CLI) # Luau.Analyze.CLI Sources target_sources(Luau.Analyze.CLI PRIVATE - CLI/FileUtils.h - CLI/FileUtils.cpp - CLI/Flags.h - CLI/Flags.cpp CLI/Analyze.cpp) endif() @@ -380,8 +392,6 @@ if(TARGET Luau.Ast.CLI) # Luau.Ast.CLI Sources target_sources(Luau.Ast.CLI PRIVATE CLI/Ast.cpp - CLI/FileUtils.h - CLI/FileUtils.cpp ) endif() @@ -437,6 +447,7 @@ if(TARGET Luau.UnitTest) tests/ScopedFlags.h tests/Simplify.test.cpp tests/Set.test.cpp + tests/SharedCodeAllocator.test.cpp tests/StringUtils.test.cpp tests/Subtyping.test.cpp tests/Symbol.test.cpp @@ -497,10 +508,6 @@ if(TARGET Luau.CLI.Test) target_sources(Luau.CLI.Test PRIVATE CLI/Coverage.h CLI/Coverage.cpp - CLI/FileUtils.h - CLI/FileUtils.cpp - CLI/Flags.h - CLI/Flags.cpp CLI/Profiler.h CLI/Profiler.cpp CLI/Repl.cpp @@ -523,27 +530,17 @@ if(TARGET Luau.Reduce.CLI) # Luau.Reduce.CLI Sources target_sources(Luau.Reduce.CLI PRIVATE CLI/Reduce.cpp - CLI/FileUtils.cpp - CLI/FileUtils.h ) endif() if(TARGET Luau.Compile.CLI) # Luau.Compile.CLI Sources target_sources(Luau.Compile.CLI PRIVATE - CLI/FileUtils.h - CLI/FileUtils.cpp - CLI/Flags.h - CLI/Flags.cpp CLI/Compile.cpp) endif() if(TARGET Luau.Bytecode.CLI) # Luau.Bytecode.CLI Sources target_sources(Luau.Bytecode.CLI PRIVATE - CLI/FileUtils.h - CLI/FileUtils.cpp - CLI/Flags.h - CLI/Flags.cpp CLI/Bytecode.cpp) endif() diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index d52d3794..54b0dd32 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -637,6 +637,14 @@ void luaM_getpagewalkinfo(lua_Page* page, char** start, char** end, int* busyBlo *blockSize = page->blockSize; } +void luaM_getpageinfo(lua_Page* page, int* pageBlocks, int* busyBlocks, int* blockSize, int* pageSize) +{ + *pageBlocks = (page->pageSize - offsetof(lua_Page, data)) / page->blockSize; + *busyBlocks = page->busyBlocks; + *blockSize = page->blockSize; + *pageSize = page->pageSize; +} + lua_Page* luaM_getnextgcopage(lua_Page* page) { return page->gcolistnext; diff --git a/VM/src/lmem.h b/VM/src/lmem.h index e552d739..d6508402 100644 --- a/VM/src/lmem.h +++ b/VM/src/lmem.h @@ -26,6 +26,7 @@ LUAI_FUNC void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t ns LUAI_FUNC l_noret luaM_toobig(lua_State* L); LUAI_FUNC void luaM_getpagewalkinfo(lua_Page* page, char** start, char** end, int* busyBlocks, int* blockSize); +LUAI_FUNC void luaM_getpageinfo(lua_Page* page, int* pageBlocks, int* busyBlocks, int* blockSize, int* pageSize); LUAI_FUNC lua_Page* luaM_getnextgcopage(lua_Page* page); LUAI_FUNC void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 6d4d3e47..8d57cb50 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -146,6 +146,38 @@ struct ACBuiltinsFixture : ACFixtureImpl { }; +#define LUAU_CHECK_HAS_KEY(map, key) do \ + { \ + auto&& _m = (map); \ + auto&& _k = (key); \ + const size_t count = _m.count(_k); \ + CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \ + if (!count) \ + { \ + MESSAGE("Keys: (count " << _m.size() << ")"); \ + for (const auto& [k, v]: _m) \ + { \ + MESSAGE("\tkey: " << k); \ + } \ + } \ + } while (false) + +#define LUAU_CHECK_HAS_NO_KEY(map, key) do \ + { \ + auto&& _m = (map); \ + auto&& _k = (key); \ + const size_t count = _m.count(_k); \ + CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \ + if (count) \ + { \ + MESSAGE("Keys: (count " << _m.size() << ")"); \ + for (const auto& [k, v]: _m) \ + { \ + MESSAGE("\tkey: " << k); \ + } \ + } \ + } while (false) + TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") @@ -203,7 +235,7 @@ TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); - CHECK(!ac.entryMap.count("myInnerLocal")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal"); ac = autocomplete('2'); CHECK(ac.entryMap.count("myLocal")); @@ -211,7 +243,7 @@ TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") ac = autocomplete('3'); CHECK(ac.entryMap.count("myLocal")); - CHECK(!ac.entryMap.count("myInnerLocal")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal"); } TEST_CASE_FIXTURE(ACFixture, "recursive_function") @@ -298,7 +330,7 @@ TEST_CASE_FIXTURE(ACFixture, "local_functions_fall_out_of_scope") auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); - CHECK(!ac.entryMap.count("abc")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "abc"); } TEST_CASE_FIXTURE(ACFixture, "function_parameters") @@ -325,7 +357,7 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "get_member_completions") CHECK_EQ(17, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); - CHECK(!ac.entryMap.count("math")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "math"); CHECK_EQ(ac.context, AutocompleteContext::Property); } @@ -471,7 +503,7 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") CHECK_NE(0, ac.entryMap.size()); - CHECK(!ac.entryMap.count("math")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "math"); CHECK_EQ(ac.context, AutocompleteContext::Property); } @@ -485,7 +517,7 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "method_call_inside_if_conditional") CHECK_NE(0, ac.entryMap.size()); CHECK(ac.entryMap.count("concat")); - CHECK(!ac.entryMap.count("math")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "math"); CHECK_EQ(ac.context, AutocompleteContext::Property); } @@ -1330,7 +1362,7 @@ local a: nu@3 ac = autocomplete('3'); - CHECK(!ac.entryMap.count("num")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "num"); CHECK(ac.entryMap.count("number")); } @@ -2052,7 +2084,7 @@ ex.a(function(x: auto ac = autocomplete("Module/B", Position{2, 16}); - CHECK(!ac.entryMap.count("done")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "done"); fileResolver.source["Module/C"] = R"( local ex = require(script.Parent.A) @@ -2063,7 +2095,7 @@ ex.b(function(x: ac = autocomplete("Module/C", Position{2, 16}); - CHECK(!ac.entryMap.count("(done) -> number")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "(done) -> number"); } TEST_CASE_FIXTURE(ACBuiltinsFixture, "suggest_external_module_type") @@ -2086,7 +2118,7 @@ ex.a(function(x: auto ac = autocomplete("Module/B", Position{2, 16}); - CHECK(!ac.entryMap.count("done")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "done"); CHECK(ac.entryMap.count("ex.done")); CHECK(ac.entryMap["ex.done"].typeCorrect == TypeCorrectKind::Correct); @@ -2099,7 +2131,7 @@ ex.b(function(x: ac = autocomplete("Module/C", Position{2, 16}); - CHECK(!ac.entryMap.count("(done) -> number")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "(done) -> number"); CHECK(ac.entryMap.count("(ex.done) -> number")); CHECK(ac.entryMap["(ex.done) -> number"].typeCorrect == TypeCorrectKind::Correct); } @@ -2113,7 +2145,7 @@ local bar: @1= foo auto ac = autocomplete('1'); - CHECK(!ac.entryMap.count("foo")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "foo"); } TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") @@ -2338,7 +2370,7 @@ local name = na@1 auto ac = autocomplete('1'); - CHECK(!ac.entryMap.count("name")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "name"); CHECK(ac.entryMap.count("other")); check(R"( @@ -2348,8 +2380,8 @@ local name, test = na@1 ac = autocomplete('1'); - CHECK(!ac.entryMap.count("name")); - CHECK(!ac.entryMap.count("test")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "name"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "test"); CHECK(ac.entryMap.count("other")); } @@ -2539,7 +2571,7 @@ TEST_CASE_FIXTURE(ACFixture, "not_the_var_we_are_defining") fileResolver.source["Module/A"] = "abc,de"; auto ac = autocomplete("Module/A", Position{0, 6}); - CHECK(!ac.entryMap.count("de")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "de"); } TEST_CASE_FIXTURE(ACFixture, "recursive_function_global") @@ -2597,8 +2629,8 @@ local t: Test = { s@1 } ac = autocomplete('1'); CHECK(ac.entryMap.count("second")); - CHECK(!ac.entryMap.count("first")); - CHECK(!ac.entryMap.count("third")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "first"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "third"); CHECK_EQ(ac.context, AutocompleteContext::Property); // No parenthesis suggestion @@ -2641,8 +2673,8 @@ local t: Test = { "f@1" } )"); ac = autocomplete('1'); - CHECK(!ac.entryMap.count("first")); - CHECK(!ac.entryMap.count("second")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "first"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "second"); CHECK_EQ(ac.context, AutocompleteContext::String); // Skip keys that are already defined @@ -2652,7 +2684,7 @@ local t: Test = { first = 2, s@1 } )"); ac = autocomplete('1'); - CHECK(!ac.entryMap.count("first")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "first"); CHECK(ac.entryMap.count("second")); CHECK_EQ(ac.context, AutocompleteContext::Property); @@ -3122,8 +3154,8 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") ac = autocomplete('4'); - CHECK(!ac.entryMap.count("up")); - CHECK(!ac.entryMap.count("down")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "up"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "down"); CHECK(ac.entryMap.count("\"up\"")); CHECK(ac.entryMap.count("\"down\"")); @@ -3145,8 +3177,8 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") ac = autocomplete('8'); - CHECK(!ac.entryMap.count("up")); - CHECK(!ac.entryMap.count("down")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "up"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "down"); CHECK(ac.entryMap.count("\"up\"")); CHECK(ac.entryMap.count("\"down\"")); @@ -3174,55 +3206,57 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement") local a: {[Direction]: boolean} = {[@A`@B`@C]} )"); - auto ac = autocomplete('1'); + Luau::AutocompleteResult ac; - CHECK(!ac.entryMap.count("left")); - CHECK(!ac.entryMap.count("right")); + ac = autocomplete('1'); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); ac = autocomplete('2'); CHECK(ac.entryMap.count("left")); - CHECK(!ac.entryMap.count("right")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); ac = autocomplete('3'); - CHECK(!ac.entryMap.count("left")); - CHECK(!ac.entryMap.count("right")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); ac = autocomplete('4'); - CHECK(!ac.entryMap.count("left")); - CHECK(!ac.entryMap.count("right")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); ac = autocomplete('5'); - CHECK(ac.entryMap.count("left")); + LUAU_CHECK_HAS_KEY(ac.entryMap, "left"); CHECK(ac.entryMap.count("right")); ac = autocomplete('6'); - CHECK(!ac.entryMap.count("left")); - CHECK(!ac.entryMap.count("right")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); ac = autocomplete('7'); - CHECK(!ac.entryMap.count("left")); - CHECK(!ac.entryMap.count("right")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); ac = autocomplete('8'); CHECK(ac.entryMap.count("left")); - CHECK(!ac.entryMap.count("right")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); ac = autocomplete('9'); - CHECK(!ac.entryMap.count("left")); - CHECK(!ac.entryMap.count("right")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); ac = autocomplete('A'); - CHECK(!ac.entryMap.count("left")); - CHECK(!ac.entryMap.count("right")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); ac = autocomplete('B'); @@ -3231,8 +3265,8 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement") ac = autocomplete('C'); - CHECK(!ac.entryMap.count("left")); - CHECK(!ac.entryMap.count("right")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "left"); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "right"); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index d6973aa8..e58526bd 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -22,6 +22,9 @@ LUAU_FASTINT(LuauCompileLoopUnrollThreshold) LUAU_FASTINT(LuauCompileLoopUnrollThresholdMaxBoost) LUAU_FASTINT(LuauRecursionLimit) +LUAU_FASTFLAG(LuauCompileNoJumpLineRetarget) +LUAU_FASTFLAG(LuauCompileRepeatUntilSkippedLocals) + using namespace Luau; static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1, bool enableVectors = false) @@ -2100,6 +2103,50 @@ RETURN R0 0 )"); } +TEST_CASE("LoopContinueEarlyCleanup") +{ + ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; + + // locals after a potential 'continue' are not accessible inside the condition and can be closed at the end of a block + CHECK_EQ("\n" + compileFunction(R"( +local y +repeat + local a, b + do continue end + local c, d + local function x() + return a + b + c + d + end + + c = 2 + a = 4 + + y = x +until a +)", + 1), + R"( +LOADNIL R0 +L0: LOADNIL R1 +LOADNIL R2 +JUMP L1 +LOADNIL R3 +LOADNIL R4 +NEWCLOSURE R5 P0 +CAPTURE REF R1 +CAPTURE REF R3 +LOADN R3 2 +LOADN R1 4 +MOVE R0 R5 +CLOSEUPVALS R3 +L1: JUMPIF R1 L2 +CLOSEUPVALS R1 +JUMPBACK L0 +L2: CLOSEUPVALS R1 +RETURN R0 0 +)"); +} + TEST_CASE("AndOrOptimizations") { // the OR/ORK optimization triggers for cutoff since lhs is simple @@ -2740,6 +2787,8 @@ end TEST_CASE("DebugLineInfoWhile") { + ScopedFastFlag luauCompileNoJumpLineRetarget{FFlag::LuauCompileNoJumpLineRetarget, true}; + Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -2761,7 +2810,7 @@ end 6: GETIMPORT R1 2 [print] 6: LOADK R2 K3 ['done!'] 6: CALL R1 1 0 -10: RETURN R0 0 +7: RETURN R0 0 3: L1: JUMPBACK L0 10: RETURN R0 0 )"); @@ -3084,6 +3133,75 @@ local 8: reg 3, start pc 35 line 21, end pc 35 line 21 )"); } +TEST_CASE("DebugLocals2") +{ + ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; + + const char* source = R"( +function foo(x) + repeat + local a, b + until true +end +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines | Luau::BytecodeBuilder::Dump_Locals); + bcb.setDumpSource(source); + + Luau::CompileOptions options; + options.debugLevel = 2; + + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +local 0: reg 1, start pc 2 line 6, no live range +local 1: reg 2, start pc 2 line 6, no live range +local 2: reg 0, start pc 0 line 4, end pc 2 line 6 +4: LOADNIL R1 +4: LOADNIL R2 +6: RETURN R0 0 +)"); +} + +TEST_CASE("DebugLocals3") +{ + ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; + ScopedFastFlag luauCompileNoJumpLineRetarget{FFlag::LuauCompileNoJumpLineRetarget, true}; + + const char* source = R"( +function foo(x) + repeat + local a, b + do continue end + local c, d = 2 + until true +end +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines | Luau::BytecodeBuilder::Dump_Locals); + bcb.setDumpSource(source); + + Luau::CompileOptions options; + options.debugLevel = 2; + + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +local 0: reg 3, start pc 5 line 8, no live range +local 1: reg 4, start pc 5 line 8, no live range +local 2: reg 1, start pc 2 line 5, end pc 4 line 6 +local 3: reg 2, start pc 2 line 5, end pc 4 line 6 +local 4: reg 0, start pc 0 line 4, end pc 5 line 8 +4: LOADNIL R1 +4: LOADNIL R2 +5: RETURN R0 0 +6: LOADN R3 2 +6: LOADNIL R4 +8: RETURN R0 0 +)"); +} TEST_CASE("DebugRemarks") { Luau::BytecodeBuilder bcb; @@ -4039,6 +4157,8 @@ RETURN R0 0 TEST_CASE("Coverage") { + ScopedFastFlag luauCompileNoJumpLineRetarget{FFlag::LuauCompileNoJumpLineRetarget, true}; + // basic statement coverage CHECK_EQ("\n" + compileFunction0Coverage(R"( print(1) @@ -4074,7 +4194,7 @@ end 3: GETIMPORT R0 3 [print] 3: LOADN R1 1 3: CALL R0 1 0 -7: RETURN R0 0 +3: RETURN R0 0 5: L0: COVERAGE 5: GETIMPORT R0 3 [print] 5: LOADN R1 2 @@ -4102,7 +4222,7 @@ end 4: GETIMPORT R0 3 [print] 4: LOADN R1 1 4: CALL R0 1 0 -9: RETURN R0 0 +4: RETURN R0 0 7: L0: COVERAGE 7: GETIMPORT R0 3 [print] 7: LOADN R1 2 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index baa5639b..095809d5 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -30,9 +30,11 @@ extern int optimizationLevel; void luaC_fullgc(lua_State* L); void luaC_validate(lua_State* L); +LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTFLAG(LuauLoadExceptionSafe) LUAU_DYNAMIC_FASTFLAG(LuauDebugInfoDupArgLeftovers) +LUAU_FASTFLAG(LuauCompileRepeatUntilSkippedLocals) static lua_CompileOptions defaultOptions() { @@ -642,6 +644,8 @@ TEST_CASE("Debugger") static bool singlestep = false; static int stephits = 0; + ScopedFastFlag luauCompileRepeatUntilSkippedLocals{FFlag::LuauCompileRepeatUntilSkippedLocals, true}; + SUBCASE("") { singlestep = false; @@ -788,6 +792,17 @@ TEST_CASE("Debugger") CHECK(lua_isnil(L, -1)); lua_pop(L, 1); } + else if (breakhits == 15) + { + // test lua_getlocal + const char* x = lua_getlocal(L, 2, 1); + REQUIRE(x); + CHECK(strcmp(x, "x") == 0); + lua_pop(L, 1); + + const char* a1 = lua_getlocal(L, 2, 2); + REQUIRE(!a1); + } if (interruptedthread) { @@ -797,7 +812,7 @@ TEST_CASE("Debugger") }, nullptr, &copts, /* skipCodegen */ true); // Native code doesn't support debugging yet - CHECK(breakhits == 14); // 2 hits per breakpoint + CHECK(breakhits == 16); // 2 hits per breakpoint if (singlestep) CHECK(stephits > 100); // note; this will depend on number of instructions which can vary, so we just make sure the callback gets hit often @@ -2040,6 +2055,16 @@ TEST_CASE("Native") if (!codegen || !luau_codegen_supported()) return; + SUBCASE("Checked") + { + FFlag::DebugLuauAbortingChecks.value = true; + } + + SUBCASE("Regular") + { + FFlag::DebugLuauAbortingChecks.value = false; + } + runConformance("native.lua", [](lua_State* L) { setupNativeHelpers(L); }); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 00ab9e98..5fbfb8bc 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -13,7 +13,7 @@ #include LUAU_FASTFLAG(LuauCodegenVectorTag2) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores3) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) LUAU_FASTFLAG(DebugLuauAbortingChecks) using namespace Luau::CodeGen; @@ -2541,7 +2541,7 @@ bb_0: ; useCount: 0 TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp block = build.block(IrBlockKind::Internal); IrOp followup = build.block(IrBlockKind::Internal); @@ -2582,7 +2582,7 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects1") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -2607,7 +2607,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -2942,7 +2942,7 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepImplicitUse") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); IrOp direct = build.block(IrBlockKind::Internal); @@ -3452,7 +3452,7 @@ TEST_SUITE_BEGIN("DeadStoreRemoval"); TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDoubleStore") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3498,7 +3498,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturn") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3528,7 +3528,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse1") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3558,7 +3558,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3592,7 +3592,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse3") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3622,7 +3622,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse4") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3656,7 +3656,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "PartialVsFullStoresWithRecombination") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3681,7 +3681,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "IgnoreFastcallAdjustment") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3710,7 +3710,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "JumpImplicitLiveOut") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); IrOp next = build.block(IrBlockKind::Internal); @@ -3747,7 +3747,7 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "KeepCapturedRegisterStores") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3784,27 +3784,40 @@ bb_0: )"); } -TEST_CASE_FIXTURE(IrBuilderFixture, "AbortingChecksRequireStores") +TEST_CASE_FIXTURE(IrBuilderFixture, "StoreCannotBeReplacedWithCheck") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; ScopedFastFlag debugLuauAbortingChecks{FFlag::DebugLuauAbortingChecks, true}; IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); build.beginBlock(block); - build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5)); + IrOp ptr = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); - build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(0))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2))); + build.inst(IrCmd::STORE_POINTER, build.vmReg(2), ptr); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(ttable)); - build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5)); + build.inst(IrCmd::CHECK_READONLY, ptr, fallback); - build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnil)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(2), build.inst(IrCmd::LOAD_POINTER, build.vmReg(0))); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(ttable)); - build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(6)); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tnil)); + + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + IrOp fallbackPtr = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(2), fallbackPtr); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(ttable)); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(3)); updateUseCounts(build.function); computeCfgInfo(build.function); @@ -3813,21 +3826,36 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "AbortingChecksRequireStores") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: -; in regs: R1, R4 - STORE_TAG R0, tnumber - STORE_DOUBLE R2, 0.5 - STORE_TAG R3, tnumber - STORE_DOUBLE R5, 0.5 - CHECK_TAG R0, tnumber, undef - STORE_TAG R0, tnil - RETURN R0, 6i +; successors: bb_fallback_1, bb_2 +; in regs: R0, R1 +; out regs: R0, R1, R2 + %0 = LOAD_POINTER R1 + CHECK_READONLY %0, bb_fallback_1 + STORE_TAG R2, tnil + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0, R1 +; out regs: R0, R1, R2 + %9 = LOAD_POINTER R1 + STORE_POINTER R2, %9 + STORE_TAG R2, ttable + CHECK_GC + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1, R2 + RETURN R0, 3i )"); } TEST_CASE_FIXTURE(IrBuilderFixture, "PartialOverFullValue") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; IrOp entry = build.block(IrBlockKind::Internal); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 12a87309..baafdb01 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -13,7 +13,8 @@ #include LUAU_FASTFLAG(LuauCodegenVectorTag2) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores3) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) +LUAU_FASTFLAG(LuauCodegenLoadTVTag) static std::string getCodegenAssembly(const char* source) { @@ -91,7 +92,7 @@ bb_bytecode_1: TEST_CASE("VectorComponentRead") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function compsum(a: vector) @@ -175,7 +176,7 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv") { ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) @@ -210,7 +211,7 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv2") { ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector) @@ -241,7 +242,7 @@ bb_bytecode_1: TEST_CASE("VectorMulDivMixed") { ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) @@ -283,7 +284,7 @@ bb_bytecode_1: TEST_CASE("ExtraMathMemoryOperands") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: number, b: number, c: number, d: number, e: number) @@ -321,7 +322,7 @@ bb_bytecode_1: TEST_CASE("DseInitialStackState") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo() @@ -361,7 +362,7 @@ bb_5: TEST_CASE("DseInitialStackState2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a) @@ -382,7 +383,7 @@ bb_bytecode_0: TEST_CASE("DseInitialStackState3") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores3, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a) @@ -401,4 +402,33 @@ bb_bytecode_0: )"); } +TEST_CASE("VectorConstantTag") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; + ScopedFastFlag luauCodegenLoadTVTag{FFlag::LuauCodegenLoadTVTag, true}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function vecrcp(a: vector) + return vector(1, 2, 3) + a +end +)"), +R"( +; function vecrcp($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %4 = LOAD_TVALUE K0, 0i, tvector + %11 = LOAD_TVALUE R0 + %12 = ADD_VEC %4, %11 + %13 = TAG_VECTOR %12 + STORE_TVALUE R1, %13 + INTERRUPT 2u + RETURN R1, 1i +)"); +} + TEST_SUITE_END(); diff --git a/tests/Set.test.cpp b/tests/Set.test.cpp index a3a6d05e..94de4f01 100644 --- a/tests/Set.test.cpp +++ b/tests/Set.test.cpp @@ -133,4 +133,18 @@ TEST_CASE("iterate_over_set_skips_first_element_if_it_is_erased") CHECK(1 == out.size()); } +TEST_CASE("erase_using_const_ref_argument") +{ + Luau::Set s1{{}}; + + s1.insert("x"); + s1.insert("y"); + + std::string key = "y"; + s1.erase(key); + + CHECK(s1.count("x")); + CHECK(!s1.count("y")); +} + TEST_SUITE_END(); diff --git a/tests/SharedCodeAllocator.test.cpp b/tests/SharedCodeAllocator.test.cpp new file mode 100644 index 00000000..70cc4d75 --- /dev/null +++ b/tests/SharedCodeAllocator.test.cpp @@ -0,0 +1,333 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/SharedCodeAllocator.h" + +#include "doctest.h" + +// We explicitly test correctness of self-assignment for some types +#ifdef __clang__ +#pragma GCC diagnostic ignored "-Wself-assign-overloaded" +#endif + +using namespace Luau::CodeGen; + +TEST_SUITE_BEGIN("SharedCodeAllocator"); + +TEST_CASE("NativeModuleRefRefcounting") +{ + SharedCodeAllocator allocator{}; + + REQUIRE(allocator.tryGetNativeModule(ModuleId{0x0a}).empty()); + + NativeModuleRef modRefA = allocator.getOrInsertNativeModule(ModuleId{0x0a}, {}, {}, {}); + REQUIRE(!modRefA.empty()); + + // If we attempt to get the module again, we should get the same module back: + REQUIRE(allocator.tryGetNativeModule(ModuleId{0x0a}).get() == modRefA.get()); + + // If we try to insert another instance of the module, we should get the + // existing module back: + REQUIRE(allocator.getOrInsertNativeModule(ModuleId{0x0a}, {}, {}, {}).get() == modRefA.get()); + + // If we try to look up a different module, we should not get the existing + // module back: + REQUIRE(allocator.tryGetNativeModule(ModuleId{0x0b}).empty()); + + // (Insert a second module to help with validation below) + NativeModuleRef modRefB = allocator.getOrInsertNativeModule(ModuleId{0x0b}, {}, {}, {}); + REQUIRE(!modRefB.empty()); + REQUIRE(modRefB.get() != modRefA.get()); + + // Verify NativeModuleRef refcounting: + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef non-null copy construction: + { + NativeModuleRef modRef1{modRefA}; + REQUIRE(modRef1.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null copy construction: + { + NativeModuleRef modRef1{}; + NativeModuleRef modRef2{modRef1}; + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.empty()); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef non-null move construction: + { + NativeModuleRef modRef1{modRefA}; + NativeModuleRef modRef2{std::move(modRef1)}; + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null move construction: + { + NativeModuleRef modRef1{}; + NativeModuleRef modRef2{std::move(modRef1)}; + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.empty()); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null -> non-null copy assignment: + { + NativeModuleRef modRef1{}; + modRef1 = modRefA; + REQUIRE(modRef1.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null -> null copy assignment: + { + NativeModuleRef modRef1{}; + NativeModuleRef modRef2{}; + modRef2 = modRef1; + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.empty()); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef self copy assignment: + { + NativeModuleRef modRef1{modRefA}; + modRef1 = modRef1; + REQUIRE(modRef1.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef non-null -> non-null copy assignment: + { + NativeModuleRef modRef1{modRefA}; + NativeModuleRef modRef2{modRefB}; + modRef2 = modRef1; + REQUIRE(modRef1.get() == modRefA.get()); + REQUIRE(modRef2.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 3); + REQUIRE(modRefB->getRefcount() == 1); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null -> non-null move assignment: + { + NativeModuleRef modRef1{modRefA}; + NativeModuleRef modRef2{}; + modRef2 = std::move(modRef1); + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null -> null move assignment: + { + NativeModuleRef modRef1{}; + NativeModuleRef modRef2{}; + modRef2 = std::move(modRef1); + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.empty()); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef self move assignment: + { + NativeModuleRef modRef1{modRefA}; + modRef1 = std::move(modRef1); + REQUIRE(modRef1.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef non-null -> non-null move assignment: + { + NativeModuleRef modRef1{modRefA}; + NativeModuleRef modRef2{modRefB}; + modRef2 = std::move(modRef1); + REQUIRE(modRef1.empty()); + REQUIRE(modRef2.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + REQUIRE(modRefB->getRefcount() == 1); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef null reset: + { + NativeModuleRef modRef1{}; + modRef1.reset(); + REQUIRE(modRef1.empty()); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef non-null reset: + { + NativeModuleRef modRef1{modRefA}; + modRef1.reset(); + REQUIRE(modRef1.empty()); + REQUIRE(modRefA->getRefcount() == 1); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // NativeModuleRef swap: + { + NativeModuleRef modRef1{modRefA}; + NativeModuleRef modRef2{modRefB}; + modRef1.swap(modRef2); + REQUIRE(modRef1.get() == modRefB.get()); + REQUIRE(modRef2.get() == modRefA.get()); + REQUIRE(modRefA->getRefcount() == 2); + REQUIRE(modRefB->getRefcount() == 2); + } + + REQUIRE(modRefA->getRefcount() == 1); + REQUIRE(modRefB->getRefcount() == 1); + + // If we release the last reference to a module, it should destroy the + // module: + modRefA.reset(); + REQUIRE(allocator.tryGetNativeModule(ModuleId{0x0a}).empty()); +} + +TEST_CASE("NativeProtoRefcounting") +{ + SharedCodeAllocator allocator{}; + + std::vector nativeProtos; + nativeProtos.reserve(1); + nativeProtos.push_back(NativeProto{0x01, createNativeProtoExecData(0)}); + + NativeModuleRef modRefA = allocator.getOrInsertNativeModule(ModuleId{0x0a}, std::move(nativeProtos), {}, {}); + REQUIRE(!modRefA.empty()); + REQUIRE(modRefA->getRefcount()); + + const NativeProto* proto1 = modRefA->tryGetNativeProto(0x01); + REQUIRE(proto1 != nullptr); + + // getNonOwningPointerToInstructionOffsets should not acquire ownership: + const uint32_t* unownedInstructionOffsets = proto1->getNonOwningPointerToInstructionOffsets(); + REQUIRE(unownedInstructionOffsets != nullptr); + REQUIRE(modRefA->getRefcount() == 1); + + // getOwningPointerToInstructionOffsets should acquire ownership: + const uint32_t* ownedInstructionOffsets = proto1->getOwningPointerToInstructionOffsets(); + REQUIRE(ownedInstructionOffsets == unownedInstructionOffsets); + REQUIRE(modRefA->getRefcount() == 2); + + // We should be able to call it multiple times to get multiple references: + const uint32_t* ownedInstructionOffsets2 = proto1->getOwningPointerToInstructionOffsets(); + REQUIRE(ownedInstructionOffsets2 == unownedInstructionOffsets); + REQUIRE(modRefA->getRefcount() == 3); + + // releaseOwningPointerToInstructionOffsets should be callable to release + // the reference: + NativeProto::releaseOwningPointerToInstructionOffsets(ownedInstructionOffsets2); + REQUIRE(modRefA->getRefcount() == 2); + + // If we release our NativeModuleRef, the module should be kept alive by + // the owning instruction offsets pointer: + modRefA.reset(); + + modRefA = allocator.tryGetNativeModule(ModuleId{0x0a}); + REQUIRE(!modRefA.empty()); + REQUIRE(modRefA->getRefcount() == 2); + + // If the last "release" comes via releaseOwningPointerToInstructionOffsets, + // the module should be successfully destroyed: + modRefA.reset(); + NativeProto::releaseOwningPointerToInstructionOffsets(ownedInstructionOffsets); + REQUIRE(allocator.tryGetNativeModule(ModuleId{0x0a}).empty()); +} + +TEST_CASE("NativeProtoState") +{ + SharedCodeAllocator allocator{}; + + const std::vector data(16); + const std::vector code(16); + + std::vector nativeProtos; + nativeProtos.reserve(2); + + { + NativeProtoExecDataPtr nativeExecData = createNativeProtoExecData(2); + nativeExecData[0] = 0; + nativeExecData[1] = 4; + + NativeProto proto{1, std::move(nativeExecData)}; + proto.setEntryOffset(0x00); + nativeProtos.push_back(std::move(proto)); + } + + { + NativeProtoExecDataPtr nativeExecData = createNativeProtoExecData(2); + nativeExecData[0] = 8; + nativeExecData[1] = 12; + + NativeProto proto{3, std::move(nativeExecData)}; + proto.setEntryOffset(0x08); + nativeProtos.push_back(std::move(proto)); + } + + NativeModuleRef modRefA = allocator.getOrInsertNativeModule(ModuleId{0x0a}, std::move(nativeProtos), data, code); + REQUIRE(!modRefA.empty()); + REQUIRE(modRefA->getModuleBaseAddress() != nullptr); + + const NativeProto* proto1 = modRefA->tryGetNativeProto(1); + REQUIRE(proto1 != nullptr); + REQUIRE(proto1->getBytecodeId() == 1); + REQUIRE(proto1->getEntryAddress() == modRefA->getModuleBaseAddress() + 0x00); + const uint32_t* proto1Offsets = proto1->getNonOwningPointerToInstructionOffsets(); + REQUIRE(proto1Offsets != nullptr); + REQUIRE(proto1Offsets[0] == 0); + REQUIRE(proto1Offsets[1] == 4); + + const NativeProto* proto3 = modRefA->tryGetNativeProto(3); + REQUIRE(proto3 != nullptr); + REQUIRE(proto3->getBytecodeId() == 3); + REQUIRE(proto3->getEntryAddress() == modRefA->getModuleBaseAddress() + 0x08); + const uint32_t* proto3Offsets = proto3->getNonOwningPointerToInstructionOffsets(); + REQUIRE(proto3Offsets != nullptr); + REQUIRE(proto3Offsets[0] == 8); + REQUIRE(proto3Offsets[1] == 12); + + // Ensure that non-existent native protos cannot be found: + REQUIRE(modRefA->tryGetNativeProto(0) == nullptr); + REQUIRE(modRefA->tryGetNativeProto(2) == nullptr); + REQUIRE(modRefA->tryGetNativeProto(4) == nullptr); +} diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index b9c0381f..da21d5cb 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -933,7 +933,7 @@ TEST_CASE_FIXTURE(Fixture, "tostring_error_mismatch") CheckResult result = check(R"( --!strict function f1() : {a : number, b : string, c : { d : number}} - return { a = 1, b = "a", c = {d = "a"}} + return { a = 1, b = "b", c = {d = "d"}} end )"); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index fe7ff512..f3b63538 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -8,6 +8,7 @@ #include "Luau/Type.h" #include "Luau/VisitType.h" +#include "ClassFixture.h" #include "Fixture.h" #include "doctest.h" @@ -2365,4 +2366,60 @@ TEST_CASE_FIXTURE(Fixture, "local_function_fwd_decl_doesnt_crash") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_callback_property") +{ + CheckResult result = check(R"( + function print(x: number) end + + type Point = {x: number, y: number} + local T : {callback: ((Point) -> ())?} = {} + + T.callback = function(p) -- No error here + print(p.z) -- error here. Point has no property z + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_MESSAGE(get(result.errors[0]), "Expected UnknownProperty but got " << result.errors[0]); + + Location location = result.errors[0].location; + CHECK(location.begin.line == 7); + CHECK(location.end.line == 7); +} + +TEST_CASE_FIXTURE(ClassFixture, "bidirectional_inference_of_class_methods") +{ + CheckResult result = check(R"( + local c = ChildClass.New() + + -- Instead of reporting that the lambda is the wrong type, report that we are using its argument improperly. + c.Touched:Connect(function(other) + print(other.ThisDoesNotExist) + end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownProperty* err = get(result.errors[0]); + REQUIRE(err); + + CHECK("ThisDoesNotExist" == err->key); + CHECK("BaseClass" == toString(err->table)); +} + +TEST_CASE_FIXTURE(Fixture, "pass_table_literal_to_function_expecting_optional_prop") +{ + CheckResult result = check(R"( + type T = {prop: number?} + + function f(t: T) end + + f({prop=5}) + f({}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 7f681023..f1260fcd 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -1004,6 +1004,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_should_not_add_an_indexer") +{ + CheckResult result = check(R"( + --!strict + local prices = { + hat = 1, + bat = 2, + } + print(prices.wwwww) + for _, _ in pairs(prices) do + end + print(prices.wwwww) + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_ERROR_COUNT(2, result); + else + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "lti_fuzzer_uninitialized_loop_crash") { CheckResult result = check(R"( @@ -1015,4 +1035,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "lti_fuzzer_uninitialized_loop_crash") LUAU_REQUIRE_ERROR_COUNT(3, result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_array_of_singletons") +{ + CheckResult result = check(R"( + --!strict + type Direction = "Left" | "Right" | "Up" | "Down" + local Instructions: { Direction } = { "Left", "Down" } + + for _, step in Instructions do + local dir: Direction = step + print(dir) + end + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 1350a0d1..7c770ecb 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4247,4 +4247,45 @@ TEST_CASE_FIXTURE(Fixture, "refined_thing_can_be_an_array") CHECK("({a}, a) -> a" == toString(requireType("foo"))); } +TEST_CASE_FIXTURE(Fixture, "mymovie_read_write_tables_bug") +{ + CheckResult result = check(R"( + type MockedResponseBody = string | (() -> MockedResponseBody) + type MockedResponse = { type: 'body', body: MockedResponseBody } | { type: 'error' } + + local function mockedResponseToHttpResponse(mockedResponse: MockedResponse) + assert(mockedResponse.type == 'body', 'Mocked response is not a body') + if typeof(mockedResponse.body) == 'string' then + else + return mockedResponseToHttpResponse(mockedResponse) + end + end + )"); + + // we're primarily interested in knowing that this does not crash. + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mymovie_read_write_tables_bug_2") +{ + CheckResult result = check(R"( + type MockedResponse = { type: 'body' } | { type: 'error' } + + local function mockedResponseToHttpResponse(mockedResponse: MockedResponse) + assert(mockedResponse.type == 'body', 'Mocked response is not a body') + + if typeof(mockedResponse.body) == 'string' then + elseif typeof(mockedResponse.body) == 'table' then + else + return mockedResponseToHttpResponse(mockedResponse) + end + end + )"); + + // we're primarily interested in knowing that this does not crash. + LUAU_REQUIRE_ERRORS(result); +} + + + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index ce2cfe6b..8d252ddb 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1198,48 +1198,6 @@ TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_higher_order_function") CHECK(location.end.line == 4); } -TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_callback_property") -{ - CheckResult result = check(R"( - local print: (number) -> () - - type Point = {x: number, y: number} - local T : {callback: ((Point) -> ())?} = {} - - T.callback = function(p) -- No error here - print(p.z) -- error here. Point has no property z - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_MESSAGE(get(result.errors[0]), "Expected UnknownProperty but got " << result.errors[0]); - - Location location = result.errors[0].location; - CHECK(location.begin.line == 7); - CHECK(location.end.line == 7); -} - -TEST_CASE_FIXTURE(ClassFixture, "bidirectional_inference_of_class_methods") -{ - CheckResult result = check(R"( - local c = ChildClass.New() - - -- Instead of reporting that the lambda is the wrong type, report that we are using its argument improperly. - c.Touched:Connect(function(other) - print(other.ThisDoesNotExist) - end) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownProperty* err = get(result.errors[0]); - REQUIRE(err); - - CHECK("ThisDoesNotExist" == err->key); - CHECK("BaseClass" == toString(err->table)); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict") { CheckResult result = check(R"( diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index 77b02fd1..0980703a 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -82,4 +82,21 @@ breakpoint(77) pcall(cond, nil) -- prevent inlining +local function continueLocals() + repeat + local x = tostring(game) + do continue end + local a1, a2, a3, a4, a5, a6 + until pcall( + function() + print("1") + print("2") + end, nil + ) or true +end + +breakpoint(93) + +pcall(continueLocals, nil) -- prevent inlining + return 'OK' diff --git a/tools/faillist.txt b/tools/faillist.txt index b657385d..a5349e48 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -2,8 +2,12 @@ AstQuery.last_argument_function_call_type AutocompleteTest.anonymous_autofilled_generic_on_argument_type_pack_vararg AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg AutocompleteTest.autocomplete_string_singleton_equality +AutocompleteTest.autocomplete_string_singletons AutocompleteTest.do_wrong_compatible_nonself_calls +AutocompleteTest.string_singleton_as_table_key +AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_suggestion_for_overloads +AutocompleteTest.type_correct_suggestion_in_table BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type @@ -17,7 +21,6 @@ BuiltinTests.gmatch_capture_types_default_capture BuiltinTests.gmatch_capture_types_parens_in_sets_are_ignored BuiltinTests.gmatch_capture_types_set_containing_lbracket BuiltinTests.gmatch_definition -BuiltinTests.os_time_takes_optional_date_table BuiltinTests.select_slightly_out_of_range BuiltinTests.select_way_out_of_range BuiltinTests.select_with_variadic_typepack_tail_and_string_head @@ -147,16 +150,15 @@ ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.while_body_are_also_refined RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number RefinementTest.call_an_incompatible_function_after_using_typeguard -RefinementTest.correctly_lookup_property_whose_base_was_previously_refined RefinementTest.dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never RefinementTest.discriminate_from_isa_of_x RefinementTest.discriminate_from_truthiness_of_x RefinementTest.discriminate_tag RefinementTest.discriminate_tag_with_implicit_else +RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union +RefinementTest.function_call_with_colon_after_refining_not_to_be_nil RefinementTest.globals_can_be_narrowed_too -RefinementTest.index_on_a_refined_property RefinementTest.isa_type_refinement_must_be_known_ahead_of_time -RefinementTest.narrow_property_of_a_bounded_variable RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage RefinementTest.refine_a_property_of_some_global @@ -184,13 +186,7 @@ TableTests.casting_unsealed_tables_with_props_into_table_with_indexer TableTests.checked_prop_too_early TableTests.cli_84607_missing_prop_in_array_or_dict TableTests.common_table_element_general -TableTests.common_table_element_inner_index -TableTests.common_table_element_inner_prop -TableTests.common_table_element_list -TableTests.common_table_element_union_assignment -TableTests.common_table_element_union_in_call TableTests.common_table_element_union_in_call_tail -TableTests.common_table_element_union_in_prop TableTests.confusing_indexing TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar @@ -255,7 +251,6 @@ TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors2 TableTests.table_unification_4 TableTests.table_unifies_into_map -TableTests.top_table_type TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.unification_of_unions_in_a_self_referential_type TableTests.used_colon_instead_of_dot @@ -280,7 +275,6 @@ TryUnifyTests.uninhabited_table_sub_never TypeAliases.dont_lose_track_of_PendingExpansionTypes_after_substitution TypeAliases.generic_param_remap TypeAliases.mismatched_generic_type_param -TypeAliases.mutually_recursive_aliases TypeAliases.mutually_recursive_generic_aliases TypeAliases.mutually_recursive_types_restriction_not_ok_1 TypeAliases.mutually_recursive_types_restriction_not_ok_2 @@ -302,7 +296,6 @@ TypeFamilyTests.table_internal_families TypeFamilyTests.type_families_inhabited_with_normalization TypeFamilyTests.unsolvable_family TypeInfer.be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload -TypeInfer.bidirectional_checking_of_callback_property TypeInfer.check_type_infer_recursion_count TypeInfer.checking_should_not_ice TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error @@ -343,6 +336,7 @@ TypeInferClasses.table_indexers_are_invariant TypeInferClasses.unions_of_intersections_of_classes TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class TypeInferFunctions.another_other_higher_order_function +TypeInferFunctions.bidirectional_checking_of_callback_property TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.complicated_return_types_require_an_explicit_annotation TypeInferFunctions.concrete_functions_are_not_supertypes_of_function @@ -431,7 +425,6 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted -TypeInferOOP.object_constructor_can_refer_to_method_of_self TypeInferOOP.promise_type_error_too_complex TypeInferOperators.add_type_family_works TypeInferOperators.cli_38355_recursive_union @@ -444,7 +437,6 @@ TypeInferOperators.equality_operations_succeed_if_any_union_branch_succeeds TypeInferOperators.error_on_invalid_operand_types_to_relational_operators2 TypeInferOperators.luau_polyfill_is_array TypeInferOperators.mm_comparisons_must_return_a_boolean -TypeInferOperators.refine_and_or TypeInferOperators.reworked_and TypeInferOperators.reworked_or TypeInferOperators.strict_binary_op_where_lhs_unknown