diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index a4d857a4..9d337942 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -138,6 +138,7 @@ public: void fneg(RegisterA64 dst, RegisterA64 src); void fsqrt(RegisterA64 dst, RegisterA64 src); void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void faddp(RegisterA64 dst, RegisterA64 src); // Vector component manipulation void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index c52d95c5..30790ee5 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -167,6 +167,8 @@ public: void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle); void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset); + void vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask); + // Run final checks bool finalize(); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index ae406bbc..b603af9e 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -194,6 +194,10 @@ enum class IrCmd : uint8_t // A: TValue UNM_VEC, + // Compute dot product between two vectors + // A, B: TValue + DOT_VEC, + // Compute Luau 'not' operation on destructured TValue // A: tag // B: int (value) diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 8d48780f..08700573 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -176,6 +176,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: case IrCmd::DIV_VEC: + case IrCmd::DOT_VEC: case IrCmd::UNM_VEC: case IrCmd::NOT_ANY: case IrCmd::CMP_ANY: diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index b98a21f2..23384e57 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -586,6 +586,14 @@ void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src) placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000); } +void AssemblyBuilderA64::faddp(RegisterA64 dst, RegisterA64 src) +{ + CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::s); + CODEGEN_ASSERT(dst.kind == src.kind); + + placeR1("faddp", dst, src, 0b011'11110'0'0'11000'01101'10 | ((dst.kind == KindA64::d) << 12)); +} + void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { if (dst.kind == KindA64::d) diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 73c40679..1e646bcb 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -946,6 +946,11 @@ void AssemblyBuilderX64::vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 s placeAvx("vpinsrd", dst, src1, src2, offset, 0x22, false, AVX_0F3A, AVX_66); } +void AssemblyBuilderX64::vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask) +{ + placeAvx("vdpps", dst, src1, src2, mask, 0x40, false, AVX_0F3A, AVX_66); +} + bool AssemblyBuilderX64::finalize() { code.resize(codePos - code.data()); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 2846db54..f4806b31 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -163,6 +163,8 @@ const char* getCmdName(IrCmd cmd) return "DIV_VEC"; case IrCmd::UNM_VEC: return "UNM_VEC"; + case IrCmd::DOT_VEC: + return "DOT_VEC"; case IrCmd::NOT_ANY: return "NOT_ANY"; case IrCmd::CMP_ANY: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index a63655cc..45ae5eeb 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -728,6 +728,21 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fneg(inst.regA64, regOp(inst.a)); break; } + case IrCmd::DOT_VEC: + { + inst.regA64 = regs.allocReg(KindA64::d, index); + + RegisterA64 temp = regs.allocTemp(KindA64::q); + RegisterA64 temps = castReg(KindA64::s, temp); + RegisterA64 regs = castReg(KindA64::s, inst.regA64); + + build.fmul(temp, regOp(inst.a), regOp(inst.b)); + build.faddp(regs, temps); // x+y + build.dup_4s(temp, temp, 2); + build.fadd(regs, regs, temps); // +z + build.fcvt(inst.regA64, regs); + break; + } case IrCmd::NOT_ANY: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index d06cef13..3e4592bf 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -675,6 +675,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0)); break; } + case IrCmd::DOT_VEC: + { + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); + + ScopedRegX64 tmp1{regs}; + ScopedRegX64 tmp2{regs}; + + RegisterX64 tmpa = vecOp(inst.a, tmp1); + RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + + build.vdpps(inst.regX64, tmpa, tmpb, 0x71); // 7 = 0b0111, sum first 3 products into first float + build.vcvtss2sd(inst.regX64, inst.regX64, inst.regX64); + break; + } case IrCmd::NOT_ANY: { // TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index cec18204..ebded522 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -14,6 +14,7 @@ static const int kMinMaxUnrolledParams = 5; static const int kBit32BinaryOpUnrolledParams = 5; LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen); +LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot); namespace Luau { @@ -907,15 +908,26 @@ static BuiltinImplResult translateBuiltinVectorMagnitude( build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); - IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); - IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); - IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + IrOp sum; - IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); - IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); - IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); - IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + sum = build.inst(IrCmd::DOT_VEC, a, a); + } + else + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + } IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); @@ -945,25 +957,43 @@ static BuiltinImplResult translateBuiltinVectorNormalize( build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); - IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); - IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); - IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); + IrOp sum = build.inst(IrCmd::DOT_VEC, a, a); - IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); - IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); - IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + IrOp invvec = build.inst(IrCmd::NUM_TO_VEC, inv); - IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + IrOp result = build.inst(IrCmd::MUL_VEC, a, invvec); - IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); - IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + result = build.inst(IrCmd::TAG_VECTOR, result); - IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); - IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); - IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result); + } + else + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); - build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + } return {BuiltinImplType::Full, 1}; } @@ -1019,19 +1049,31 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos)); - IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); - IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); - IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + IrOp sum; - IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); - IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4)); - IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); + IrOp b = build.inst(IrCmd::LOAD_TVALUE, args, build.constInt(0)); - IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); - IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); - IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); + sum = build.inst(IrCmd::DOT_VEC, a, b); + } + else + { + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); - IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); + IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); + + sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); + } build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index ebf4c34b..c1183a47 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -75,6 +75,8 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::DIV_VEC: case IrCmd::UNM_VEC: return IrValueKind::Tvalue; + case IrCmd::DOT_VEC: + return IrValueKind::Double; case IrCmd::NOT_ANY: case IrCmd::CMP_ANY: return IrValueKind::Int; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index fa1b18d3..6d453765 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -768,7 +768,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (tag == LUA_TBOOLEAN && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int))) canSplitTvalueStore = true; - else if (tag == LUA_TNUMBER && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) + else if (tag == LUA_TNUMBER && + (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) canSplitTvalueStore = true; else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst) canSplitTvalueStore = true; @@ -1342,6 +1343,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: case IrCmd::DIV_VEC: + case IrCmd::DOT_VEC: if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR) replace(function, inst.a, a->a); diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 2cd821b5..ee319a5f 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -400,6 +400,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath") SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841); SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD); + SINGLE_COMPARE(faddp(s29, s28), 0x7E30DB9D); + SINGLE_COMPARE(faddp(d29, d28), 0x7E70DB9D); + SINGLE_COMPARE(frinta(d1, d2), 0x1E664041); SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 655fa8f1..016616e0 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -577,6 +577,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4); SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02); + + SINGLE_COMPARE(vdpps(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x40, 0x3c, 0x11, 0x02); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions")