diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index a8b17c55..7a6188a6 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -764,16 +764,17 @@ TypeId makeStringMetatable(NotNull builtinTypes) const TypeId numberType = builtinTypes->numberType; const TypeId booleanType = builtinTypes->booleanType; const TypeId stringType = builtinTypes->stringType; - const TypeId anyType = builtinTypes->anyType; const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}}); const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}}); const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}}); const TypePackId oneStringPack = arena->addTypePack({stringType}); - const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); + const TypePackId anyTypePack = builtinTypes->anyTypePack; - FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; + const TypePackId variadicTailPack = FFlag::DebugLuauDeferredConstraintResolution ? builtinTypes->unknownTypePack : anyTypePack; + + FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; formatFTV.magicFunction = &magicFunctionFormat; const TypeId formatFn = arena->addType(formatFTV); attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); @@ -820,13 +821,13 @@ TypeId makeStringMetatable(NotNull builtinTypes) {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, {"pack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType}, anyTypePack}), + arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack, })}}, {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, {"unpack", {arena->addType(FunctionType{ arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - anyTypePack, + variadicTailPack, })}}, }; diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 17bffe87..7fcd4d9e 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -2270,10 +2270,6 @@ std::tuple ConstraintGenerator::checkBinary( if (!key) return {leftType, rightType, nullptr}; - auto augmentForErrorSupression = [&](TypeId ty) -> TypeId { - return arena->addType(UnionType{{ty, builtinTypes->errorType}}); - }; - TypeId discriminantTy = builtinTypes->neverType; if (typeguard->type == "nil") discriminantTy = builtinTypes->nilType; @@ -2288,9 +2284,9 @@ std::tuple ConstraintGenerator::checkBinary( else if (typeguard->type == "buffer") discriminantTy = builtinTypes->bufferType; else if (typeguard->type == "table") - discriminantTy = augmentForErrorSupression(builtinTypes->tableType); + discriminantTy = builtinTypes->tableType; else if (typeguard->type == "function") - discriminantTy = augmentForErrorSupression(builtinTypes->functionType); + discriminantTy = builtinTypes->functionType; else if (typeguard->type == "userdata") { // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 996d638b..40f0f8b9 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -472,6 +472,11 @@ struct FreeTypeSearcher : TypeOnceVisitor result->push_back({ty, location}); return false; } + + bool visit(TypeId, const ClassType&) override + { + return false; + } }; } // namespace @@ -672,13 +677,13 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope, builtinTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; - std::optional anyified = anyify.substitute(c.variables); - LUAU_ASSERT(anyified); - unify(constraint, *anyified, c.variables); + unify(constraint, builtinTypes->anyTypePack, c.variables); return true; } - TypeId nextTy = follow(iteratorTypes[0]); + TypeId nextTy = follow(iterator.head[0]); if (get(nextTy)) return block_(nextTy); if (get(nextTy)) { TypeId tableTy = builtinTypes->nilType; - if (iteratorTypes.size() >= 2) - tableTy = iteratorTypes[1]; + if (iterator.head.size() >= 2) + tableTy = iterator.head[1]; TypeId firstIndexTy = builtinTypes->nilType; - if (iteratorTypes.size() >= 3) - firstIndexTy = iteratorTypes[2]; + if (iterator.head.size() >= 3) + firstIndexTy = iterator.head[2]; return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force); } else - return tryDispatchIterableTable(iteratorTypes[0], c, constraint, force); + return tryDispatchIterableTable(iterator.head[0], c, constraint, force); return true; } @@ -1174,10 +1176,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull expectedArgs = flatten(ftv->argTypes).first; const std::vector argPackHead = flatten(argsPack).first; - for (size_t i = 0; i < c.callSite->args.size && i < expectedArgs.size() && i < argPackHead.size(); ++i) + // If this is a self call, the types will have more elements than the AST call. + // We don't attempt to perform bidirectional inference on the self type. + const size_t typeOffset = c.callSite->self ? 1 : 0; + + for (size_t i = 0; i < c.callSite->args.size && i + typeOffset < expectedArgs.size() && i + typeOffset < argPackHead.size(); ++i) { - const TypeId expectedArgTy = follow(expectedArgs[i]); - const TypeId actualArgTy = follow(argPackHead[i]); + const TypeId expectedArgTy = follow(expectedArgs[i + typeOffset]); + const TypeId actualArgTy = follow(argPackHead[i + typeOffset]); const AstExpr* expr = c.callSite->args.data[i]; (*c.astExpectedTypes)[expr] = expectedArgTy; @@ -1375,7 +1381,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNulllocation); + bindBlockedType(a, b, subjectType, constraint->location); }; if (existingPropType) @@ -1387,6 +1393,8 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) subjectType = follow(mt->table); @@ -1419,7 +1427,7 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNulllocation); return true; } @@ -1802,21 +1810,15 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl } TypeId nextFn = iterRets.head[0]; - TypeId table = iterRets.head.size() == 2 ? iterRets.head[1] : freshType(arena, builtinTypes, constraint->scope); if (std::optional instantiatedNextFn = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, nextFn)) { - const TypeId firstIndex = freshType(arena, builtinTypes, constraint->scope); - - // nextTy : (iteratorTy, indexTy?) -> (indexTy, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({table, arena->addType(UnionType{{firstIndex, builtinTypes->nilType}})}); - const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); - const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); - - const TypeId expectedNextTy = arena->addType(FunctionType{nextArgPack, nextRetPack}); - unify(constraint, *instantiatedNextFn, expectedNextTy); + const FunctionType* nextFn = get(*instantiatedNextFn); + LUAU_ASSERT(nextFn); + const TypePackId nextRetPack = nextFn->retTypes; pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, nextRetPack}); + return true; } else { @@ -1864,31 +1866,13 @@ bool ConstraintSolver::tryDispatchIterableFunction( return false; } - TypeId firstIndex; - TypeId retIndex; - if (isNil(firstIndexTy) || isOptional(firstIndexTy)) - { - // FIXME freshType is suspect here - firstIndex = arena->addType(UnionType{{freshType(arena, builtinTypes, constraint->scope), builtinTypes->nilType}}); - retIndex = firstIndex; - } - else - { - firstIndex = firstIndexTy; - retIndex = arena->addType(UnionType{{firstIndexTy, builtinTypes->nilType}}); - } + const FunctionType* nextFn = get(nextTy); + // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. + LUAU_ASSERT(nextFn); + const TypePackId nextRetPack = nextFn->retTypes; - // nextTy : (tableTy, indexTy?) -> (indexTy?, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({tableTy, firstIndex}); - const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); - const TypePackId nextRetPack = arena->addTypePack(TypePack{{retIndex}, valueTailTy}); - - const TypeId expectedNextTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); - bool ok = unify(constraint, nextTy, expectedNextTy); - - // if there are no errors from unifying the two, we can pass forward the expected type as our selected resolution. - if (ok) - (*c.astForInNextTypes)[c.nextAstFragment] = expectedNextTy; + // the type of the `nextAstFragment` is the `nextTy`. + (*c.astForInNextTypes)[c.nextAstFragment] = nextTy; auto it = begin(nextRetPack); std::vector modifiedNextRetHead; @@ -1988,7 +1972,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa return {{}, result}; } } - else if (auto mt = get(subjectType)) + else if (auto mt = get(subjectType); mt && context == ValueContext::RValue) { auto [blocked, result] = lookupTableProp(mt->table, propName, context, suppressSimplification, seen); if (!blocked.empty() || result) @@ -2023,6 +2007,8 @@ std::pair, std::optional> ConstraintSolver::lookupTa else return lookupTableProp(indexType, propName, context, suppressSimplification, seen); } + else if (get(mtt)) + return lookupTableProp(mtt, propName, context, suppressSimplification, seen); } else if (auto ct = get(subjectType)) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 82b78149..b608e28a 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1169,6 +1169,8 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vectorname = sourceModule.name; result->humanReadableName = sourceModule.humanReadableName; + result->mode = sourceModule.mode.value_or(Mode::NoCheck); + result->internalTypes.owningModule = result.get(); result->interfaceTypes.owningModule = result.get(); @@ -1199,7 +1201,7 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vectorerrors = std::move(cg.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), result->humanReadableName, moduleResolver, + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), result->name, moduleResolver, requireCycles, logger.get(), limits}; if (options.randomizeConstraintResolutionSeed) @@ -1294,8 +1296,8 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect catch (const InternalCompilerError& err) { InternalCompilerError augmented = err.location.has_value() - ? InternalCompilerError{err.message, sourceModule.humanReadableName, *err.location} - : InternalCompilerError{err.message, sourceModule.humanReadableName}; + ? InternalCompilerError{err.message, sourceModule.name, *err.location} + : InternalCompilerError{err.message, sourceModule.name}; throw augmented; } } diff --git a/Analysis/src/OverloadResolution.cpp b/Analysis/src/OverloadResolution.cpp index b7d01984..8bce3efd 100644 --- a/Analysis/src/OverloadResolution.cpp +++ b/Analysis/src/OverloadResolution.cpp @@ -236,6 +236,8 @@ std::pair OverloadResolver::checkOverload_ */ Location argLocation; + if (reason.superPath.components.size() <= 1) + break; if (const Luau::TypePath::Index* pathIndexComponent = get_if(&reason.superPath.components.at(1))) { diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index dcfc1965..c4eb9368 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -1033,9 +1033,17 @@ TypeId TypeSimplifier::intersectIntersectionWithType(TypeId left, TypeId right) std::optional TypeSimplifier::basicIntersect(TypeId left, TypeId right) { - if (get(left)) + if (get(left) && get(right)) return right; + if (get(right) && get(left)) + return left; + if (get(left)) + return arena->addType(UnionType{{right, builtinTypes->errorType}}); if (get(right)) + return arena->addType(UnionType{{left, builtinTypes->errorType}}); + if (get(left)) + return right; + if (get(right)) return left; if (get(left)) return left; @@ -1120,9 +1128,17 @@ TypeId TypeSimplifier::intersect(TypeId left, TypeId right) left = simplify(left); right = simplify(right); - if (get(left)) + if (get(left) && get(right)) return right; + if (get(right) && get(left)) + return left; + if (get(left)) + return arena->addType(UnionType{{right, builtinTypes->errorType}}); if (get(right)) + return arena->addType(UnionType{{left, builtinTypes->errorType}}); + if (get(left)) + return right; + if (get(right)) return left; if (get(left)) return left; @@ -1278,9 +1294,11 @@ TypeId TypeSimplifier::simplify(TypeId ty, DenseHashSet& seen) { TypeId negatedTy = follow(nt->ty); if (get(negatedTy)) + return arena->addType(UnionType{{builtinTypes->neverType, builtinTypes->errorType}}); + else if (get(negatedTy)) return builtinTypes->neverType; else if (get(negatedTy)) - return builtinTypes->anyType; + return builtinTypes->unknownType; if (auto nnt = get(negatedTy)) return simplify(nnt->ty, seen); } diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index 6abf9a3b..ecdb039f 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -726,7 +726,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId if (TypePackId* other = env.mappedGenericPacks.find(*superTail)) // TODO: TypePath can't express "slice of a pack + its tail". - results.push_back(isCovariantWith(env, *other, subTailPack).withSuperComponent(TypePath::PackField::Tail)); + results.push_back(isContravariantWith(env, subTailPack, *other).withSuperComponent(TypePath::PackField::Tail)); else env.mappedGenericPacks.try_insert(*superTail, subTailPack); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 853e46ff..261c578f 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -1269,7 +1269,16 @@ struct TypeChecker2 return; else if (isOptional(fnTy)) { - reportError(OptionalValueAccess{fnTy}, call->func->location); + switch (shouldSuppressErrors(NotNull{&normalizer}, fnTy)) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + reportError(NormalizationTooComplex{}, call->func->location); + // fallthrough intentional + case ErrorSuppression::DoNotSuppress: + reportError(OptionalValueAccess{fnTy}, call->func->location); + } return; } diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 5456c423..986fcacd 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -15,6 +15,7 @@ #include "Luau/TxnLog.h" #include "Luau/Type.h" #include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" #include "Luau/VecDeque.h" @@ -861,20 +862,33 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, con // lt< 'a, t> -> 'a is t - we'll solve the constraint, return and solve lt -> bool // lt< t, 'a> -> same as above bool canSubmitConstraint = ctx->solver && ctx->constraint; + bool lhsFree = get(lhsTy) != nullptr; + bool rhsFree = get(rhsTy) != nullptr; if (canSubmitConstraint) { - if (get(lhsTy) && get(rhsTy) == nullptr) + // Implement injective type families for comparison type families + // lt implies t is number + // lt implies t is number + if (lhsFree && isNumber(rhsTy)) + asMutable(lhsTy)->ty.emplace(ctx->builtins->numberType); + else if (rhsFree && isNumber(lhsTy)) + asMutable(rhsTy)->ty.emplace(ctx->builtins->numberType); + else if (lhsFree && get(rhsTy) == nullptr) { auto c1 = ctx->solver->pushConstraint(ctx->scope, {}, EqualityConstraint{lhsTy, rhsTy}); const_cast(ctx->constraint)->dependencies.emplace_back(c1); } - else if (get(rhsTy) && get(lhsTy) == nullptr) + else if (rhsFree && get(lhsTy) == nullptr) { auto c1 = ctx->solver->pushConstraint(ctx->scope, {}, EqualityConstraint{rhsTy, lhsTy}); const_cast(ctx->constraint)->dependencies.emplace_back(c1); } } + // The above might have caused the operand types to be rebound, we need to follow them again + lhsTy = follow(lhsTy); + rhsTy = follow(rhsTy); + // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) return {std::nullopt, false, {lhsTy}, {}}; diff --git a/Analysis/src/TypePath.cpp b/Analysis/src/TypePath.cpp index 50507263..76f24421 100644 --- a/Analysis/src/TypePath.cpp +++ b/Analysis/src/TypePath.cpp @@ -432,6 +432,13 @@ struct TraversalState if (auto tt = get(current); tt && tt->indexer) indexer = &(*tt->indexer); + else if (auto mt = get(current)) + { + if (auto mtTab = get(follow(mt->table)); mtTab && mtTab->indexer) + indexer = &(*mtTab->indexer); + else if (auto mtMt = get(follow(mt->metatable)); mtMt && mtMt->indexer) + indexer = &(*mtMt->indexer); + } // Note: we don't appear to walk the class hierarchy for indexers else if (auto ct = get(current); ct && ct->indexer) indexer = &(*ct->indexer); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 93f8a851..3a075ecd 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -401,6 +401,9 @@ Unifier::Unifier(NotNull normalizer, NotNull scope, const Loc , sharedState(*normalizer->sharedState) { LUAU_ASSERT(sharedState.iceHandler); + + // Unifier is not usable when this flag is enabled! Please consider using Subtyping instead. + LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); } void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection, const LiteralProperties* literalProperties) diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 02e2bf67..5faa9553 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -580,6 +580,11 @@ struct FreeTypeSearcher : TypeVisitor return false; } + + bool visit(TypeId, const ClassType&) override + { + return false; + } }; struct MutatingGeneralizer : TypeOnceVisitor diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index be0f7198..bed7e0e3 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCache32BitAsmConsts, false) - namespace Luau { namespace CodeGen @@ -1041,33 +1039,24 @@ OperandX64 AssemblyBuilderX64::i64(int64_t value) OperandX64 AssemblyBuilderX64::f32(float value) { - if (FFlag::LuauCache32BitAsmConsts) + uint32_t as32BitKey; + static_assert(sizeof(as32BitKey) == sizeof(value), "Expecting float to be 32-bit"); + memcpy(&as32BitKey, &value, sizeof(value)); + + if (as32BitKey != ~0u) { - uint32_t as32BitKey; - static_assert(sizeof(as32BitKey) == sizeof(value), "Expecting float to be 32-bit"); - memcpy(&as32BitKey, &value, sizeof(value)); - - if (as32BitKey != ~0u) - { - if (int32_t* prev = constCache32.find(as32BitKey)) - return OperandX64(SizeX64::dword, noreg, 1, rip, *prev); - } - - size_t pos = allocateData(4, 4); - writef32(&data[pos], value); - int32_t offset = int32_t(pos - data.size()); - - if (as32BitKey != ~0u) - constCache32[as32BitKey] = offset; - - return OperandX64(SizeX64::dword, noreg, 1, rip, offset); - } - else - { - size_t pos = allocateData(4, 4); - writef32(&data[pos], value); - return OperandX64(SizeX64::dword, noreg, 1, rip, int32_t(pos - data.size())); + if (int32_t* prev = constCache32.find(as32BitKey)) + return OperandX64(SizeX64::dword, noreg, 1, rip, *prev); } + + size_t pos = allocateData(4, 4); + writef32(&data[pos], value); + int32_t offset = int32_t(pos - data.size()); + + if (as32BitKey != ~0u) + constCache32[as32BitKey] = offset; + + return OperandX64(SizeX64::dword, noreg, 1, rip, offset); } OperandX64 AssemblyBuilderX64::f64(double value) diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 6a5703d1..2a296949 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -11,10 +11,9 @@ #include "lstate.h" #include "lgc.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCodeGenFixBufferLenCheckA64, false) LUAU_FASTFLAGVARIABLE(LuauCodeGenVectorA64, false) -LUAU_FASTFLAG(LuauCodegenVectorTag) +LUAU_FASTFLAG(LuauCodegenVectorTag2) namespace Luau { @@ -680,7 +679,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { build.fadd(inst.regA64, regOp(inst.a), regOp(inst.b)); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) { RegisterA64 tempw = regs.allocTemp(KindA64::w); build.mov(tempw, LUA_TVECTOR); @@ -710,7 +709,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { build.fsub(inst.regA64, regOp(inst.a), regOp(inst.b)); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) { RegisterA64 tempw = regs.allocTemp(KindA64::w); build.mov(tempw, LUA_TVECTOR); @@ -740,7 +739,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { build.fmul(inst.regA64, regOp(inst.a), regOp(inst.b)); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) { RegisterA64 tempw = regs.allocTemp(KindA64::w); build.mov(tempw, LUA_TVECTOR); @@ -770,7 +769,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { build.fdiv(inst.regA64, regOp(inst.a), regOp(inst.b)); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) { RegisterA64 tempw = regs.allocTemp(KindA64::w); build.mov(tempw, LUA_TVECTOR); @@ -800,7 +799,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { build.fneg(inst.regA64, regOp(inst.a)); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) { RegisterA64 tempw = regs.allocTemp(KindA64::w); build.mov(tempw, LUA_TVECTOR); @@ -1184,7 +1183,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fcvt(temps, tempd); build.dup_4s(inst.regA64, castReg(KindA64::q, temps), 0); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) { build.mov(tempw, LUA_TVECTOR); build.ins_4s(inst.regA64, tempw, 3); @@ -1629,11 +1628,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) RegisterA64 tempx = castReg(KindA64::x, temp); build.sub(tempx, tempx, regOp(inst.b)); // implicit uxtw build.cmp(tempx, uint16_t(accessSize)); - - if (DFFlag::LuauCodeGenFixBufferLenCheckA64) - build.b(ConditionA64::Less, target); // note: this is a signed 64-bit comparison so that out of bounds offset fails - else - build.b(ConditionA64::LessEqual, target); // note: this is a signed 64-bit comparison so that out of bounds offset fails + build.b(ConditionA64::Less, target); // note: this is a signed 64-bit comparison so that out of bounds offset fails } } else if (inst.b.kind == IrOpKind::Constant) diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index b2b0ced2..bf82be52 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -15,7 +15,7 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAG(LuauCodegenVectorTag) +LUAU_FASTFLAG(LuauCodegenVectorTag2) LUAU_FASTFLAGVARIABLE(LuauCodegenVectorOptAnd, false) namespace Luau @@ -612,7 +612,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vaddps(inst.regX64, tmpa, tmpb); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp()); break; } @@ -627,7 +627,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); build.vsubps(inst.regX64, tmpa, tmpb); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp()); break; } @@ -642,7 +642,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); build.vmulps(inst.regX64, tmpa, tmpb); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) build.vorps(inst.regX64, inst.regX64, vectorOrMaskOp()); break; } @@ -657,7 +657,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); build.vdivps(inst.regX64, tmpa, tmpb); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3); break; } @@ -677,7 +677,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vxorpd(inst.regX64, inst.regX64, build.f32x4(-0.0, -0.0, -0.0, -0.0)); } - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3); break; } @@ -983,7 +983,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) static_assert(sizeof(asU32) == sizeof(value), "Expecting float to be 32-bit"); memcpy(&asU32, &value, sizeof(value)); - if (FFlag::LuauCodegenVectorTag) + if (FFlag::LuauCodegenVectorTag2) build.vmovaps(inst.regX64, build.u32x4(asU32, asU32, asU32, 0)); else build.vmovaps(inst.regX64, build.u32x4(asU32, asU32, asU32, LUA_TVECTOR)); @@ -993,7 +993,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vcvtsd2ss(inst.regX64, inst.regX64, memRegDoubleOp(inst.a)); build.vpshufps(inst.regX64, inst.regX64, inst.regX64, 0b00'00'00'00); - if (!FFlag::LuauCodegenVectorTag) + if (!FFlag::LuauCodegenVectorTag2) build.vpinsrd(inst.regX64, inst.regX64, build.i32(LUA_TVECTOR), 3); } break; @@ -2237,7 +2237,7 @@ OperandX64 IrLoweringX64::bufferAddrOp(IrOp bufferOp, IrOp indexOp) RegisterX64 IrLoweringX64::vecOp(IrOp op, ScopedRegX64& tmp) { - if (FFlag::LuauCodegenVectorOptAnd && FFlag::LuauCodegenVectorTag) + if (FFlag::LuauCodegenVectorOptAnd && FFlag::LuauCodegenVectorTag2) { IrInst source = function.instOp(op); CODEGEN_ASSERT(source.cmd != IrCmd::SUBSTITUTE); // we don't process substitutions @@ -2298,7 +2298,7 @@ OperandX64 IrLoweringX64::vectorAndMaskOp() OperandX64 IrLoweringX64::vectorOrMaskOp() { - CODEGEN_ASSERT(!FFlag::LuauCodegenVectorTag); + CODEGEN_ASSERT(!FFlag::LuauCodegenVectorTag2); if (vectorOrMask.base == noreg) vectorOrMask = build.u32x4(0, 0, 0, LUA_TVECTOR); diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 686d5130..995225a6 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -12,8 +12,7 @@ #include "lstate.h" #include "ltm.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenLuData, false) -LUAU_FASTFLAGVARIABLE(LuauCodegenVector, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenVectorTag2, false) LUAU_FASTFLAGVARIABLE(LuauCodegenVectorTag, false) namespace Luau @@ -354,100 +353,97 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, { BytecodeTypes bcTypes = build.function.getBytecodeTypesAt(pcpos); - if (FFlag::LuauCodegenVector) + // Special fast-paths for vectors, matching the cases we have in VM + if (bcTypes.a == LBC_TYPE_VECTOR && bcTypes.b == LBC_TYPE_VECTOR && (tm == TM_ADD || tm == TM_SUB || tm == TM_MUL || tm == TM_DIV)) { - // Special fast-paths for vectors, matching the cases we have in VM - if (bcTypes.a == LBC_TYPE_VECTOR && bcTypes.b == LBC_TYPE_VECTOR && (tm == TM_ADD || tm == TM_SUB || tm == TM_MUL || tm == TM_DIV)) + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)), build.constTag(LUA_TVECTOR), build.vmExit(pcpos)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)), build.constTag(LUA_TVECTOR), build.vmExit(pcpos)); + + IrOp vb = build.inst(IrCmd::LOAD_TVALUE, opb); + IrOp vc = build.inst(IrCmd::LOAD_TVALUE, opc); + IrOp result; + + switch (tm) { - build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)), build.constTag(LUA_TVECTOR), build.vmExit(pcpos)); - build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)), build.constTag(LUA_TVECTOR), build.vmExit(pcpos)); - - IrOp vb = build.inst(IrCmd::LOAD_TVALUE, opb); - IrOp vc = build.inst(IrCmd::LOAD_TVALUE, opc); - IrOp result; - - switch (tm) - { - case TM_ADD: - result = build.inst(IrCmd::ADD_VEC, vb, vc); - break; - case TM_SUB: - result = build.inst(IrCmd::SUB_VEC, vb, vc); - break; - case TM_MUL: - result = build.inst(IrCmd::MUL_VEC, vb, vc); - break; - case TM_DIV: - result = build.inst(IrCmd::DIV_VEC, vb, vc); - break; - default: - CODEGEN_ASSERT(!"Unknown TM op"); - } - - if (FFlag::LuauCodegenVectorTag) - result = build.inst(IrCmd::TAG_VECTOR, result); - - build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result); - return; + case TM_ADD: + result = build.inst(IrCmd::ADD_VEC, vb, vc); + break; + case TM_SUB: + result = build.inst(IrCmd::SUB_VEC, vb, vc); + break; + case TM_MUL: + result = build.inst(IrCmd::MUL_VEC, vb, vc); + break; + case TM_DIV: + result = build.inst(IrCmd::DIV_VEC, vb, vc); + break; + default: + CODEGEN_ASSERT(!"Unknown TM op"); } - else if (bcTypes.a == LBC_TYPE_NUMBER && bcTypes.b == LBC_TYPE_VECTOR && (tm == TM_MUL || tm == TM_DIV)) + + if (FFlag::LuauCodegenVectorTag2) + result = build.inst(IrCmd::TAG_VECTOR, result); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result); + return; + } + else if (bcTypes.a == LBC_TYPE_NUMBER && bcTypes.b == LBC_TYPE_VECTOR && (tm == TM_MUL || tm == TM_DIV)) + { + if (rb != -1) + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)), build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)), build.constTag(LUA_TVECTOR), build.vmExit(pcpos)); + + IrOp vb = build.inst(IrCmd::NUM_TO_VEC, loadDoubleOrConstant(build, opb)); + IrOp vc = build.inst(IrCmd::LOAD_TVALUE, opc); + IrOp result; + + switch (tm) { - if (rb != -1) - build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)), build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - - build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)), build.constTag(LUA_TVECTOR), build.vmExit(pcpos)); - - IrOp vb = build.inst(IrCmd::NUM_TO_VEC, loadDoubleOrConstant(build, opb)); - IrOp vc = build.inst(IrCmd::LOAD_TVALUE, opc); - IrOp result; - - switch (tm) - { - case TM_MUL: - result = build.inst(IrCmd::MUL_VEC, vb, vc); - break; - case TM_DIV: - result = build.inst(IrCmd::DIV_VEC, vb, vc); - break; - default: - CODEGEN_ASSERT(!"Unknown TM op"); - } - - if (FFlag::LuauCodegenVectorTag) - result = build.inst(IrCmd::TAG_VECTOR, result); - - build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result); - return; + case TM_MUL: + result = build.inst(IrCmd::MUL_VEC, vb, vc); + break; + case TM_DIV: + result = build.inst(IrCmd::DIV_VEC, vb, vc); + break; + default: + CODEGEN_ASSERT(!"Unknown TM op"); } - else if (bcTypes.a == LBC_TYPE_VECTOR && bcTypes.b == LBC_TYPE_NUMBER && (tm == TM_MUL || tm == TM_DIV)) + + if (FFlag::LuauCodegenVectorTag2) + result = build.inst(IrCmd::TAG_VECTOR, result); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result); + return; + } + else if (bcTypes.a == LBC_TYPE_VECTOR && bcTypes.b == LBC_TYPE_NUMBER && (tm == TM_MUL || tm == TM_DIV)) + { + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)), build.constTag(LUA_TVECTOR), build.vmExit(pcpos)); + + if (rc != -1) + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)), build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + + IrOp vb = build.inst(IrCmd::LOAD_TVALUE, opb); + IrOp vc = build.inst(IrCmd::NUM_TO_VEC, loadDoubleOrConstant(build, opc)); + IrOp result; + + switch (tm) { - build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)), build.constTag(LUA_TVECTOR), build.vmExit(pcpos)); - - if (rc != -1) - build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rc)), build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - - IrOp vb = build.inst(IrCmd::LOAD_TVALUE, opb); - IrOp vc = build.inst(IrCmd::NUM_TO_VEC, loadDoubleOrConstant(build, opc)); - IrOp result; - - switch (tm) - { - case TM_MUL: - result = build.inst(IrCmd::MUL_VEC, vb, vc); - break; - case TM_DIV: - result = build.inst(IrCmd::DIV_VEC, vb, vc); - break; - default: - CODEGEN_ASSERT(!"Unknown TM op"); - } - - if (FFlag::LuauCodegenVectorTag) - result = build.inst(IrCmd::TAG_VECTOR, result); - - build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result); - return; + case TM_MUL: + result = build.inst(IrCmd::MUL_VEC, vb, vc); + break; + case TM_DIV: + result = build.inst(IrCmd::DIV_VEC, vb, vc); + break; + default: + CODEGEN_ASSERT(!"Unknown TM op"); } + + if (FFlag::LuauCodegenVectorTag2) + result = build.inst(IrCmd::TAG_VECTOR, result); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result); + return; } IrOp fallback; @@ -467,30 +463,10 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, bcTypes.b == LBC_TYPE_NUMBER ? build.vmExit(pcpos) : getInitializedFallback(build, fallback)); } - IrOp vb, vc; + IrOp vb = loadDoubleOrConstant(build, opb); + IrOp vc; IrOp result; - if (FFlag::LuauCodegenVector) - { - vb = loadDoubleOrConstant(build, opb); - } - else - { - if (opb.kind == IrOpKind::VmConst) - { - CODEGEN_ASSERT(build.function.proto); - TValue protok = build.function.proto->k[vmConstOp(opb)]; - - CODEGEN_ASSERT(protok.tt == LUA_TNUMBER); - - vb = build.constDouble(protok.value.n); - } - else - { - vb = build.inst(IrCmd::LOAD_DOUBLE, opb); - } - } - if (opc.kind == IrOpKind::VmConst) { CODEGEN_ASSERT(build.function.proto); @@ -600,13 +576,13 @@ void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos) int ra = LUAU_INSN_A(*pc); int rb = LUAU_INSN_B(*pc); - if (FFlag::LuauCodegenVector && bcTypes.a == LBC_TYPE_VECTOR) + if (bcTypes.a == LBC_TYPE_VECTOR) { build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)), build.constTag(LUA_TVECTOR), build.vmExit(pcpos)); IrOp vb = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb)); IrOp va = build.inst(IrCmd::UNM_VEC, vb); - if (FFlag::LuauCodegenVectorTag) + if (FFlag::LuauCodegenVectorTag2) va = build.inst(IrCmd::TAG_VECTOR, va); build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), va); return; @@ -940,10 +916,7 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo // setpvalue(ra + 2, reinterpret_cast(uintptr_t(0)), LU_TAG_ITERATOR); build.inst(IrCmd::STORE_POINTER, build.vmReg(ra + 2), build.constInt(0)); - - if (FFlag::LuauCodegenLuData) - build.inst(IrCmd::STORE_EXTRA, build.vmReg(ra + 2), build.constInt(LU_TAG_ITERATOR)); - + build.inst(IrCmd::STORE_EXTRA, build.vmReg(ra + 2), build.constInt(LU_TAG_ITERATOR)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 2), build.constTag(LUA_TLIGHTUSERDATA)); build.inst(IrCmd::JUMP, target); @@ -976,10 +949,7 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp // setpvalue(ra + 2, reinterpret_cast(uintptr_t(0)), LU_TAG_ITERATOR); build.inst(IrCmd::STORE_POINTER, build.vmReg(ra + 2), build.constInt(0)); - - if (FFlag::LuauCodegenLuData) - build.inst(IrCmd::STORE_EXTRA, build.vmReg(ra + 2), build.constInt(LU_TAG_ITERATOR)); - + build.inst(IrCmd::STORE_EXTRA, build.vmReg(ra + 2), build.constInt(LU_TAG_ITERATOR)); build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 2), build.constTag(LUA_TLIGHTUSERDATA)); build.inst(IrCmd::JUMP, target); @@ -1225,7 +1195,7 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); - if (FFlag::LuauCodegenVector && bcTypes.a == LBC_TYPE_VECTOR) + if (bcTypes.a == LBC_TYPE_VECTOR) { build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TVECTOR), build.vmExit(pcpos)); diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 4214d015..d765b800 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -17,9 +17,8 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) -LUAU_FASTFLAG(LuauCodegenVector) -LUAU_FASTFLAG(LuauCodegenVectorTag) -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCodeGenCheckGcEffectFix, false) +LUAU_FASTFLAG(LuauCodegenVectorTag2) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCodeGenCoverForgprepEffect, false) namespace Luau { @@ -712,11 +711,11 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& uint8_t tag = state.tryGetTag(inst.b); // We know the tag of some instructions that result in TValue - if (FFlag::LuauCodegenVector && tag == 0xff) + if (tag == 0xff) { if (IrInst* arg = function.asInstOp(inst.b)) { - if (FFlag::LuauCodegenVectorTag) + if (FFlag::LuauCodegenVectorTag2) { if (arg->cmd == IrCmd::TAG_VECTOR) tag = LUA_TVECTOR; @@ -1050,11 +1049,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& { state.checkedGc = true; - if (DFFlag::LuauCodeGenCheckGcEffectFix) - { - // GC assist might modify table data (hash part) - state.invalidateHeapTableData(); - } + // GC assist might modify table data (hash part) + state.invalidateHeapTableData(); } break; case IrCmd::BARRIER_OBJ: @@ -1264,20 +1260,21 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: case IrCmd::DIV_VEC: - if (FFlag::LuauCodegenVectorTag) + if (FFlag::LuauCodegenVectorTag2) { if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR) - inst.a = a->a; + replace(function, inst.a, a->a); + if (IrInst* b = function.asInstOp(inst.b); b && b->cmd == IrCmd::TAG_VECTOR) - inst.b = b->a; + replace(function, inst.b, b->a); } break; case IrCmd::UNM_VEC: - if (FFlag::LuauCodegenVectorTag) + if (FFlag::LuauCodegenVectorTag2) { if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR) - inst.a = a->a; + replace(function, inst.a, a->a); } break; @@ -1409,6 +1406,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 0u}); state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 1u}); state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 2u}); + + if (DFFlag::LuauCodeGenCoverForgprepEffect) + state.invalidateUserCall(); break; } } diff --git a/CodeGen/src/OptimizeFinalX64.cpp b/CodeGen/src/OptimizeFinalX64.cpp index 911750b0..b2a5f7fd 100644 --- a/CodeGen/src/OptimizeFinalX64.cpp +++ b/CodeGen/src/OptimizeFinalX64.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauCodegenMathMemArgs, false) - namespace Luau { namespace CodeGen @@ -116,7 +114,7 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block) case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: { - if (FFlag::LuauCodegenMathMemArgs && inst.a.kind == IrOpKind::Inst) + if (inst.a.kind == IrOpKind::Inst) { IrInst& arg = function.instOp(inst.a); diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index 910e259a..05d44f82 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -108,7 +108,7 @@ // upper bound for number of size classes used by page allocator #ifndef LUA_SIZECLASSES -#define LUA_SIZECLASSES 32 +#define LUA_SIZECLASSES 40 #endif // available number of separate memory categories diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 8b264ef4..6628918f 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -120,9 +120,19 @@ static_assert(offsetof(Udata, data) == ABISWITCH(16, 16, 12), "size mismatch for static_assert(sizeof(Table) == ABISWITCH(48, 32, 32), "size mismatch for table header"); static_assert(offsetof(Buffer, data) == ABISWITCH(8, 8, 8), "size mismatch for buffer header"); +LUAU_FASTFLAGVARIABLE(LuauExtendedSizeClasses, false) + const size_t kSizeClasses = LUA_SIZECLASSES; -const size_t kMaxSmallSize = 512; -const size_t kPageSize = 16 * 1024 - 24; // slightly under 16KB since that results in less fragmentation due to heap metadata +const size_t kMaxSmallSize_DEPRECATED = 512; // TODO: remove with FFlagLuauExtendedSizeClasses +const size_t kMaxSmallSize = 1024; +const size_t kLargePageThreshold = 512; // larger pages are used for objects larger than this size to fit more of them into a page + +// constant factor to reduce our page sizes by, to increase the chances that pages we allocate will +// allow external allocators to allocate them without wasting space due to rounding introduced by their heap meta data +const size_t kExternalAllocatorMetaDataReduction = 24; + +const size_t kSmallPageSize = 16 * 1024 - kExternalAllocatorMetaDataReduction; +const size_t kLargePageSize = 32 * 1024 - kExternalAllocatorMetaDataReduction; const size_t kBlockHeader = sizeof(double) > sizeof(void*) ? sizeof(double) : sizeof(void*); // suitable for aligning double & void* on all platforms const size_t kGCOLinkOffset = (sizeof(GCheader) + sizeof(void*) - 1) & ~(sizeof(void*) - 1); // GCO pages contain freelist links after the GC header @@ -143,6 +153,7 @@ struct SizeClassConfig // - we first allocate sizes classes in multiples of 8 // - after the first cutoff we allocate size classes in multiples of 16 // - after the second cutoff we allocate size classes in multiples of 32 + // - after the third cutoff we allocate size classes in multiples of 64 // this balances internal fragmentation vs external fragmentation for (int size = 8; size < 64; size += 8) sizeOfClass[classCount++] = size; @@ -150,7 +161,10 @@ struct SizeClassConfig for (int size = 64; size < 256; size += 16) sizeOfClass[classCount++] = size; - for (int size = 256; size <= 512; size += 32) + for (int size = 256; size < 512; size += 32) + sizeOfClass[classCount++] = size; + + for (int size = 512; size <= 1024; size += 64) sizeOfClass[classCount++] = size; LUAU_ASSERT(size_t(classCount) <= kSizeClasses); @@ -169,7 +183,8 @@ struct SizeClassConfig const SizeClassConfig kSizeClassConfig; // size class for a block of size sz; returns -1 for size=0 because empty allocations take no space -#define sizeclass(sz) (size_t((sz)-1) < kMaxSmallSize ? kSizeClassConfig.classForSize[sz] : -1) +#define sizeclass(sz) \ + (size_t((sz)-1) < (FFlag::LuauExtendedSizeClasses ? kMaxSmallSize : kMaxSmallSize_DEPRECATED) ? kSizeClassConfig.classForSize[sz] : -1) // metadata for a block is stored in the first pointer of the block #define metadata(block) (*(void**)(block)) @@ -247,16 +262,34 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, uint8_t sizeClass, bool storeMetadata) { - int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + (storeMetadata ? kBlockHeader : 0); - int blockCount = (kPageSize - offsetof(lua_Page, data)) / blockSize; + if (FFlag::LuauExtendedSizeClasses) + { + int sizeOfClass = kSizeClassConfig.sizeOfClass[sizeClass]; + int pageSize = sizeOfClass > int(kLargePageThreshold) ? kLargePageSize : kSmallPageSize; + int blockSize = sizeOfClass + (storeMetadata ? kBlockHeader : 0); + int blockCount = (pageSize - offsetof(lua_Page, data)) / blockSize; - lua_Page* page = newpage(L, gcopageset, kPageSize, blockSize, blockCount); + lua_Page* page = newpage(L, gcopageset, pageSize, blockSize, blockCount); - // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) - LUAU_ASSERT(!freepageset[sizeClass]); - freepageset[sizeClass] = page; + // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) + LUAU_ASSERT(!freepageset[sizeClass]); + freepageset[sizeClass] = page; - return page; + return page; + } + else + { + int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + (storeMetadata ? kBlockHeader : 0); + int blockCount = (kSmallPageSize - offsetof(lua_Page, data)) / blockSize; + + lua_Page* page = newpage(L, gcopageset, kSmallPageSize, blockSize, blockCount); + + // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) + LUAU_ASSERT(!freepageset[sizeClass]); + freepageset[sizeClass] = page; + + return page; + } } static void freepage(lua_State* L, lua_Page** gcopageset, lua_Page* page) diff --git a/VM/src/lnumprint.cpp b/VM/src/lnumprint.cpp index c09b1be2..763675e0 100644 --- a/VM/src/lnumprint.cpp +++ b/VM/src/lnumprint.cpp @@ -11,8 +11,6 @@ #include #endif -LUAU_FASTFLAGVARIABLE(LuauSciNumberSkipTrailDot, false) - // This work is based on: // Raffaello Giulietti. The Schubfach way to render doubles. 2021 // https://drive.google.com/file/d/1IEeATSVnEE6TkrHlCYNY2GjaraBjOT4f/edit @@ -363,7 +361,7 @@ char* luai_num2str(char* buf, double n) char* exp = trimzero(buf + declen + 1); - if (FFlag::LuauSciNumberSkipTrailDot && exp[-1] == '.') + if (exp[-1] == '.') exp--; return printexp(exp, dot - 1); diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index 514a4359..640bd96e 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -48,7 +48,7 @@ int luaO_rawequalObj(const TValue* t1, const TValue* t2) case LUA_TBOOLEAN: return bvalue(t1) == bvalue(t2); // boolean true must be 1 !! case LUA_TLIGHTUSERDATA: - return pvalue(t1) == pvalue(t2) && (!FFlag::LuauTaggedLuData || lightuserdatatag(t1) == lightuserdatatag(t2)); + return pvalue(t1) == pvalue(t2) && lightuserdatatag(t1) == lightuserdatatag(t2); default: LUAU_ASSERT(iscollectable(t1)); return gcvalue(t1) == gcvalue(t2); @@ -71,7 +71,7 @@ int luaO_rawequalKey(const TKey* t1, const TValue* t2) case LUA_TBOOLEAN: return bvalue(t1) == bvalue(t2); // boolean true must be 1 !! case LUA_TLIGHTUSERDATA: - return pvalue(t1) == pvalue(t2) && (!FFlag::LuauTaggedLuData || lightuserdatatag(t1) == lightuserdatatag(t2)); + return pvalue(t1) == pvalue(t2) && lightuserdatatag(t1) == lightuserdatatag(t2); default: LUAU_ASSERT(iscollectable(t1)); return gcvalue(t1) == gcvalue(t2); diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 44f2bccc..1f84e2da 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -5,8 +5,6 @@ #include "lua.h" #include "lcommon.h" -LUAU_FASTFLAG(LuauTaggedLuData) - /* ** Union of all collectible objects */ diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 03d7cf39..85669e97 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,8 +8,6 @@ #include #include -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauInterruptablePatternMatch, false) - // macro to `unsign' a character #define uchar(c) ((unsigned char)(c)) @@ -432,18 +430,15 @@ static const char* match(MatchState* ms, const char* s, const char* p) if (ms->matchdepth-- == 0) luaL_error(ms->L, "pattern too complex"); - if (DFFlag::LuauInterruptablePatternMatch) - { - lua_State* L = ms->L; - void (*interrupt)(lua_State*, int) = L->global->cb.interrupt; + lua_State* L = ms->L; + void (*interrupt)(lua_State*, int) = L->global->cb.interrupt; - if (LUAU_UNLIKELY(!!interrupt)) - { - // this interrupt is not yieldable - L->nCcalls++; - interrupt(L, -1); - L->nCcalls--; - } + if (LUAU_UNLIKELY(!!interrupt)) + { + // this interrupt is not yieldable + L->nCcalls++; + interrupt(L, -1); + L->nCcalls--; } init: // using goto's to optimize tail recursion diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 23369027..09c3d824 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -129,7 +129,7 @@ const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) if (ttisstring(type)) return tsvalue(type); } - else if (FFlag::LuauTaggedLuData && ttislightuserdata(o)) + else if (ttislightuserdata(o)) { int tag = lightuserdatatag(o); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 2ed4819e..74e30c94 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -133,8 +133,6 @@ // Does VM support native execution via ExecutionCallbacks? We mostly assume it does but keep the define to make it easy to quantify the cost. #define VM_HAS_NATIVE 1 -LUAU_FASTFLAGVARIABLE(LuauTaggedLuData, false) - LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) { ptrdiff_t base = savestack(L, L->base); @@ -1110,9 +1108,7 @@ reentry: VM_NEXT(); case LUA_TLIGHTUSERDATA: - pc += (pvalue(ra) == pvalue(rb) && (!FFlag::LuauTaggedLuData || lightuserdatatag(ra) == lightuserdatatag(rb))) - ? LUAU_INSN_D(insn) - : 1; + pc += (pvalue(ra) == pvalue(rb) && lightuserdatatag(ra) == lightuserdatatag(rb)) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); VM_NEXT(); @@ -1227,9 +1223,7 @@ reentry: VM_NEXT(); case LUA_TLIGHTUSERDATA: - pc += (pvalue(ra) != pvalue(rb) || (FFlag::LuauTaggedLuData && lightuserdatatag(ra) != lightuserdatatag(rb))) - ? LUAU_INSN_D(insn) - : 1; + pc += (pvalue(ra) != pvalue(rb) || lightuserdatatag(ra) != lightuserdatatag(rb)) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); VM_NEXT(); diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index a2186c5f..4db8bba7 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -288,7 +288,7 @@ int luaV_equalval(lua_State* L, const TValue* t1, const TValue* t2) case LUA_TBOOLEAN: return bvalue(t1) == bvalue(t2); // true must be 1 !! case LUA_TLIGHTUSERDATA: - return pvalue(t1) == pvalue(t2) && (!FFlag::LuauTaggedLuData || lightuserdatatag(t1) == lightuserdatatag(t2)); + return pvalue(t1) == pvalue(t2) && lightuserdatatag(t1) == lightuserdatatag(t2); case LUA_TUSERDATA: { tm = get_compTM(L, uvalue(t1)->metatable, uvalue(t2)->metatable, TM_EQ); diff --git a/bench/other/boatbomber-HashLib/Base64.lua b/bench/other/boatbomber-HashLib/Base64.lua new file mode 100644 index 00000000..f2b1172a --- /dev/null +++ b/bench/other/boatbomber-HashLib/Base64.lua @@ -0,0 +1,130 @@ +-- @original: https://gist.github.com/Reselim/40d62b17d138cc74335a1b0709e19ce2 +local Alphabet = {} +local Indexes = {} + +-- A-Z +for Index = 65, 90 do + table.insert(Alphabet, Index) +end + +-- a-z +for Index = 97, 122 do + table.insert(Alphabet, Index) +end + +-- 0-9 +for Index = 48, 57 do + table.insert(Alphabet, Index) +end + +table.insert(Alphabet, 43) -- + +table.insert(Alphabet, 47) -- / + +for Index, Character in ipairs(Alphabet) do + Indexes[Character] = Index +end + +local Base64 = {} + +local bit32_rshift = bit32.rshift +local bit32_lshift = bit32.lshift +local bit32_band = bit32.band + +--[[** + Encodes a string in Base64. + @param [t:string] Input The input string to encode. + @returns [t:string] The string encoded in Base64. +**--]] +function Base64.Encode(Input) + local Output = {} + local Length = 0 + + for Index = 1, #Input, 3 do + local C1, C2, C3 = string.byte(Input, Index, Index + 2) + + local A = bit32_rshift(C1, 2) + local B = bit32_lshift(bit32_band(C1, 3), 4) + bit32_rshift(C2 or 0, 4) + local C = bit32_lshift(bit32_band(C2 or 0, 15), 2) + bit32_rshift(C3 or 0, 6) + local D = bit32_band(C3 or 0, 63) + + Length = Length + 1 + Output[Length] = Alphabet[A + 1] + + Length = Length + 1 + Output[Length] = Alphabet[B + 1] + + Length = Length + 1 + Output[Length] = C2 and Alphabet[C + 1] or 61 + + Length = Length + 1 + Output[Length] = C3 and Alphabet[D + 1] or 61 + end + + local NewOutput = {} + local NewLength = 0 + local IndexAdd4096Sub1 + + for Index = 1, Length, 4096 do + NewLength = NewLength + 1 + IndexAdd4096Sub1 = Index + 4096 - 1 + + NewOutput[NewLength] = string.char( + table.unpack(Output, Index, IndexAdd4096Sub1 > Length and Length or IndexAdd4096Sub1) + ) + end + + return table.concat(NewOutput) +end + +--[[** + Decodes a string from Base64. + @param [t:string] Input The input string to decode. + @returns [t:string] The newly decoded string. +**--]] +function Base64.Decode(Input) + local Output = {} + local Length = 0 + + for Index = 1, #Input, 4 do + local C1, C2, C3, C4 = string.byte(Input, Index, Index + 3) + + local I1 = Indexes[C1] - 1 + local I2 = Indexes[C2] - 1 + local I3 = (Indexes[C3] or 1) - 1 + local I4 = (Indexes[C4] or 1) - 1 + + local A = bit32_lshift(I1, 2) + bit32_rshift(I2, 4) + local B = bit32_lshift(bit32_band(I2, 15), 4) + bit32_rshift(I3, 2) + local C = bit32_lshift(bit32_band(I3, 3), 6) + I4 + + Length = Length + 1 + Output[Length] = A + + if C3 ~= 61 then + Length = Length + 1 + Output[Length] = B + end + + if C4 ~= 61 then + Length = Length + 1 + Output[Length] = C + end + end + + local NewOutput = {} + local NewLength = 0 + local IndexAdd4096Sub1 + + for Index = 1, Length, 4096 do + NewLength = NewLength + 1 + IndexAdd4096Sub1 = Index + 4096 - 1 + + NewOutput[NewLength] = string.char( + table.unpack(Output, Index, IndexAdd4096Sub1 > Length and Length or IndexAdd4096Sub1) + ) + end + + return table.concat(NewOutput) +end + +return Base64 diff --git a/bench/other/boatbomber-HashLib/HashLib.spec.lua b/bench/other/boatbomber-HashLib/HashLib.spec.lua new file mode 100644 index 00000000..c5734b30 --- /dev/null +++ b/bench/other/boatbomber-HashLib/HashLib.spec.lua @@ -0,0 +1,39 @@ +local function describe(phrase, callback) end +local function it(phrase, callback) end +local function expect(value) end + +return function() + local HashLib = require(script.Parent) + local sha256 = HashLib.sha256 + + describe("HashLib.sha256", function() + it("should properly encode strings", function() + expect(sha256("abc").to.equal("ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")) + expect( + sha256("The quick brown fox jumps over the lazy dog").to.equal( + "d7a8fbb307d7809469ca9abcb0082e4f8d5651e46d3cdb762d02d0bf37c9e592" + ) + ) + expect(sha256("123456").to.equal("8d969eef6ecad3c29a3a629280e686cf0c3f5d5a86aff3ca12020c923adc6c92")) + end) + + it("should create a private closure that works", function() + local AppendNextChunk = sha256() + AppendNextChunk("The quick brown fox") + AppendNextChunk(" jumps ") + AppendNextChunk("") -- chunk may be an empty string + AppendNextChunk("over the lazy dog") + expect(AppendNextChunk()).to.equal("d7a8fbb307d7809469ca9abcb0082e4f8d5651e46d3cdb762d02d0bf37c9e592") + end) + + it("should allow the private closure to work if called twice", function() + local AppendNextChunk = sha256() + AppendNextChunk("The quick brown fox") + AppendNextChunk(" jumps ") + AppendNextChunk("") -- chunk may be an empty string + AppendNextChunk("over the lazy dog") + AppendNextChunk() + expect(AppendNextChunk()).to.equal("d7a8fbb307d7809469ca9abcb0082e4f8d5651e46d3cdb762d02d0bf37c9e592") + end) + end) +end diff --git a/bench/other/boatbomber-HashLib/LICENSE b/bench/other/boatbomber-HashLib/LICENSE new file mode 100644 index 00000000..32dc1f23 --- /dev/null +++ b/bench/other/boatbomber-HashLib/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 boatbomber + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/bench/other/boatbomber-HashLib/init.lua b/bench/other/boatbomber-HashLib/init.lua new file mode 100644 index 00000000..9d44d8ca --- /dev/null +++ b/bench/other/boatbomber-HashLib/init.lua @@ -0,0 +1,1555 @@ +--[=[------------------------------------------------------------------------------------------------------------------------ +-- HashLib by Egor Skriptunoff, boatbomber, and howmanysmall + +-------------------------------------------------------------------------------------------------------------------------- + +Module was originally written by Egor Skriptunoff and distributed under an MIT license. +It can be found here: https://github.com/Egor-Skriptunoff/pure_lua_SHA/blob/master/sha2.lua + +That version was around 3000 lines long, and supported Lua versions 5.1, 5.2, 5.3, and 5.4, and LuaJIT. +Although that is super cool, Roblox only uses Lua 5.1, so that was extreme overkill. + +I, boatbomber, worked to port it to Roblox in a way that doesn't overcomplicate it with support of unreachable +cases. Then, howmanysmall did some final optimizations that really squeeze out all the performance possible. +It's gotten stupid fast, thanks to her! + +After quite a bit of work and benchmarking, this is what we were left with. +Enjoy! + +-------------------------------------------------------------------------------------------------------------------------- + +DESCRIPTION: + This module contains functions to calculate SHA digest: + MD5, SHA-1, + SHA-224, SHA-256, SHA-512/224, SHA-512/256, SHA-384, SHA-512, + SHA3-224, SHA3-256, SHA3-384, SHA3-512, SHAKE128, SHAKE256, + HMAC + Additionally, it has a few extra utility functions: + hex_to_bin + base64_to_bin + bin_to_base64 + Written in pure Lua. +USAGE: + Input data should be a string + Result (SHA digest) is returned in hexadecimal representation as a string of lowercase hex digits. + Simplest usage example: + local HashLib = require(script.HashLib) + local your_hash = HashLib.sha256("your string") +API: + HashLib.md5 + HashLib.sha1 + SHA2 hash functions: + HashLib.sha224 + HashLib.sha256 + HashLib.sha512_224 + HashLib.sha512_256 + HashLib.sha384 + HashLib.sha512 + SHA3 hash functions: + HashLib.sha3_224 + HashLib.sha3_256 + HashLib.sha3_384 + HashLib.sha3_512 + HashLib.shake128 + HashLib.shake256 + Misc utilities: + HashLib.hmac (Applicable to any hash function from this module except SHAKE*) + HashLib.hex_to_bin + HashLib.base64_to_bin + HashLib.bin_to_base64 + +--]=] +--------------------------------------------------------------------------- + +local Base64 = require(script.Base64) + +-------------------------------------------------------------------------------- +-- LOCALIZATION FOR VM OPTIMIZATIONS +-------------------------------------------------------------------------------- + +local ipairs = ipairs + +-------------------------------------------------------------------------------- +-- 32-BIT BITWISE FUNCTIONS +-------------------------------------------------------------------------------- +-- Only low 32 bits of function arguments matter, high bits are ignored +-- The result of all functions (except HEX) is an integer inside "correct range": +-- for "bit" library: (-TWO_POW_31)..(TWO_POW_31-1) +-- for "bit32" library: 0..(TWO_POW_32-1) +local bit32_band = bit32.band -- 2 arguments +local bit32_bor = bit32.bor -- 2 arguments +local bit32_bxor = bit32.bxor -- 2..5 arguments +local bit32_lshift = bit32.lshift -- second argument is integer 0..31 +local bit32_rshift = bit32.rshift -- second argument is integer 0..31 +local bit32_lrotate = bit32.lrotate -- second argument is integer 0..31 +local bit32_rrotate = bit32.rrotate -- second argument is integer 0..31 + +-------------------------------------------------------------------------------- +-- CREATING OPTIMIZED INNER LOOP +-------------------------------------------------------------------------------- +-- Arrays of SHA2 "magic numbers" (in "INT64" and "FFI" branches "*_lo" arrays contain 64-bit values) +local sha2_K_lo, sha2_K_hi, sha2_H_lo, sha2_H_hi, sha3_RC_lo, sha3_RC_hi = {}, {}, {}, {}, {}, {} +local sha2_H_ext256 = { + [224] = {}, + [256] = sha2_H_hi, +} + +local sha2_H_ext512_lo, sha2_H_ext512_hi = { + [384] = {}, + [512] = sha2_H_lo, +}, { + [384] = {}, + [512] = sha2_H_hi, +} + +local md5_K, md5_sha1_H = {}, { 0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0 } +local md5_next_shift = { + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 28, + 25, + 26, + 27, + 0, + 0, + 10, + 9, + 11, + 12, + 0, + 15, + 16, + 17, + 18, + 0, + 20, + 22, + 23, + 21, +} +local HEX64, XOR64A5, lanes_index_base -- defined only for branches that internally use 64-bit integers: "INT64" and "FFI" +local common_W = {} -- temporary table shared between all calculations (to avoid creating new temporary table every time) +local K_lo_modulo, hi_factor, hi_factor_keccak = 4294967296, 0, 0 + +local TWO_POW_NEG_56 = 2 ^ -56 +local TWO_POW_NEG_17 = 2 ^ -17 + +local TWO_POW_2 = 2 ^ 2 +local TWO_POW_3 = 2 ^ 3 +local TWO_POW_4 = 2 ^ 4 +local TWO_POW_5 = 2 ^ 5 +local TWO_POW_6 = 2 ^ 6 +local TWO_POW_7 = 2 ^ 7 +local TWO_POW_8 = 2 ^ 8 +local TWO_POW_9 = 2 ^ 9 +local TWO_POW_10 = 2 ^ 10 +local TWO_POW_11 = 2 ^ 11 +local TWO_POW_12 = 2 ^ 12 +local TWO_POW_13 = 2 ^ 13 +local TWO_POW_14 = 2 ^ 14 +local TWO_POW_15 = 2 ^ 15 +local TWO_POW_16 = 2 ^ 16 +local TWO_POW_17 = 2 ^ 17 +local TWO_POW_18 = 2 ^ 18 +local TWO_POW_19 = 2 ^ 19 +local TWO_POW_20 = 2 ^ 20 +local TWO_POW_21 = 2 ^ 21 +local TWO_POW_22 = 2 ^ 22 +local TWO_POW_23 = 2 ^ 23 +local TWO_POW_24 = 2 ^ 24 +local TWO_POW_25 = 2 ^ 25 +local TWO_POW_26 = 2 ^ 26 +local TWO_POW_27 = 2 ^ 27 +local TWO_POW_28 = 2 ^ 28 +local TWO_POW_29 = 2 ^ 29 +local TWO_POW_30 = 2 ^ 30 +local TWO_POW_31 = 2 ^ 31 +local TWO_POW_32 = 2 ^ 32 +local TWO_POW_40 = 2 ^ 40 + +local TWO56_POW_7 = 256 ^ 7 + +-- Implementation for Lua 5.1/5.2 (with or without bitwise library available) +local function sha256_feed_64(H, str, offs, size) + -- offs >= 0, size >= 0, size is multiple of 64 + local W, K = common_W, sha2_K_hi + local h1, h2, h3, h4, h5, h6, h7, h8 = H[1], H[2], H[3], H[4], H[5], H[6], H[7], H[8] + for pos = offs, offs + size - 1, 64 do + for j = 1, 16 do + pos = pos + 4 + local a, b, c, d = string.byte(str, pos - 3, pos) + W[j] = ((a * 256 + b) * 256 + c) * 256 + d + end + + for j = 17, 64 do + local a, b = W[j - 15], W[j - 2] + W[j] = bit32_bxor(bit32_rrotate(a, 7), bit32_lrotate(a, 14), bit32_rshift(a, 3)) + + bit32_bxor(bit32_lrotate(b, 15), bit32_lrotate(b, 13), bit32_rshift(b, 10)) + + W[j - 7] + + W[j - 16] + end + + local a, b, c, d, e, f, g, h = h1, h2, h3, h4, h5, h6, h7, h8 + for j = 1, 64 do + local z = bit32_bxor(bit32_rrotate(e, 6), bit32_rrotate(e, 11), bit32_lrotate(e, 7)) + + bit32_band(e, f) + + bit32_band(-1 - e, g) + + h + + K[j] + + W[j] + h = g + g = f + f = e + e = z + d + d = c + c = b + b = a + a = z + + bit32_band(d, c) + + bit32_band(a, bit32_bxor(d, c)) + + bit32_bxor(bit32_rrotate(a, 2), bit32_rrotate(a, 13), bit32_lrotate(a, 10)) + end + + h1, h2, h3, h4 = (a + h1) % 4294967296, (b + h2) % 4294967296, (c + h3) % 4294967296, (d + h4) % 4294967296 + h5, h6, h7, h8 = (e + h5) % 4294967296, (f + h6) % 4294967296, (g + h7) % 4294967296, (h + h8) % 4294967296 + end + + H[1], H[2], H[3], H[4], H[5], H[6], H[7], H[8] = h1, h2, h3, h4, h5, h6, h7, h8 +end + +local function sha512_feed_128(H_lo, H_hi, str, offs, size) + -- offs >= 0, size >= 0, size is multiple of 128 + -- W1_hi, W1_lo, W2_hi, W2_lo, ... Wk_hi = W[2*k-1], Wk_lo = W[2*k] + local W, K_lo, K_hi = common_W, sha2_K_lo, sha2_K_hi + local h1_lo, h2_lo, h3_lo, h4_lo, h5_lo, h6_lo, h7_lo, h8_lo = + H_lo[1], H_lo[2], H_lo[3], H_lo[4], H_lo[5], H_lo[6], H_lo[7], H_lo[8] + local h1_hi, h2_hi, h3_hi, h4_hi, h5_hi, h6_hi, h7_hi, h8_hi = + H_hi[1], H_hi[2], H_hi[3], H_hi[4], H_hi[5], H_hi[6], H_hi[7], H_hi[8] + for pos = offs, offs + size - 1, 128 do + for j = 1, 16 * 2 do + pos = pos + 4 + local a, b, c, d = string.byte(str, pos - 3, pos) + W[j] = ((a * 256 + b) * 256 + c) * 256 + d + end + + for jj = 34, 160, 2 do + local a_lo, a_hi, b_lo, b_hi = W[jj - 30], W[jj - 31], W[jj - 4], W[jj - 5] + local tmp1 = bit32_bxor( + bit32_rshift(a_lo, 1) + bit32_lshift(a_hi, 31), + bit32_rshift(a_lo, 8) + bit32_lshift(a_hi, 24), + bit32_rshift(a_lo, 7) + bit32_lshift(a_hi, 25) + ) % 4294967296 + bit32_bxor( + bit32_rshift(b_lo, 19) + bit32_lshift(b_hi, 13), + bit32_lshift(b_lo, 3) + bit32_rshift(b_hi, 29), + bit32_rshift(b_lo, 6) + bit32_lshift(b_hi, 26) + ) % 4294967296 + W[jj - 14] + W[jj - 32] + + local tmp2 = tmp1 % 4294967296 + W[jj - 1] = bit32_bxor( + bit32_rshift(a_hi, 1) + bit32_lshift(a_lo, 31), + bit32_rshift(a_hi, 8) + bit32_lshift(a_lo, 24), + bit32_rshift(a_hi, 7) + ) + bit32_bxor( + bit32_rshift(b_hi, 19) + bit32_lshift(b_lo, 13), + bit32_lshift(b_hi, 3) + bit32_rshift(b_lo, 29), + bit32_rshift(b_hi, 6) + ) + W[jj - 15] + W[jj - 33] + (tmp1 - tmp2) / 4294967296 + + W[jj] = tmp2 + end + + local a_lo, b_lo, c_lo, d_lo, e_lo, f_lo, g_lo, h_lo = h1_lo, h2_lo, h3_lo, h4_lo, h5_lo, h6_lo, h7_lo, h8_lo + local a_hi, b_hi, c_hi, d_hi, e_hi, f_hi, g_hi, h_hi = h1_hi, h2_hi, h3_hi, h4_hi, h5_hi, h6_hi, h7_hi, h8_hi + for j = 1, 80 do + local jj = 2 * j + local tmp1 = bit32_bxor( + bit32_rshift(e_lo, 14) + bit32_lshift(e_hi, 18), + bit32_rshift(e_lo, 18) + bit32_lshift(e_hi, 14), + bit32_lshift(e_lo, 23) + bit32_rshift(e_hi, 9) + ) % 4294967296 + (bit32_band(e_lo, f_lo) + bit32_band(-1 - e_lo, g_lo)) % 4294967296 + h_lo + K_lo[j] + W[jj] + + local z_lo = tmp1 % 4294967296 + local z_hi = bit32_bxor( + bit32_rshift(e_hi, 14) + bit32_lshift(e_lo, 18), + bit32_rshift(e_hi, 18) + bit32_lshift(e_lo, 14), + bit32_lshift(e_hi, 23) + bit32_rshift(e_lo, 9) + ) + bit32_band(e_hi, f_hi) + bit32_band(-1 - e_hi, g_hi) + h_hi + K_hi[j] + W[jj - 1] + (tmp1 - z_lo) / 4294967296 + + h_lo = g_lo + h_hi = g_hi + g_lo = f_lo + g_hi = f_hi + f_lo = e_lo + f_hi = e_hi + tmp1 = z_lo + d_lo + e_lo = tmp1 % 4294967296 + e_hi = z_hi + d_hi + (tmp1 - e_lo) / 4294967296 + d_lo = c_lo + d_hi = c_hi + c_lo = b_lo + c_hi = b_hi + b_lo = a_lo + b_hi = a_hi + tmp1 = z_lo + + (bit32_band(d_lo, c_lo) + bit32_band(b_lo, bit32_bxor(d_lo, c_lo))) % 4294967296 + + bit32_bxor( + bit32_rshift(b_lo, 28) + bit32_lshift(b_hi, 4), + bit32_lshift(b_lo, 30) + bit32_rshift(b_hi, 2), + bit32_lshift(b_lo, 25) + bit32_rshift(b_hi, 7) + ) + % 4294967296 + a_lo = tmp1 % 4294967296 + a_hi = z_hi + + (bit32_band(d_hi, c_hi) + bit32_band(b_hi, bit32_bxor(d_hi, c_hi))) + + bit32_bxor( + bit32_rshift(b_hi, 28) + bit32_lshift(b_lo, 4), + bit32_lshift(b_hi, 30) + bit32_rshift(b_lo, 2), + bit32_lshift(b_hi, 25) + bit32_rshift(b_lo, 7) + ) + + (tmp1 - a_lo) / 4294967296 + end + + a_lo = h1_lo + a_lo + h1_lo = a_lo % 4294967296 + h1_hi = (h1_hi + a_hi + (a_lo - h1_lo) / 4294967296) % 4294967296 + a_lo = h2_lo + b_lo + h2_lo = a_lo % 4294967296 + h2_hi = (h2_hi + b_hi + (a_lo - h2_lo) / 4294967296) % 4294967296 + a_lo = h3_lo + c_lo + h3_lo = a_lo % 4294967296 + h3_hi = (h3_hi + c_hi + (a_lo - h3_lo) / 4294967296) % 4294967296 + a_lo = h4_lo + d_lo + h4_lo = a_lo % 4294967296 + h4_hi = (h4_hi + d_hi + (a_lo - h4_lo) / 4294967296) % 4294967296 + a_lo = h5_lo + e_lo + h5_lo = a_lo % 4294967296 + h5_hi = (h5_hi + e_hi + (a_lo - h5_lo) / 4294967296) % 4294967296 + a_lo = h6_lo + f_lo + h6_lo = a_lo % 4294967296 + h6_hi = (h6_hi + f_hi + (a_lo - h6_lo) / 4294967296) % 4294967296 + a_lo = h7_lo + g_lo + h7_lo = a_lo % 4294967296 + h7_hi = (h7_hi + g_hi + (a_lo - h7_lo) / 4294967296) % 4294967296 + a_lo = h8_lo + h_lo + h8_lo = a_lo % 4294967296 + h8_hi = (h8_hi + h_hi + (a_lo - h8_lo) / 4294967296) % 4294967296 + end + + H_lo[1], H_lo[2], H_lo[3], H_lo[4], H_lo[5], H_lo[6], H_lo[7], H_lo[8] = + h1_lo, h2_lo, h3_lo, h4_lo, h5_lo, h6_lo, h7_lo, h8_lo + H_hi[1], H_hi[2], H_hi[3], H_hi[4], H_hi[5], H_hi[6], H_hi[7], H_hi[8] = + h1_hi, h2_hi, h3_hi, h4_hi, h5_hi, h6_hi, h7_hi, h8_hi +end + +local function md5_feed_64(H, str, offs, size) + -- offs >= 0, size >= 0, size is multiple of 64 + local W, K, md5_next_shift = common_W, md5_K, md5_next_shift + local h1, h2, h3, h4 = H[1], H[2], H[3], H[4] + for pos = offs, offs + size - 1, 64 do + for j = 1, 16 do + pos = pos + 4 + local a, b, c, d = string.byte(str, pos - 3, pos) + W[j] = ((d * 256 + c) * 256 + b) * 256 + a + end + + local a, b, c, d = h1, h2, h3, h4 + local s = 25 + for j = 1, 16 do + local F = bit32_rrotate(bit32_band(b, c) + bit32_band(-1 - b, d) + a + K[j] + W[j], s) + b + s = md5_next_shift[s] + a = d + d = c + c = b + b = F + end + + s = 27 + for j = 17, 32 do + local F = bit32_rrotate(bit32_band(d, b) + bit32_band(-1 - d, c) + a + K[j] + W[(5 * j - 4) % 16 + 1], s) + + b + s = md5_next_shift[s] + a = d + d = c + c = b + b = F + end + + s = 28 + for j = 33, 48 do + local F = bit32_rrotate(bit32_bxor(bit32_bxor(b, c), d) + a + K[j] + W[(3 * j + 2) % 16 + 1], s) + b + s = md5_next_shift[s] + a = d + d = c + c = b + b = F + end + + s = 26 + for j = 49, 64 do + local F = bit32_rrotate(bit32_bxor(c, bit32_bor(b, -1 - d)) + a + K[j] + W[(j * 7 - 7) % 16 + 1], s) + b + s = md5_next_shift[s] + a = d + d = c + c = b + b = F + end + + h1 = (a + h1) % 4294967296 + h2 = (b + h2) % 4294967296 + h3 = (c + h3) % 4294967296 + h4 = (d + h4) % 4294967296 + end + + H[1], H[2], H[3], H[4] = h1, h2, h3, h4 +end + +local function sha1_feed_64(H, str, offs, size) + -- offs >= 0, size >= 0, size is multiple of 64 + local W = common_W + local h1, h2, h3, h4, h5 = H[1], H[2], H[3], H[4], H[5] + for pos = offs, offs + size - 1, 64 do + for j = 1, 16 do + pos = pos + 4 + local a, b, c, d = string.byte(str, pos - 3, pos) + W[j] = ((a * 256 + b) * 256 + c) * 256 + d + end + + for j = 17, 80 do + W[j] = bit32_lrotate(bit32_bxor(W[j - 3], W[j - 8], W[j - 14], W[j - 16]), 1) + end + + local a, b, c, d, e = h1, h2, h3, h4, h5 + for j = 1, 20 do + local z = bit32_lrotate(a, 5) + bit32_band(b, c) + bit32_band(-1 - b, d) + 0x5A827999 + W[j] + e -- constant = math.floor(TWO_POW_30 * sqrt(2)) + e = d + d = c + c = bit32_rrotate(b, 2) + b = a + a = z + end + + for j = 21, 40 do + local z = bit32_lrotate(a, 5) + bit32_bxor(b, c, d) + 0x6ED9EBA1 + W[j] + e -- TWO_POW_30 * sqrt(3) + e = d + d = c + c = bit32_rrotate(b, 2) + b = a + a = z + end + + for j = 41, 60 do + local z = bit32_lrotate(a, 5) + bit32_band(d, c) + bit32_band(b, bit32_bxor(d, c)) + 0x8F1BBCDC + W[j] + e -- TWO_POW_30 * sqrt(5) + e = d + d = c + c = bit32_rrotate(b, 2) + b = a + a = z + end + + for j = 61, 80 do + local z = bit32_lrotate(a, 5) + bit32_bxor(b, c, d) + 0xCA62C1D6 + W[j] + e -- TWO_POW_30 * sqrt(10) + e = d + d = c + c = bit32_rrotate(b, 2) + b = a + a = z + end + + h1 = (a + h1) % 4294967296 + h2 = (b + h2) % 4294967296 + h3 = (c + h3) % 4294967296 + h4 = (d + h4) % 4294967296 + h5 = (e + h5) % 4294967296 + end + + H[1], H[2], H[3], H[4], H[5] = h1, h2, h3, h4, h5 +end + +local function keccak_feed(lanes_lo, lanes_hi, str, offs, size, block_size_in_bytes) + -- This is an example of a Lua function having 79 local variables :-) + -- offs >= 0, size >= 0, size is multiple of block_size_in_bytes, block_size_in_bytes is positive multiple of 8 + local RC_lo, RC_hi = sha3_RC_lo, sha3_RC_hi + local qwords_qty = block_size_in_bytes / 8 + for pos = offs, offs + size - 1, block_size_in_bytes do + for j = 1, qwords_qty do + local a, b, c, d = string.byte(str, pos + 1, pos + 4) + lanes_lo[j] = bit32_bxor(lanes_lo[j], ((d * 256 + c) * 256 + b) * 256 + a) + pos = pos + 8 + a, b, c, d = string.byte(str, pos - 3, pos) + lanes_hi[j] = bit32_bxor(lanes_hi[j], ((d * 256 + c) * 256 + b) * 256 + a) + end + + local L01_lo, L01_hi, L02_lo, L02_hi, L03_lo, L03_hi, L04_lo, L04_hi, L05_lo, L05_hi, L06_lo, L06_hi, L07_lo, L07_hi, L08_lo, L08_hi, L09_lo, L09_hi, L10_lo, L10_hi, L11_lo, L11_hi, L12_lo, L12_hi, L13_lo, L13_hi, L14_lo, L14_hi, L15_lo, L15_hi, L16_lo, L16_hi, L17_lo, L17_hi, L18_lo, L18_hi, L19_lo, L19_hi, L20_lo, L20_hi, L21_lo, L21_hi, L22_lo, L22_hi, L23_lo, L23_hi, L24_lo, L24_hi, L25_lo, L25_hi = + lanes_lo[1], + lanes_hi[1], + lanes_lo[2], + lanes_hi[2], + lanes_lo[3], + lanes_hi[3], + lanes_lo[4], + lanes_hi[4], + lanes_lo[5], + lanes_hi[5], + lanes_lo[6], + lanes_hi[6], + lanes_lo[7], + lanes_hi[7], + lanes_lo[8], + lanes_hi[8], + lanes_lo[9], + lanes_hi[9], + lanes_lo[10], + lanes_hi[10], + lanes_lo[11], + lanes_hi[11], + lanes_lo[12], + lanes_hi[12], + lanes_lo[13], + lanes_hi[13], + lanes_lo[14], + lanes_hi[14], + lanes_lo[15], + lanes_hi[15], + lanes_lo[16], + lanes_hi[16], + lanes_lo[17], + lanes_hi[17], + lanes_lo[18], + lanes_hi[18], + lanes_lo[19], + lanes_hi[19], + lanes_lo[20], + lanes_hi[20], + lanes_lo[21], + lanes_hi[21], + lanes_lo[22], + lanes_hi[22], + lanes_lo[23], + lanes_hi[23], + lanes_lo[24], + lanes_hi[24], + lanes_lo[25], + lanes_hi[25] + + for round_idx = 1, 24 do + local C1_lo = bit32_bxor(L01_lo, L06_lo, L11_lo, L16_lo, L21_lo) + local C1_hi = bit32_bxor(L01_hi, L06_hi, L11_hi, L16_hi, L21_hi) + local C2_lo = bit32_bxor(L02_lo, L07_lo, L12_lo, L17_lo, L22_lo) + local C2_hi = bit32_bxor(L02_hi, L07_hi, L12_hi, L17_hi, L22_hi) + local C3_lo = bit32_bxor(L03_lo, L08_lo, L13_lo, L18_lo, L23_lo) + local C3_hi = bit32_bxor(L03_hi, L08_hi, L13_hi, L18_hi, L23_hi) + local C4_lo = bit32_bxor(L04_lo, L09_lo, L14_lo, L19_lo, L24_lo) + local C4_hi = bit32_bxor(L04_hi, L09_hi, L14_hi, L19_hi, L24_hi) + local C5_lo = bit32_bxor(L05_lo, L10_lo, L15_lo, L20_lo, L25_lo) + local C5_hi = bit32_bxor(L05_hi, L10_hi, L15_hi, L20_hi, L25_hi) + + local D_lo = bit32_bxor(C1_lo, C3_lo * 2 + (C3_hi % TWO_POW_32 - C3_hi % TWO_POW_31) / TWO_POW_31) + local D_hi = bit32_bxor(C1_hi, C3_hi * 2 + (C3_lo % TWO_POW_32 - C3_lo % TWO_POW_31) / TWO_POW_31) + + local T0_lo = bit32_bxor(D_lo, L02_lo) + local T0_hi = bit32_bxor(D_hi, L02_hi) + local T1_lo = bit32_bxor(D_lo, L07_lo) + local T1_hi = bit32_bxor(D_hi, L07_hi) + local T2_lo = bit32_bxor(D_lo, L12_lo) + local T2_hi = bit32_bxor(D_hi, L12_hi) + local T3_lo = bit32_bxor(D_lo, L17_lo) + local T3_hi = bit32_bxor(D_hi, L17_hi) + local T4_lo = bit32_bxor(D_lo, L22_lo) + local T4_hi = bit32_bxor(D_hi, L22_hi) + + L02_lo = (T1_lo % TWO_POW_32 - T1_lo % TWO_POW_20) / TWO_POW_20 + T1_hi * TWO_POW_12 + L02_hi = (T1_hi % TWO_POW_32 - T1_hi % TWO_POW_20) / TWO_POW_20 + T1_lo * TWO_POW_12 + L07_lo = (T3_lo % TWO_POW_32 - T3_lo % TWO_POW_19) / TWO_POW_19 + T3_hi * TWO_POW_13 + L07_hi = (T3_hi % TWO_POW_32 - T3_hi % TWO_POW_19) / TWO_POW_19 + T3_lo * TWO_POW_13 + L12_lo = T0_lo * 2 + (T0_hi % TWO_POW_32 - T0_hi % TWO_POW_31) / TWO_POW_31 + L12_hi = T0_hi * 2 + (T0_lo % TWO_POW_32 - T0_lo % TWO_POW_31) / TWO_POW_31 + L17_lo = T2_lo * TWO_POW_10 + (T2_hi % TWO_POW_32 - T2_hi % TWO_POW_22) / TWO_POW_22 + L17_hi = T2_hi * TWO_POW_10 + (T2_lo % TWO_POW_32 - T2_lo % TWO_POW_22) / TWO_POW_22 + L22_lo = T4_lo * TWO_POW_2 + (T4_hi % TWO_POW_32 - T4_hi % TWO_POW_30) / TWO_POW_30 + L22_hi = T4_hi * TWO_POW_2 + (T4_lo % TWO_POW_32 - T4_lo % TWO_POW_30) / TWO_POW_30 + + D_lo = bit32_bxor(C2_lo, C4_lo * 2 + (C4_hi % TWO_POW_32 - C4_hi % TWO_POW_31) / TWO_POW_31) + D_hi = bit32_bxor(C2_hi, C4_hi * 2 + (C4_lo % TWO_POW_32 - C4_lo % TWO_POW_31) / TWO_POW_31) + + T0_lo = bit32_bxor(D_lo, L03_lo) + T0_hi = bit32_bxor(D_hi, L03_hi) + T1_lo = bit32_bxor(D_lo, L08_lo) + T1_hi = bit32_bxor(D_hi, L08_hi) + T2_lo = bit32_bxor(D_lo, L13_lo) + T2_hi = bit32_bxor(D_hi, L13_hi) + T3_lo = bit32_bxor(D_lo, L18_lo) + T3_hi = bit32_bxor(D_hi, L18_hi) + T4_lo = bit32_bxor(D_lo, L23_lo) + T4_hi = bit32_bxor(D_hi, L23_hi) + + L03_lo = (T2_lo % TWO_POW_32 - T2_lo % TWO_POW_21) / TWO_POW_21 + T2_hi * TWO_POW_11 + L03_hi = (T2_hi % TWO_POW_32 - T2_hi % TWO_POW_21) / TWO_POW_21 + T2_lo * TWO_POW_11 + L08_lo = (T4_lo % TWO_POW_32 - T4_lo % TWO_POW_3) / TWO_POW_3 + T4_hi * TWO_POW_29 % TWO_POW_32 + L08_hi = (T4_hi % TWO_POW_32 - T4_hi % TWO_POW_3) / TWO_POW_3 + T4_lo * TWO_POW_29 % TWO_POW_32 + L13_lo = T1_lo * TWO_POW_6 + (T1_hi % TWO_POW_32 - T1_hi % TWO_POW_26) / TWO_POW_26 + L13_hi = T1_hi * TWO_POW_6 + (T1_lo % TWO_POW_32 - T1_lo % TWO_POW_26) / TWO_POW_26 + L18_lo = T3_lo * TWO_POW_15 + (T3_hi % TWO_POW_32 - T3_hi % TWO_POW_17) / TWO_POW_17 + L18_hi = T3_hi * TWO_POW_15 + (T3_lo % TWO_POW_32 - T3_lo % TWO_POW_17) / TWO_POW_17 + L23_lo = (T0_lo % TWO_POW_32 - T0_lo % TWO_POW_2) / TWO_POW_2 + T0_hi * TWO_POW_30 % TWO_POW_32 + L23_hi = (T0_hi % TWO_POW_32 - T0_hi % TWO_POW_2) / TWO_POW_2 + T0_lo * TWO_POW_30 % TWO_POW_32 + + D_lo = bit32_bxor(C3_lo, C5_lo * 2 + (C5_hi % TWO_POW_32 - C5_hi % TWO_POW_31) / TWO_POW_31) + D_hi = bit32_bxor(C3_hi, C5_hi * 2 + (C5_lo % TWO_POW_32 - C5_lo % TWO_POW_31) / TWO_POW_31) + + T0_lo = bit32_bxor(D_lo, L04_lo) + T0_hi = bit32_bxor(D_hi, L04_hi) + T1_lo = bit32_bxor(D_lo, L09_lo) + T1_hi = bit32_bxor(D_hi, L09_hi) + T2_lo = bit32_bxor(D_lo, L14_lo) + T2_hi = bit32_bxor(D_hi, L14_hi) + T3_lo = bit32_bxor(D_lo, L19_lo) + T3_hi = bit32_bxor(D_hi, L19_hi) + T4_lo = bit32_bxor(D_lo, L24_lo) + T4_hi = bit32_bxor(D_hi, L24_hi) + + L04_lo = T3_lo * TWO_POW_21 % TWO_POW_32 + (T3_hi % TWO_POW_32 - T3_hi % TWO_POW_11) / TWO_POW_11 + L04_hi = T3_hi * TWO_POW_21 % TWO_POW_32 + (T3_lo % TWO_POW_32 - T3_lo % TWO_POW_11) / TWO_POW_11 + L09_lo = T0_lo * TWO_POW_28 % TWO_POW_32 + (T0_hi % TWO_POW_32 - T0_hi % TWO_POW_4) / TWO_POW_4 + L09_hi = T0_hi * TWO_POW_28 % TWO_POW_32 + (T0_lo % TWO_POW_32 - T0_lo % TWO_POW_4) / TWO_POW_4 + L14_lo = T2_lo * TWO_POW_25 % TWO_POW_32 + (T2_hi % TWO_POW_32 - T2_hi % TWO_POW_7) / TWO_POW_7 + L14_hi = T2_hi * TWO_POW_25 % TWO_POW_32 + (T2_lo % TWO_POW_32 - T2_lo % TWO_POW_7) / TWO_POW_7 + L19_lo = (T4_lo % TWO_POW_32 - T4_lo % TWO_POW_8) / TWO_POW_8 + T4_hi * TWO_POW_24 % TWO_POW_32 + L19_hi = (T4_hi % TWO_POW_32 - T4_hi % TWO_POW_8) / TWO_POW_8 + T4_lo * TWO_POW_24 % TWO_POW_32 + L24_lo = (T1_lo % TWO_POW_32 - T1_lo % TWO_POW_9) / TWO_POW_9 + T1_hi * TWO_POW_23 % TWO_POW_32 + L24_hi = (T1_hi % TWO_POW_32 - T1_hi % TWO_POW_9) / TWO_POW_9 + T1_lo * TWO_POW_23 % TWO_POW_32 + + D_lo = bit32_bxor(C4_lo, C1_lo * 2 + (C1_hi % TWO_POW_32 - C1_hi % TWO_POW_31) / TWO_POW_31) + D_hi = bit32_bxor(C4_hi, C1_hi * 2 + (C1_lo % TWO_POW_32 - C1_lo % TWO_POW_31) / TWO_POW_31) + + T0_lo = bit32_bxor(D_lo, L05_lo) + T0_hi = bit32_bxor(D_hi, L05_hi) + T1_lo = bit32_bxor(D_lo, L10_lo) + T1_hi = bit32_bxor(D_hi, L10_hi) + T2_lo = bit32_bxor(D_lo, L15_lo) + T2_hi = bit32_bxor(D_hi, L15_hi) + T3_lo = bit32_bxor(D_lo, L20_lo) + T3_hi = bit32_bxor(D_hi, L20_hi) + T4_lo = bit32_bxor(D_lo, L25_lo) + T4_hi = bit32_bxor(D_hi, L25_hi) + + L05_lo = T4_lo * TWO_POW_14 + (T4_hi % TWO_POW_32 - T4_hi % TWO_POW_18) / TWO_POW_18 + L05_hi = T4_hi * TWO_POW_14 + (T4_lo % TWO_POW_32 - T4_lo % TWO_POW_18) / TWO_POW_18 + L10_lo = T1_lo * TWO_POW_20 % TWO_POW_32 + (T1_hi % TWO_POW_32 - T1_hi % TWO_POW_12) / TWO_POW_12 + L10_hi = T1_hi * TWO_POW_20 % TWO_POW_32 + (T1_lo % TWO_POW_32 - T1_lo % TWO_POW_12) / TWO_POW_12 + L15_lo = T3_lo * TWO_POW_8 + (T3_hi % TWO_POW_32 - T3_hi % TWO_POW_24) / TWO_POW_24 + L15_hi = T3_hi * TWO_POW_8 + (T3_lo % TWO_POW_32 - T3_lo % TWO_POW_24) / TWO_POW_24 + L20_lo = T0_lo * TWO_POW_27 % TWO_POW_32 + (T0_hi % TWO_POW_32 - T0_hi % TWO_POW_5) / TWO_POW_5 + L20_hi = T0_hi * TWO_POW_27 % TWO_POW_32 + (T0_lo % TWO_POW_32 - T0_lo % TWO_POW_5) / TWO_POW_5 + L25_lo = (T2_lo % TWO_POW_32 - T2_lo % TWO_POW_25) / TWO_POW_25 + T2_hi * TWO_POW_7 + L25_hi = (T2_hi % TWO_POW_32 - T2_hi % TWO_POW_25) / TWO_POW_25 + T2_lo * TWO_POW_7 + + D_lo = bit32_bxor(C5_lo, C2_lo * 2 + (C2_hi % TWO_POW_32 - C2_hi % TWO_POW_31) / TWO_POW_31) + D_hi = bit32_bxor(C5_hi, C2_hi * 2 + (C2_lo % TWO_POW_32 - C2_lo % TWO_POW_31) / TWO_POW_31) + + T1_lo = bit32_bxor(D_lo, L06_lo) + T1_hi = bit32_bxor(D_hi, L06_hi) + T2_lo = bit32_bxor(D_lo, L11_lo) + T2_hi = bit32_bxor(D_hi, L11_hi) + T3_lo = bit32_bxor(D_lo, L16_lo) + T3_hi = bit32_bxor(D_hi, L16_hi) + T4_lo = bit32_bxor(D_lo, L21_lo) + T4_hi = bit32_bxor(D_hi, L21_hi) + + L06_lo = T2_lo * TWO_POW_3 + (T2_hi % TWO_POW_32 - T2_hi % TWO_POW_29) / TWO_POW_29 + L06_hi = T2_hi * TWO_POW_3 + (T2_lo % TWO_POW_32 - T2_lo % TWO_POW_29) / TWO_POW_29 + L11_lo = T4_lo * TWO_POW_18 + (T4_hi % TWO_POW_32 - T4_hi % TWO_POW_14) / TWO_POW_14 + L11_hi = T4_hi * TWO_POW_18 + (T4_lo % TWO_POW_32 - T4_lo % TWO_POW_14) / TWO_POW_14 + L16_lo = (T1_lo % TWO_POW_32 - T1_lo % TWO_POW_28) / TWO_POW_28 + T1_hi * TWO_POW_4 + L16_hi = (T1_hi % TWO_POW_32 - T1_hi % TWO_POW_28) / TWO_POW_28 + T1_lo * TWO_POW_4 + L21_lo = (T3_lo % TWO_POW_32 - T3_lo % TWO_POW_23) / TWO_POW_23 + T3_hi * TWO_POW_9 + L21_hi = (T3_hi % TWO_POW_32 - T3_hi % TWO_POW_23) / TWO_POW_23 + T3_lo * TWO_POW_9 + + L01_lo = bit32_bxor(D_lo, L01_lo) + L01_hi = bit32_bxor(D_hi, L01_hi) + L01_lo, L02_lo, L03_lo, L04_lo, L05_lo = + bit32_bxor(L01_lo, bit32_band(-1 - L02_lo, L03_lo)), + bit32_bxor(L02_lo, bit32_band(-1 - L03_lo, L04_lo)), + bit32_bxor(L03_lo, bit32_band(-1 - L04_lo, L05_lo)), + bit32_bxor(L04_lo, bit32_band(-1 - L05_lo, L01_lo)), + bit32_bxor(L05_lo, bit32_band(-1 - L01_lo, L02_lo)) + L01_hi, L02_hi, L03_hi, L04_hi, L05_hi = + bit32_bxor(L01_hi, bit32_band(-1 - L02_hi, L03_hi)), + bit32_bxor(L02_hi, bit32_band(-1 - L03_hi, L04_hi)), + bit32_bxor(L03_hi, bit32_band(-1 - L04_hi, L05_hi)), + bit32_bxor(L04_hi, bit32_band(-1 - L05_hi, L01_hi)), + bit32_bxor(L05_hi, bit32_band(-1 - L01_hi, L02_hi)) + L06_lo, L07_lo, L08_lo, L09_lo, L10_lo = + bit32_bxor(L09_lo, bit32_band(-1 - L10_lo, L06_lo)), + bit32_bxor(L10_lo, bit32_band(-1 - L06_lo, L07_lo)), + bit32_bxor(L06_lo, bit32_band(-1 - L07_lo, L08_lo)), + bit32_bxor(L07_lo, bit32_band(-1 - L08_lo, L09_lo)), + bit32_bxor(L08_lo, bit32_band(-1 - L09_lo, L10_lo)) + L06_hi, L07_hi, L08_hi, L09_hi, L10_hi = + bit32_bxor(L09_hi, bit32_band(-1 - L10_hi, L06_hi)), + bit32_bxor(L10_hi, bit32_band(-1 - L06_hi, L07_hi)), + bit32_bxor(L06_hi, bit32_band(-1 - L07_hi, L08_hi)), + bit32_bxor(L07_hi, bit32_band(-1 - L08_hi, L09_hi)), + bit32_bxor(L08_hi, bit32_band(-1 - L09_hi, L10_hi)) + L11_lo, L12_lo, L13_lo, L14_lo, L15_lo = + bit32_bxor(L12_lo, bit32_band(-1 - L13_lo, L14_lo)), + bit32_bxor(L13_lo, bit32_band(-1 - L14_lo, L15_lo)), + bit32_bxor(L14_lo, bit32_band(-1 - L15_lo, L11_lo)), + bit32_bxor(L15_lo, bit32_band(-1 - L11_lo, L12_lo)), + bit32_bxor(L11_lo, bit32_band(-1 - L12_lo, L13_lo)) + L11_hi, L12_hi, L13_hi, L14_hi, L15_hi = + bit32_bxor(L12_hi, bit32_band(-1 - L13_hi, L14_hi)), + bit32_bxor(L13_hi, bit32_band(-1 - L14_hi, L15_hi)), + bit32_bxor(L14_hi, bit32_band(-1 - L15_hi, L11_hi)), + bit32_bxor(L15_hi, bit32_band(-1 - L11_hi, L12_hi)), + bit32_bxor(L11_hi, bit32_band(-1 - L12_hi, L13_hi)) + L16_lo, L17_lo, L18_lo, L19_lo, L20_lo = + bit32_bxor(L20_lo, bit32_band(-1 - L16_lo, L17_lo)), + bit32_bxor(L16_lo, bit32_band(-1 - L17_lo, L18_lo)), + bit32_bxor(L17_lo, bit32_band(-1 - L18_lo, L19_lo)), + bit32_bxor(L18_lo, bit32_band(-1 - L19_lo, L20_lo)), + bit32_bxor(L19_lo, bit32_band(-1 - L20_lo, L16_lo)) + L16_hi, L17_hi, L18_hi, L19_hi, L20_hi = + bit32_bxor(L20_hi, bit32_band(-1 - L16_hi, L17_hi)), + bit32_bxor(L16_hi, bit32_band(-1 - L17_hi, L18_hi)), + bit32_bxor(L17_hi, bit32_band(-1 - L18_hi, L19_hi)), + bit32_bxor(L18_hi, bit32_band(-1 - L19_hi, L20_hi)), + bit32_bxor(L19_hi, bit32_band(-1 - L20_hi, L16_hi)) + L21_lo, L22_lo, L23_lo, L24_lo, L25_lo = + bit32_bxor(L23_lo, bit32_band(-1 - L24_lo, L25_lo)), + bit32_bxor(L24_lo, bit32_band(-1 - L25_lo, L21_lo)), + bit32_bxor(L25_lo, bit32_band(-1 - L21_lo, L22_lo)), + bit32_bxor(L21_lo, bit32_band(-1 - L22_lo, L23_lo)), + bit32_bxor(L22_lo, bit32_band(-1 - L23_lo, L24_lo)) + L21_hi, L22_hi, L23_hi, L24_hi, L25_hi = + bit32_bxor(L23_hi, bit32_band(-1 - L24_hi, L25_hi)), + bit32_bxor(L24_hi, bit32_band(-1 - L25_hi, L21_hi)), + bit32_bxor(L25_hi, bit32_band(-1 - L21_hi, L22_hi)), + bit32_bxor(L21_hi, bit32_band(-1 - L22_hi, L23_hi)), + bit32_bxor(L22_hi, bit32_band(-1 - L23_hi, L24_hi)) + L01_lo = bit32_bxor(L01_lo, RC_lo[round_idx]) + L01_hi = L01_hi + RC_hi[round_idx] -- RC_hi[] is either 0 or 0x80000000, so we could use fast addition instead of slow XOR + end + + lanes_lo[1] = L01_lo + lanes_hi[1] = L01_hi + lanes_lo[2] = L02_lo + lanes_hi[2] = L02_hi + lanes_lo[3] = L03_lo + lanes_hi[3] = L03_hi + lanes_lo[4] = L04_lo + lanes_hi[4] = L04_hi + lanes_lo[5] = L05_lo + lanes_hi[5] = L05_hi + lanes_lo[6] = L06_lo + lanes_hi[6] = L06_hi + lanes_lo[7] = L07_lo + lanes_hi[7] = L07_hi + lanes_lo[8] = L08_lo + lanes_hi[8] = L08_hi + lanes_lo[9] = L09_lo + lanes_hi[9] = L09_hi + lanes_lo[10] = L10_lo + lanes_hi[10] = L10_hi + lanes_lo[11] = L11_lo + lanes_hi[11] = L11_hi + lanes_lo[12] = L12_lo + lanes_hi[12] = L12_hi + lanes_lo[13] = L13_lo + lanes_hi[13] = L13_hi + lanes_lo[14] = L14_lo + lanes_hi[14] = L14_hi + lanes_lo[15] = L15_lo + lanes_hi[15] = L15_hi + lanes_lo[16] = L16_lo + lanes_hi[16] = L16_hi + lanes_lo[17] = L17_lo + lanes_hi[17] = L17_hi + lanes_lo[18] = L18_lo + lanes_hi[18] = L18_hi + lanes_lo[19] = L19_lo + lanes_hi[19] = L19_hi + lanes_lo[20] = L20_lo + lanes_hi[20] = L20_hi + lanes_lo[21] = L21_lo + lanes_hi[21] = L21_hi + lanes_lo[22] = L22_lo + lanes_hi[22] = L22_hi + lanes_lo[23] = L23_lo + lanes_hi[23] = L23_hi + lanes_lo[24] = L24_lo + lanes_hi[24] = L24_hi + lanes_lo[25] = L25_lo + lanes_hi[25] = L25_hi + end +end + +-------------------------------------------------------------------------------- +-- MAGIC NUMBERS CALCULATOR +-------------------------------------------------------------------------------- +-- Q: +-- Is 53-bit "double" math enough to calculate square roots and cube roots of primes with 64 correct bits after decimal point? +-- A: +-- Yes, 53-bit "double" arithmetic is enough. +-- We could obtain first 40 bits by direct calculation of p^(1/3) and next 40 bits by one step of Newton's method. +do + local function mul(src1, src2, factor, result_length) + -- src1, src2 - long integers (arrays of digits in base TWO_POW_24) + -- factor - small integer + -- returns long integer result (src1 * src2 * factor) and its floating point approximation + local result, carry, value, weight = table.create(result_length), 0, 0, 1 + for j = 1, result_length do + for k = math.max(1, j + 1 - #src2), math.min(j, #src1) do + carry = carry + factor * src1[k] * src2[j + 1 - k] -- "int32" is not enough for multiplication result, that's why "factor" must be of type "double" + end + + local digit = carry % TWO_POW_24 + result[j] = math.floor(digit) + carry = (carry - digit) / TWO_POW_24 + value = value + digit * weight + weight = weight * TWO_POW_24 + end + + return result, value + end + + local idx, step, p, one, sqrt_hi, sqrt_lo = 0, { 4, 1, 2, -2, 2 }, 4, { 1 }, sha2_H_hi, sha2_H_lo + repeat + p = p + step[p % 6] + local d = 1 + repeat + d = d + step[d % 6] + if d * d > p then + -- next prime number is found + local root = p ^ (1 / 3) + local R = root * TWO_POW_40 + R = mul(table.create(1, math.floor(R)), one, 1, 2) + local _, delta = mul(R, mul(R, R, 1, 4), -1, 4) + local hi = R[2] % 65536 * 65536 + math.floor(R[1] / 256) + local lo = R[1] % 256 * 16777216 + math.floor(delta * (TWO_POW_NEG_56 / 3) * root / p) + + if idx < 16 then + root = math.sqrt(p) + R = root * TWO_POW_40 + R = mul(table.create(1, math.floor(R)), one, 1, 2) + _, delta = mul(R, R, -1, 2) + local hi = R[2] % 65536 * 65536 + math.floor(R[1] / 256) + local lo = R[1] % 256 * 16777216 + math.floor(delta * TWO_POW_NEG_17 / root) + local idx = idx % 8 + 1 + sha2_H_ext256[224][idx] = lo + sqrt_hi[idx], sqrt_lo[idx] = hi, lo + hi * hi_factor + if idx > 7 then + sqrt_hi, sqrt_lo = sha2_H_ext512_hi[384], sha2_H_ext512_lo[384] + end + end + + idx = idx + 1 + sha2_K_hi[idx], sha2_K_lo[idx] = hi, lo % K_lo_modulo + hi * hi_factor + break + end + until p % d == 0 + until idx > 79 +end + +-- Calculating IVs for SHA512/224 and SHA512/256 +for width = 224, 256, 32 do + local H_lo, H_hi = {}, nil + if XOR64A5 then + for j = 1, 8 do + H_lo[j] = XOR64A5(sha2_H_lo[j]) + end + else + H_hi = {} + for j = 1, 8 do + H_lo[j] = bit32_bxor(sha2_H_lo[j], 0xA5A5A5A5) % 4294967296 + H_hi[j] = bit32_bxor(sha2_H_hi[j], 0xA5A5A5A5) % 4294967296 + end + end + + sha512_feed_128(H_lo, H_hi, "SHA-512/" .. tostring(width) .. "\128" .. string.rep("\0", 115) .. "\88", 0, 128) + sha2_H_ext512_lo[width] = H_lo + sha2_H_ext512_hi[width] = H_hi +end + +-- Constants for MD5 +do + for idx = 1, 64 do + -- we can't use formula math.floor(abs(sin(idx))*TWO_POW_32) because its result may be beyond integer range on Lua built with 32-bit integers + local hi, lo = math.modf(math.abs(math.sin(idx)) * TWO_POW_16) + md5_K[idx] = hi * 65536 + math.floor(lo * TWO_POW_16) + end +end + +-- Constants for SHA3 +do + local sh_reg = 29 + local function next_bit() + local r = sh_reg % 2 + sh_reg = bit32_bxor((sh_reg - r) / 2, 142 * r) + return r + end + + for idx = 1, 24 do + local lo, m = 0, nil + for _ = 1, 6 do + m = m and m * m * 2 or 1 + lo = lo + next_bit() * m + end + + local hi = next_bit() * m + sha3_RC_hi[idx], sha3_RC_lo[idx] = hi, lo + hi * hi_factor_keccak + end +end + +-------------------------------------------------------------------------------- +-- MAIN FUNCTIONS +-------------------------------------------------------------------------------- +local function sha256ext(width, message) + -- Create an instance (private objects for current calculation) + local Array256 = sha2_H_ext256[width] -- # == 8 + local length, tail = 0, "" + local H = table.create(8) + H[1], H[2], H[3], H[4], H[5], H[6], H[7], H[8] = + Array256[1], Array256[2], Array256[3], Array256[4], Array256[5], Array256[6], Array256[7], Array256[8] + + local function partial(message_part) + if message_part then + local partLength = #message_part + if tail then + length = length + partLength + local offs = 0 + local tailLength = #tail + if tail ~= "" and tailLength + partLength >= 64 then + offs = 64 - tailLength + sha256_feed_64(H, tail .. string.sub(message_part, 1, offs), 0, 64) + tail = "" + end + + local size = partLength - offs + local size_tail = size % 64 + sha256_feed_64(H, message_part, offs, size - size_tail) + tail = tail .. string.sub(message_part, partLength + 1 - size_tail) + return partial + else + error("Adding more chunks is not allowed after receiving the result", 2) + end + else + if tail then + local final_blocks = table.create(10) --{tail, "\128", string.rep("\0", (-9 - length) % 64 + 1)} + final_blocks[1] = tail + final_blocks[2] = "\128" + final_blocks[3] = string.rep("\0", (-9 - length) % 64 + 1) + + tail = nil + -- Assuming user data length is shorter than (TWO_POW_53)-9 bytes + -- Anyway, it looks very unrealistic that someone would spend more than a year of calculations to process TWO_POW_53 bytes of data by using this Lua script :-) + -- TWO_POW_53 bytes = TWO_POW_56 bits, so "bit-counter" fits in 7 bytes + length = length * (8 / TWO56_POW_7) -- convert "byte-counter" to "bit-counter" and move decimal point to the left + for j = 4, 10 do + length = length % 1 * 256 + final_blocks[j] = string.char(math.floor(length)) + end + + final_blocks = table.concat(final_blocks) + sha256_feed_64(H, final_blocks, 0, #final_blocks) + local max_reg = width / 32 + for j = 1, max_reg do + H[j] = string.format("%08x", H[j] % 4294967296) + end + + H = table.concat(H, "", 1, max_reg) + end + + return H + end + end + + if message then + -- Actually perform calculations and return the SHA256 digest of a message + return partial(message)() + else + -- Return function for chunk-by-chunk loading + -- User should feed every chunk of input data as single argument to this function and finally get SHA256 digest by invoking this function without an argument + return partial + end +end + +local function sha512ext(width, message) + -- Create an instance (private objects for current calculation) + local length, tail, H_lo, H_hi = + 0, + "", + table.pack(table.unpack(sha2_H_ext512_lo[width])), + not HEX64 and table.pack(table.unpack(sha2_H_ext512_hi[width])) + + local function partial(message_part) + if message_part then + local partLength = #message_part + if tail then + length = length + partLength + local offs = 0 + if tail ~= "" and #tail + partLength >= 128 then + offs = 128 - #tail + sha512_feed_128(H_lo, H_hi, tail .. string.sub(message_part, 1, offs), 0, 128) + tail = "" + end + + local size = partLength - offs + local size_tail = size % 128 + sha512_feed_128(H_lo, H_hi, message_part, offs, size - size_tail) + tail = tail .. string.sub(message_part, partLength + 1 - size_tail) + return partial + else + error("Adding more chunks is not allowed after receiving the result", 2) + end + else + if tail then + local final_blocks = table.create(3) --{tail, "\128", string.rep("\0", (-17-length) % 128 + 9)} + final_blocks[1] = tail + final_blocks[2] = "\128" + final_blocks[3] = string.rep("\0", (-17 - length) % 128 + 9) + + tail = nil + -- Assuming user data length is shorter than (TWO_POW_53)-17 bytes + -- TWO_POW_53 bytes = TWO_POW_56 bits, so "bit-counter" fits in 7 bytes + length = length * (8 / TWO56_POW_7) -- convert "byte-counter" to "bit-counter" and move floating point to the left + for j = 4, 10 do + length = length % 1 * 256 + final_blocks[j] = string.char(math.floor(length)) + end + + final_blocks = table.concat(final_blocks) + sha512_feed_128(H_lo, H_hi, final_blocks, 0, #final_blocks) + local max_reg = math.ceil(width / 64) + + if HEX64 then + for j = 1, max_reg do + H_lo[j] = HEX64(H_lo[j]) + end + else + for j = 1, max_reg do + H_lo[j] = string.format("%08x", H_hi[j] % 4294967296) + .. string.format("%08x", H_lo[j] % 4294967296) + end + + H_hi = nil + end + + H_lo = string.sub(table.concat(H_lo, "", 1, max_reg), 1, width / 4) + end + + return H_lo + end + end + + if message then + -- Actually perform calculations and return the SHA512 digest of a message + return partial(message)() + else + -- Return function for chunk-by-chunk loading + -- User should feed every chunk of input data as single argument to this function and finally get SHA512 digest by invoking this function without an argument + return partial + end +end + +local function md5(message) + -- Create an instance (private objects for current calculation) + local H, length, tail = table.create(4), 0, "" + H[1], H[2], H[3], H[4] = md5_sha1_H[1], md5_sha1_H[2], md5_sha1_H[3], md5_sha1_H[4] + + local function partial(message_part) + if message_part then + local partLength = #message_part + if tail then + length = length + partLength + local offs = 0 + if tail ~= "" and #tail + partLength >= 64 then + offs = 64 - #tail + md5_feed_64(H, tail .. string.sub(message_part, 1, offs), 0, 64) + tail = "" + end + + local size = partLength - offs + local size_tail = size % 64 + md5_feed_64(H, message_part, offs, size - size_tail) + tail = tail .. string.sub(message_part, partLength + 1 - size_tail) + return partial + else + error("Adding more chunks is not allowed after receiving the result", 2) + end + else + if tail then + local final_blocks = table.create(3) --{tail, "\128", string.rep("\0", (-9 - length) % 64)} + final_blocks[1] = tail + final_blocks[2] = "\128" + final_blocks[3] = string.rep("\0", (-9 - length) % 64) + tail = nil + length = length * 8 -- convert "byte-counter" to "bit-counter" + for j = 4, 11 do + local low_byte = length % 256 + final_blocks[j] = string.char(low_byte) + length = (length - low_byte) / 256 + end + + final_blocks = table.concat(final_blocks) + md5_feed_64(H, final_blocks, 0, #final_blocks) + for j = 1, 4 do + H[j] = string.format("%08x", H[j] % 4294967296) + end + + H = string.gsub(table.concat(H), "(..)(..)(..)(..)", "%4%3%2%1") + end + + return H + end + end + + if message then + -- Actually perform calculations and return the MD5 digest of a message + return partial(message)() + else + -- Return function for chunk-by-chunk loading + -- User should feed every chunk of input data as single argument to this function and finally get MD5 digest by invoking this function without an argument + return partial + end +end + +local function sha1(message) + -- Create an instance (private objects for current calculation) + local H, length, tail = table.pack(table.unpack(md5_sha1_H)), 0, "" + + local function partial(message_part) + if message_part then + local partLength = #message_part + if tail then + length = length + partLength + local offs = 0 + if tail ~= "" and #tail + partLength >= 64 then + offs = 64 - #tail + sha1_feed_64(H, tail .. string.sub(message_part, 1, offs), 0, 64) + tail = "" + end + + local size = partLength - offs + local size_tail = size % 64 + sha1_feed_64(H, message_part, offs, size - size_tail) + tail = tail .. string.sub(message_part, partLength + 1 - size_tail) + return partial + else + error("Adding more chunks is not allowed after receiving the result", 2) + end + else + if tail then + local final_blocks = table.create(10) --{tail, "\128", string.rep("\0", (-9 - length) % 64 + 1)} + final_blocks[1] = tail + final_blocks[2] = "\128" + final_blocks[3] = string.rep("\0", (-9 - length) % 64 + 1) + tail = nil + + -- Assuming user data length is shorter than (TWO_POW_53)-9 bytes + -- TWO_POW_53 bytes = TWO_POW_56 bits, so "bit-counter" fits in 7 bytes + length = length * (8 / TWO56_POW_7) -- convert "byte-counter" to "bit-counter" and move decimal point to the left + for j = 4, 10 do + length = length % 1 * 256 + final_blocks[j] = string.char(math.floor(length)) + end + + final_blocks = table.concat(final_blocks) + sha1_feed_64(H, final_blocks, 0, #final_blocks) + for j = 1, 5 do + H[j] = string.format("%08x", H[j] % 4294967296) + end + + H = table.concat(H) + end + + return H + end + end + + if message then + -- Actually perform calculations and return the SHA-1 digest of a message + return partial(message)() + else + -- Return function for chunk-by-chunk loading + -- User should feed every chunk of input data as single argument to this function and finally get SHA-1 digest by invoking this function without an argument + return partial + end +end + +local function keccak(block_size_in_bytes, digest_size_in_bytes, is_SHAKE, message) + -- "block_size_in_bytes" is multiple of 8 + if type(digest_size_in_bytes) ~= "number" then + -- arguments in SHAKE are swapped: + -- NIST FIPS 202 defines SHAKE(message,num_bits) + -- this module defines SHAKE(num_bytes,message) + -- it's easy to forget about this swap, hence the check + error("Argument 'digest_size_in_bytes' must be a number", 2) + end + + -- Create an instance (private objects for current calculation) + local tail, lanes_lo, lanes_hi = "", table.create(25, 0), hi_factor_keccak == 0 and table.create(25, 0) + local result + + --~ pad the input N using the pad function, yielding a padded bit string P with a length divisible by r (such that n = len(P)/r is integer), + --~ break P into n consecutive r-bit pieces P0, ..., Pn-1 (last is zero-padded) + --~ initialize the state S to a string of b 0 bits. + --~ absorb the input into the state: For each block Pi, + --~ extend Pi at the end by a string of c 0 bits, yielding one of length b, + --~ XOR that with S and + --~ apply the block permutation f to the result, yielding a new state S + --~ initialize Z to be the empty string + --~ while the length of Z is less than d: + --~ append the first r bits of S to Z + --~ if Z is still less than d bits long, apply f to S, yielding a new state S. + --~ truncate Z to d bits + local function partial(message_part) + if message_part then + local partLength = #message_part + if tail then + local offs = 0 + if tail ~= "" and #tail + partLength >= block_size_in_bytes then + offs = block_size_in_bytes - #tail + keccak_feed( + lanes_lo, + lanes_hi, + tail .. string.sub(message_part, 1, offs), + 0, + block_size_in_bytes, + block_size_in_bytes + ) + tail = "" + end + + local size = partLength - offs + local size_tail = size % block_size_in_bytes + keccak_feed(lanes_lo, lanes_hi, message_part, offs, size - size_tail, block_size_in_bytes) + tail = tail .. string.sub(message_part, partLength + 1 - size_tail) + return partial + else + error("Adding more chunks is not allowed after receiving the result", 2) + end + else + if tail then + -- append the following bits to the message: for usual SHA3: 011(0*)1, for SHAKE: 11111(0*)1 + local gap_start = is_SHAKE and 31 or 6 + tail = tail + .. ( + #tail + 1 == block_size_in_bytes and string.char(gap_start + 128) + or string.char(gap_start) .. string.rep("\0", (-2 - #tail) % block_size_in_bytes) .. "\128" + ) + keccak_feed(lanes_lo, lanes_hi, tail, 0, #tail, block_size_in_bytes) + tail = nil + + local lanes_used = 0 + local total_lanes = math.floor(block_size_in_bytes / 8) + local qwords = {} + + local function get_next_qwords_of_digest(qwords_qty) + -- returns not more than 'qwords_qty' qwords ('qwords_qty' might be non-integer) + -- doesn't go across keccak-buffer boundary + -- block_size_in_bytes is a multiple of 8, so, keccak-buffer contains integer number of qwords + if lanes_used >= total_lanes then + keccak_feed(lanes_lo, lanes_hi, "\0\0\0\0\0\0\0\0", 0, 8, 8) + lanes_used = 0 + end + + qwords_qty = math.floor(math.min(qwords_qty, total_lanes - lanes_used)) + if hi_factor_keccak ~= 0 then + for j = 1, qwords_qty do + qwords[j] = HEX64(lanes_lo[lanes_used + j - 1 + lanes_index_base]) + end + else + for j = 1, qwords_qty do + qwords[j] = string.format("%08x", lanes_hi[lanes_used + j] % 4294967296) + .. string.format("%08x", lanes_lo[lanes_used + j] % 4294967296) + end + end + + lanes_used = lanes_used + qwords_qty + return string.gsub( + table.concat(qwords, "", 1, qwords_qty), + "(..)(..)(..)(..)(..)(..)(..)(..)", + "%8%7%6%5%4%3%2%1" + ), + qwords_qty * 8 + end + + local parts = {} -- digest parts + local last_part, last_part_size = "", 0 + + local function get_next_part_of_digest(bytes_needed) + -- returns 'bytes_needed' bytes, for arbitrary integer 'bytes_needed' + bytes_needed = bytes_needed or 1 + if bytes_needed <= last_part_size then + last_part_size = last_part_size - bytes_needed + local part_size_in_nibbles = bytes_needed * 2 + local result = string.sub(last_part, 1, part_size_in_nibbles) + last_part = string.sub(last_part, part_size_in_nibbles + 1) + return result + end + + local parts_qty = 0 + if last_part_size > 0 then + parts_qty = 1 + parts[parts_qty] = last_part + bytes_needed = bytes_needed - last_part_size + end + + -- repeats until the length is enough + while bytes_needed >= 8 do + local next_part, next_part_size = get_next_qwords_of_digest(bytes_needed / 8) + parts_qty = parts_qty + 1 + parts[parts_qty] = next_part + bytes_needed = bytes_needed - next_part_size + end + + if bytes_needed > 0 then + last_part, last_part_size = get_next_qwords_of_digest(1) + parts_qty = parts_qty + 1 + parts[parts_qty] = get_next_part_of_digest(bytes_needed) + else + last_part, last_part_size = "", 0 + end + + return table.concat(parts, "", 1, parts_qty) + end + + if digest_size_in_bytes < 0 then + result = get_next_part_of_digest + else + result = get_next_part_of_digest(digest_size_in_bytes) + end + end + + return result + end + end + + if message then + -- Actually perform calculations and return the SHA3 digest of a message + return partial(message)() + else + -- Return function for chunk-by-chunk loading + -- User should feed every chunk of input data as single argument to this function and finally get SHA3 digest by invoking this function without an argument + return partial + end +end + +local function HexToBinFunction(hh) + return string.char(tonumber(hh, 16)) +end + +local function hex2bin(hex_string) + return (string.gsub(hex_string, "%x%x", HexToBinFunction)) +end + +local base64_symbols = { + ["+"] = 62, + ["-"] = 62, + [62] = "+", + ["/"] = 63, + ["_"] = 63, + [63] = "/", + ["="] = -1, + ["."] = -1, + [-1] = "=", +} + +local symbol_index = 0 +for j, pair in ipairs({ "AZ", "az", "09" }) do + for ascii = string.byte(pair), string.byte(pair, 2) do + local ch = string.char(ascii) + base64_symbols[ch] = symbol_index + base64_symbols[symbol_index] = ch + symbol_index = symbol_index + 1 + end +end + +local function bin2base64(binary_string) + local stringLength = #binary_string + local result = table.create(math.ceil(stringLength / 3)) + local length = 0 + + for pos = 1, #binary_string, 3 do + local c1, c2, c3, c4 = string.byte(string.sub(binary_string, pos, pos + 2) .. "\0", 1, -1) + length = length + 1 + result[length] = base64_symbols[math.floor(c1 / 4)] + .. base64_symbols[c1 % 4 * 16 + math.floor(c2 / 16)] + .. base64_symbols[c3 and c2 % 16 * 4 + math.floor(c3 / 64) or -1] + .. base64_symbols[c4 and c3 % 64 or -1] + end + + return table.concat(result) +end + +local function base642bin(base64_string) + local result, chars_qty = {}, 3 + for pos, ch in string.gmatch(string.gsub(base64_string, "%s+", ""), "()(.)") do + local code = base64_symbols[ch] + if code < 0 then + chars_qty = chars_qty - 1 + code = 0 + end + + local idx = pos % 4 + if idx > 0 then + result[-idx] = code + else + local c1 = result[-1] * 4 + math.floor(result[-2] / 16) + local c2 = (result[-2] % 16) * 16 + math.floor(result[-3] / 4) + local c3 = (result[-3] % 4) * 64 + code + result[#result + 1] = string.sub(string.char(c1, c2, c3), 1, chars_qty) + end + end + + return table.concat(result) +end + +local block_size_for_HMAC -- this table will be initialized at the end of the module +--local function pad_and_xor(str, result_length, byte_for_xor) +-- return string.gsub(str, ".", function(c) +-- return string.char(bit32_bxor(string.byte(c), byte_for_xor)) +-- end) .. string.rep(string.char(byte_for_xor), result_length - #str) +--end + +-- For the sake of speed of converting hexes to strings, there's a map of the conversions here +local BinaryStringMap = {} +for Index = 0, 255 do + BinaryStringMap[string.format("%02x", Index)] = string.char(Index) +end + +-- Update 02.14.20 - added AsBinary for easy GameAnalytics replacement. +local function hmac(hash_func, key, message, AsBinary) + -- Create an instance (private objects for current calculation) + local block_size = block_size_for_HMAC[hash_func] + if not block_size then + error("Unknown hash function", 2) + end + + local KeyLength = #key + if KeyLength > block_size then + key = string.gsub(hash_func(key), "%x%x", HexToBinFunction) + KeyLength = #key + end + + local append = hash_func()(string.gsub(key, ".", function(c) + return string.char(bit32_bxor(string.byte(c), 0x36)) + end) .. string.rep("6", block_size - KeyLength)) -- 6 = string.char(0x36) + + local result + + local function partial(message_part) + if not message_part then + result = result + or hash_func( + string.gsub(key, ".", function(c) + return string.char(bit32_bxor(string.byte(c), 0x5c)) + end) + .. string.rep("\\", block_size - KeyLength) -- \ = string.char(0x5c) + .. (string.gsub(append(), "%x%x", HexToBinFunction)) + ) + + return result + elseif result then + error("Adding more chunks is not allowed after receiving the result", 2) + else + append(message_part) + return partial + end + end + + if message then + -- Actually perform calculations and return the HMAC of a message + local FinalMessage = partial(message)() + return AsBinary and (string.gsub(FinalMessage, "%x%x", BinaryStringMap)) or FinalMessage + else + -- Return function for chunk-by-chunk loading of a message + -- User should feed every chunk of the message as single argument to this function and finally get HMAC by invoking this function without an argument + return partial + end +end + +local sha = { + md5 = md5, + sha1 = sha1, + -- SHA2 hash functions: + sha224 = function(message) + return sha256ext(224, message) + end, + + sha256 = function(message) + return sha256ext(256, message) + end, + + sha512_224 = function(message) + return sha512ext(224, message) + end, + + sha512_256 = function(message) + return sha512ext(256, message) + end, + + sha384 = function(message) + return sha512ext(384, message) + end, + + sha512 = function(message) + return sha512ext(512, message) + end, + + -- SHA3 hash functions: + sha3_224 = function(message) + return keccak((1600 - 2 * 224) / 8, 224 / 8, false, message) + end, + + sha3_256 = function(message) + return keccak((1600 - 2 * 256) / 8, 256 / 8, false, message) + end, + + sha3_384 = function(message) + return keccak((1600 - 2 * 384) / 8, 384 / 8, false, message) + end, + + sha3_512 = function(message) + return keccak((1600 - 2 * 512) / 8, 512 / 8, false, message) + end, + + shake128 = function(message, digest_size_in_bytes) + return keccak((1600 - 2 * 128) / 8, digest_size_in_bytes, true, message) + end, + + shake256 = function(message, digest_size_in_bytes) + return keccak((1600 - 2 * 256) / 8, digest_size_in_bytes, true, message) + end, + + -- misc utilities: + hmac = hmac, -- HMAC(hash_func, key, message) is applicable to any hash function from this module except SHAKE* + hex_to_bin = hex2bin, -- converts hexadecimal representation to binary string + base64_to_bin = base642bin, -- converts base64 representation to binary string + bin_to_base64 = bin2base64, -- converts binary string to base64 representation + base64_encode = Base64.Encode, + base64_decode = Base64.Decode, +} + +block_size_for_HMAC = { + [sha.md5] = 64, + [sha.sha1] = 64, + [sha.sha224] = 64, + [sha.sha256] = 64, + [sha.sha512_224] = 128, + [sha.sha512_256] = 128, + [sha.sha384] = 128, + [sha.sha512] = 128, + [sha.sha3_224] = (1600 - 2 * 224) / 8, + [sha.sha3_256] = (1600 - 2 * 256) / 8, + [sha.sha3_384] = (1600 - 2 * 384) / 8, + [sha.sha3_512] = (1600 - 2 * 512) / 8, +} + +return sha diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 758c522e..eaea3545 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAG(LuauCache32BitAsmConsts) - using namespace Luau::CodeGen; using namespace Luau::CodeGen::X64; @@ -748,7 +746,6 @@ TEST_CASE("ConstantStorage") TEST_CASE("ConstantStorageDedup") { - ScopedFastFlag luauCache32BitAsmConsts{FFlag::LuauCache32BitAsmConsts, true}; AssemblyBuilderX64 build(/* logText= */ false); for (int i = 0; i <= 3000; i++) diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index 78adfeed..db6cd327 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -18,10 +18,18 @@ ClassFixture::ClassFixture() unfreeze(arena); + TypeId connectionType = arena.addType(ClassType{"Connection", {}, nullopt, nullopt, {}, {}, "Connection"}); + TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(baseClassInstanceType)->props = { {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, {"BaseField", {numberType}}, + + {"Touched", {connectionType}}, + }; + + getMutable(connectionType)->props = { + {"Connect", {makeFunction(arena, connectionType, {makeFunction(arena, nullopt, {baseClassInstanceType}, {})}, {})}} }; TypeId baseClassType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 4ce13658..8f217783 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -26,11 +26,7 @@ extern bool verbose; extern bool codegen; extern int optimizationLevel; -LUAU_FASTFLAG(LuauTaggedLuData) -LUAU_FASTFLAG(LuauSciNumberSkipTrailDot) -LUAU_DYNAMIC_FASTFLAG(LuauInterruptablePatternMatch) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) -LUAU_DYNAMIC_FASTFLAG(LuauCodeGenFixBufferLenCheckA64) LUAU_DYNAMIC_FASTFLAG(LuauCodegenTrackingMultilocationFix) static lua_CompileOptions defaultOptions() @@ -1459,8 +1455,6 @@ TEST_CASE("Coverage") TEST_CASE("StringConversion") { - ScopedFastFlag luauSciNumberSkipTrailDot{FFlag::LuauSciNumberSkipTrailDot, true}; - runConformance("strconv.lua"); } @@ -1654,8 +1648,6 @@ TEST_CASE("Interrupt") } }; - ScopedFastFlag luauInterruptablePatternMatch{DFFlag::LuauInterruptablePatternMatch, true}; - for (int test = 1; test <= 5; ++test) { lua_State* T = lua_newthread(L); @@ -1764,8 +1756,6 @@ TEST_CASE("UserdataApi") TEST_CASE("LightuserdataApi") { - ScopedFastFlag luauTaggedLuData{FFlag::LuauTaggedLuData, true}; - StateRef globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); @@ -2040,7 +2030,6 @@ TEST_CASE("SafeEnv") TEST_CASE("Native") { - ScopedFastFlag luauCodeGenFixBufferLenCheckA64{DFFlag::LuauCodeGenFixBufferLenCheckA64, true}; ScopedFastFlag luauCodegenTrackingMultilocationFix{DFFlag::LuauCodegenTrackingMultilocationFix, true}; // This tests requires code to run natively, otherwise all 'is_native' checks will fail diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index a6d8a9c6..3651dfeb 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -117,7 +117,17 @@ std::optional TestFileResolver::resolveModule(const ModuleInfo* cont std::string TestFileResolver::getHumanReadableModuleName(const ModuleName& name) const { - return name; + // We have a handful of tests that need to distinguish between a canonical + // ModuleName and the human-readable version so we apply a simple transform + // here: We replace all slashes with dots. + std::string result = name; + for (size_t i = 0; i < result.size(); ++i) + { + if (result[i] == '/') + result[i] = '.'; + } + + return result; } std::optional TestFileResolver::getEnvironmentForModule(const ModuleName& name) const diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 039decb5..411d4914 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -14,6 +14,7 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauFreezeArena); +LUAU_FASTFLAG(DebugLuauMagicTypes); namespace { @@ -1273,4 +1274,63 @@ TEST_CASE_FIXTURE(FrontendFixture, "markdirty_early_return") } } +TEST_CASE_FIXTURE(FrontendFixture, "attribute_ices_to_the_correct_module") +{ + ScopedFastFlag sff{FFlag::DebugLuauMagicTypes, true}; + + fileResolver.source["game/one"] = R"( + require(game.two) + )"; + + fileResolver.source["game/two"] = R"( + local a: _luau_ice + )"; + + try + { + frontend.check("game/one"); + } + catch (InternalCompilerError& err) + { + CHECK("game/two" == err.moduleName); + return; + } + + FAIL("Expected an InternalCompilerError!"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "checked_modules_have_the_correct_mode") +{ + fileResolver.source["game/A"] = R"( + --!nocheck + local a: number = "five" + )"; + + fileResolver.source["game/B"] = R"( + --!nonstrict + local a = math.abs("five") + )"; + + fileResolver.source["game/C"] = R"( + --!strict + local a = 10 + )"; + + frontend.check("game/A"); + frontend.check("game/B"); + frontend.check("game/C"); + + ModulePtr moduleA = frontend.moduleResolver.getModule("game/A"); + REQUIRE(moduleA); + CHECK(moduleA->mode == Mode::NoCheck); + + ModulePtr moduleB = frontend.moduleResolver.getModule("game/B"); + REQUIRE(moduleB); + CHECK(moduleB->mode == Mode::Nonstrict); + + ModulePtr moduleC = frontend.moduleResolver.getModule("game/C"); + REQUIRE(moduleC); + CHECK(moduleC->mode == Mode::Strict); +} + TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index ca8a0b0f..78f51809 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -11,9 +11,9 @@ #include -using namespace Luau::CodeGen; +LUAU_FASTFLAG(LuauCodegenVectorTag2) -LUAU_DYNAMIC_FASTFLAG(LuauCodeGenCheckGcEffectFix) +using namespace Luau::CodeGen; class IrBuilderFixture { @@ -2060,8 +2060,6 @@ bb_fallback_1: TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateHashSlotChecksInvalidation") { - ScopedFastFlag luauCodeGenCheckGcEffectFix{DFFlag::LuauCodeGenCheckGcEffectFix, true}; - IrOp block = build.block(IrBlockKind::Internal); IrOp fallback = build.block(IrBlockKind::Fallback); @@ -2498,6 +2496,85 @@ bb_fallback_1: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "TagVectorSkipErrorFix") +{ + ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; + + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp a = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0)); + IrOp b = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1)); + + IrOp mul = build.inst(IrCmd::TAG_VECTOR, build.inst(IrCmd::MUL_VEC, a, b)); + + IrOp t1 = build.inst(IrCmd::TAG_VECTOR, build.inst(IrCmd::ADD_VEC, mul, mul)); + IrOp t2 = build.inst(IrCmd::TAG_VECTOR, build.inst(IrCmd::SUB_VEC, mul, mul)); + + IrOp t3 = build.inst(IrCmd::TAG_VECTOR, build.inst(IrCmd::DIV_VEC, t1, build.inst(IrCmd::UNM_VEC, t2))); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(0), t3); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::Yes) == R"( +bb_0: ; useCount: 0 + %0 = LOAD_TVALUE R0 ; useCount: 1, lastUse: %0 + %1 = LOAD_TVALUE R1 ; useCount: 1, lastUse: %0 + %2 = MUL_VEC %0, %1 ; useCount: 4, lastUse: %0 + %4 = ADD_VEC %2, %2 ; useCount: 1, lastUse: %0 + %6 = SUB_VEC %2, %2 ; useCount: 1, lastUse: %0 + %8 = UNM_VEC %6 ; useCount: 1, lastUse: %0 + %9 = DIV_VEC %4, %8 ; useCount: 1, lastUse: %0 + %10 = TAG_VECTOR %9 ; useCount: 1, lastUse: %0 + STORE_TVALUE R0, %10 ; %11 + RETURN R0, 1u ; %12 + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp followup = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp tbl = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + build.inst(IrCmd::CHECK_READONLY, tbl, build.vmExit(1)); + + build.inst(IrCmd::FALLBACK_FORGPREP, build.constUint(2), build.vmReg(1), followup); + + build.beginBlock(followup); + build.inst(IrCmd::CHECK_READONLY, tbl, build.vmExit(2)); + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(3)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1 +; in regs: R0, R1 +; out regs: R1, R2, R3 + %0 = LOAD_POINTER R0 + CHECK_READONLY %0, exit(1) + FALLBACK_FORGPREP 2u, R1, bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R1, R2, R3 + CHECK_READONLY %0, exit(2) + RETURN R1, 3i + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Analysis"); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 13f44dca..31661711 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -12,9 +12,7 @@ #include -LUAU_FASTFLAG(LuauCodegenVector) -LUAU_FASTFLAG(LuauCodegenVectorTag) -LUAU_FASTFLAG(LuauCodegenMathMemArgs) +LUAU_FASTFLAG(LuauCodegenVectorTag2) static std::string getCodegenAssembly(const char* source) { @@ -65,8 +63,7 @@ TEST_SUITE_BEGIN("IrLowering"); TEST_CASE("VectorReciprocal") { - ScopedFastFlag luauCodegenVector{FFlag::LuauCodegenVector, true}; - ScopedFastFlag luauCodegenVectorTag{FFlag::LuauCodegenVectorTag, true}; + ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vecrcp(a: vector) @@ -93,8 +90,6 @@ bb_bytecode_1: TEST_CASE("VectorComponentRead") { - ScopedFastFlag luauCodegenVector{FFlag::LuauCodegenVector, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function compsum(a: vector) return a.X + a.Y + a.Z @@ -129,8 +124,7 @@ bb_bytecode_1: TEST_CASE("VectorAdd") { - ScopedFastFlag luauCodegenVector{FFlag::LuauCodegenVector, true}; - ScopedFastFlag luauCodegenVectorTag{FFlag::LuauCodegenVectorTag, true}; + ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3add(a: vector, b: vector) @@ -158,8 +152,7 @@ bb_bytecode_1: TEST_CASE("VectorMinus") { - ScopedFastFlag luauCodegenVector{FFlag::LuauCodegenVector, true}; - ScopedFastFlag luauCodegenVectorTag{FFlag::LuauCodegenVectorTag, true}; + ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3minus(a: vector) @@ -185,8 +178,7 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv") { - ScopedFastFlag luauCodegenVector{FFlag::LuauCodegenVector, true}; - ScopedFastFlag luauCodegenVectorTag{FFlag::LuauCodegenVectorTag, true}; + ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) @@ -222,10 +214,45 @@ bb_bytecode_1: )"); } +TEST_CASE("VectorSubMulDiv2") +{ + ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; + + CHECK_EQ("\n" + getCodegenAssembly(R"( +local function vec3combo(a: vector) + local tmp = a * a + return (tmp - tmp) / (tmp + tmp) +end +)"), + R"( +; function vec3combo($arg0) line 2 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %8 = LOAD_TVALUE R0 + %10 = MUL_VEC %8, %8 + %11 = TAG_VECTOR %10 + STORE_TVALUE R1, %11 + %19 = SUB_VEC %10, %10 + %20 = TAG_VECTOR %19 + STORE_TVALUE R3, %20 + %28 = ADD_VEC %10, %10 + %29 = TAG_VECTOR %28 + STORE_TVALUE R4, %29 + %37 = DIV_VEC %19, %28 + %38 = TAG_VECTOR %37 + STORE_TVALUE R2, %38 + INTERRUPT 4u + RETURN R2, 1i +)"); +} + TEST_CASE("VectorMulDivMixed") { - ScopedFastFlag luauCodegenVector{FFlag::LuauCodegenVector, true}; - ScopedFastFlag luauCodegenVectorTag{FFlag::LuauCodegenVectorTag, true}; + ScopedFastFlag luauCodegenVectorTag2{FFlag::LuauCodegenVectorTag2, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) @@ -281,8 +308,6 @@ bb_bytecode_1: TEST_CASE("ExtraMathMemoryOperands") { - ScopedFastFlag luauCodegenMathMemArgs{FFlag::LuauCodegenMathMemArgs, true}; - CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: number, b: number, c: number, d: number, e: number) return math.floor(a) + math.ceil(b) + math.round(c) + math.sqrt(d) + math.abs(e) diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp index 94776296..536de39d 100644 --- a/tests/Simplify.test.cpp +++ b/tests/Simplify.test.cpp @@ -130,16 +130,17 @@ TEST_CASE_FIXTURE(SimplifyFixture, "overload_negation_refinement_is_never") TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_other_tops_and_bottom_types") { + CHECK(unknownTy == intersect(unknownTy, unknownTy)); - CHECK(unknownTy == intersect(unknownTy, anyTy)); - CHECK(unknownTy == intersect(anyTy, unknownTy)); + CHECK("*error-type* | unknown" == intersectStr(unknownTy, anyTy)); + CHECK("*error-type* | unknown" == intersectStr(anyTy, unknownTy)); CHECK(neverTy == intersect(unknownTy, neverTy)); CHECK(neverTy == intersect(neverTy, unknownTy)); - CHECK(neverTy == intersect(unknownTy, errorTy)); - CHECK(neverTy == intersect(errorTy, unknownTy)); + CHECK(errorTy == intersect(unknownTy, errorTy)); + CHECK(errorTy == intersect(errorTy, unknownTy)); } TEST_CASE_FIXTURE(SimplifyFixture, "nil") @@ -179,17 +180,37 @@ TEST_CASE_FIXTURE(SimplifyFixture, "boolean_and_truthy_and_falsy") TEST_CASE_FIXTURE(SimplifyFixture, "any_and_indeterminate_types") { - CHECK("'a" == intersectStr(anyTy, freeTy)); - CHECK("'a" == intersectStr(freeTy, anyTy)); + CHECK("'a | *error-type*" == intersectStr(anyTy, freeTy)); + CHECK("'a | *error-type*" == intersectStr(freeTy, anyTy)); - CHECK("b" == intersectStr(anyTy, genericTy)); - CHECK("b" == intersectStr(genericTy, anyTy)); + CHECK("*error-type* | b" == intersectStr(anyTy, genericTy)); + CHECK("*error-type* | b" == intersectStr(genericTy, anyTy)); - CHECK(blockedTy == intersect(anyTy, blockedTy)); - CHECK(blockedTy == intersect(blockedTy, anyTy)); + auto anyRhsBlocked = get(intersect(anyTy, blockedTy)); + auto anyLhsBlocked = get(intersect(blockedTy, anyTy)); - CHECK(pendingTy == intersect(anyTy, pendingTy)); - CHECK(pendingTy == intersect(pendingTy, anyTy)); + REQUIRE(anyRhsBlocked); + REQUIRE(anyRhsBlocked->options.size() == 2); + CHECK(blockedTy == anyRhsBlocked->options[0]); + CHECK(errorTy == anyRhsBlocked->options[1]); + + REQUIRE(anyLhsBlocked); + REQUIRE(anyLhsBlocked->options.size() == 2); + CHECK(blockedTy == anyLhsBlocked->options[0]); + CHECK(errorTy == anyLhsBlocked->options[1]); + + auto anyRhsPending = get(intersect(anyTy, pendingTy)); + auto anyLhsPending = get(intersect(pendingTy, anyTy)); + + REQUIRE(anyRhsPending); + REQUIRE(anyRhsPending->options.size() == 2); + CHECK(pendingTy == anyRhsPending->options[0]); + CHECK(errorTy == anyRhsPending->options[1]); + + REQUIRE(anyLhsPending); + REQUIRE(anyLhsPending->options.size() == 2); + CHECK(pendingTy == anyLhsPending->options[0]); + CHECK(errorTy == anyLhsPending->options[1]); } TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_indeterminate_types") @@ -197,22 +218,14 @@ TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_indeterminate_types") CHECK(freeTy == intersect(unknownTy, freeTy)); CHECK(freeTy == intersect(freeTy, unknownTy)); - TypeId t = nullptr; + CHECK(genericTy == intersect(unknownTy, genericTy)); + CHECK(genericTy == intersect(genericTy, unknownTy)); - t = intersect(unknownTy, genericTy); - CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); - t = intersect(genericTy, unknownTy); - CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); + CHECK(blockedTy == intersect(unknownTy, blockedTy)); + CHECK(blockedTy == intersect(unknownTy, blockedTy)); - t = intersect(unknownTy, blockedTy); - CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); - t = intersect(blockedTy, unknownTy); - CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); - - t = intersect(unknownTy, pendingTy); - CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); - t = intersect(pendingTy, unknownTy); - CHECK_MESSAGE(isIntersection(t), "Should be an intersection but got " << t); + CHECK(pendingTy == intersect(unknownTy, pendingTy)); + CHECK(pendingTy == intersect(unknownTy, pendingTy)); } TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_concrete") @@ -274,8 +287,8 @@ TEST_CASE_FIXTURE(SimplifyFixture, "primitives") CHECK(neverTy == intersect(neverTy, tableTy)); CHECK(neverTy == intersect(tableTy, neverTy)); - CHECK(numberTy == intersect(anyTy, numberTy)); - CHECK(numberTy == intersect(numberTy, anyTy)); + CHECK("*error-type* | number" == intersectStr(anyTy, numberTy)); + CHECK("*error-type* | number" == intersectStr(numberTy, anyTy)); CHECK(neverTy == intersect(stringTy, nilTy)); CHECK(neverTy == intersect(nilTy, stringTy)); @@ -504,7 +517,15 @@ TEST_CASE_FIXTURE(SimplifyFixture, "some_tables_are_really_never") CHECK(neverTy == intersect(t1, numberTy)); CHECK(neverTy == intersect(numberTy, t1)); - CHECK(neverTy == intersect(t1, t1)); + CHECK(t1 == intersect(t1, t1)); + + TypeId notUnknownTy = mkNegation(unknownTy); + + TypeId t2 = mkTable({{"someKey", notUnknownTy}}); + + CHECK(neverTy == intersect(t2, numberTy)); + CHECK(neverTy == intersect(numberTy, t2)); + CHECK(neverTy == intersect(t2, t2)); } TEST_CASE_FIXTURE(SimplifyFixture, "simplify_stops_at_cycles") @@ -520,20 +541,26 @@ TEST_CASE_FIXTURE(SimplifyFixture, "simplify_stops_at_cycles") tt->props["cyclic"] = Property{t2}; t2t->props["cyclic"] = Property{t}; - CHECK(t == intersect(t, anyTy)); - CHECK(t == intersect(anyTy, t)); + CHECK(t == intersect(t, unknownTy)); + CHECK(t == intersect(unknownTy, t)); - CHECK(t2 == intersect(t2, anyTy)); - CHECK(t2 == intersect(anyTy, t2)); + CHECK(t2 == intersect(t2, unknownTy)); + CHECK(t2 == intersect(unknownTy, t2)); + + CHECK("*error-type* | t1 where t1 = { cyclic: { cyclic: t1 } }" == intersectStr(t, anyTy)); + CHECK("*error-type* | t1 where t1 = { cyclic: { cyclic: t1 } }" == intersectStr(anyTy, t)); + + CHECK("*error-type* | t1 where t1 = { cyclic: { cyclic: t1 } }" == intersectStr(t2, anyTy)); + CHECK("*error-type* | t1 where t1 = { cyclic: { cyclic: t1 } }" == intersectStr(anyTy, t2)); } TEST_CASE_FIXTURE(SimplifyFixture, "free_type_bound_by_any_with_any") { - CHECK(freeTy == intersect(freeTy, anyTy)); - CHECK(freeTy == intersect(anyTy, freeTy)); + CHECK("'a | *error-type*" == intersectStr(freeTy, anyTy)); + CHECK("'a | *error-type*" == intersectStr(anyTy, freeTy)); - CHECK(freeTy == intersect(freeTy, anyTy)); - CHECK(freeTy == intersect(anyTy, freeTy)); + CHECK("'a | *error-type*" == intersectStr(freeTy, anyTy)); + CHECK("'a | *error-type*" == intersectStr(anyTy, freeTy)); } TEST_SUITE_END(); diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index dbc706ae..734ff036 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -509,7 +509,7 @@ TEST_CASE_FIXTURE(ClassFixture, "keyof_type_family_works_on_classes") CheckResult result = check(R"( type KeysOfMyObject = keyof - local function ok(idx: KeysOfMyObject): "BaseMethod" | "BaseField" return idx end + local function ok(idx: KeysOfMyObject): "BaseMethod" | "BaseField" | "Touched" return idx end local function err(idx: KeysOfMyObject): "BaseMethod" return idx end )"); @@ -518,7 +518,7 @@ TEST_CASE_FIXTURE(ClassFixture, "keyof_type_family_works_on_classes") TypePackMismatch* tpm = get(result.errors[0]); REQUIRE(tpm); CHECK_EQ("\"BaseMethod\"", toString(tpm->wantedTp)); - CHECK_EQ("\"BaseField\" | \"BaseMethod\"", toString(tpm->givenTp)); + CHECK_EQ("\"BaseField\" | \"BaseMethod\" | \"Touched\"", toString(tpm->givenTp)); } TEST_CASE_FIXTURE(ClassFixture, "keyof_type_family_errors_if_it_has_nonclass_part") diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 5d6b9b16..8d14f56b 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -32,7 +32,15 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(builtinTypes->anyType, requireType("a")); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // Bug: We do not simplify at the right time + CHECK_EQ("any?", toString(requireType("a"))); + } + else + { + CHECK_EQ(builtinTypes->anyType, requireType("a")); + } } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") @@ -64,7 +72,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") { CheckResult result = check(R"( - local bar: any + local bar = nil :: any local a for b in bar do @@ -74,13 +82,21 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("a"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // Bug: We do not simplify at the right time + CHECK_EQ("any?", toString(requireType("a"))); + } + else + { + CHECK_EQ("any", toString(requireType("a"))); + } } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") { CheckResult result = check(R"( - local bar: any + local bar = nil :: any local a for b in bar() do @@ -90,7 +106,39 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("a"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // Bug: We do not simplify at the right time + CHECK_EQ("any?", toString(requireType("a"))); + } + else + { + CHECK_EQ("any", toString(requireType("a"))); + } +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any_pack") +{ + CheckResult result = check(R"( + function bar(): ...any end + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // Bug: We do not simplify at the right time + CHECK_EQ("any?", toString(requireType("a"))); + } + else + { + CHECK_EQ("any", toString(requireType("a"))); + } } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") @@ -104,7 +152,16 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("*error-type*", toString(requireType("a"))); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // Bug: We do not simplify at the right time + CHECK_EQ("*error-type*?", toString(requireType("a"))); + } + else + { + CHECK_EQ("*error-type*", toString(requireType("a"))); + } } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -118,9 +175,21 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // CLI-97375(awe): `bar()` is returning `nil` here, which isn't wrong necessarily, + // but then we're signaling an additional error for the access on `nil`. + LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("*error-type*", toString(requireType("a"))); + // Bug: We do not simplify at the right time + CHECK_EQ("*error-type*?", toString(requireType("a"))); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("*error-type*", toString(requireType("a"))); + } } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 38c6748c..7a3397ce 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -967,7 +967,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") CHECK_EQ("number", toString(requireTypeAtPosition({10, 49}))); if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("unknown & ~number", toString(requireTypeAtPosition({10, 66}))); + CHECK_EQ("~number", toString(requireTypeAtPosition({10, 66}))); else CHECK_EQ("unknown", toString(requireTypeAtPosition({10, 66}))); } @@ -1497,7 +1497,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") if (FFlag::DebugLuauDeferredConstraintResolution) { CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("unknown & ~string", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("~string", toString(requireTypeAtPosition({5, 28}))); } else { diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index a9ae8fc6..6d300769 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4022,8 +4022,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_write_property") LUAU_REQUIRE_NO_ERRORS(result); - // CHECK("({ y: number }) -> ()" == toString(requireType("f"))); - CHECK("({ y: number & unknown }) -> ()" == toString(requireType("f"))); + CHECK("({ y: number }) -> ()" == toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "table_subtyping_error_suppression") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 8877c762..ce2cfe6b 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -9,6 +9,7 @@ #include "Luau/VisitType.h" #include "Fixture.h" +#include "ClassFixture.h" #include "ScopedFlags.h" #include "doctest.h" @@ -1219,6 +1220,26 @@ TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_callback_property") CHECK(location.end.line == 7); } +TEST_CASE_FIXTURE(ClassFixture, "bidirectional_inference_of_class_methods") +{ + CheckResult result = check(R"( + local c = ChildClass.New() + + -- Instead of reporting that the lambda is the wrong type, report that we are using its argument improperly. + c.Touched:Connect(function(other) + print(other.ThisDoesNotExist) + end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownProperty* err = get(result.errors[0]); + REQUIRE(err); + + CHECK("ThisDoesNotExist" == err->key); + CHECK("BaseClass" == toString(err->table)); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict") { CheckResult result = check(R"( diff --git a/tests/Unifier2.test.cpp b/tests/Unifier2.test.cpp index 842f9e06..8a3bc8de 100644 --- a/tests/Unifier2.test.cpp +++ b/tests/Unifier2.test.cpp @@ -182,4 +182,17 @@ TEST_CASE_FIXTURE(Unifier2Fixture, "generalize_a_type_that_is_bounded_by_another CHECK(builtinTypes.unknownType == follow(t2)); } +TEST_CASE_FIXTURE(Unifier2Fixture, "dont_traverse_into_class_types_when_generalizing") +{ + auto [propTy, _] = freshType(); + + TypeId cursedClass = arena.addType(ClassType{"Cursed", {{"oh_no", Property::readonly(propTy)}}, std::nullopt, std::nullopt, {}, {}, ""}); + + auto genClass = u2.generalize(cursedClass); + REQUIRE(genClass); + + auto genPropTy = get(*genClass)->props.at("oh_no").readTy; + CHECK(is(*genPropTy)); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index 89675715..008f0e9c 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -4,7 +4,6 @@ AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg AutocompleteTest.autocomplete_response_perf1 AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.do_wrong_compatible_nonself_calls -AutocompleteTest.type_correct_expected_argument_type_suggestion_self AutocompleteTest.type_correct_suggestion_for_overloads BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types @@ -92,7 +91,6 @@ GenericsTests.factories_of_generics GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_factories -GenericsTests.generic_functions_dont_cache_type_parameters GenericsTests.generic_functions_in_types GenericsTests.generic_type_families_work_in_subtyping GenericsTests.generic_type_pack_parentheses @@ -244,7 +242,6 @@ TableTests.generic_table_instantiation_potential_regression TableTests.indexer_mismatch TableTests.indexers_get_quantified_too TableTests.indexing_from_a_table_should_prefer_properties_when_possible -TableTests.inequality_operators_imply_exactly_matching_types TableTests.infer_indexer_from_its_variable_type_and_unifiable TableTests.inferred_return_type_of_free_table TableTests.instantiate_table_cloning_3 @@ -264,7 +261,6 @@ TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr TableTests.okay_to_add_property_to_unsealed_tables_by_assignment TableTests.okay_to_add_property_to_unsealed_tables_by_function_call TableTests.only_ascribe_synthetic_names_at_module_scope -TableTests.oop_polymorphic TableTests.open_table_unification_2 TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table_2 @@ -367,11 +363,6 @@ TypeInferAnyError.any_type_propagates TypeInferAnyError.assign_prop_to_table_by_calling_any_yields_any TypeInferAnyError.call_to_any_yields_any TypeInferAnyError.can_subscript_any -TypeInferAnyError.for_in_loop_iterator_is_any -TypeInferAnyError.for_in_loop_iterator_is_any2 -TypeInferAnyError.for_in_loop_iterator_is_error -TypeInferAnyError.for_in_loop_iterator_is_error2 -TypeInferAnyError.for_in_loop_iterator_returns_any TypeInferAnyError.intersection_of_any_can_have_props TypeInferAnyError.metatable_of_any_can_be_a_table TypeInferAnyError.quantify_any_does_not_bind_to_itself @@ -442,6 +433,7 @@ TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function TypeInferLoops.cli_68448_iterators_need_not_accept_nil TypeInferLoops.dcr_iteration_explore_raycast_minimization +TypeInferLoops.dcr_iteration_fragmented_keys TypeInferLoops.dcr_iteration_on_never_gives_never TypeInferLoops.dcr_xpath_candidates TypeInferLoops.for_in_loop @@ -449,10 +441,9 @@ TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_va TypeInferLoops.for_in_loop_error_on_iterator_requiring_args_but_none_given TypeInferLoops.for_in_loop_on_error TypeInferLoops.for_in_loop_on_non_function -TypeInferLoops.for_in_loop_with_custom_iterator -TypeInferLoops.for_in_loop_with_incompatible_args_to_iterator TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_an_iterator_of_type_any +TypeInferLoops.for_in_with_generic_next TypeInferLoops.for_loop TypeInferLoops.ipairs_produces_integral_indices TypeInferLoops.iterate_over_free_table @@ -483,7 +474,6 @@ TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted TypeInferOOP.object_constructor_can_refer_to_method_of_self TypeInferOOP.promise_type_error_too_complex -TypeInferOOP.react_style_oo TypeInferOperators.add_type_family_works TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_mismatch_metatable