Improve A64 lowering for vector operations by using vector instructions (#1164)

This change replaces scalar versions of vector opcodes for A64 with
actual vector instructions.

We take the approach similar to X64: patch last component with zero,
perform the math, patch last component with type tag. I'm hoping that in
the future the type tag will be placed separately (separate IR opcode?)
because right now chains of math operations result in excessive type tag
operations.

To patch the type tag without always keeping a mask in a register,
ins.4s instructions can be used; unfortunately it's only capable of
patching a register in-place, so we need an extra register copy in case
it's not last-use. Usually it's last-use so the patch is free; probably
with IR rework mentioned above all of this can be improved (e.g.
load-with-patch will never need to copy).

~It's not 100% clear if we *have* to patch type tag: Apple does preserve
denormals but we'd need to benchmark this to see if there's an actual
performance impact. But for now we're playing it safe.~

This was tested by running the conformance tests, and new opcode
implementations were checked by comparing the result with
https://armconverter.com/.

Performance testing is complicated by the fact that OSS Luau doesn't
support vector constructor out of the box, and other limitations of
codegen. I've hacked vector constructor/type into REPL and confirmed
that on a test that calls this function in a loop (not inlined):

```
function fma(a: vector, b: vector, c: vector)
        return a * b + c
end
```

... this PR improves performance by ~6% (note that probably most of the
overhead here is the call dispatch; I didn't want to brave testing a
more complex expression). The assembly for an individual operation
changes as follows:

Before:

```
#   %14 = MUL_VEC %12, %13                                    ; useCount: 2, lastUse: %22
 dup         s29,v31.s[0]
 dup         s28,v30.s[0]
 fmul        s29,s29,s28
 ins         v31.s[0],v29.s[0]
 dup         s29,v31.s[1]
 dup         s28,v30.s[1]
 fmul        s29,s29,s28
 ins         v31.s[1],v29.s[0]
 dup         s29,v31.s[2]
 dup         s28,v30.s[2]
 fmul        s29,s29,s28
 ins         v31.s[2],v29.s[0]
```

After:

```
#   %14 = MUL_VEC %12, %13                                    ; useCount: 2, lastUse: %22
 ins         v31.s[3],w31
 ins         v30.s[3],w31
 fmul        v31.4s,v31.4s,v30.4s
 movz        w17,#4
 ins         v31.s[3],w17
```

**edit** final form (see comments):

```
#   %14 = MUL_VEC %12, %13                                    ; useCount: 2, lastUse: %22
 fmul        v31.4s,v31.4s,v30.4s
 movz        w17,#4
 ins         v31.s[3],w17
```
This commit is contained in:
Arseny Kapoulkine 2024-02-16 08:30:35 -08:00 committed by GitHub
parent ea14e65ea0
commit c5f4d973d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 172 additions and 66 deletions

View File

@ -211,7 +211,6 @@ private:
void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0, int N = 0);
void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2 = 0);
void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2);
void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t sizes, uint8_t op, uint8_t op2);
void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op);
void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op);
void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0);
@ -230,6 +229,7 @@ private:
void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op);
void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms);
void placeER(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift);
void placeVR(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint16_t op, uint8_t op2);
void place(uint32_t word);

View File

@ -63,13 +63,22 @@ AssemblyBuilderA64::~AssemblyBuilderA64()
void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src)
{
CODEGEN_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp);
CODEGEN_ASSERT(dst.kind == src.kind || (dst.kind == KindA64::x && src == sp) || (dst == sp && src.kind == KindA64::x));
if (dst.kind != KindA64::q)
{
CODEGEN_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp);
CODEGEN_ASSERT(dst.kind == src.kind || (dst.kind == KindA64::x && src == sp) || (dst == sp && src.kind == KindA64::x));
if (dst == sp || src == sp)
placeR1("mov", dst, src, 0b00'100010'0'000000000000);
if (dst == sp || src == sp)
placeR1("mov", dst, src, 0b00'100010'0'000000000000);
else
placeSR2("mov", dst, src, 0b01'01010);
}
else
placeSR2("mov", dst, src, 0b01'01010);
{
CODEGEN_ASSERT(dst.kind == src.kind);
placeR1("mov", dst, src, 0b10'01110'10'1'00000'00011'1 | (src.index << 6));
}
}
void AssemblyBuilderA64::mov(RegisterA64 dst, int src)
@ -575,12 +584,18 @@ void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src
placeR3("fadd", dst, src1, src2, 0b11110'01'1, 0b0010'10);
}
else
else if (dst.kind == KindA64::s)
{
CODEGEN_ASSERT(dst.kind == KindA64::s && src1.kind == KindA64::s && src2.kind == KindA64::s);
CODEGEN_ASSERT(src1.kind == KindA64::s && src2.kind == KindA64::s);
placeR3("fadd", dst, src1, src2, 0b11110'00'1, 0b0010'10);
}
else
{
CODEGEN_ASSERT(dst.kind == KindA64::q && src1.kind == KindA64::q && src2.kind == KindA64::q);
placeVR("fadd", dst, src1, src2, 0b0'01110'0'0'1, 0b11010'1);
}
}
void AssemblyBuilderA64::fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
@ -591,12 +606,18 @@ void AssemblyBuilderA64::fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src
placeR3("fdiv", dst, src1, src2, 0b11110'01'1, 0b0001'10);
}
else
else if (dst.kind == KindA64::s)
{
CODEGEN_ASSERT(dst.kind == KindA64::s && src1.kind == KindA64::s && src2.kind == KindA64::s);
CODEGEN_ASSERT(src1.kind == KindA64::s && src2.kind == KindA64::s);
placeR3("fdiv", dst, src1, src2, 0b11110'00'1, 0b0001'10);
}
else
{
CODEGEN_ASSERT(dst.kind == KindA64::q && src1.kind == KindA64::q && src2.kind == KindA64::q);
placeVR("fdiv", dst, src1, src2, 0b1'01110'00'1, 0b11111'1);
}
}
void AssemblyBuilderA64::fmul(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
@ -607,12 +628,18 @@ void AssemblyBuilderA64::fmul(RegisterA64 dst, RegisterA64 src1, RegisterA64 src
placeR3("fmul", dst, src1, src2, 0b11110'01'1, 0b0000'10);
}
else
else if (dst.kind == KindA64::s)
{
CODEGEN_ASSERT(dst.kind == KindA64::s && src1.kind == KindA64::s && src2.kind == KindA64::s);
CODEGEN_ASSERT(src1.kind == KindA64::s && src2.kind == KindA64::s);
placeR3("fmul", dst, src1, src2, 0b11110'00'1, 0b0000'10);
}
else
{
CODEGEN_ASSERT(dst.kind == KindA64::q && src1.kind == KindA64::q && src2.kind == KindA64::q);
placeVR("fmul", dst, src1, src2, 0b1'01110'00'1, 0b11011'1);
}
}
void AssemblyBuilderA64::fneg(RegisterA64 dst, RegisterA64 src)
@ -623,12 +650,18 @@ void AssemblyBuilderA64::fneg(RegisterA64 dst, RegisterA64 src)
placeR1("fneg", dst, src, 0b000'11110'01'1'0000'10'10000);
}
else
else if (dst.kind == KindA64::s)
{
CODEGEN_ASSERT(dst.kind == KindA64::s && src.kind == KindA64::s);
CODEGEN_ASSERT(src.kind == KindA64::s);
placeR1("fneg", dst, src, 0b000'11110'00'1'0000'10'10000);
}
else
{
CODEGEN_ASSERT(dst.kind == KindA64::q && src.kind == KindA64::q);
placeR1("fneg", dst, src, 0b011'01110'1'0'10000'01111'10);
}
}
void AssemblyBuilderA64::fsqrt(RegisterA64 dst, RegisterA64 src)
@ -646,12 +679,18 @@ void AssemblyBuilderA64::fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src
placeR3("fsub", dst, src1, src2, 0b11110'01'1, 0b0011'10);
}
else
else if (dst.kind == KindA64::s)
{
CODEGEN_ASSERT(dst.kind == KindA64::s && src1.kind == KindA64::s && src2.kind == KindA64::s);
CODEGEN_ASSERT(src1.kind == KindA64::s && src2.kind == KindA64::s);
placeR3("fsub", dst, src1, src2, 0b11110'00'1, 0b0011'10);
}
else
{
CODEGEN_ASSERT(dst.kind == KindA64::q && src1.kind == KindA64::q && src2.kind == KindA64::q);
placeVR("fsub", dst, src1, src2, 0b0'01110'10'1, 0b11010'1);
}
}
void AssemblyBuilderA64::ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index)
@ -952,18 +991,6 @@ void AssemblyBuilderA64::placeR3(const char* name, RegisterA64 dst, RegisterA64
commit();
}
void AssemblyBuilderA64::placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t sizes, uint8_t op, uint8_t op2)
{
if (logText)
log(name, dst, src1, src2);
CODEGEN_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst.kind == KindA64::d || dst.kind == KindA64::q);
CODEGEN_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind);
place(dst.index | (src1.index << 5) | (op2 << 10) | (src2.index << 16) | (op << 21) | (sizes << 29));
commit();
}
void AssemblyBuilderA64::placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op)
{
if (logText)
@ -1226,6 +1253,17 @@ void AssemblyBuilderA64::placeER(const char* name, RegisterA64 dst, RegisterA64
commit();
}
void AssemblyBuilderA64::placeVR(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint16_t op, uint8_t op2)
{
if (logText)
logAppend(" %-12sv%d.4s,v%d.4s,v%d.4s\n", name, dst.index, src1.index, src2.index);
CODEGEN_ASSERT(dst.kind == KindA64::q && dst.kind == src1.kind && dst.kind == src2.kind);
place(dst.index | (src1.index << 5) | (op2 << 10) | (src2.index << 16) | (op << 21) | (1 << 30));
commit();
}
void AssemblyBuilderA64::place(uint32_t word)
{
CODEGEN_ASSERT(codePos < codeEnd);

View File

@ -12,6 +12,7 @@
#include "lgc.h"
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCodeGenFixBufferLenCheckA64, false)
LUAU_FASTFLAGVARIABLE(LuauCodeGenVectorA64, false)
namespace Luau
{
@ -673,15 +674,26 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
for (uint8_t i = 0; i < 3; i++)
if (FFlag::LuauCodeGenVectorA64)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fadd(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
build.fadd(inst.regA64, regOp(inst.a), regOp(inst.b));
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
else
{
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
for (uint8_t i = 0; i < 3; i++)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fadd(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
}
}
break;
}
@ -689,15 +701,26 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
for (uint8_t i = 0; i < 3; i++)
if (FFlag::LuauCodeGenVectorA64)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fsub(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
build.fsub(inst.regA64, regOp(inst.a), regOp(inst.b));
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
else
{
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
for (uint8_t i = 0; i < 3; i++)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fsub(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
}
}
break;
}
@ -705,15 +728,26 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
for (uint8_t i = 0; i < 3; i++)
if (FFlag::LuauCodeGenVectorA64)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fmul(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
build.fmul(inst.regA64, regOp(inst.a), regOp(inst.b));
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
else
{
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
for (uint8_t i = 0; i < 3; i++)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fmul(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
}
}
break;
}
@ -721,15 +755,26 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
for (uint8_t i = 0; i < 3; i++)
if (FFlag::LuauCodeGenVectorA64)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fdiv(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
build.fdiv(inst.regA64, regOp(inst.a), regOp(inst.b));
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
else
{
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
for (uint8_t i = 0; i < 3; i++)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fdiv(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
}
}
break;
}
@ -737,13 +782,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a});
RegisterA64 tempa = regs.allocTemp(KindA64::s);
for (uint8_t i = 0; i < 3; i++)
if (FFlag::LuauCodeGenVectorA64)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.fneg(tempa, tempa);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
build.fneg(inst.regA64, regOp(inst.a));
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
else
{
RegisterA64 tempa = regs.allocTemp(KindA64::s);
for (uint8_t i = 0; i < 3; i++)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.fneg(tempa, tempa);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
}
}
break;
}

View File

@ -218,6 +218,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Moves")
{
SINGLE_COMPARE(mov(x0, x1), 0xAA0103E0);
SINGLE_COMPARE(mov(w0, w1), 0x2A0103E0);
SINGLE_COMPARE(mov(q0, q1), 0x4EA11C20);
SINGLE_COMPARE(movz(x0, 42), 0xD2800540);
SINGLE_COMPARE(movz(w0, 42), 0x52800540);
@ -501,6 +502,15 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "PrePostIndexing")
SINGLE_COMPARE(str(q0, mem(x1, 1, AddressKindA64::post)), 0x3C801420);
}
TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "SIMDMath")
{
SINGLE_COMPARE(fadd(q0, q1, q2), 0x4E22D420);
SINGLE_COMPARE(fsub(q0, q1, q2), 0x4EA2D420);
SINGLE_COMPARE(fmul(q0, q1, q2), 0x6E22DC20);
SINGLE_COMPARE(fdiv(q0, q1, q2), 0x6E22FC20);
SINGLE_COMPARE(fneg(q0, q1), 0x6EA0F820);
}
TEST_CASE("LogTest")
{
AssemblyBuilderA64 build(/* logText= */ true);
@ -552,6 +562,7 @@ TEST_CASE("LogTest")
build.ins_4s(q31, 1, q29, 2);
build.dup_4s(s29, q31, 2);
build.dup_4s(q29, q30, 0);
build.fmul(q0, q1, q2);
build.setLabel(l);
build.ret();
@ -594,6 +605,7 @@ TEST_CASE("LogTest")
ins v31.s[1],v29.s[2]
dup s29,v31.s[2]
dup v29.4s,v30.s[0]
fmul v0.4s,v1.4s,v2.4s
.L1:
ret
)";