From e6bf71871a6b9f601545dba8a42ce89c6069675c Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 8 Nov 2024 16:23:09 -0800 Subject: [PATCH] CodeGen: Rewrite dot product lowering using a dedicated IR instruction (#1512) Instead of doing the dot product related math in scalar IR, we lift the computation into a dedicated IR instruction. On x64, we can use VDPPS which was more or less tailor made for this purpose. This is better than manual scalar lowering that requires reloading components from memory; it's not always a strict improvement over the shuffle+add version (which we never had), but this can now be adjusted in the IR lowering in an optimal fashion (maybe even based on CPU vendor, although that'd create issues for offline compilation). On A64, we can either use naive adds or paired adds, as there is no dedicated vector-wide horizontal instruction until SVE. Both run at about the same performance on M2, but paired adds require fewer instructions and temporaries. I've measured this using mesh-normal-vector benchmark, changing the benchmark to just report the time of the second loop inside `calculate_normals`, testing master vs #1504 vs this PR, also increasing the grid size to 400 for more stable timings. On Zen 4 (7950X), this PR is comfortably ~8% faster vs master, while I see neutral to negative results in #1504. On M2 (base), this PR is ~28% faster vs master, while #1504 is only about ~10% faster. If I measure the second loop in `calculate_tangent_space` instead, I get: On Zen 4 (7950X), this PR is ~12% faster vs master, while #1504 is ~3% faster On M2 (base), this PR is ~24% faster vs master, while #1504 is only about ~13% faster. Note that the loops in question are not quite optimal, as they store and reload various vectors to dictionary values due to inappropriate use of locals. The underlying gains in individual functions are thus larger than the numbers above; for example, changing the `calculate_normals` loop to use a local variable to store the normalized vector (but still saving the result to dictionary value), I get a ~24% performance increase from this PR on Zen4 vs master instead of just 8% (#1504 is ~15% slower in this setup). --- CodeGen/include/Luau/AssemblyBuilderA64.h | 1 + CodeGen/include/Luau/AssemblyBuilderX64.h | 2 + CodeGen/include/Luau/IrData.h | 4 + CodeGen/include/Luau/IrUtils.h | 1 + CodeGen/src/AssemblyBuilderA64.cpp | 8 ++ CodeGen/src/AssemblyBuilderX64.cpp | 5 ++ CodeGen/src/IrDump.cpp | 2 + CodeGen/src/IrLoweringA64.cpp | 15 ++++ CodeGen/src/IrLoweringX64.cpp | 14 +++ CodeGen/src/IrTranslateBuiltins.cpp | 104 +++++++++++++++------- CodeGen/src/IrUtils.cpp | 2 + CodeGen/src/OptimizeConstProp.cpp | 4 +- tests/AssemblyBuilderA64.test.cpp | 3 + tests/AssemblyBuilderX64.test.cpp | 2 + 14 files changed, 135 insertions(+), 32 deletions(-) diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index a4d857a4..9d337942 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -138,6 +138,7 @@ public: void fneg(RegisterA64 dst, RegisterA64 src); void fsqrt(RegisterA64 dst, RegisterA64 src); void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void faddp(RegisterA64 dst, RegisterA64 src); // Vector component manipulation void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index c52d95c5..30790ee5 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -167,6 +167,8 @@ public: void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle); void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset); + void vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask); + // Run final checks bool finalize(); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index ae406bbc..b603af9e 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -194,6 +194,10 @@ enum class IrCmd : uint8_t // A: TValue UNM_VEC, + // Compute dot product between two vectors + // A, B: TValue + DOT_VEC, + // Compute Luau 'not' operation on destructured TValue // A: tag // B: int (value) diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 8d48780f..08700573 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -176,6 +176,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: case IrCmd::DIV_VEC: + case IrCmd::DOT_VEC: case IrCmd::UNM_VEC: case IrCmd::NOT_ANY: case IrCmd::CMP_ANY: diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index b98a21f2..23384e57 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -586,6 +586,14 @@ void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src) placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000); } +void AssemblyBuilderA64::faddp(RegisterA64 dst, RegisterA64 src) +{ + CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::s); + CODEGEN_ASSERT(dst.kind == src.kind); + + placeR1("faddp", dst, src, 0b011'11110'0'0'11000'01101'10 | ((dst.kind == KindA64::d) << 12)); +} + void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { if (dst.kind == KindA64::d) diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 73c40679..1e646bcb 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -946,6 +946,11 @@ void AssemblyBuilderX64::vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 s placeAvx("vpinsrd", dst, src1, src2, offset, 0x22, false, AVX_0F3A, AVX_66); } +void AssemblyBuilderX64::vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask) +{ + placeAvx("vdpps", dst, src1, src2, mask, 0x40, false, AVX_0F3A, AVX_66); +} + bool AssemblyBuilderX64::finalize() { code.resize(codePos - code.data()); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 2846db54..f4806b31 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -163,6 +163,8 @@ const char* getCmdName(IrCmd cmd) return "DIV_VEC"; case IrCmd::UNM_VEC: return "UNM_VEC"; + case IrCmd::DOT_VEC: + return "DOT_VEC"; case IrCmd::NOT_ANY: return "NOT_ANY"; case IrCmd::CMP_ANY: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index a63655cc..45ae5eeb 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -728,6 +728,21 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fneg(inst.regA64, regOp(inst.a)); break; } + case IrCmd::DOT_VEC: + { + inst.regA64 = regs.allocReg(KindA64::d, index); + + RegisterA64 temp = regs.allocTemp(KindA64::q); + RegisterA64 temps = castReg(KindA64::s, temp); + RegisterA64 regs = castReg(KindA64::s, inst.regA64); + + build.fmul(temp, regOp(inst.a), regOp(inst.b)); + build.faddp(regs, temps); // x+y + build.dup_4s(temp, temp, 2); + build.fadd(regs, regs, temps); // +z + build.fcvt(inst.regA64, regs); + break; + } case IrCmd::NOT_ANY: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index d06cef13..3e4592bf 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -675,6 +675,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0)); break; } + case IrCmd::DOT_VEC: + { + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); + + ScopedRegX64 tmp1{regs}; + ScopedRegX64 tmp2{regs}; + + RegisterX64 tmpa = vecOp(inst.a, tmp1); + RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + + build.vdpps(inst.regX64, tmpa, tmpb, 0x71); // 7 = 0b0111, sum first 3 products into first float + build.vcvtss2sd(inst.regX64, inst.regX64, inst.regX64); + break; + } case IrCmd::NOT_ANY: { // TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index cec18204..ebded522 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -14,6 +14,7 @@ static const int kMinMaxUnrolledParams = 5; static const int kBit32BinaryOpUnrolledParams = 5; LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen); +LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot); namespace Luau { @@ -907,15 +908,26 @@ static BuiltinImplResult translateBuiltinVectorMagnitude( build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); - IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); - IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); - IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + IrOp sum; - IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); - IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); - IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); - IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + sum = build.inst(IrCmd::DOT_VEC, a, a); + } + else + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + } IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); @@ -945,25 +957,43 @@ static BuiltinImplResult translateBuiltinVectorNormalize( build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); - IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); - IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); - IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); + IrOp sum = build.inst(IrCmd::DOT_VEC, a, a); - IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); - IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); - IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + IrOp invvec = build.inst(IrCmd::NUM_TO_VEC, inv); - IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + IrOp result = build.inst(IrCmd::MUL_VEC, a, invvec); - IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); - IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + result = build.inst(IrCmd::TAG_VECTOR, result); - IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); - IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); - IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result); + } + else + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); - build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + } return {BuiltinImplType::Full, 1}; } @@ -1019,19 +1049,31 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos)); - IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); - IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); - IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + IrOp sum; - IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); - IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4)); - IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); + IrOp b = build.inst(IrCmd::LOAD_TVALUE, args, build.constInt(0)); - IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); - IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); - IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); + sum = build.inst(IrCmd::DOT_VEC, a, b); + } + else + { + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); - IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); + IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); + + sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); + } build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index ebf4c34b..c1183a47 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -75,6 +75,8 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::DIV_VEC: case IrCmd::UNM_VEC: return IrValueKind::Tvalue; + case IrCmd::DOT_VEC: + return IrValueKind::Double; case IrCmd::NOT_ANY: case IrCmd::CMP_ANY: return IrValueKind::Int; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index fa1b18d3..6d453765 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -768,7 +768,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (tag == LUA_TBOOLEAN && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int))) canSplitTvalueStore = true; - else if (tag == LUA_TNUMBER && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) + else if (tag == LUA_TNUMBER && + (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) canSplitTvalueStore = true; else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst) canSplitTvalueStore = true; @@ -1342,6 +1343,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: case IrCmd::DIV_VEC: + case IrCmd::DOT_VEC: if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR) replace(function, inst.a, a->a); diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 2cd821b5..ee319a5f 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -400,6 +400,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath") SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841); SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD); + SINGLE_COMPARE(faddp(s29, s28), 0x7E30DB9D); + SINGLE_COMPARE(faddp(d29, d28), 0x7E70DB9D); + SINGLE_COMPARE(frinta(d1, d2), 0x1E664041); SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 655fa8f1..016616e0 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -577,6 +577,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4); SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02); + + SINGLE_COMPARE(vdpps(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x40, 0x3c, 0x11, 0x02); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions")