我一直在尝试各种选项来加速 PyTorch 中的一些 for 循环逻辑。这样做的两个明显的选择是使用numba https://stackoverflow.com/a/75580380/1804173 or 编写自定义 C++ 扩展 https://pytorch.org/tutorials/advanced/cpp_extension.html.
作为一个例子,我从数字信号处理中选择了“可变长度延迟线”。使用简单的 Python for 循环可以简单但低效地编写此代码:
def delay_line(samples, delays):
"""
:param samples: Float tensor of shape (N,)
:param delays: Int tensor of shape (N,)
The goal is basically to mix each `samples[i]` with the delayed sample
specified by a per-sample `delays[i]`.
"""
for i in range(len(samples)):
delay = int(delays[i].item())
index_delayed = i - delay
if index_delayed < 0:
index_delayed = 0
samples[i] = 0.5 * (samples[i] + samples[index_delayed])
知道 for 循环在 Python 中的执行情况有多糟糕,我希望通过在 C++ 中实现相同的循环可以获得明显更好的性能。下列的教程 https://pytorch.org/tutorials/advanced/cpp_extension.html,我想出了从 Python 到 C++ 的直译:
void delay_line(torch::Tensor samples, torch::Tensor delays) {
int64_t input_size = samples.size(-1);
for (int64_t i = 0; i < input_size; ++i) {
int64_t delay = delays[i].item<int64_t>();
int64_t index_delayed = i - delay;
if (index_delayed < 0) {
index_delayed = 0;
}
samples[i] = 0.5 * (samples[i] + samples[index_delayed]);
}
}
我还采用了 Python 函数并将其包装到各种 jit 装饰器中以获得该函数的 numba 和 torchscript 版本(请参阅我的其他answer https://stackoverflow.com/a/75580380/1804173有关 numba 包装的详细信息)。然后,我对所有版本执行了基准测试,这还取决于张量是驻留在 CPU 还是 GPU 上。结果相当令人惊讶:
╭──────────────┬──────────┬────────────────────╮
│ Method │ Device │ Median time [ms] │
├──────────────┼──────────┼────────────────────┤
│ plain_python │ CPU │ 13.481 │
│ torchscript │ CPU │ 6.318 │
│ numba │ CPU │ 0.016 │
│ cpp │ CPU │ 9.056 │
│ plain_python │ GPU │ 45.412 │
│ torchscript │ GPU │ 47.809 │
│ numba │ GPU │ 0.236 │
│ cpp │ GPU │ 31.145 │
╰──────────────┴──────────┴────────────────────╯
Notes: sample buffer size was fixed to 1024; results are medians of 100 executions to ignore artifacts from the initial jit overhead; input data creation and moving it to the device is excluded from the measurements; full benchmark script gist https://gist.github.com/bluenote10/3370da06204b94995614ed014410f6c2
最显着的结果:C++ 变体似乎出奇地慢。 numba 快两个数量级的事实表明问题确实可以更快地解决。事实上,C++ 变体仍然非常接近众所周知的缓慢的 Python for 循环,这可能表明有些事情不太正确。
我想知道什么可以解释 C++ 扩展的糟糕性能。第一个想到的就是缺少优化。不过,我已经确保编译使用了优化。切换自-O2
to -O3
也没有什么区别。
为了隔离 pybind11 函数调用的开销,我用空函数体替换了 C++ 函数,即不执行任何操作。这将时间减少到 2-3μs,这意味着时间确实花在该特定函数体上。
有什么想法为什么我会观察到如此糟糕的性能吗?我可以在 C++ 方面做些什么来匹配 numba 实现的性能吗?
额外问题:GPU 版本是否会比 CPU 版本慢很多?