Summary
- 考虑交换
a
and b
(原提案)
- 尝试避免条件跳转(未成功优化)
- 重塑输入公式(预计增益 35%)
- 删除重复班次
- 展开循环:“最佳”装配
- 说服编译器给出最佳汇编
1. Consider swapping a
and b
一个改进是首先比较 a 和 b,如果满足则交换它们a<b
: 你应该使用b
两者中较小的一个,以便获得最少的循环数。请注意,您可以通过以下方式避免交换复制代码 (if (a<b)
然后跳转到镜像代码部分),但我怀疑它是否值得。
2. Trying to avoid conditional jumps (Not successful optimization)
Try:
uint16_t umul16_(uint16_t a, uint16_t b)
{
///Here swap if necessary
uint16_t accum=0;
while (b) {
accum += ((b&1) * uint16_t(0xffff)) & a; //Hopefully this multiplication is optimized away
b>>=1;
a+=a;
}
return accum;
}
从 Sergio 的反馈来看,这并没有带来任何改进。
3. Reshaping of the input formula
考虑到目标架构基本只有8bit指令,如果将输入变量的高低8位分开,可以这样写:
a = a1 * 0xff + a0;
b = b1 * 0xff + b0;
a * b = a1 * b1 * 0xffff + a0 * b1 * 0xff + a1 * b0 * 0xff + a0 * b0
现在,最酷的事情是我们可以扔掉这个词a1 * b1 * 0xffff
,因为0xffff
将其从您的寄存器中发送出去.
(16bit) a * b = a0 * b1 * 0xff + a1 * b0 * 0xff + a0 * b0
此外,a0*b1
and a1*b0
项可以被视为 8 位乘法,因为0xff
:任何超过 256 的部分都将是从登记册中发出.
到目前为止令人兴奋! ...但是,令人震惊的现实来了:a0 * b0
必须被视为 16 位乘法,因为您必须保留所有结果位。a0
必须保持在 16 位以允许左移。该乘法的迭代次数为a * b
, it is in part8位(因为b0)但是你仍然要考虑前面提到的2个8位乘法,以及最终的结果组合。我们需要进一步重塑!
所以现在我收集b0
.
(16bit) a * b = a0 * b1 * 0xff + b0 * (a0 + a1 * 0xff)
But
(a0 + a1 * 0xff) = a
所以我们得到:
(16bit) a * b = a0 * b1 * 0xff + b0 * a
如果N是原来的循环a * b
,现在第一项是 N/2 个周期的 8 位乘法,第二项是 N/2 个周期的 16 位 * 8 位乘法。考虑 M 是原始中每次迭代的指令数a*b
,8bit*8bit 迭代有一半的指令,16bit*8bit 大约有 M 的 80%(b0 比 b 少了一个移位指令)。放在一起我们有N/2*M/2+N/2*M*0.8 = N*M*0.65
复杂度,所以预计节省约 35%相对于原来的N*M
。声音有希望.
这是代码:
uint16_t umul16_(uint16_t a, uint16_t b)
{
uint8_t res1 = 0;
uint8_t a0 = a & 0xff; //This effectively needs to copy the data
uint8_t b0 = b & 0xff; //This should be optimized away
uint8_t b1 = b >>8; //This should be optimized away
//Here a0 and b1 could be swapped (to have b1 < a0)
while (b1) {///Maximum 8 cycles
if ( (b1 & 1) )
res1+=a0;
b1>>=1;
a0+=a0;
}
uint16_t res = (uint16_t) res1 * 256; //Should be optimized away, it's not even a copy!
//Here swapping wouldn't make much sense
while (b0) {///Maximum 8 cycles
if ( (b0 & 1) )
res+=a;
b0>>=1;
a+=a;
}
return res;
}
此外,从理论上讲,分成 2 个周期的分裂应该使跳过某些周期的机会加倍:N/2 可能稍微高估了。
进一步的微小改进在于避免了最后的不必要的转变a
变量。小旁注:如果 b0 或 b1 为零,则会导致 2 个额外指令。But它还保存了对 b0 和 b1 的第一次检查,这是最昂贵的,因为它无法检查zero flag
for 循环条件跳转的移位操作的状态。
uint16_t umul16_(uint16_t a, uint16_t b)
{
uint8_t res1 = 0;
uint8_t a0 = a & 0xff; //This effectively needs to copy the data
uint8_t b0 = b & 0xff; //This should be optimized away
uint8_t b1 = b >>8; //This should be optimized away
//Here a0 and b1 could be swapped (to have b1 < a0)
if ( (b1 & 1) )
res1+=a0;
b1>>=1;
while (b1) {///Maximum 7 cycles
a0+=a0;
if ( (b1 & 1) )
res1+=a0;
b1>>=1;
}
uint16_t res = (uint16_t) res1 * 256; //Should be optimized away, it's not even a copy!
//Here swapping wouldn't make much sense
if ( (b0 & 1) )
res+=a;
b0>>=1;
while (b0) {///Maximum 7 cycles
a+=a;
if ( (b0 & 1) )
res+=a;
b0>>=1;
}
return res;
}
4. Removing duplicated shift
还有改进的空间吗?Yes,作为字节a0
被转移两次。因此,将两个循环结合起来应该会有好处。说服编译器完全按照我们想要的方式执行可能有点棘手,尤其是对于结果寄存器。
所以,我们在同一个循环中处理b0
and b1
。首先要处理的是,循环退出条件是什么?到目前为止使用b0
/b1
清除状态很方便,因为它避免使用计数器。此外,在右移之后,如果运算结果为零,则可能已经设置了标志,并且该标志可能允许条件跳转而无需进一步评估。
现在循环退出条件可能是失败(b0 || b1)
。然而,这可能需要昂贵的计算。一种解决方案是比较 b0 和 b1 并跳转到 2 个不同的代码段: ifb1 > b0
我测试条件b1
,否则我测试条件b0
。我更喜欢另一种解决方案,有两个循环,第一次退出时b0
为零,第二个当b1
为零。在某些情况下我会进行零迭代b1
。关键是在第二个循环中我知道b0
为零,因此我可以减少执行的操作数量。
现在,让我们忘记退出条件并尝试加入上一节的 2 个循环。
uint16_t umul16_(uint16_t a, uint16_t b)
{
uint16_t res = 0;
uint8_t b0 = b & 0xff; //This should be optimized away
uint8_t b1 = b >>8; //This should be optimized away
//Swapping probably doesn't make much sense anymore
if ( (b1 & 1) )
res+=(uint16_t)((uint8_t)(a && 0xff))*256;
//Hopefully the compiler understands it has simply to add the low 8bit register of a to the high 8bit register of res
if ( (b0 & 1) )
res+=a;
b1>>=1;
b0>>=1;
while (b0) {///N cycles, maximum 7
a+=a;
if ( (b1 & 1) )
res+=(uint16_t)((uint8_t)(a & 0xff))*256;
if ( (b0 & 1) )
res+=a;
b1>>=1;
b0>>=1; //I try to put as last the one that will leave the carry flag in the desired state
}
uint8_t a0 = a & 0xff; //Again, not a real copy but a register selection
while (b1) {///P cycles, maximum 7 - N cycles
a0+=a0;
if ( (b1 & 1) )
res+=(uint16_t) a0 * 256;
b1>>=1;
}
return res;
}
感谢 Sergio 提供生成的程序集 (-Ofast)。乍一看,考虑到数量惊人mov
在代码中,编译器似乎没有解释,因为我想要我给他的提示来解释寄存器。
输入为:r22、r23 和 r24、25。
AVR指令集:快速参考 http://www.atmel.com/Images/8006S.pdf, 详细文档 http://www.atmel.com/images/doc0856.pdf
sbrs //Tests a single bit in a register and skips the next instruction if the bit is set. Skip takes 2 clocks.
ldi // Load immediate, 1 clock
sbiw // Subtracts immediate to *word*, 2 clocks
00000010 <umul16_Antonio5>:
10: 70 ff sbrs r23, 0
12: 39 c0 rjmp .+114 ; 0x86 <__SREG__+0x47>
14: 41 e0 ldi r20, 0x01 ; 1
16: 00 97 sbiw r24, 0x00 ; 0
18: c9 f1 breq .+114 ; 0x8c <__SREG__+0x4d>
1a: 34 2f mov r19, r20
1c: 20 e0 ldi r18, 0x00 ; 0
1e: 60 ff sbrs r22, 0
20: 07 c0 rjmp .+14 ; 0x30 <umul16_Antonio5+0x20>
22: 28 0f add r18, r24
24: 39 1f adc r19, r25
26: 04 c0 rjmp .+8 ; 0x30 <umul16_Antonio5+0x20>
28: e4 2f mov r30, r20
2a: 45 2f mov r20, r21
2c: 2e 2f mov r18, r30
2e: 34 2f mov r19, r20
30: 76 95 lsr r23
32: 66 95 lsr r22
34: b9 f0 breq .+46 ; 0x64 <__SREG__+0x25>
36: 88 0f add r24, r24
38: 99 1f adc r25, r25
3a: 58 2f mov r21, r24
3c: 44 27 eor r20, r20
3e: 42 0f add r20, r18
40: 53 1f adc r21, r19
42: 70 ff sbrs r23, 0
44: 02 c0 rjmp .+4 ; 0x4a <__SREG__+0xb>
46: 24 2f mov r18, r20
48: 35 2f mov r19, r21
4a: 42 2f mov r20, r18
4c: 53 2f mov r21, r19
4e: 48 0f add r20, r24
50: 59 1f adc r21, r25
52: 60 fd sbrc r22, 0
54: e9 cf rjmp .-46 ; 0x28 <umul16_Antonio5+0x18>
56: e2 2f mov r30, r18
58: 43 2f mov r20, r19
5a: e8 cf rjmp .-48 ; 0x2c <umul16_Antonio5+0x1c>
5c: 95 2f mov r25, r21
5e: 24 2f mov r18, r20
60: 39 2f mov r19, r25
62: 76 95 lsr r23
64: 77 23 and r23, r23
66: 61 f0 breq .+24 ; 0x80 <__SREG__+0x41>
68: 88 0f add r24, r24
6a: 48 2f mov r20, r24
6c: 50 e0 ldi r21, 0x00 ; 0
6e: 54 2f mov r21, r20
70: 44 27 eor r20, r20
72: 42 0f add r20, r18
74: 53 1f adc r21, r19
76: 70 fd sbrc r23, 0
78: f1 cf rjmp .-30 ; 0x5c <__SREG__+0x1d>
7a: 42 2f mov r20, r18
7c: 93 2f mov r25, r19
7e: ef cf rjmp .-34 ; 0x5e <__SREG__+0x1f>
80: 82 2f mov r24, r18
82: 93 2f mov r25, r19
84: 08 95 ret
86: 20 e0 ldi r18, 0x00 ; 0
88: 30 e0 ldi r19, 0x00 ; 0
8a: c9 cf rjmp .-110 ; 0x1e <umul16_Antonio5+0xe>
8c: 40 e0 ldi r20, 0x00 ; 0
8e: c5 cf rjmp .-118 ; 0x1a <umul16_Antonio5+0xa>
5. Unrolling the loop: The "optimal" assembly
有了所有这些信息,让我们尝试了解给定架构限制的“最佳”解决方案是什么。引用“最佳”是因为“最佳”很大程度上取决于输入数据和我们想要优化的内容。假设我们想要优化最坏情况下的周期数。如果我们考虑最坏的情况,循环展开是一个合理的选择:我们知道我们有 8 个周期,并且我们删除所有测试来了解我们是否完成(如果 b0 和 b1 为零)。到目前为止,我们使用了“移位,然后检查零标志”的技巧来检查是否必须退出循环。删除这个要求,我们可以使用不同的技巧:我们转移,并且我们检查进位位(移位时我们从寄存器中发出的位)了解我是否应该更新结果。给定指令集,在汇编“叙述”代码中,指令如下。
//Input: a = a1 * 256 + a0, b = b1 * 256 + b0
//Output: r = r1 * 256 + r0
Preliminary:
P0 r0 = 0 (CLR)
P1 r1 = 0 (CLR)
Main block:
0 Shift right b0 (LSR)
1 If carry is not set skip 2 instructions = jump to 4 (BRCC)
2 r0 = r0 + a0 (ADD)
3 r1 = r1 + a1 + carry from prev. (ADC)
4 Shift right b1 (LSR)
5 If carry is not set skip 1 instruction = jump to 7 (BRCC)
6 r1 = r1 + a0 (ADD)
7 a0 = a0 + a0 (ADD)
8 a1 = a1 + a1 + carry from prev. (ADC)
[Repeat same instructions for another 7 times]
如果没有引起跳转,则分支需要 1 条指令,否则需要 2 条指令。所有其他指令均为 1 个周期。所以b1状态对周期数没有影响,而如果b0 = 1,我们有9个周期,如果b0 = 0,我们有8个周期。在最坏的情况下,算上初始化,8次迭代并跳过a0和a1的最后更新(b0 = 11111111b),我们总共有8 * 9 + 2 - 2 =
72个周期。我不知道哪个 C++ 实现会说服编译器生成它。或许:
void iterate(uint8_t& b0,uint8_t& b1,uint16_t& a, uint16_t& r) {
const uint8_t temp0 = b0;
b0 >>=1;
if (temp0 & 0x01) {//Will this convince him to use the carry flag?
r += a;
}
const uint8_t temp1 = b1;
b1 >>=1;
if (temp1 & 0x01) {
r+=(uint16_t)((uint8_t)(a & 0xff))*256;
}
a += a;
}
uint16_t umul16_(uint16_t a, uint16_t b) {
uint16_t r = 0;
uint8_t b0 = b & 0xff;
uint8_t b1 = b >>8;
iterate(b0,b1,a,r);
iterate(b0,b1,a,r);
iterate(b0,b1,a,r);
iterate(b0,b1,a,r);
iterate(b0,b1,a,r);
iterate(b0,b1,a,r);
iterate(b0,b1,a,r);
iterate(b0,b1,a,r); //Hopefully he understands he doesn't need the last update for variable a
return r;
}
但是,考虑到前面的结果,要真正获得所需的代码,应该真正切换到汇编!
最后,人们还可以考虑对循环展开的更极端的解释:sbrc/sbrs 指令允许测试寄存器的特定位。因此,我们可以避免移位 b0 和 b1,并在每个周期检查不同的位。唯一的问题是这些指令只允许跳过下一条指令,而不允许自定义跳转。所以,在“叙事代码”中,它看起来像这样:
Main block:
0 Test Nth bit of b0 (SBRS). If set jump to 2 (+ 1cycle) otherwise continue with 1
1 Jump to 4 (RJMP)
2 r0 = r0 + a0 (ADD)
3 r1 = r1 + a1 + carry from prev. (ADC)
4 Test Nth bit of (SBRC). If cleared jump to 6 (+ 1cycle) otherwise continue with 5
5 r1 = r1 + a0 (ADD)
6 a0 = a0 + a0 (ADD)
7 a1 = a1 + a1 + carry from prev. (ADC)
虽然第二次替换可以节省1个周期,但第二次替换并没有明显的优势。然而,我相信 C++ 代码可能更容易被编译器解释。考虑到 8 个周期、初始化并跳过 a0 和 a1 的最后更新,我们现在有64个周期.
C++代码:
template<uint8_t mask>
void iterateWithMask(const uint8_t& b0,const uint8_t& b1, uint16_t& a, uint16_t& r) {
if (b0 & mask)
r += a;
if (b1 & mask)
r+=(uint16_t)((uint8_t)(a & 0xff))*256;
a += a;
}
uint16_t umul16_(uint16_t a, const uint16_t b) {
uint16_t r = 0;
const uint8_t b0 = b & 0xff;
const uint8_t b1 = b >>8;
iterateWithMask<0x01>(b0,b1,a,r);
iterateWithMask<0x02>(b0,b1,a,r);
iterateWithMask<0x04>(b0,b1,a,r);
iterateWithMask<0x08>(b0,b1,a,r);
iterateWithMask<0x10>(b0,b1,a,r);
iterateWithMask<0x20>(b0,b1,a,r);
iterateWithMask<0x40>(b0,b1,a,r);
iterateWithMask<0x80>(b0,b1,a,r);
//Hopefully he understands he doesn't need the last update for a
return r;
}
请注意,在此实现中,0x01、0x02 不是实际值,而只是提示编译器知道要测试哪个位。因此,掩码无法通过右移获得:与迄今为止看到的所有其他函数不同,这实际上没有等效的循环版本。
一个大问题是
r+=(uint16_t)((uint8_t)(a & 0xff))*256;
它应该只是高位寄存器的总和r
与较低的寄存器a
。
没有得到我想要的解释。其他选项:
r+=(uint16_t) 256 *((uint8_t)(a & 0xff));
6. Convincing the compiler to give the optimal assembly
我们还可以保留a
常数,并移动结果r
。在这种情况下我们处理b
从最高有效位开始。复杂性是相同的,但编译器可能更容易消化。另外,这次我们必须小心地显式地编写最后一个循环,它不能对 for 进行进一步的右移r
.
template<uint8_t mask>
void inverseIterateWithMask(const uint8_t& b0,const uint8_t& b1,const uint16_t& a, const uint8_t& a0, uint16_t& r) {
if (b0 & mask)
r += a;
if (b1 & mask)
r+=(uint16_t)256*a0; //Hopefully easier to understand for the compiler?
r += r;
}
uint16_t umul16_(const uint16_t a, const uint16_t b) {
uint16_t r = 0;
const uint8_t b0 = b & 0xff;
const uint8_t b1 = b >>8;
const uint8_t a0 = a & 0xff;
inverseIterateWithMask<0x80>(b0,b1,a,r);
inverseIterateWithMask<0x40>(b0,b1,a,r);
inverseIterateWithMask<0x20>(b0,b1,a,r);
inverseIterateWithMask<0x10>(b0,b1,a,r);
inverseIterateWithMask<0x08>(b0,b1,a,r);
inverseIterateWithMask<0x04>(b0,b1,a,r);
inverseIterateWithMask<0x02>(b0,b1,a,r);
//Last iteration:
if (b0 & 0x01)
r += a;
if (b1 & 0x01)
r+=(uint16_t)256*a0;
return r;
}