diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 4598efc4..ad10ca99 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -258,6 +258,8 @@ struct Constraint ConstraintV c; std::vector> dependencies; + + DenseHashSet getFreeTypes() const; }; using ConstraintPtr = std::unique_ptr; diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 3f2feaef..a0afeed7 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -75,6 +75,9 @@ struct ConstraintSolver // Memoized instantiations of type aliases. DenseHashMap instantiatedAliases{{}}; + // A mapping from free types to the number of unresolved constraints that mention them. + DenseHashMap unresolvedConstraints{{}}; + // Recorded errors that take place within the solver. ErrorVec errors; diff --git a/Analysis/include/Luau/TypePairHash.h b/Analysis/include/Luau/TypePairHash.h index 5cddebef..591f20f1 100644 --- a/Analysis/include/Luau/TypePairHash.h +++ b/Analysis/include/Luau/TypePairHash.h @@ -3,6 +3,7 @@ #include "Luau/TypeFwd.h" +#include #include namespace Luau diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 3a6417dc..3035d480 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Constraint.h" +#include "Luau/VisitType.h" namespace Luau { @@ -12,4 +13,40 @@ Constraint::Constraint(NotNull scope, const Location& location, Constrain { } +struct FreeTypeCollector : TypeOnceVisitor +{ + + DenseHashSet* result; + + FreeTypeCollector(DenseHashSet* result) + : result(result) + { + } + + bool visit(TypeId ty, const FreeType&) override + { + result->insert(ty); + return false; + } +}; + +DenseHashSet Constraint::getFreeTypes() const +{ + DenseHashSet types{{}}; + FreeTypeCollector ftc{&types}; + + if (auto sc = get(*this)) + { + ftc.traverse(sc->subType); + ftc.traverse(sc->superType); + } + else if (auto psc = get(*this)) + { + ftc.traverse(psc->subPack); + ftc.traverse(psc->superPack); + } + + return types; +} + } // namespace Luau diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 924958ab..3b478494 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -280,6 +280,13 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNullgetFreeTypes()) + { + // increment the reference count for `ty` + unresolvedConstraints[ty] += 1; + } + for (NotNull dep : c->dependencies) { block(dep, c); @@ -360,6 +367,10 @@ void ConstraintSolver::run() unblock(c); unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + // decrement the referenced free types for this constraint if we dispatched successfully! + for (auto ty : c->getFreeTypes()) + unresolvedConstraints[ty] -= 1; + if (logger) { logger->commitStepSnapshot(snapshot); @@ -2380,48 +2391,12 @@ void ConstraintSolver::reportError(TypeError e) errors.back().moduleName = currentModuleName; } -struct ContainsType : TypeOnceVisitor -{ - TypeId needle; - bool found = false; - - explicit ContainsType(TypeId needle) - : needle(needle) - { - } - - bool visit(TypeId) override - { - return !found; // traverse into the type iff we have yet to find the needle - } - - bool visit(TypeId ty, const FreeType&) override - { - found |= ty == needle; - return false; - } -}; - bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) { - if (!get(ty) || unsolvedConstraints.empty()) - return false; // if it's not free, it never has any unresolved constraints, maybe? + if (auto refCount = unresolvedConstraints.find(ty)) + return *refCount > 0; - ContainsType containsTy{ty}; - - for (auto constraint : unsolvedConstraints) - { - if (auto sc = get(*constraint)) - { - containsTy.traverse(sc->subType); - containsTy.traverse(sc->superType); - - if (containsTy.found) - return true; - } - } - - return containsTy.found; + return false; } TypeId ConstraintSolver::errorRecoveryType() const diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 65b04d62..632f8e5c 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,9 +1,45 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAGVARIABLE(LuauBufferDefinitions, false) + namespace Luau { +static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC( + +-- TODO: this will be replaced with a built-in primitive type +declare class buffer end + +declare buffer: { + create: (size: number) -> buffer, + fromstring: (str: string) -> buffer, + tostring: () -> string, + len: (b: buffer) -> number, + copy: (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (), + fill: (b: buffer, offset: number, value: number, count: number?) -> (), + readi8: (b: buffer, offset: number) -> number, + readu8: (b: buffer, offset: number) -> number, + readi16: (b: buffer, offset: number) -> number, + readu16: (b: buffer, offset: number) -> number, + readi32: (b: buffer, offset: number) -> number, + readu32: (b: buffer, offset: number) -> number, + readf32: (b: buffer, offset: number) -> number, + readf64: (b: buffer, offset: number) -> number, + writei8: (b: buffer, offset: number, value: number) -> (), + writeu8: (b: buffer, offset: number, value: number) -> (), + writei16: (b: buffer, offset: number, value: number) -> (), + writeu16: (b: buffer, offset: number, value: number) -> (), + writei32: (b: buffer, offset: number, value: number) -> (), + writeu32: (b: buffer, offset: number, value: number) -> (), + writef32: (b: buffer, offset: number, value: number) -> (), + writef64: (b: buffer, offset: number, value: number) -> (), + readstring: (b: buffer, offset: number, count: number) -> string, + writestring: (b: buffer, offset: number, value: string, count: number?) -> (), +} + +)BUILTIN_SRC"; + static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC( declare bit32: { @@ -21,6 +57,7 @@ declare bit32: { replace: (n: number, v: number, field: number, width: number?) -> number, countlz: (n: number) -> number, countrz: (n: number) -> number, + byteswap: (n: number) -> number, } declare math: { @@ -198,6 +235,10 @@ declare function unpack(tab: {V}, i: number?, j: number?): ...V std::string getBuiltinDefinitionSource() { std::string result = kBuiltinDefinitionLuaSrc; + + if (FFlag::LuauBufferDefinitions) + result = kBuiltinDefinitionBufferSrc + result; + return result; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 166b7525..bf8e362d 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -868,6 +868,22 @@ struct TypeChecker2 } } + std::optional getBindingType(AstExpr* expr) + { + if (auto localExpr = expr->as()) + { + Scope* s = stack.back(); + return s->lookup(localExpr->local); + } + else if (auto globalExpr = expr->as()) + { + Scope* s = stack.back(); + return s->lookup(globalExpr->name); + } + else + return std::nullopt; + } + void visit(AstStatAssign* assign) { size_t count = std::min(assign->vars.size, assign->values.size); @@ -885,7 +901,15 @@ struct TypeChecker2 if (get(lhsType)) continue; - testIsSubtype(rhsType, lhsType, rhs->location); + bool ok = testIsSubtype(rhsType, lhsType, rhs->location); + + // If rhsType bindingType = getBindingType(lhs); + if (bindingType) + testIsSubtype(rhsType, *bindingType, rhs->location); + } } } diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 7e4a591d..e3afb944 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -347,6 +347,11 @@ TypeFamilyReductionResult lenFamilyFn(const std::vector& typePar } TypeId operandTy = follow(typeParams.at(0)); + + // check to see if the operand type is resolved enough, and wait to reduce if not + if (isPending(operandTy, ctx->solver)) + return {std::nullopt, false, {operandTy}, {}}; + const NormalizedType* normTy = ctx->normalizer->normalize(operandTy); // if the type failed to normalize, we can't reduce, but know nothing about inhabitance. @@ -370,10 +375,6 @@ TypeFamilyReductionResult lenFamilyFn(const std::vector& typePar if (normTy->hasTopTable() || get(normalizedOperand)) return {ctx->builtins->numberType, false, {}, {}}; - // otherwise, we wait to see if the operand type is resolved - if (isPending(operandTy, ctx->solver)) - return {std::nullopt, false, {operandTy}, {}}; - // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -421,6 +422,11 @@ TypeFamilyReductionResult unmFamilyFn( } TypeId operandTy = follow(typeParams.at(0)); + + // check to see if the operand type is resolved enough, and wait to reduce if not + if (isPending(operandTy, ctx->solver)) + return {std::nullopt, false, {operandTy}, {}}; + const NormalizedType* normTy = ctx->normalizer->normalize(operandTy); // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. @@ -439,10 +445,6 @@ TypeFamilyReductionResult unmFamilyFn( if (normTy->isExactlyNumber()) return {ctx->builtins->numberType, false, {}, {}}; - // otherwise, check if we need to wait on the type to be further resolved - if (isPending(operandTy, ctx->solver)) - return {std::nullopt, false, {operandTy}, {}}; - // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -493,6 +495,13 @@ TypeFamilyReductionResult numericBinopFamilyFn( TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy); const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy); @@ -512,12 +521,6 @@ TypeFamilyReductionResult numericBinopFamilyFn( if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) return {ctx->builtins->numberType, false, {}, {}}; - // otherwise, check if we need to wait on either type to be further resolved - if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; - else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; - // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -653,6 +656,13 @@ TypeFamilyReductionResult concatFamilyFn(const std::vector& type TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy); const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy); @@ -672,12 +682,6 @@ TypeFamilyReductionResult concatFamilyFn(const std::vector& type if ((normLhsTy->isSubtypeOfString() || normLhsTy->isExactlyNumber()) && (normRhsTy->isSubtypeOfString() || normRhsTy->isExactlyNumber())) return {ctx->builtins->stringType, false, {}, {}}; - // otherwise, check if we need to wait on either type to be further resolved - if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; - else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; - // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -738,14 +742,11 @@ TypeFamilyReductionResult andFamilyFn(const std::vector& typePar TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - { return {std::nullopt, false, {lhsTy}, {}}; - } else if (isPending(rhsTy, ctx->solver)) - { return {std::nullopt, false, {rhsTy}, {}}; - } // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is truthy. SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->falsyType); @@ -766,14 +767,11 @@ TypeFamilyReductionResult orFamilyFn(const std::vector& typePara TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - { return {std::nullopt, false, {lhsTy}, {}}; - } else if (isPending(rhsTy, ctx->solver)) - { return {std::nullopt, false, {rhsTy}, {}}; - } // Or evalutes to the LHS type if the LHS is truthy, and the RHS type if LHS is falsy. SimplifyResult filteredLhs = simplifyIntersection(ctx->builtins, ctx->arena, lhsTy, ctx->builtins->truthyType); @@ -795,6 +793,13 @@ static TypeFamilyReductionResult comparisonFamilyFn( TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy); const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy); @@ -818,12 +823,6 @@ static TypeFamilyReductionResult comparisonFamilyFn( if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) return {ctx->builtins->booleanType, false, {}, {}}; - // otherwise, check if we need to wait on either type to be further resolved - if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; - else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; - // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -895,6 +894,13 @@ TypeFamilyReductionResult eqFamilyFn(const std::vector& typePara TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + + // check to see if both operand types are resolved enough, and wait to reduce if not + if (isPending(lhsTy, ctx->solver)) + return {std::nullopt, false, {lhsTy}, {}}; + else if (isPending(rhsTy, ctx->solver)) + return {std::nullopt, false, {rhsTy}, {}}; + const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy); const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy); @@ -910,12 +916,6 @@ TypeFamilyReductionResult eqFamilyFn(const std::vector& typePara if (is(lhsTy) || is(rhsTy)) return {ctx->builtins->booleanType, false, {}, {}}; - // otherwise, check if we need to wait on either type to be further resolved - if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; - else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; - // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 3b9e95f8..a9747143 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -9,16 +9,21 @@ #include #include +LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) +LUAU_FASTINTVARIABLE(LuauTypeLengthLimit, 1000) +LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) + // Warning: If you are introducing new syntax, ensure that it is behind a separate // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. -LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) -LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseDeclareClassIndexer, false) LUAU_FASTFLAGVARIABLE(LuauClipExtraHasEndProps, false) LUAU_FASTFLAG(LuauFloorDivision) LUAU_FASTFLAG(LuauCheckedFunctionSyntax) +LUAU_FASTFLAGVARIABLE(LuauBetterTypeUnionLimits, false) +LUAU_FASTFLAGVARIABLE(LuauBetterTypeRecLimits, false) + namespace Luau { @@ -245,13 +250,13 @@ AstStatBlock* Parser::parseBlockNoScope() while (!blockFollow(lexer.current())) { - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("block"); AstStat* stat = parseStat(); - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; if (lexer.current().type == ';') { @@ -378,13 +383,13 @@ AstStat* Parser::parseIf() { if (FFlag::LuauClipExtraHasEndProps) thenbody->hasEnd = true; - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("elseif"); elseLocation = lexer.current().location; elsebody = parseIf(); end = elsebody->location; DEPRECATED_hasEnd = elsebody->as()->DEPRECATED_hasEnd; - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; } else { @@ -625,7 +630,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug // parse funcname into a chain of indexing operators AstExpr* expr = parseNameExpr("function name"); - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; while (lexer.current().type == '.') { @@ -643,7 +648,7 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug incrementRecursionCounter("function name"); } - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; // finish with : if (lexer.current().type == ':') @@ -1526,6 +1531,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) bool isUnion = false; bool isIntersection = false; + bool hasOptional = false; Location location = begin; @@ -1535,20 +1541,34 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) if (c == '|') { nextLexeme(); + + unsigned int oldRecursionCount = recursionCounter; parts.push_back(parseSimpleType(/* allowPack= */ false).type); + if (FFlag::LuauBetterTypeUnionLimits) + recursionCounter = oldRecursionCount; + isUnion = true; } else if (c == '?') { Location loc = lexer.current().location; nextLexeme(); - parts.push_back(allocator.alloc(loc, std::nullopt, nameNil, std::nullopt, loc)); + + if (!FFlag::LuauBetterTypeUnionLimits || !hasOptional) + parts.push_back(allocator.alloc(loc, std::nullopt, nameNil, std::nullopt, loc)); + isUnion = true; + hasOptional = true; } else if (c == '&') { nextLexeme(); + + unsigned int oldRecursionCount = recursionCounter; parts.push_back(parseSimpleType(/* allowPack= */ false).type); + if (FFlag::LuauBetterTypeUnionLimits) + recursionCounter = oldRecursionCount; + isIntersection = true; } else if (c == Lexeme::Dot3) @@ -1558,6 +1578,9 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) } else break; + + if (FFlag::LuauBetterTypeUnionLimits && parts.size() > unsigned(FInt::LuauTypeLengthLimit) + hasOptional) + ParseError::raise(parts.back()->location, "Exceeded allowed type length; simplify your type annotation to make the code compile"); } if (parts.size() == 1) @@ -1584,7 +1607,10 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) AstTypeOrPack Parser::parseTypeOrPack() { unsigned int oldRecursionCount = recursionCounter; - incrementRecursionCounter("type annotation"); + + // recursion counter is incremented in parseSimpleType + if (!FFlag::LuauBetterTypeRecLimits) + incrementRecursionCounter("type annotation"); Location begin = lexer.current().location; @@ -1604,7 +1630,10 @@ AstTypeOrPack Parser::parseTypeOrPack() AstType* Parser::parseType(bool inDeclarationContext) { unsigned int oldRecursionCount = recursionCounter; - incrementRecursionCounter("type annotation"); + + // recursion counter is incremented in parseSimpleType + if (!FFlag::LuauBetterTypeRecLimits) + incrementRecursionCounter("type annotation"); Location begin = lexer.current().location; @@ -1935,7 +1964,7 @@ AstExpr* Parser::parseExpr(unsigned int limit) }; static_assert(sizeof(binaryPriority) / sizeof(binaryPriority[0]) == size_t(AstExprBinary::Op__Count), "binaryPriority needs an entry per op"); - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; // this handles recursive calls to parseSubExpr/parseExpr incrementRecursionCounter("expression"); @@ -1987,7 +2016,7 @@ AstExpr* Parser::parseExpr(unsigned int limit) incrementRecursionCounter("expression"); } - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; return expr; } @@ -2054,7 +2083,7 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) AstExpr* expr = parsePrefixExpr(); - unsigned int recursionCounterOld = recursionCounter; + unsigned int oldRecursionCount = recursionCounter; while (true) { @@ -2114,7 +2143,7 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) incrementRecursionCounter("expression"); } - recursionCounter = recursionCounterOld; + recursionCounter = oldRecursionCount; return expr; } diff --git a/CMakeLists.txt b/CMakeLists.txt index 1cda7ef4..0dbfbee1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -248,3 +248,19 @@ if(LUAU_BUILD_WEB) # the output is a single .js file with an embedded wasm blob target_link_options(Luau.Web PRIVATE -sSINGLE_FILE=1) endif() + +# validate dependencies for internal libraries +foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.CodeGen Luau.VM) + if(TARGET ${LIB}) + get_target_property(DEPENDS ${LIB} LINK_LIBRARIES) + if(LIB MATCHES "CodeGen|VM" AND DEPENDS MATCHES "Ast|Analysis|Config|Compiler") + message(FATAL_ERROR ${LIB} " is a runtime component but it depends on one of the offline components") + endif() + if(LIB MATCHES "Ast|Analysis|Compiler" AND DEPENDS MATCHES "CodeGen|VM") + message(FATAL_ERROR ${LIB} " is an offline component but it depends on one of the runtime components") + endif() + if(LIB MATCHES "Ast|Compiler" AND DEPENDS MATCHES "Analysis|Config") + message(FATAL_ERROR ${LIB} " is a compiler component but it depends on one of the analysis components") + endif() + endif() +endforeach() diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index fa615c13..6fdeac27 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -1564,11 +1564,16 @@ void AssemblyBuilderX64::log(OperandX64 op) case CategoryX64::mem: if (op.base == rip) { - logAppend("%s ptr [.start%+d]", getSizeName(op.memSize), op.imm); + if (op.memSize != SizeX64::none) + logAppend("%s ptr ", getSizeName(op.memSize)); + logAppend("[.start%+d]", op.imm); return; } - logAppend("%s ptr [", getSizeName(op.memSize)); + if (op.memSize != SizeX64::none) + logAppend("%s ptr ", getSizeName(op.memSize)); + + logAppend("["); if (op.base != noreg) logAppend("%s", getRegisterName(op.base)); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 40aee13a..56bbf904 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -85,6 +85,9 @@ static void buildArgumentTypeChecks(IrBuilder& build, Proto* proto) case LBC_TYPE_VECTOR: build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.vmExit(kVmExitEntryGuardPc)); break; + case LBC_TYPE_BUFFER: + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBUFFER), build.vmExit(kVmExitEntryGuardPc)); + break; } if (optional) diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 1d156af6..483e3e00 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -61,6 +61,8 @@ static const char* getTagName(uint8_t tag) return "tuserdata"; case LUA_TTHREAD: return "tthread"; + case LUA_TBUFFER: + return "tbuffer"; case LUA_TPROTO: return "tproto"; case LUA_TUPVAL: diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index 2a4fd93f..3b6b5def 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -472,6 +472,9 @@ static BuiltinImplResult translateBuiltinBit32Extract( if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; + if (nparams == 2 && args.kind == IrOpKind::Constant && unsigned(int(build.function.doubleOp(args))) >= 32) + return {BuiltinImplType::None, -1}; + builtinCheckDouble(build, build.vmReg(arg), pcpos); builtinCheckDouble(build, args, pcpos); @@ -486,15 +489,14 @@ static BuiltinImplResult translateBuiltinBit32Extract( if (vb.kind == IrOpKind::Constant) { int f = int(build.function.doubleOp(vb)); + LUAU_ASSERT(unsigned(f) < 32); // checked above - if (unsigned(f) >= 32) - build.inst(IrCmd::JUMP, fallback); + value = n; - // TODO: this pair can be optimized using a bit-select instruction (bt on x86) if (f) - value = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f)); + value = build.inst(IrCmd::BITRSHIFT_UINT, value, build.constInt(f)); - if ((f + 1) < 32) + if (f + 1 < 32) value = build.inst(IrCmd::BITAND_UINT, value, build.constInt(1)); } else @@ -505,7 +507,6 @@ static BuiltinImplResult translateBuiltinBit32Extract( build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block); build.beginBlock(block); - // TODO: this pair can be optimized using a bit-select instruction (bt on x86) IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f); value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1)); } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 8d9ebc8a..763a8478 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -12,7 +12,6 @@ #include "lstate.h" #include "ltm.h" -LUAU_FASTFLAGVARIABLE(LuauImproveForN2, false) LUAU_FASTFLAG(LuauReduceStackSpills) LUAU_FASTFLAGVARIABLE(LuauInlineArrConstOffset, false) LUAU_FASTFLAGVARIABLE(LuauLowerAltLoopForn, false) @@ -635,22 +634,16 @@ static IrOp getLoopStepK(IrBuilder& build, int ra) void beforeInstForNPrep(IrBuilder& build, const Instruction* pc) { - if (FFlag::LuauImproveForN2) - { - int ra = LUAU_INSN_A(*pc); + int ra = LUAU_INSN_A(*pc); - IrOp stepK = getLoopStepK(build, ra); - build.loopStepStack.push_back(stepK); - } + IrOp stepK = getLoopStepK(build, ra); + build.loopStepStack.push_back(stepK); } void afterInstForNLoop(IrBuilder& build, const Instruction* pc) { - if (FFlag::LuauImproveForN2) - { - LUAU_ASSERT(!build.loopStepStack.empty()); - build.loopStepStack.pop_back(); - } + LUAU_ASSERT(!build.loopStepStack.empty()); + build.loopStepStack.pop_back(); } void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) @@ -660,119 +653,65 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) IrOp loopStart = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); IrOp loopExit = build.blockAtInst(getJumpTarget(*pc, pcpos)); - if (FFlag::LuauImproveForN2) + LUAU_ASSERT(!build.loopStepStack.empty()); + IrOp stepK = build.loopStepStack.back(); + + // When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails + // Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant + // To avoid that overhead for an extremely rare case (that doesn't even typecheck), we exit to VM to handle it + IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); + build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + + if (stepK.kind == IrOpKind::Undef) { - LUAU_ASSERT(!build.loopStepStack.empty()); - IrOp stepK = build.loopStepStack.back(); + IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - // When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails - // Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant - // To avoid that overhead for an extremely rare case (that doesn't even typecheck), we exit to VM to handle it - IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); - build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); - build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - - IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); - IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); - - if (stepK.kind == IrOpKind::Undef) + if (FFlag::LuauLowerAltLoopForn) { - IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); - build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - if (FFlag::LuauLowerAltLoopForn) - { - IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - - build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopStart, loopExit); - } - else - { - IrOp direct = build.block(IrBlockKind::Internal); - IrOp reverse = build.block(IrBlockKind::Internal); - - IrOp zero = build.constDouble(0.0); - IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - - // step > 0 - // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 - build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse); - - // Condition to start the loop: step > 0 ? idx <= limit : limit <= idx - // We invert the condition so that loopStart is the fallthrough (false) label - - // step > 0 is false, check limit <= idx - build.beginBlock(reverse); - build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); - - // step > 0 is true, check idx <= limit - build.beginBlock(direct); - build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); - } + build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopStart, loopExit); } else { - double stepN = build.function.doubleOp(stepK); + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); + + IrOp zero = build.constDouble(0.0); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + + // step > 0 + // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse); // Condition to start the loop: step > 0 ? idx <= limit : limit <= idx // We invert the condition so that loopStart is the fallthrough (false) label - if (stepN > 0) - build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); - else - build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + + // step > 0 is false, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + + // step > 0 is true, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); } } - else if (FFlag::LuauLowerAltLoopForn) - { - // When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails - // Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant - // To avoid that overhead for an extreemely rare case (that doesn't even typecheck), we exit to VM to handle it - IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); - build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); - build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); - build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - - IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); - IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); - - build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopStart, loopExit); - } else { - IrOp direct = build.block(IrBlockKind::Internal); - IrOp reverse = build.block(IrBlockKind::Internal); + double stepN = build.function.doubleOp(stepK); - // When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails - // Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant - // To avoid that overhead for an extreemely rare case (that doesn't even typecheck), we exit to VM to handle it - IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); - build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); - build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); - build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - - IrOp zero = build.constDouble(0.0); - IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); - IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); - - // step <= 0 - build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); - - // TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation - - // step <= 0 is false, check idx <= limit - build.beginBlock(direct); - build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopStart, loopExit); - - // step <= 0 is true, check limit <= idx - build.beginBlock(reverse); - build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopStart, loopExit); + // Condition to start the loop: step > 0 ? idx <= limit : limit <= idx + // We invert the condition so that loopStart is the fallthrough (false) label + if (stepN > 0) + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + else + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); } // Fallthrough in original bytecode is implicit, so we start next internal block here @@ -782,8 +721,7 @@ void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) // VM places interrupt in FORNLOOP, but that creates a likely spill point for short loops that use loop index as INTERRUPT always spills // We place the interrupt at the beginning of the loop body instead; VM uses FORNLOOP because it doesn't want to waste an extra instruction. // Because loop block may not have been started yet (as it's started when lowering the first instruction!), we need to defer INTERRUPT placement. - if (FFlag::LuauImproveForN2) - build.interruptRequested = true; + build.interruptRequested = true; } void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) @@ -793,95 +731,59 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) IrOp loopRepeat = build.blockAtInst(getJumpTarget(*pc, pcpos)); IrOp loopExit = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); - if (FFlag::LuauImproveForN2) + // normally, the interrupt is placed at the beginning of the loop body by FORNPREP translation + // however, there are rare contrived cases where FORNLOOP ends up jumping to itself without an interrupt placed + // we detect this by checking if loopRepeat has any instructions (it should normally start with INTERRUPT) and emit a failsafe INTERRUPT if not + if (build.function.blockOp(loopRepeat).start == build.function.instructions.size()) + build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + + LUAU_ASSERT(!build.loopStepStack.empty()); + IrOp stepK = build.loopStepStack.back(); + + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = stepK.kind == IrOpKind::Undef ? build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)) : stepK; + + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + idx = build.inst(IrCmd::ADD_NUM, idx, step); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); + + if (stepK.kind == IrOpKind::Undef) { - LUAU_ASSERT(!build.loopStepStack.empty()); - IrOp stepK = build.loopStepStack.back(); - - IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); - IrOp step = stepK.kind == IrOpKind::Undef ? build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)) : stepK; - - IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); - idx = build.inst(IrCmd::ADD_NUM, idx, step); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); - - if (stepK.kind == IrOpKind::Undef) + if (FFlag::LuauLowerAltLoopForn) { - if (FFlag::LuauLowerAltLoopForn) - { - build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopRepeat, loopExit); - } - else - { - IrOp direct = build.block(IrBlockKind::Internal); - IrOp reverse = build.block(IrBlockKind::Internal); - - IrOp zero = build.constDouble(0.0); - - // step > 0 - // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 - build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse); - - // Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx - - // step > 0 is false, check limit <= idx - build.beginBlock(reverse); - build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); - - // step > 0 is true, check idx <= limit - build.beginBlock(direct); - build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); - } + build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopRepeat, loopExit); } else { - double stepN = build.function.doubleOp(stepK); + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); + + IrOp zero = build.constDouble(0.0); + + // step > 0 + // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse); // Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx - if (stepN > 0) - build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); - else - build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + + // step > 0 is false, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + + // step > 0 is true, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); } } - else if (FFlag::LuauLowerAltLoopForn) - { - build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); - - IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); - IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - - IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); - idx = build.inst(IrCmd::ADD_NUM, idx, step); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); - - build.inst(IrCmd::JUMP_FORN_LOOP_COND, idx, limit, step, loopRepeat, loopExit); - } else { - build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + double stepN = build.function.doubleOp(stepK); - IrOp zero = build.constDouble(0.0); - IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); - IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - - IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); - idx = build.inst(IrCmd::ADD_NUM, idx, step); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); - - IrOp direct = build.block(IrBlockKind::Internal); - IrOp reverse = build.block(IrBlockKind::Internal); - - // step <= 0 - build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); - - // step <= 0 is false, check idx <= limit - build.beginBlock(direct); - build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); - - // step <= 0 is true, check limit <= idx - build.beginBlock(reverse); - build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + // Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx + if (stepN > 0) + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + else + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); } // Fallthrough in original bytecode is implicit, so we start next internal block here diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index c9b2dae6..3315ec96 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -16,9 +16,8 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) LUAU_FASTFLAGVARIABLE(LuauReuseHashSlots2, false) -LUAU_FASTFLAGVARIABLE(LuauKeepVmapLinear, false) LUAU_FASTFLAGVARIABLE(LuauMergeTagLoads, false) -LUAU_FASTFLAGVARIABLE(LuauReuseArrSlots, false) +LUAU_FASTFLAGVARIABLE(LuauReuseArrSlots2, false) LUAU_FASTFLAG(LuauLowerAltLoopForn) namespace Luau @@ -505,6 +504,20 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid case LBF_GETMETATABLE: case LBF_TONUMBER: case LBF_TOSTRING: + case LBF_BIT32_BYTESWAP: + case LBF_BUFFER_READI8: + case LBF_BUFFER_READU8: + case LBF_BUFFER_WRITEU8: + case LBF_BUFFER_READI16: + case LBF_BUFFER_READU16: + case LBF_BUFFER_WRITEU16: + case LBF_BUFFER_READI32: + case LBF_BUFFER_READU32: + case LBF_BUFFER_WRITEU32: + case LBF_BUFFER_READF32: + case LBF_BUFFER_WRITEF32: + case LBF_BUFFER_READF64: + case LBF_BUFFER_WRITEF64: break; case LBF_TABLE_INSERT: state.invalidateHeap(); @@ -940,7 +953,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::LOAD_ENV: break; case IrCmd::GET_ARR_ADDR: - if (!FFlag::LuauReuseArrSlots) + if (!FFlag::LuauReuseArrSlots2) break; for (uint32_t prevIdx : state.getArrAddrCache) @@ -1013,7 +1026,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::DUP_TABLE: break; case IrCmd::TRY_NUM_TO_INDEX: - if (!FFlag::LuauReuseArrSlots) + if (!FFlag::LuauReuseArrSlots2) break; for (uint32_t prevIdx : state.tryNumToIndexCache) @@ -1052,6 +1065,13 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& { std::optional arrayIndex = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)); + // Negative offsets will jump to fallback, no need to keep the check + if (FFlag::LuauReuseArrSlots2 && arrayIndex && *arrayIndex < 0) + { + replace(function, block, index, {IrCmd::JUMP, inst.c}); + break; + } + if (RegisterInfo* info = state.tryGetRegisterInfo(inst.a); info && arrayIndex) { if (info->knownTableArraySize >= 0) @@ -1068,11 +1088,11 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& replace(function, block, index, {IrCmd::JUMP, inst.c}); } - return; // Break out from both the loop and the switch + break; } } - if (!FFlag::LuauReuseArrSlots) + if (!FFlag::LuauReuseArrSlots2) break; for (uint32_t prevIdx : state.checkArraySizeCache) @@ -1086,7 +1106,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& // If arguments are different, in case they are both constant, we can check if a larger bound was already tested if (!sameBoundary && inst.b.kind == IrOpKind::Constant && prev.b.kind == IrOpKind::Constant && - function.intOp(inst.b) < function.intOp(prev.b)) + unsigned(function.intOp(inst.b)) < unsigned(function.intOp(prev.b))) sameBoundary = true; if (sameBoundary) @@ -1268,14 +1288,11 @@ static void constPropInBlock(IrBuilder& build, IrBlock& block, ConstPropState& s constPropInInst(state, build, function, block, inst, index); } - if (!FFlag::LuauKeepVmapLinear) - { - // Value numbering and load/store propagation is not performed between blocks - state.invalidateValuePropagation(); + // Value numbering and load/store propagation is not performed between blocks + state.invalidateValuePropagation(); - // Same for table slot data propagation - state.invalidateHeapTableData(); - } + // Same for table slot data propagation + state.invalidateHeapTableData(); } static void constPropInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block, ConstPropState& state) @@ -1295,16 +1312,6 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite constPropInBlock(build, *block, state); - if (FFlag::LuauKeepVmapLinear) - { - // Value numbering and load/store propagation is not performed between blocks right now - // This is because cross-block value uses limit creation of linear block (restriction in collectDirectBlockJumpPath) - state.invalidateValuePropagation(); - - // Same for table slot data propagation - state.invalidateHeapTableData(); - } - // Blocks in a chain are guaranteed to follow each other // We force that by giving all blocks the same sorting key, but consecutive chain keys block->sortkey = startSortkey; diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index a4f1d67e..8096eec5 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -452,6 +452,7 @@ enum LuauBytecodeType LBC_TYPE_THREAD, LBC_TYPE_USERDATA, LBC_TYPE_VECTOR, + LBC_TYPE_BUFFER, LBC_TYPE_ANY = 15, LBC_TYPE_OPTIONAL_BIT = 1 << 7, @@ -560,6 +561,24 @@ enum LuauBuiltinFunction // tonumber/tostring LBF_TONUMBER, LBF_TOSTRING, + + // bit32.byteswap(n) + LBF_BIT32_BYTESWAP, + + // buffer. + LBF_BUFFER_READI8, + LBF_BUFFER_READU8, + LBF_BUFFER_WRITEU8, + LBF_BUFFER_READI16, + LBF_BUFFER_READU16, + LBF_BUFFER_WRITEU16, + LBF_BUFFER_READI32, + LBF_BUFFER_READU32, + LBF_BUFFER_WRITEU32, + LBF_BUFFER_READF32, + LBF_BUFFER_WRITEF32, + LBF_BUFFER_READF64, + LBF_BUFFER_WRITEF64, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index a15c8f08..407b76a4 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,6 +4,10 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +LUAU_FASTFLAGVARIABLE(LuauBit32ByteswapBuiltin, false) + +LUAU_FASTFLAGVARIABLE(LuauBufferBuiltins, false) + namespace Luau { namespace Compile @@ -166,6 +170,8 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op return LBF_BIT32_COUNTLZ; if (builtin.method == "countrz") return LBF_BIT32_COUNTRZ; + if (FFlag::LuauBit32ByteswapBuiltin && builtin.method == "byteswap") + return LBF_BIT32_BYTESWAP; } if (builtin.object == "string") @@ -188,6 +194,36 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op return LBF_TABLE_UNPACK; } + if (FFlag::LuauBufferBuiltins && builtin.object == "buffer") + { + if (builtin.method == "readi8") + return LBF_BUFFER_READI8; + if (builtin.method == "readu8") + return LBF_BUFFER_READU8; + if (builtin.method == "writei8" || builtin.method == "writeu8") + return LBF_BUFFER_WRITEU8; + if (builtin.method == "readi16") + return LBF_BUFFER_READI16; + if (builtin.method == "readu16") + return LBF_BUFFER_READU16; + if (builtin.method == "writei16" || builtin.method == "writeu16") + return LBF_BUFFER_WRITEU16; + if (builtin.method == "readi32") + return LBF_BUFFER_READI32; + if (builtin.method == "readu32") + return LBF_BUFFER_READU32; + if (builtin.method == "writei32" || builtin.method == "writeu32") + return LBF_BUFFER_WRITEU32; + if (builtin.method == "readf32") + return LBF_BUFFER_READF32; + if (builtin.method == "writef32") + return LBF_BUFFER_WRITEF32; + if (builtin.method == "readf64") + return LBF_BUFFER_READF64; + if (builtin.method == "writef64") + return LBF_BUFFER_WRITEF64; + } + if (options.vectorCtor) { if (options.vectorLib) @@ -402,6 +438,26 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_TOSTRING: return {1, 1}; + + case LBF_BIT32_BYTESWAP: + return {1, 1, BuiltinInfo::Flag_NoneSafe}; + + case LBF_BUFFER_READI8: + case LBF_BUFFER_READU8: + case LBF_BUFFER_READI16: + case LBF_BUFFER_READU16: + case LBF_BUFFER_READI32: + case LBF_BUFFER_READU32: + case LBF_BUFFER_READF32: + case LBF_BUFFER_READF64: + return {2, 1, BuiltinInfo::Flag_NoneSafe}; + + case LBF_BUFFER_WRITEU8: + case LBF_BUFFER_WRITEU16: + case LBF_BUFFER_WRITEU32: + case LBF_BUFFER_WRITEF32: + case LBF_BUFFER_WRITEF64: + return {3, 0, BuiltinInfo::Flag_NoneSafe}; }; LUAU_UNREACHABLE(); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 6b88162d..83fb9ce5 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -2296,6 +2296,8 @@ static const char* getBaseTypeString(uint8_t type) return "userdata"; case LBC_TYPE_VECTOR: return "vector"; + case LBC_TYPE_BUFFER: + return "buffer"; case LBC_TYPE_ANY: return "any"; } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 411a9920..e0a0cac8 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,13 +26,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileFenvNoBuiltinFold, false) -LUAU_FASTFLAGVARIABLE(LuauCompileTopCold, false) - LUAU_FASTFLAG(LuauFloorDivision) LUAU_FASTFLAGVARIABLE(LuauCompileFixContinueValidation2, false) - -LUAU_FASTFLAGVARIABLE(LuauCompileContinueCloseUpvals, false) LUAU_FASTFLAGVARIABLE(LuauCompileIfElseAndOr, false) namespace Luau @@ -267,7 +262,7 @@ struct Compiler CompileError::raise(func->location, "Exceeded function instruction limit; split the function into parts to compile"); // since top-level code only executes once, it can be marked as cold if it has no loops (top-level code with loops might be profitable to compile natively) - if (FFlag::LuauCompileTopCold && func->functionDepth == 0 && !hasLoops) + if (func->functionDepth == 0 && !hasLoops) protoflags |= LPF_NATIVE_COLD; bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size()), protoflags); @@ -2649,8 +2644,7 @@ struct Compiler // (but it must still close upvalues defined in more nested blocks) // this is because the upvalues defined inside the loop body may be captured by a closure defined in the until // expression that continue will jump to. - if (FFlag::LuauCompileContinueCloseUpvals) - loops.back().localOffsetContinue = localStack.size(); + loops.back().localOffsetContinue = localStack.size(); // if continue was called from this statement, then any local defined after this in the loop body should not be accessed by until condition // it is sufficient to check this condition once, as if this holds for the first continue, it must hold for all subsequent continues. @@ -4016,7 +4010,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c } // builtin folding is enabled on optimization level 2 since we can't deoptimize folding at runtime - if (options.optimizationLevel >= 2 && (!FFlag::LuauCompileFenvNoBuiltinFold || (!compiler.getfenvUsed && !compiler.setfenvUsed))) + if (options.optimizationLevel >= 2 && (!compiler.getfenvUsed && !compiler.setfenvUsed)) { compiler.builtinsFold = &compiler.builtins; diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index 8ac74d02..bb89cd15 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -3,6 +3,8 @@ #include "Luau/BytecodeBuilder.h" +LUAU_FASTFLAGVARIABLE(LuauCompileBufferAnnotation, false) + namespace Luau { @@ -27,6 +29,8 @@ static LuauBytecodeType getPrimitiveType(AstName name) return LBC_TYPE_STRING; else if (name == "thread") return LBC_TYPE_THREAD; + else if (FFlag::LuauCompileBufferAnnotation && name == "buffer") + return LBC_TYPE_BUFFER; else if (name == "any" || name == "unknown") return LBC_TYPE_ANY; else diff --git a/Sources.cmake b/Sources.cmake index 267e5826..2604514e 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -281,6 +281,7 @@ target_sources(Luau.VM PRIVATE VM/src/lbaselib.cpp VM/src/lbitlib.cpp VM/src/lbuffer.cpp + VM/src/lbuflib.cpp VM/src/lbuiltins.cpp VM/src/lcorolib.cpp VM/src/ldblib.cpp diff --git a/VM/include/lualib.h b/VM/include/lualib.h index dc8a01c7..367a0281 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -124,6 +124,9 @@ LUALIB_API int luaopen_string(lua_State* L); #define LUA_BITLIBNAME "bit32" LUALIB_API int luaopen_bit32(lua_State* L); +#define LUA_BUFFERLIBNAME "buffer" +LUALIB_API int luaopen_buffer(lua_State* L); + #define LUA_UTF8LIBNAME "utf8" LUALIB_API int luaopen_utf8(lua_State* L); diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 47445b80..627d599e 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -5,6 +5,8 @@ #include "lcommon.h" #include "lnumutils.h" +LUAU_FASTFLAGVARIABLE(LuauBit32Byteswap, false) + #define ALLONES ~0u #define NBITS int(8 * sizeof(unsigned)) @@ -210,6 +212,18 @@ static int b_countrz(lua_State* L) return 1; } +static int b_swap(lua_State* L) +{ + if (!FFlag::LuauBit32Byteswap) + luaL_error(L, "bit32.byteswap isn't enabled"); + + b_uint n = luaL_checkunsigned(L, 1); + n = (n << 24) | ((n << 8) & 0xff0000) | ((n >> 8) & 0xff00) | (n >> 24); + + lua_pushunsigned(L, n); + return 1; +} + static const luaL_Reg bitlib[] = { {"arshift", b_arshift}, {"band", b_and}, @@ -225,6 +239,7 @@ static const luaL_Reg bitlib[] = { {"rshift", b_rshift}, {"countlz", b_countlz}, {"countrz", b_countrz}, + {"byteswap", b_swap}, {NULL, NULL}, }; diff --git a/VM/src/lbuflib.cpp b/VM/src/lbuflib.cpp new file mode 100644 index 00000000..51ed5dac --- /dev/null +++ b/VM/src/lbuflib.cpp @@ -0,0 +1,286 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lualib.h" + +#include "lcommon.h" +#include "lbuffer.h" + +#if defined(LUAU_BIG_ENDIAN) +#include +#endif + +#include + +// while C API returns 'size_t' for binary compatibility in case of future extensions, +// in the current implementation, length and offset are limited to 31 bits +// because offset is limited to an integer, a single 64bit comparison can be used and will not overflow +#define isoutofbounds(offset, len, accessize) (uint64_t(unsigned(offset)) + (accessize) > uint64_t(len)) + +static_assert(MAX_BUFFER_SIZE <= INT_MAX, "current implementation can't handle a larger limit"); + +#if defined(LUAU_BIG_ENDIAN) +template +inline T buffer_swapbe(T v) +{ + if (sizeof(T) == 8) + return htole64(v); + else if (sizeof(T) == 4) + return htole32(v); + else if (sizeof(T) == 2) + return htole16(v); + else + return v; +} +#endif + +static int buffer_create(lua_State* L) +{ + int size = luaL_checkinteger(L, 1); + + if (size < 0) + luaL_error(L, "size cannot be negative"); + + lua_newbuffer(L, size); + return 1; +} + +static int buffer_fromstring(lua_State* L) +{ + size_t len = 0; + const char* val = luaL_checklstring(L, 1, &len); + + void* data = lua_newbuffer(L, len); + memcpy(data, val, len); + return 1; +} + +static int buffer_tostring(lua_State* L) +{ + size_t len = 0; + void* data = luaL_checkbuffer(L, 1, &len); + + lua_pushlstring(L, (char*)data, len); + return 1; +} + +template +static int buffer_readinteger(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int offset = luaL_checkinteger(L, 2); + + if (isoutofbounds(offset, len, sizeof(T))) + luaL_error(L, "buffer access out of bounds"); + + T val; + memcpy(&val, (char*)buf + offset, sizeof(T)); + +#if defined(LUAU_BIG_ENDIAN) + val = buffer_swapbe(val); +#endif + + lua_pushnumber(L, double(val)); + return 1; +} + +template +static int buffer_writeinteger(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int offset = luaL_checkinteger(L, 2); + int value = luaL_checkunsigned(L, 3); + + if (isoutofbounds(offset, len, sizeof(T))) + luaL_error(L, "buffer access out of bounds"); + + T val = T(value); + +#if defined(LUAU_BIG_ENDIAN) + val = buffer_swapbe(val); +#endif + + memcpy((char*)buf + offset, &val, sizeof(T)); + return 0; +} + +template +static int buffer_readfp(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int offset = luaL_checkinteger(L, 2); + + if (isoutofbounds(offset, len, sizeof(T))) + luaL_error(L, "buffer access out of bounds"); + + T val; + +#if defined(LUAU_BIG_ENDIAN) + static_assert(sizeof(T) == sizeof(StorageType), "type size must match to reinterpret data"); + StorageType tmp; + memcpy(&tmp, (char*)buf + offset, sizeof(tmp)); + tmp = buffer_swapbe(tmp); + + memcpy(&val, &tmp, sizeof(tmp)); +#else + memcpy(&val, (char*)buf + offset, sizeof(T)); +#endif + + lua_pushnumber(L, double(val)); + return 1; +} + +template +static int buffer_writefp(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int offset = luaL_checkinteger(L, 2); + double value = luaL_checknumber(L, 3); + + if (isoutofbounds(offset, len, sizeof(T))) + luaL_error(L, "buffer access out of bounds"); + + T val = T(value); + +#if defined(LUAU_BIG_ENDIAN) + static_assert(sizeof(T) == sizeof(StorageType), "type size must match to reinterpret data"); + StorageType tmp; + memcpy(&tmp, &val, sizeof(tmp)); + tmp = buffer_swapbe(tmp); + + memcpy((char*)buf + offset, &tmp, sizeof(tmp)); +#else + memcpy((char*)buf + offset, &val, sizeof(T)); +#endif + + return 0; +} + +static int buffer_readstring(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int offset = luaL_checkinteger(L, 2); + int size = luaL_checkinteger(L, 3); + + if (size < 0) + luaL_error(L, "size cannot be negative"); + + if (isoutofbounds(offset, len, unsigned(size))) + luaL_error(L, "buffer access out of bounds"); + + lua_pushlstring(L, (char*)buf + offset, size); + return 1; +} + +static int buffer_writestring(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int offset = luaL_checkinteger(L, 2); + size_t size = 0; + const char* val = luaL_checklstring(L, 3, &size); + int count = luaL_optinteger(L, 4, int(size)); + + if (count < 0) + luaL_error(L, "count cannot be negative"); + + if (size_t(count) > size) + luaL_error(L, "string length overflow"); + + // string size can't exceed INT_MAX at this point + if (isoutofbounds(offset, len, unsigned(count))) + luaL_error(L, "buffer access out of bounds"); + + memcpy((char*)buf + offset, val, count); + return 0; +} + +static int buffer_len(lua_State* L) +{ + size_t len = 0; + luaL_checkbuffer(L, 1, &len); + + lua_pushnumber(L, double(unsigned(len))); + return 1; +} + +static int buffer_copy(lua_State* L) +{ + size_t tlen = 0; + void* tbuf = luaL_checkbuffer(L, 1, &tlen); + int toffset = luaL_checkinteger(L, 2); + + size_t slen = 0; + void* sbuf = luaL_checkbuffer(L, 3, &slen); + int soffset = luaL_optinteger(L, 4, 0); + + int size = luaL_optinteger(L, 5, int(slen) - soffset); + + if (size < 0) + luaL_error(L, "buffer access out of bounds"); + + if (isoutofbounds(soffset, slen, unsigned(size))) + luaL_error(L, "buffer access out of bounds"); + + if (isoutofbounds(toffset, tlen, unsigned(size))) + luaL_error(L, "buffer access out of bounds"); + + memmove((char*)tbuf + toffset, (char*)sbuf + soffset, size); + return 0; +} + +static int buffer_fill(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int offset = luaL_checkinteger(L, 2); + unsigned value = luaL_checkunsigned(L, 3); + int size = luaL_optinteger(L, 4, int(len) - offset); + + if (size < 0) + luaL_error(L, "buffer access out of bounds"); + + if (isoutofbounds(offset, len, unsigned(size))) + luaL_error(L, "buffer access out of bounds"); + + memset((char*)buf + offset, value & 0xff, size); + return 0; +} + +static const luaL_Reg bufferlib[] = { + {"create", buffer_create}, + {"fromstring", buffer_fromstring}, + {"tostring", buffer_tostring}, + {"readi8", buffer_readinteger}, + {"readu8", buffer_readinteger}, + {"readi16", buffer_readinteger}, + {"readu16", buffer_readinteger}, + {"readi32", buffer_readinteger}, + {"readu32", buffer_readinteger}, + {"readf32", buffer_readfp}, + {"readf64", buffer_readfp}, + {"writei8", buffer_writeinteger}, + {"writeu8", buffer_writeinteger}, + {"writei16", buffer_writeinteger}, + {"writeu16", buffer_writeinteger}, + {"writei32", buffer_writeinteger}, + {"writeu32", buffer_writeinteger}, + {"writef32", buffer_writefp}, + {"writef64", buffer_writefp}, + {"readstring", buffer_readstring}, + {"writestring", buffer_writestring}, + {"len", buffer_len}, + {"copy", buffer_copy}, + {"fill", buffer_fill}, + {NULL, NULL}, +}; + +int luaopen_buffer(lua_State* L) +{ + luaL_register(L, LUA_BUFFERLIBNAME, bufferlib); + + return 1; +} diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index a916f73a..04852e87 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -8,8 +8,10 @@ #include "lgc.h" #include "lnumutils.h" #include "ldo.h" +#include "lbuffer.h" #include +#include #ifdef _MSC_VER #include @@ -1319,6 +1321,111 @@ static int luauF_tostring(lua_State* L, StkId res, TValue* arg0, int nresults, S return -1; } +static int luauF_byteswap(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + unsigned n; + luai_num2unsigned(n, a1); + + n = (n << 24) | ((n << 8) & 0xff0000) | ((n >> 8) & 0xff00) | (n >> 24); + + setnvalue(res, double(n)); + return 1; + } + + return -1; +} + +// because offset is limited to an integer, a single 64bit comparison can be used and will not overflow +#define checkoutofbounds(offset, len, accessize) (uint64_t(unsigned(offset)) + (accessize - 1) >= uint64_t(len)) + +template +static int luauF_readinteger(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ +#if !defined(LUAU_BIG_ENDIAN) + if (nparams >= 2 && nresults <= 1 && ttisbuffer(arg0) && ttisnumber(args)) + { + int offset; + luai_num2int(offset, nvalue(args)); + if (checkoutofbounds(offset, bufvalue(arg0)->len, sizeof(T))) + return -1; + + T val; + memcpy(&val, (char*)bufvalue(arg0)->data + offset, sizeof(T)); + setnvalue(res, double(val)); + return 1; + } +#endif + + return -1; +} + +template +static int luauF_writeinteger(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ +#if !defined(LUAU_BIG_ENDIAN) + if (nparams >= 3 && nresults <= 0 && ttisbuffer(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + int offset; + luai_num2int(offset, nvalue(args)); + if (checkoutofbounds(offset, bufvalue(arg0)->len, sizeof(T))) + return -1; + + unsigned value; + luai_num2unsigned(value, nvalue(args + 1)); + + T val = T(value); + memcpy((char*)bufvalue(arg0)->data + offset, &val, sizeof(T)); + return 0; + } +#endif + + return -1; +} + +template +static int luauF_readfp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ +#if !defined(LUAU_BIG_ENDIAN) + if (nparams >= 2 && nresults <= 1 && ttisbuffer(arg0) && ttisnumber(args)) + { + int offset; + luai_num2int(offset, nvalue(args)); + if (checkoutofbounds(offset, bufvalue(arg0)->len, sizeof(T))) + return -1; + + T val; + memcpy(&val, (char*)bufvalue(arg0)->data + offset, sizeof(T)); + setnvalue(res, double(val)); + return 1; + } +#endif + + return -1; +} + +template +static int luauF_writefp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ +#if !defined(LUAU_BIG_ENDIAN) + if (nparams >= 3 && nresults <= 0 && ttisbuffer(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + int offset; + luai_num2int(offset, nvalue(args)); + if (checkoutofbounds(offset, bufvalue(arg0)->len, sizeof(T))) + return -1; + + T val = T(nvalue(args + 1)); + memcpy((char*)bufvalue(arg0)->data + offset, &val, sizeof(T)); + return 0; + } +#endif + + return -1; +} + static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { return -1; @@ -1486,6 +1593,22 @@ const luau_FastFunction luauF_table[256] = { luauF_tonumber, luauF_tostring, + luauF_byteswap, + + luauF_readinteger, + luauF_readinteger, + luauF_writeinteger, + luauF_readinteger, + luauF_readinteger, + luauF_writeinteger, + luauF_readinteger, + luauF_readinteger, + luauF_writeinteger, + luauF_readfp, + luauF_writefp, + luauF_readfp, + luauF_writefp, + // When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback. // This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing. // Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest. diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index fd95f596..b7491952 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -14,6 +14,7 @@ static const luaL_Reg lualibs[] = { {LUA_DBLIBNAME, luaopen_debug}, {LUA_UTF8LIBNAME, luaopen_utf8}, {LUA_BITLIBNAME, luaopen_bit32}, + {LUA_BUFFERLIBNAME, luaopen_buffer}, {NULL, NULL}, }; diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index 4eab2893..75bcf398 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -3,13 +3,13 @@ static const std::string kNames[] = { "_G", - "_LOADED", "_VERSION", "__add", "__call", "__concat", "__div", "__eq", + "__idiv", "__index", "__iter", "__le", @@ -37,18 +37,22 @@ static const std::string kNames[] = { "boolean", "bor", "btest", + "buffer", "bxor", "byte", "ceil", "char", "charpattern", "clamp", + "clear", "clock", "clone", "close", "codepoint", "codes", + "collectgarbage", "concat", + "copy", "coroutine", "cos", "cosh", @@ -62,18 +66,19 @@ static const std::string kNames[] = { "error", "exp", "extract", + "fill", "find", "floor", "fmod", "foreach", "foreachi", "format", - "frexp", "freeze", + "frexp", + "fromstring", "function", "gcinfo", "getfenv", - "getinfo", "getmetatable", "getn", "gmatch", @@ -118,13 +123,24 @@ static const std::string kNames[] = { "randomseed", "rawequal", "rawget", + "rawlen", "rawset", + "readf32", + "readf64", + "readi16", + "readi32", + "readi8", + "readstring", + "readu16", + "readu32", + "readu8", "remove", "rep", "replace", "require", "resume", "reverse", + "round", "rrotate", "rshift", "running", @@ -138,7 +154,6 @@ static const std::string kNames[] = { "split", "sqrt", "status", - "stdin", "string", "sub", "table", @@ -148,6 +163,7 @@ static const std::string kNames[] = { "time", "tonumber", "tostring", + "tostring", "traceback", "type", "typeof", @@ -157,17 +173,28 @@ static const std::string kNames[] = { "utf8", "vector", "wrap", + "writef32", + "writef64", + "writei16", + "writei32", + "writei8", + "writestring", + "writeu16", + "writeu32", + "writeu8", "xpcall", "yield", }; static const std::string kTypes[] = { "any", + "boolean", + "buffer", "nil", "number", "string", - "boolean", "thread", + "vector", }; static const std::string kClasses[] = { diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 63e65dd4..ccf1ca17 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -588,6 +588,7 @@ TEST_CASE("LogTest") build.cmp(rsi, rdi); build.jcc(ConditionX64::Equal, start); build.lea(rcx, start); + build.lea(rcx, addr[rdx]); build.jmp(qword[rdx]); build.vaddps(ymm9, ymm12, ymmword[rbp + 0xc]); @@ -634,6 +635,7 @@ TEST_CASE("LogTest") cmp rsi,rdi je .L1 lea rcx,.L1 + lea rcx,[rdx] jmp qword ptr [rdx] vaddps ymm9,ymm12,ymmword ptr [rbp+0Ch] vaddpd ymm2,ymm7,qword ptr [.start-8] diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index e2c779e9..bc9a12ea 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -89,7 +89,7 @@ TEST_CASE("BytecodeIsStable") // Bytecode type encoding (serialized & in-memory) // Note: these *can* change retroactively *if* type version is bumped, but probably shouldn't - LUAU_ASSERT(LBC_TYPE_VECTOR == 8); // type version 1 + LUAU_ASSERT(LBC_TYPE_BUFFER == 9); // type version 1 } TEST_CASE("CompileToBytecode") @@ -1772,8 +1772,6 @@ RETURN R0 0 TEST_CASE("LoopContinueUntil") { - ScopedFastFlag sff("LuauCompileContinueCloseUpvals", true); - // it's valid to use locals defined inside the loop in until expression if they're defined before continue CHECK_EQ("\n" + compileFunction0("repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until r < 0.5"), R"( L0: GETIMPORT R0 2 [math.random] @@ -2026,8 +2024,6 @@ end TEST_CASE("LoopContinueUntilCapture") { - ScopedFastFlag sff("LuauCompileContinueCloseUpvals", true); - // validate continue upvalue closing behavior: continue must close locals defined in the nested scopes // but can't close locals defined in the loop scope - these are visible to the condition and will be closed // when evaluating the condition instead. @@ -7586,8 +7582,6 @@ RETURN R0 1 TEST_CASE("NoBuiltinFoldFenv") { - ScopedFastFlag sff("LuauCompileFenvNoBuiltinFold", true); - // builtin folding is disabled when getfenv/setfenv is used in the module CHECK_EQ("\n" + compileFunction(R"( getfenv() diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 61ea2416..b7f77711 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -289,11 +289,15 @@ TEST_CASE("Assert") TEST_CASE("Basic") { ScopedFastFlag sffs{"LuauFloorDivision", true}; - ScopedFastFlag sfff{"LuauImproveForN2", true}; runConformance("basic.lua"); } +TEST_CASE("Buffers") +{ + runConformance("buffers.lua"); +} + TEST_CASE("Math") { runConformance("math.lua"); @@ -381,8 +385,6 @@ TEST_CASE("Events") TEST_CASE("Constructs") { - ScopedFastFlag sff("LuauCompileContinueCloseUpvals", true); - runConformance("constructs.lua"); } @@ -408,6 +410,7 @@ TEST_CASE("GC") TEST_CASE("Bitwise") { + ScopedFastFlag sffs{"LuauBit32Byteswap", true}; runConformance("bitwise.lua"); } @@ -557,6 +560,8 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { + ScopedFastFlag luauBufferDefinitions{"LuauBufferDefinitions", true}; + runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::NullFileResolver fileResolver; @@ -1506,68 +1511,81 @@ TEST_CASE("Interrupt") lua_CompileOptions copts = defaultOptions(); copts.optimizationLevel = 1; // disable loop unrolling to get fixed expected hit results - static const int expectedhits[] = { - 2, - 9, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 6, - 18, - 13, - 13, - 13, - 13, - 16, - 23, - 21, - 25, - }; static int index; - index = 0; + StateRef globalState = runConformance("interrupt.lua", nullptr, nullptr, nullptr, &copts); - runConformance( - "interrupt.lua", - [](lua_State* L) { - auto* cb = lua_callbacks(L); + lua_State* L = globalState.get(); - // note: for simplicity here we setup the interrupt callback once - // however, this carries a noticeable performance cost. in a real application, - // it's advised to set interrupt callback on a timer from a different thread, - // and set it back to nullptr once the interrupt triggered. - cb->interrupt = [](lua_State* L, int gc) { - if (gc >= 0) - return; + // note: for simplicity here we setup the interrupt callback when the test starts + // however, this carries a noticeable performance cost. in a real application, + // it's advised to set interrupt callback on a timer from a different thread, + // and set it back to nullptr once the interrupt triggered. - CHECK(index < int(std::size(expectedhits))); + // define the interrupt to check the expected hits + static const int expectedhits[] = {11, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 20, 15, 15, 15, 15, 18, 25, 23, 26}; - lua_Debug ar = {}; - lua_getinfo(L, 0, "l", &ar); + lua_callbacks(L)->interrupt = [](lua_State* L, int gc) { + if (gc >= 0) + return; - CHECK(ar.currentline == expectedhits[index]); + CHECK(index < int(std::size(expectedhits))); - index++; + lua_Debug ar = {}; + lua_getinfo(L, 0, "l", &ar); - // check that we can yield inside an interrupt - if (index == 5) - lua_yield(L, 0); - }; - }, - [](lua_State* L) { - CHECK(index == 5); // a single yield point - }, - nullptr, &copts); + CHECK(ar.currentline == expectedhits[index]); - CHECK(index == int(std::size(expectedhits))); + index++; + + // check that we can yield inside an interrupt + if (index == 4) + lua_yield(L, 0); + }; + + { + lua_State* T = lua_newthread(L); + + lua_getglobal(T, "test"); + + index = 0; + int status = lua_resume(T, nullptr, 0); + CHECK(status == LUA_YIELD); + CHECK(index == 4); + + status = lua_resume(T, nullptr, 0); + CHECK(status == LUA_OK); + CHECK(index == int(std::size(expectedhits))); + + lua_pop(L, 1); + } + + // redefine the interrupt to break after 10 iterations of a loop that would otherwise be infinite + // the test exposes a few global functions that we will call; the interrupt will force a yield + lua_callbacks(L)->interrupt = [](lua_State* L, int gc) { + if (gc >= 0) + return; + + CHECK(index < 10); + if (++index == 10) + lua_yield(L, 0); + }; + + for (int test = 1; test <= 9; ++test) + { + lua_State* T = lua_newthread(L); + + std::string name = "infloop" + std::to_string(test); + lua_getglobal(T, name.c_str()); + + index = 0; + int status = lua_resume(T, nullptr, 0); + CHECK(status == LUA_YIELD); + CHECK(index == 10); + + // abandon the thread + lua_pop(L, 1); + } } TEST_CASE("UserdataApi") @@ -1889,6 +1907,8 @@ TEST_CASE("NativeTypeAnnotations") if (!codegen || !luau_codegen_supported()) return; + ScopedFastFlag luauCompileBufferAnnotation{"LuauCompileBufferAnnotation", true}; + lua_CompileOptions copts = defaultOptions(); copts.vectorCtor = "vector"; copts.vectorType = "vector"; diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index c70f6933..4388c400 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -2062,7 +2062,7 @@ bb_fallback_1: TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateArrayElemChecksSameIndex") { - ScopedFastFlag luauReuseHashSlots{"LuauReuseArrSlots", true}; + ScopedFastFlag luauReuseHashSlots{"LuauReuseArrSlots2", true}; IrOp block = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); @@ -2117,9 +2117,9 @@ bb_fallback_1: )"); } -TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateArrayElemChecksLowerIndex") +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateArrayElemChecksSameValue") { - ScopedFastFlag luauReuseHashSlots{"LuauReuseArrSlots", true}; + ScopedFastFlag luauReuseHashSlots{"LuauReuseArrSlots2", true}; IrOp block = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); @@ -2182,9 +2182,9 @@ bb_fallback_1: )"); } -TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateArrayElemChecksSameValue") +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateArrayElemChecksLowerIndex") { - ScopedFastFlag luauReuseHashSlots{"LuauReuseArrSlots", true}; + ScopedFastFlag luauReuseHashSlots{"LuauReuseArrSlots2", true}; IrOp block = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); @@ -2240,7 +2240,7 @@ bb_fallback_1: TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateArrayElemChecksInvalidations") { - ScopedFastFlag luauReuseHashSlots{"LuauReuseArrSlots", true}; + ScopedFastFlag luauReuseHashSlots{"LuauReuseArrSlots2", true}; IrOp block = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); @@ -2298,6 +2298,55 @@ bb_fallback_1: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "ArrayElemChecksNegativeIndex") +{ + ScopedFastFlag luauReuseHashSlots{"LuauReuseArrSlots2", true}; + + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + // This roughly corresponds to 'return t[1] + t[0]' + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, build.constInt(0), fallback); + IrOp elem1 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(0)); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, elem1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + build.inst(IrCmd::CHECK_ARRAY_SIZE, table1, build.constInt(-1), fallback); // This will jump directly to fallback + IrOp elem2 = build.inst(IrCmd::GET_ARR_ADDR, table1, build.constInt(-1)); + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, elem2, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(4), value1b); + + IrOp a = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp b = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(4)); + IrOp sum = build.inst(IrCmd::ADD_NUM, a, b); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), sum); + + build.inst(IrCmd::RETURN, build.vmReg(2), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_POINTER R1 + CHECK_ARRAY_SIZE %0, 0i, bb_fallback_1 + %2 = GET_ARR_ADDR %0, 0i + %3 = LOAD_TVALUE %2, 0i + STORE_TVALUE R3, %3 + JUMP bb_fallback_1 + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Analysis"); diff --git a/tests/IrCallWrapperX64.test.cpp b/tests/IrCallWrapperX64.test.cpp index 1ff22a32..8336a634 100644 --- a/tests/IrCallWrapperX64.test.cpp +++ b/tests/IrCallWrapperX64.test.cpp @@ -483,9 +483,9 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ExtraCoverage") checkMatch(R"( vmovups xmm2,xmmword ptr [r13] mov rax,rcx - lea rcx,none ptr [r12+8] + lea rcx,[r12+8] mov rbx,rdx - lea rdx,none ptr [r12+010h] + lea rdx,[r12+010h] call qword ptr [rax+rbx] )"); } @@ -500,7 +500,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "AddressInStackArguments") callWrap.call(qword[r14]); checkMatch(R"( - lea rax,none ptr [r12+010h] + lea rax,[r12+010h] mov qword ptr [rsp+020h],rax mov ecx,1 mov edx,2 diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index a900b7ab..0ee135a1 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1317,16 +1317,53 @@ end TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_type_group") { - ScopedFastInt sfis{"LuauRecursionLimit", 20}; + ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError( - "function f(): (((((((((Fail))))))))) end", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + "function f(): ((((((((((Fail)))))))))) end", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); matchParseError("function f(): () -> () -> () -> () -> () -> () -> () -> () -> () -> () -> () end", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); - matchParseError( - "local t: {a: {b: {c: {d: {e: {f: {}}}}}}}", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + matchParseError("local t: {a: {b: {c: {d: {e: {f: {g: {h: {i: {j: {}}}}}}}}}}}", + "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + + matchParseError("local f: ((((((((((Fail))))))))))", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + + matchParseError("local t: a & (b & (c & (d & (e & (f & (g & (h & (i & (j & nil)))))))))", + "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); +} + +TEST_CASE_FIXTURE(Fixture, "can_parse_complex_unions_successfully") +{ + ScopedFastInt sfis[] = {{"LuauRecursionLimit", 10}, {"LuauTypeLengthLimit", 10}}; + ScopedFastFlag sff{"LuauBetterTypeUnionLimits", true}; + + parse(R"( +local f: +() -> () +| +() -> () +| +{a: number} +| +{b: number} +| +((number)) +| +((number)) +| +(a & (b & nil)) +| +(a & (b & nil)) +)"); + + parse(R"( +local f: a? | b? | c? | d? | e? | f? | g? | h? +)"); + + matchParseError("local t: a & b & c & d & e & f & g & h & i & j & nil", + "Exceeded allowed type length; simplify your type annotation to make the code compile"); } TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index f53fc5d4..3f6d90fa 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -185,6 +185,7 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") LUAU_REQUIRE_NO_ERRORS(result); } +#if 0 TEST_CASE_FIXTURE(Fixture, "generic_aliases") { ScopedFastFlag sff[] = { @@ -224,6 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") CHECK(result.errors[0].location == Location{{4, 31}, {4, 52}}); CHECK_EQ(expected, toString(result.errors[0])); } +#endif TEST_CASE_FIXTURE(Fixture, "mutually_recursive_generic_aliases") { diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index f71b1bf0..03534413 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -13,23 +13,99 @@ using namespace Luau; TEST_SUITE_BEGIN("AnnotationTests"); -TEST_CASE_FIXTURE(Fixture, "check_against_annotations") +TEST_CASE_FIXTURE(Fixture, "initializers_are_checked_against_annotations") { CheckResult result = check("local a: number = \"Hello Types!\""); LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "check_multi_assign") +TEST_CASE_FIXTURE(Fixture, "check_multi_initialize") { - CheckResult result = check("local a: number, b: string = \"994\", 888"); - CHECK_EQ(2, result.errors.size()); + CheckResult result = check(R"( + local a: number, b: string = "one", 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(get(result.errors[0])); + CHECK(get(result.errors[1])); } TEST_CASE_FIXTURE(Fixture, "successful_check") { - CheckResult result = check("local a: number, b: string = 994, \"eight eighty eight\""); + CheckResult result = check(R"( + local a: number, b: string = 1, "two" + )"); + LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "assignments_are_checked_against_annotations") +{ + CheckResult result = check(R"( + local x: number = 1 + x = "two" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "multi_assign_checks_against_annotations") +{ + CheckResult result = check(R"( + local a: number, b: string = 1, "two" + a, b = "one", 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(Location{{2, 15}, {2, 20}} == result.errors[0].location); + CHECK(Location{{2, 22}, {2, 23}} == result.errors[1].location); +} + +TEST_CASE_FIXTURE(Fixture, "assignment_cannot_transform_a_table_property_type") +{ + CheckResult result = check(R"( + local a = {x=0} + a.x = "one" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(Location{{2, 14}, {2, 19}} == result.errors[0].location); +} + +TEST_CASE_FIXTURE(Fixture, "assignments_to_unannotated_parameters_can_transform_the_type") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function f(x) + x = 0 + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(unknown) -> number" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "assignments_to_annotated_parameters_are_checked") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function f(x: string) + x = 0 + return x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(Location{{2, 16}, {2, 17}} == result.errors[0].location); + + CHECK("(string) -> number" == toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "variable_type_is_supertype") @@ -42,6 +118,22 @@ TEST_CASE_FIXTURE(Fixture, "variable_type_is_supertype") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "assignment_also_checks_subtyping") +{ + CheckResult result = check(R"( + function f(): number? + return nil + end + local x: number = 1 + local y: number? = f() + x = y + y = x + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(Location{{6, 12}, {6, 13}} == result.errors[0].location); +} + TEST_CASE_FIXTURE(Fixture, "function_parameters_can_have_annotations") { CheckResult result = check(R"( @@ -191,7 +283,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_type_of_value_a_via_typeof_with_assignment") TEST_CASE_FIXTURE(Fixture, "table_annotation") { CheckResult result = check(R"( - local x: {a: number, b: string} + local x: {a: number, b: string} = {a=2, b="three"} local y = x.a local z = x.b )"); @@ -391,7 +483,7 @@ TEST_CASE_FIXTURE(Fixture, "two_type_params") { CheckResult result = check(R"( type Map = {[K]: V} - local m: Map = {}; + local m: Map = {} local a = m['foo'] local b = m[9] -- error here )"); @@ -572,8 +664,8 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti CHECK(isInArena(aType, mod.interfaceTypes)); CHECK(isInArena(bType, mod.interfaceTypes)); - CHECK_EQ(recordType, aType); - CHECK_EQ(recordType, bType); + CHECK(toString(recordType, {true}) == toString(aType, {true})); + CHECK(toString(recordType, {true}) == toString(bType, {true})); } TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 6892c78f..18b8ab8a 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -976,6 +976,7 @@ local y = x["Bar"] LUAU_REQUIRE_NO_ERRORS(result); } +#if 0 TEST_CASE_FIXTURE(Fixture, "cli_80596_simplify_degenerate_intersections") { ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; @@ -1025,5 +1026,6 @@ TEST_CASE_FIXTURE(Fixture, "cli_80596_simplify_more_realistic_intersections") LUAU_REQUIRE_ERRORS(result); } +#endif TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 23314535..97ee15e1 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -350,6 +350,7 @@ Table type 'a' not compatible with type 'Bad' because the former is missing fiel CHECK_EQ(expected, toString(result.errors[0])); } +#if 0 TEST_CASE_FIXTURE(Fixture, "parametric_tagged_union_alias") { ScopedFastFlag sff[] = { @@ -371,6 +372,7 @@ TEST_CASE_FIXTURE(Fixture, "parametric_tagged_union_alias") CHECK(toString(result.errors[0]) == expectedError); } +#endif TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") { diff --git a/tests/TypePath.test.cpp b/tests/TypePath.test.cpp index 5d4a49bf..53127c3d 100644 --- a/tests/TypePath.test.cpp +++ b/tests/TypePath.test.cpp @@ -93,6 +93,7 @@ TEST_SUITE_BEGIN("TypePathTraversal"); LUAU_REQUIRE_NO_ERRORS(result); \ } while (false); +#if 0 TEST_CASE_FIXTURE(Fixture, "empty_traversal") { CHECK(traverseForType(builtinTypes->numberType, kEmpty, builtinTypes) == builtinTypes->numberType); @@ -474,6 +475,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "complex_chains") CHECK(*result == builtinTypes->falseType); } } +#endif TEST_SUITE_END(); // TypePathTraversal diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua index 3b117892..281ad274 100644 --- a/tests/conformance/bitwise.lua +++ b/tests/conformance/bitwise.lua @@ -101,6 +101,7 @@ assert(bit32.extract(0xa0001111, 28, 4) == 0xa) assert(bit32.extract(0xa0001111, 31, 1) == 1) assert(bit32.extract(0x50000111, 31, 1) == 0) assert(bit32.extract(0xf2345679, 0, 32) == 0xf2345679) +assert(bit32.extract(0xa0001111, 0) == 1) assert(bit32.extract(0xa0001111, 16) == 0) assert(bit32.extract(0xa0001111, 31) == 1) assert(bit32.extract(42, 1, 3) == 5) @@ -134,6 +135,11 @@ assert(bit32.countrz(0x80000000) == 31) assert(bit32.countrz(0x40000000) == 30) assert(bit32.countrz(0x7fffffff) == 0) +-- testing byteswap +assert(bit32.byteswap(0x10203040) == 0x40302010) +assert(bit32.byteswap(0) == 0) +assert(bit32.byteswap(-1) == 0xffffffff) + --[[ This test verifies a fix in luauF_replace() where if the 4th parameter was not a number, but the first three are numbers, it will @@ -164,5 +170,6 @@ assert(bit32.btest("1", 3) == true) assert(bit32.countlz("42") == 26) assert(bit32.countrz("42") == 1) assert(bit32.extract("42", 1, 3) == 5) +assert(bit32.byteswap("0xa1b2c3d4") == 0xd4c3b2a1) return('OK') diff --git a/tests/conformance/buffers.lua b/tests/conformance/buffers.lua new file mode 100644 index 00000000..a6b951ea --- /dev/null +++ b/tests/conformance/buffers.lua @@ -0,0 +1,540 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing byte buffer library") + +function call(fn, ...) + local ok, res = pcall(fn, ...) + assert(ok) + return res +end + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub((err:find(": ") or -1) + 2, #err) +end + +local function simple_byte_reads() + local b = buffer.create(1024) + + assert(buffer.len(b) == 1024) + + assert(buffer.readi8(b, 5) == 0) + buffer.writei8(b, 10, 32) + assert(buffer.readi8(b, 10) == 32) + buffer.writei8(b, 15, 5) + buffer.writei8(b, 14, 4) + buffer.writei8(b, 13, 3) + buffer.writei8(b, 12, 2) + buffer.writei8(b, 11, 1) + assert(buffer.readi8(b, 11) == 1) + assert(buffer.readi8(b, 12) == 2) + assert(buffer.readi8(b, 13) == 3) + assert(buffer.readi8(b, 14) == 4) + assert(buffer.readi8(b, 15) == 5) + + local x = buffer.readi8(b, 14) + buffer.readi8(b, 13) + assert(x == 7) +end + +simple_byte_reads() + +local function offset_byte_reads(start: number) + local b = buffer.create(1024) + + buffer.writei8(b, start, 32) + assert(buffer.readi8(b, start) == 32) + buffer.writei8(b, start + 5, 5) + buffer.writei8(b, start + 4, 4) + buffer.writei8(b, start + 3, 3) + buffer.writei8(b, start + 2, 2) + buffer.writei8(b, start + 1, 1) + assert(buffer.readi8(b, start + 1) == 1) + assert(buffer.readi8(b, start + 2) == 2) + assert(buffer.readi8(b, start + 3) == 3) + assert(buffer.readi8(b, start + 4) == 4) + assert(buffer.readi8(b, start + 5) == 5) + + local x = buffer.readi8(b, start + 4) + buffer.readi8(b, start + 3) + assert(x == 7) +end + +offset_byte_reads(5) +offset_byte_reads(30) + +local function simple_float_reinterpret() + local b = buffer.create(1024) + + buffer.writei32(b, 10, 0x3f800000) + local one = buffer.readf32(b, 10) + assert(one == 1.0) + + buffer.writef32(b, 10, 2.75197) + local magic = buffer.readi32(b, 10) + assert(magic == 0x40302047) +end + +simple_float_reinterpret() + +local function simple_double_reinterpret() + local b = buffer.create(1024) + + buffer.writei32(b, 10, 0x00000000) + buffer.writei32(b, 14, 0x3ff00000) + local one = buffer.readf64(b, 10) + assert(one == 1.0) + + buffer.writef64(b, 10, 1.437576533064206) + local magic1 = buffer.readi32(b, 10) + local magic2 = buffer.readi32(b, 14) + + assert(magic1 == 0x40302010) + assert(magic2 == 0x3ff70050) +end + +simple_double_reinterpret() + +local function simple_string_ops() + local b = buffer.create(1024) + + buffer.writestring(b, 15, " world") + buffer.writestring(b, 10, "hello") + buffer.writei8(b, 21, string.byte('!')) + assert(buffer.readstring(b, 10, 12) == "hello world!") + + buffer.writestring(b, 10, "hellommm", 5) + assert(buffer.readstring(b, 10, 12) == "hello world!") + + buffer.writestring(b, 10, string.rep("hellommm", 1000), 5) + assert(buffer.readstring(b, 10, 12) == "hello world!") +end + +simple_string_ops() + +local function simple_copy_ops() + local b1 = buffer.create(1024) + local b2 = buffer.create(1024) + + buffer.writestring(b1, 200, "hello") + buffer.writestring(b1, 100, "world") + + buffer.copy(b1, 300, b1, 100, 5) + + buffer.writei8(b2, 35, string.byte(' ')) + buffer.writei8(b2, 41, string.byte('!')) + + buffer.copy(b2, 30, b1, 200, 5) + buffer.copy(b2, 36, b1, 300, 5) + + assert(buffer.readstring(b2, 30, 12) == "hello world!") + + local b3 = buffer.create(9) + buffer.writestring(b3, 0, "say hello") + buffer.copy(b2, 36, b3, 4) + assert(buffer.readstring(b2, 30, 12) == "hello hello!") + + local b4 = buffer.create(5) + buffer.writestring(b4, 0, "world") + buffer.copy(b2, 36, b4) + assert(buffer.readstring(b2, 30, 12) == "hello world!") + + buffer.writestring(b1, 200, "abcdefgh"); + buffer.copy(b1, 200, b1, 202, 6) + assert(buffer.readstring(b1, 200, 8) == "cdefghgh") + buffer.copy(b1, 202, b1, 200, 6) + assert(buffer.readstring(b1, 200, 8) == "cdcdefgh") +end + +simple_copy_ops() + +-- bounds checking + +local function createchecks() + assert(ecall(function() buffer.create(-1) end) == "size cannot be negative") + assert(ecall(function() buffer.create(-1000000) end) == "size cannot be negative") +end + +createchecks() + +local function boundchecks() + local b = buffer.create(1024) + + assert(call(function() return buffer.readi8(b, 1023) end) == 0) + assert(ecall(function() buffer.readi8(b, 1024) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, -100000) end) == "buffer access out of bounds") + + call(function() buffer.writei8(b, 1023, 0) end) + assert(ecall(function() buffer.writei8(b, 1024, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, -100000, 0) end) == "buffer access out of bounds") + + -- i16 + assert(call(function() return buffer.readi16(b, 1022) end) == 0) + assert(ecall(function() buffer.readi16(b, 1023) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, -100000) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, 0x7fffffff) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, 0x7ffffffe) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, 0x7ffffffd) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, 0x80000000) end) == "buffer access out of bounds") + + call(function() buffer.writei16(b, 1022, 0) end) + assert(ecall(function() buffer.writei16(b, 1023, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, -100000, 0) end) == "buffer access out of bounds") + + -- i32 + assert(call(function() return buffer.readi32(b, 1020) end) == 0) + assert(ecall(function() buffer.readi32(b, 1021) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, -100000) end) == "buffer access out of bounds") + + call(function() buffer.writei32(b, 1020, 0) end) + assert(ecall(function() buffer.writei32(b, 1021, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, -100000, 0) end) == "buffer access out of bounds") + + -- f32 + assert(call(function() return buffer.readf32(b, 1020) end) == 0) + assert(ecall(function() buffer.readf32(b, 1021) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, -100000) end) == "buffer access out of bounds") + + call(function() buffer.writef32(b, 1020, 0) end) + assert(ecall(function() buffer.writef32(b, 1021, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, -100000, 0) end) == "buffer access out of bounds") + + -- f64 + assert(call(function() return buffer.readf64(b, 1016) end) == 0) + assert(ecall(function() buffer.readf64(b, 1017) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, -100000) end) == "buffer access out of bounds") + + call(function() buffer.writef64(b, 1016, 0) end) + assert(ecall(function() buffer.writef64(b, 1017, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, -100000, 0) end) == "buffer access out of bounds") + + -- string + assert(call(function() return buffer.readstring(b, 1016, 8) end) == "\0\0\0\0\0\0\0\0") + assert(ecall(function() buffer.readstring(b, 1017, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, -1, -8) end) == "size cannot be negative") + assert(ecall(function() buffer.readstring(b, -100000, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, -100000, 8) end) == "buffer access out of bounds") + + call(function() buffer.writestring(b, 1016, "abcdefgh") end) + assert(ecall(function() buffer.writestring(b, 1017, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, -1, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, -100000, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, 100, "abcd", -5) end) == "count cannot be negative") + assert(ecall(function() buffer.writestring(b, 100, "abcd", 50) end) == "string length overflow") + + -- copy + assert(ecall(function() buffer.copy(b, 30, b, 200, 1000) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 30, b, 200, -5) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 30, b, 2000, 10) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 30, b, -1, 10) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 30, b, -10, 10) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 30, b, -100000, 10) end) == "buffer access out of bounds") + + local b2 = buffer.create(1024) + assert(ecall(function() buffer.copy(b, -200, b, 200, 200) end) == "buffer access out of bounds") + assert(ecall(function() buffer.copy(b, 825, b, 200, 200) end) == "buffer access out of bounds") +end + +boundchecks() + +local function boundchecksnonconst(size, minus1, minusbig, intmax) + local b = buffer.create(size) + + assert(call(function() return buffer.readi8(b, size-1) end) == 0) + assert(ecall(function() buffer.readi8(b, size) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, minusbig) end) == "buffer access out of bounds") + + call(function() buffer.writei8(b, size-1, 0) end) + assert(ecall(function() buffer.writei8(b, size, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, minusbig, 0) end) == "buffer access out of bounds") + + -- i16 + assert(call(function() return buffer.readi16(b, size-2) end) == 0) + assert(ecall(function() buffer.readi16(b, size-1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, minusbig) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, intmax) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, intmax-1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, intmax-2) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, intmax+1) end) == "buffer access out of bounds") + + call(function() buffer.writei16(b, size-2, 0) end) + assert(ecall(function() buffer.writei16(b, size-1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, minusbig, 0) end) == "buffer access out of bounds") + + -- i32 + assert(call(function() return buffer.readi32(b, size-4) end) == 0) + assert(ecall(function() buffer.readi32(b, size-3) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, minusbig) end) == "buffer access out of bounds") + + call(function() buffer.writei32(b, size-4, 0) end) + assert(ecall(function() buffer.writei32(b, size-3, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, minusbig, 0) end) == "buffer access out of bounds") + + -- f32 + assert(call(function() return buffer.readf32(b, size-4) end) == 0) + assert(ecall(function() buffer.readf32(b, size-3) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, minusbig) end) == "buffer access out of bounds") + + call(function() buffer.writef32(b, size-4, 0) end) + assert(ecall(function() buffer.writef32(b, size-3, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, minusbig, 0) end) == "buffer access out of bounds") + + -- f64 + assert(call(function() return buffer.readf64(b, size-8) end) == 0) + assert(ecall(function() buffer.readf64(b, size-7) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, minus1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, minusbig) end) == "buffer access out of bounds") + + call(function() buffer.writef64(b, size-8, 0) end) + assert(ecall(function() buffer.writef64(b, size-7, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, minus1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, minusbig, 0) end) == "buffer access out of bounds") + + -- string + assert(call(function() return buffer.readstring(b, size-8, 8) end) == "\0\0\0\0\0\0\0\0") + assert(ecall(function() buffer.readstring(b, size-7, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, minus1, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, minusbig, 8) end) == "buffer access out of bounds") + + call(function() buffer.writestring(b, size-8, "abcdefgh") end) + assert(ecall(function() buffer.writestring(b, size-7, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, minus1, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, minusbig, "abcdefgh") end) == "buffer access out of bounds") +end + +boundchecksnonconst(1024, -1, -100000, 0x7fffffff) + +local function boundcheckssmall() + local b = buffer.create(1) + + assert(call(function() return buffer.readi8(b, 0) end) == 0) + assert(ecall(function() buffer.readi8(b, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, -1) end) == "buffer access out of bounds") + + call(function() buffer.writei8(b, 0, 0) end) + assert(ecall(function() buffer.writei8(b, 1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, -1, 0) end) == "buffer access out of bounds") + + -- i16 + assert(ecall(function() buffer.readi16(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi16(b, -2) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, 0, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei16(b, -2, 0) end) == "buffer access out of bounds") + + -- i32 + assert(ecall(function() buffer.readi32(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, -4) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, 0, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei32(b, -4, 0) end) == "buffer access out of bounds") + + -- f32 + assert(ecall(function() buffer.readf32(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, -4) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, 0, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef32(b, -4, 0) end) == "buffer access out of bounds") + + -- f64 + assert(ecall(function() buffer.readf64(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, -8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, 0, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writef64(b, -7, 0) end) == "buffer access out of bounds") + + -- string + assert(ecall(function() buffer.readstring(b, 0, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, -1, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, -8, 8) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, 0, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, -1, "abcdefgh") end) == "buffer access out of bounds") + assert(ecall(function() buffer.writestring(b, -7, "abcdefgh") end) == "buffer access out of bounds") +end + +boundcheckssmall() + +local function boundchecksempty() + local b = buffer.create(0) -- useless, but probably more generic + + assert(ecall(function() buffer.readi8(b, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi8(b, -1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, 1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, 0, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writei8(b, -1, 0) end) == "buffer access out of bounds") + + assert(ecall(function() buffer.readi16(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readi32(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf32(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readf64(b, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, 0, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readstring(b, 0, 8) end) == "buffer access out of bounds") +end + +boundchecksempty() + +local function intuint() + local b = buffer.create(32) + + buffer.writeu32(b, 0, 0xffffffff) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == -1) + assert(buffer.readu16(b, 0) == 65535) + assert(buffer.readi32(b, 0) == -1) + assert(buffer.readu32(b, 0) == 4294967295) + + buffer.writei32(b, 0, -1) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == -1) + assert(buffer.readu16(b, 0) == 65535) + assert(buffer.readi32(b, 0) == -1) + assert(buffer.readu32(b, 0) == 4294967295) + + buffer.writei16(b, 0, 65535) + buffer.writei16(b, 2, -1) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == -1) + assert(buffer.readu16(b, 0) == 65535) + assert(buffer.readi32(b, 0) == -1) + assert(buffer.readu32(b, 0) == 4294967295) + + buffer.writeu16(b, 0, 65535) + buffer.writeu16(b, 2, -1) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == -1) + assert(buffer.readu16(b, 0) == 65535) + assert(buffer.readi32(b, 0) == -1) + assert(buffer.readu32(b, 0) == 4294967295) +end + +intuint() + +local function intuinttricky() + local b = buffer.create(32) + + buffer.writeu8(b, 0, 0xffffffff) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == 255) + assert(buffer.readu16(b, 0) == 255) + assert(buffer.readi32(b, 0) == 255) + assert(buffer.readu32(b, 0) == 255) + + buffer.writeu16(b, 0, 0xffffffff) + assert(buffer.readi8(b, 0) == -1) + assert(buffer.readu8(b, 0) == 255) + assert(buffer.readi16(b, 0) == -1) + assert(buffer.readu16(b, 0) == 65535) + assert(buffer.readi32(b, 0) == 65535) + assert(buffer.readu32(b, 0) == 65535) + + buffer.writei32(b, 8, 0xffffffff) + buffer.writeu32(b, 12, 0xffffffff) + assert(buffer.readstring(b, 8, 4) == buffer.readstring(b, 12, 4)) + + buffer.writei32(b, 8, -2147483648) + buffer.writeu32(b, 12, 0x80000000) + assert(buffer.readstring(b, 8, 4) == buffer.readstring(b, 12, 4)) +end + +intuinttricky() + +local function fromtostring() + local b = buffer.fromstring("1234567890") + assert(buffer.tostring(b) == "1234567890") + + buffer.writestring(b, 4, "xyz") + assert(buffer.tostring(b) == "1234xyz890") + + local b2 = buffer.fromstring("abcd\0ef") + assert(buffer.tostring(b2) == "abcd\0ef") +end + +fromtostring() + +local function fill() + local b = buffer.create(10) + + buffer.fill(b, 0, 0x61) + assert(buffer.tostring(b) == "aaaaaaaaaa") + + buffer.fill(b, 0, 0x62, 5) + assert(buffer.tostring(b) == "bbbbbaaaaa") + + buffer.fill(b, 4, 0x63) + assert(buffer.tostring(b) == "bbbbcccccc") + + buffer.fill(b, 6, 0x64, 3) + assert(buffer.tostring(b) == "bbbbccdddc") + + buffer.fill(b, 2, 0xffffff65, 8) + assert(buffer.tostring(b) == "bbeeeeeeee") + + -- out of bounds + assert(ecall(function() buffer.fill(b, -10, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.fill(b, 11, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.fill(b, 0, 1, 11) end) == "buffer access out of bounds") + assert(ecall(function() buffer.fill(b, 5, 1, 6) end) == "buffer access out of bounds") + assert(ecall(function() buffer.fill(b, 5, 1, -1) end) == "buffer access out of bounds") +end + +fill() + +local function misc() + local b = buffer.create(1000) + + assert(select('#', buffer.writei32(b, 10, 40)) == 0) + assert(select('#', buffer.writef32(b, 20, 40.0)) == 0) +end + +misc() + +local function testslowcalls() + getfenv() + + simple_byte_reads() + offset_byte_reads(5) + offset_byte_reads(30) + simple_float_reinterpret() + simple_double_reinterpret() + simple_string_ops() + createchecks() + boundchecks() + boundchecksnonconst(1024, -1, -100000, 0x7fffffff) + boundcheckssmall() + boundchecksempty() + intuint() + intuinttricky() + fromtostring() + fill() + misc() +end + +testslowcalls() + +return('OK') diff --git a/tests/conformance/interrupt.lua b/tests/conformance/interrupt.lua index c07f57e7..ca6d5c6c 100644 --- a/tests/conformance/interrupt.lua +++ b/tests/conformance/interrupt.lua @@ -1,25 +1,70 @@ -- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details print("testing interrupts") -function foo() - for i=1,10 do end - return +-- this function will be called by C code with a special interrupt handler that validates hit locations +function test() + function foo() + for i=1,10 do end + return + end + + foo() + + function bar() + local i = 0 + while i < 10 do + i += i + 1 + end + end + + bar() + + function baz() + end + + baz() end -foo() - -function bar() - local i = 0 - while i < 10 do - i += i + 1 - end +-- these functions will be called by C code with a special interrupt handler that terminates after a few invocations +function infloop1() + while true do end end -bar() - -function baz() +function infloop2() + while true do continue end end -baz() +function infloop3() + repeat until false +end + +function infloop4() + repeat continue until false +end + +function infloop5() + for i=0,0,0 do end +end + +function infloop6() + for i=0,0,0 do continue end +end + +function infloop7() + for i=1,math.huge do end +end + +function infloop8() + for i=1,math.huge do continue end +end + +function infloop9() + -- technically not a loop, but an exponentially recursive function + local function boom() + boom() + boom() + end + boom() +end return "OK" diff --git a/tests/conformance/native.lua b/tests/conformance/native.lua index d94e4a49..08d458f9 100644 --- a/tests/conformance/native.lua +++ b/tests/conformance/native.lua @@ -153,6 +153,26 @@ end assert(pcall(fuzzfail15) == true) +local function fuzzfail16() + _ = {[{[2]=77,_=_,[2]=_,}]=not _,} + _ = {77,[2]=11008,[2]=_,[0]=_,} +end + +assert(pcall(fuzzfail16) == true) + +local function fuzzfail17() + return bit32.extract(1293942816,1293942816) +end + +assert(pcall(fuzzfail17) == false) + +local function fuzzfail18() + return bit32.extract(7890276,0) +end + +assert(pcall(fuzzfail18) == true) +assert(fuzzfail18() == 0) + local function arraySizeInv1() local t = {1, 2, nil, nil, nil, nil, nil, nil, nil, true} diff --git a/tests/conformance/native_types.lua b/tests/conformance/native_types.lua index c375ab81..639ce80b 100644 --- a/tests/conformance/native_types.lua +++ b/tests/conformance/native_types.lua @@ -52,6 +52,8 @@ local function checkfunction(a: () -> ()) assert(is_native()) end local function checkthread(a: thread) assert(is_native()) end local function checkuserdata(a: userdata) assert(is_native()) end local function checkvector(a: vector) assert(is_native()) end +local function checkbuffer(a: buffer) assert(is_native()) end +local function checkoptbuffer(a: buffer?) assert(is_native()) end call(checktable, {}) ecall(checktable, 2) @@ -68,6 +70,12 @@ ecall(checkuserdata, 2) call(checkvector, vector(1, 2, 3)) ecall(checkvector, 2) +call(checkbuffer, buffer.create(10)) +ecall(checkbuffer, 2) +call(checkoptbuffer, buffer.create(10)) +call(checkoptbuffer, nil) +ecall(checkoptbuffer, 2) + local function mutation_causes_bad_exit(a: number, count: number, sum: number) repeat a = 's' diff --git a/tools/faillist.txt b/tools/faillist.txt index 3707c457..46454bbc 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,5 +1,3 @@ -AnnotationTests.cloned_interface_maintains_pointers_between_definitions -AnnotationTests.table_annotation AnnotationTests.two_type_params AnnotationTests.typeof_expr AnnotationTests.use_generic_type_alias @@ -11,6 +9,7 @@ AutocompleteTest.anonymous_autofilled_generic_on_argument_type_pack_vararg AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg AutocompleteTest.autocomplete_interpolated_string_as_singleton AutocompleteTest.autocomplete_oop_implicit_self +AutocompleteTest.autocomplete_response_perf1 AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.autocomplete_string_singleton_escape AutocompleteTest.autocomplete_string_singletons @@ -194,7 +193,6 @@ GenericsTests.self_recursive_instantiated_param GenericsTests.type_parameters_can_be_polytypes GenericsTests.typefuns_sharing_types IntersectionTypes.argument_is_intersection -IntersectionTypes.cli_80596_simplify_degenerate_intersections IntersectionTypes.error_detailed_intersection_all IntersectionTypes.error_detailed_intersection_part IntersectionTypes.fx_intersection_as_argument @@ -448,10 +446,8 @@ TryUnifyTests.uninhabited_table_sub_never TryUnifyTests.variadics_should_use_reversed_properly TypeAliases.corecursive_types_generic TypeAliases.cyclic_types_of_named_table_fields_do_not_expand_when_stringified -TypeAliases.dependent_generic_aliases TypeAliases.dont_lose_track_of_PendingExpansionTypes_after_substitution TypeAliases.free_variables_from_typeof_in_aliases -TypeAliases.generic_aliases TypeAliases.generic_param_remap TypeAliases.mismatched_generic_type_param TypeAliases.mutually_recursive_aliases @@ -462,7 +458,6 @@ TypeAliases.mutually_recursive_types_swapsies_not_ok TypeAliases.recursive_types_restriction_not_ok TypeAliases.report_shadowed_aliases TypeAliases.saturate_to_first_type_pack -TypeAliases.stringify_optional_parameterized_alias TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_locations @@ -490,6 +485,8 @@ TypeInfer.globals TypeInfer.globals2 TypeInfer.globals_are_banned_in_strict_mode TypeInfer.if_statement +TypeInfer.infer_assignment_value_types +TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.infer_locals_via_assignment_from_its_call_site TypeInfer.infer_locals_with_nil_value TypeInfer.infer_through_group_expr @@ -529,10 +526,12 @@ TypeInferClasses.class_unification_type_mismatch_is_correct_order TypeInferClasses.detailed_class_unification_error TypeInferClasses.index_instance_property TypeInferClasses.indexable_classes +TypeInferClasses.intersections_of_unions_of_classes TypeInferClasses.optional_class_field_access_error TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferClasses.table_indexers_are_invariant TypeInferClasses.type_mismatch_invariance_required_for_error +TypeInferClasses.unions_of_intersections_of_classes TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class TypeInferFunctions.another_other_higher_order_function TypeInferFunctions.apply_of_lambda_with_inferred_and_explicit_types @@ -634,7 +633,6 @@ TypeInferModules.bound_free_table_export_is_ok TypeInferModules.do_not_modify_imported_types TypeInferModules.do_not_modify_imported_types_4 TypeInferModules.do_not_modify_imported_types_5 -TypeInferModules.general_require_call_expression TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated TypeInferModules.require @@ -681,6 +679,7 @@ TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.string_function_indirect TypeInferPrimitives.string_index TypeInferUnknownNever.array_like_table_of_never_is_inhabitable +TypeInferUnknownNever.assign_to_local_which_is_never TypeInferUnknownNever.assign_to_prop_which_is_never TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_never TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_sorta_never @@ -696,7 +695,6 @@ TypePackTests.type_packs_with_tails_in_vararg_adjustment TypePackTests.unify_variadic_tails_in_arguments TypePackTests.unify_variadic_tails_in_arguments_free TypePackTests.variadic_argument_tail -TypePathToStringForError.basic TypeSingletons.enums_using_singletons_mismatch TypeSingletons.error_detailed_tagged_union_mismatch_bool TypeSingletons.error_detailed_tagged_union_mismatch_string @@ -704,7 +702,6 @@ TypeSingletons.function_args_infer_singletons TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch TypeSingletons.overloaded_function_call_with_singletons_mismatch -TypeSingletons.parametric_tagged_union_alias TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.table_properties_singleton_strings TypeSingletons.table_properties_type_error_escapes @@ -724,7 +721,7 @@ UnionTypes.index_on_a_union_type_with_one_property_of_type_any UnionTypes.index_on_a_union_type_with_property_guaranteed_to_exist UnionTypes.index_on_a_union_type_works_at_arbitrary_depth UnionTypes.less_greedy_unification_with_union_types -UnionTypes.optional_arguments_table2 +UnionTypes.optional_arguments_table UnionTypes.optional_assignment_errors UnionTypes.optional_call_error UnionTypes.optional_field_access_error @@ -736,6 +733,7 @@ UnionTypes.optional_union_functions UnionTypes.optional_union_members UnionTypes.optional_union_methods UnionTypes.table_union_write_indirect +UnionTypes.unify_unsealed_table_union_check UnionTypes.union_of_functions UnionTypes.union_of_functions_mentioning_generic_typepacks UnionTypes.union_of_functions_mentioning_generics diff --git a/tools/fuzzfilter.py b/tools/fuzzfilter.py new file mode 100644 index 00000000..92891a0c --- /dev/null +++ b/tools/fuzzfilter.py @@ -0,0 +1,47 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a fuzzer binary and a list of crashing programs, this tool collects unique crash reasons and prints reproducers. + +import re +import sys +import subprocess + +def get_crash_reason(binary, file): + res = subprocess.run([binary, file], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE) + if res.returncode == 0: + print(f"Warning: {binary} {file} returned 0") + return None + err = res.stderr.decode("utf-8") + + if (pos := err.find("ERROR: libFuzzer:")) != -1: + return err[pos:] + + print(f"Warning: {binary} {file} returned unrecognized error {err}") + return None + +def get_crash_fingerprint(reason): + # Due to ASLR addresses are different every time, so we filter them out + reason = re.sub(r"0x[0-9a-f]+", "0xXXXX", reason) + return reason + +binary = sys.argv[1] +files = sys.argv[2:] + +seen = set() + +for index, file in enumerate(files): + reason = get_crash_reason(binary, file) + if reason is None: + continue + fingerprint = get_crash_fingerprint(reason) + if fingerprint in seen: + # print a spinning ASCII wheel to indicate that we're making progress + print("-\|/"[index % 4] + "\r", end="") + continue + seen.add(fingerprint) + print(f"Reproducer: {binary} {file}") + print(f"Crash reason: {reason}") + print() + +print(f"Total unique crash reasons: {len(seen)}") \ No newline at end of file