如何在 C++ 中高效应用多项式而无需循环? [关闭]

2023-11-24

我想获得一些复杂函数的准确近似值(pow, exp, log, log2...) 比 C++ 标准库中 cmath 提供的更快。

为此,我想利用浮点编码方式并使用位操作获取指数和尾数,然后进行多项式近似。尾数在 1 和 2 之间,因此我使用 n 阶多项式来近似 [1, 2] 中域 x 中的目标函数,并对浮点表达式进行位操作和简单数学运算以使计算有效。

I used np.polyfit生成多项式。作为示例,以下是我用来在 1

P = np.array(
    [
        0.01459855,
        -0.17811046,
        0.95074541,
        -2.91450247,
        5.67353733,
        -7.39616658,
        7.08511059,
        -3.23521156,
    ],
    dtype=float,
)

要应用多项式,请将第一项乘以 x 的 7 次方,第二项乘以 x 的 6 次方,依此类推...

In code:

P[0] * x**7 + P[1] * x**6 + P[2] * x**5 + P[3] * x**4 + P[4] * x**3 + P[5] * x**2 + P[6] * x + P[7]

当然,这是非常低效的,首先计算较大的幂,因此存在大量重复计算,如果我们颠倒顺序,我们可以根据先前的幂计算出当前的幂,如下所示:

PR = P[::-1]
s = 0
c = 1
for i in PR:
    s += i * c
    c *= x

这正是我在 C++ 中所做的:

constexpr double LOG2_POLY7[8] = {
    -3.23521156,
    7.08511059,
    -7.39616658,
    5.67353733,
    -2.91450247,
    0.95074541,
    -0.17811046,
    0.01459855,
};
constexpr float FU = 1.0 / (1 << 23);

inline float fast_log2_accurate(float f) {
    uint32_t bits = *reinterpret_cast<uint32_t*>(&f);
    int e = ((bits >> 23) & 0xff) - 127;
    double m1 = 1 + (bits & 0x7fffff) * FU;
    double s = 0;
    double m = 1;
    for (const double& p : LOG2_POLY7) {
        s += p * m;
        m *= m1;
    }
    return e + s;
}

它比 cmath 的 log2 快得多,同时获得相同的精度:

log2(3.1415927f) = 1.651496171951294 : 42.68856048583984 nanoseconds
fast_log2_accurate(3.1415927f) = 1.651496171951294 : 9.967899322509766 nanoseconds

我使用 Visual Studio 2022 进行编译,编译器标志:

/permissive- /ifcOutput "x64\Release\" /GS /GL /W3 /Gy /Zc:wchar_t /Zi /Gm- /O2 /Ob1 /sdl /Fd"x64\Release\vc143.pdb" /Zc:inline /fp:fast /D "NDEBUG" /D "_CONSOLE" /D "_UNICODE" /D "UNICODE" /errorReport:prompt /WX- /Zc:forScope /std:c17 /Gd /Oi /MD /std:c++20 /FC /Fa"x64\Release\" /EHsc /nologo /Fo"x64\Release\" /Ot /Fp"x64\Release\exponentiation.pch" /diagnostics:column /arch:AVX2

但我认为它可以更有效。有一个循环开销,如果我可以优化循环,它应该会更快。

如何在没有循环的情况下应用多项式?


如果循环已经展开,那么是否可以使用 SIMD 指令进行计算以使其更快?


我对下面提供的解决方案以及我之前编写的其他一些函数进行了基准测试:

#include <vector>
#include <numbers>
using std::vector;

using std::numbers;
using numbers::ln2;
using numbers::pi;

constexpr double LOG2_POLY7[8] = {
    -3.23521156,
    7.08511059,
    -7.39616658,
    5.67353733,
    -2.91450247,
    0.95074541,
    -0.17811046,
    0.01459855,
};
constexpr float FU = 1.0 / (1 << 23);

inline float fast_log2_accurate(float f) {
    uint32_t bits = std::bit_cast<uint32_t>(f);
    int e = ((bits >> 23) & 0xff) - 127;
    double m1 = 1 + (bits & 0x7fffff) * FU;
    double s = 0;
    double m = 1;
    for (const double& p : LOG2_POLY7) {
        s += p * m;
        m *= m1;
    }
    return e + s;
}

template <int N> inline double poly(const double* a, const float x) {
    return (a[0] + x * poly<N - 1>(a + 1, x));
}

template <> inline double poly<0>(const double* a, const float x) {
    return x * a[0];
}

inline float fast_log2_accurate2(float f) {
    uint32_t bits = std::bit_cast<uint32_t>(f);
    int e = ((bits >> 23) & 0xff) - 127;
    double m1 = 1 + (bits & 0x7fffff) * FU;
    return e + poly<8>(LOG2_POLY7, m1);
}


inline float fast_log2_accurate3(float f) {
    uint32_t bits = std::bit_cast<uint32_t>(f);
    int e = ((bits >> 23) & 0xff) - 127;
    double m1 = 1 + (bits & 0x7fffff) * FU;
    double s = 0;
    double m = 1;
    for (int i = 0; i < 8; i++) {
        s += LOG2_POLY7[i] * m;
        m *= m1;
    }
    return e + s;
}

vector<float> log2_float() {
    int lim = 1 << 24;
    vector<float> table(lim);
    for (int i = 0; i < lim; i++) {
        table[i] = float(log(i) / ln2) - 150;
    }
    return table;
}
const vector<float> LOG2_FLOAT = log2_float();

inline float fast_log2(float f) {
    uint32_t bits = std::bit_cast<uint32_t>(f);
    int e = (bits >> 23) & 0xff;
    int m = bits & 0x7fffff;
    return (e == 0 ? LOG2_FLOAT[m << 1] : e + LOG2_FLOAT[m | 0x00800000]);
}

inline float fast_log(float f) {
    return fast_log2(f) * ln2;
}

vector<double> log2_double() {
    int lim = 1 << 24;
    vector<double> table(lim);
    for (uint64_t i = 0; i < lim; i++) {
        table[i] = log(i << 29) / ln2 - 1075;
    }
    return table;
}
const vector<double> LOG2_DOUBLE = log2_double();

inline double fast_log2_double(double d) {
    uint64_t bits = std::bit_cast<uint64_t>(d);
    uint64_t e = (bits >> 52) & 0x7ff;
    uint64_t m = bits & 0xfffffffffffff;
    return (e == 0 ? LOG2_DOUBLE[m >> 28] : e + LOG2_DOUBLE[(m | 0x10000000000000) >> 29]);
}
fast_log2(3.1415927f) = 1.651496887207031 : 0.2610206604003906 nanoseconds
log2f(3.1415927f) = 1.651496171951294 : 33.27693939208984 nanoseconds
fast_log2_double(pi) = 1.651496060131421 : 0.3225326538085938 nanoseconds
fast_log2_accurate(3.1415927f) = 1.651496171951294 : 8.907032012939453 nanoseconds
fast_log2_accurate3(3.1415927f) = 1.651496171951294 : 7.831001281738281 nanoseconds
fast_log2_accurate2(3.1415927f) = 1.651496171951294 : 13.57889175415039 nanoseconds

虽然使用查找表的两个函数是无与伦比的,但它们相当不准确。我明确地使用了log2f在我的基准测试中。正如您所看到的,在 MSVC 中它非常慢。

正如预期的那样,递归函数会显着减慢代码速度。使用旧式循环可使代码运行速度加快 2 纳秒。然而我无法对使用的那个进行基准测试std::index_sequence,它导致编译器错误,我无法解决它。


我的基准测试代码中有一个错误,导致递归版本计时不准确,它使测量的时间更长,我修复了这个问题。

最新答案的解决方案:

inline float fast_log2_accurate4(float f) {
    uint32_t bits = std::bit_cast<uint32_t>(f);
    int e = ((bits >> 23) & 0xff) - 127;
    float m = 1 + (bits & 0x7fffff) * FU;
    float s_even = LOG2_POLY7[0];
    float s_odd = LOG2_POLY7[1] * m;
    float m2 = m * m;
    float m_even = m2;
    float m_odd = m * m2;

    for (int i = 2; i < 8; i += 2) {
        s_even += LOG2_POLY7[i] * m_even;
        s_odd += LOG2_POLY7[i + 1] * m_odd;
        m_even *= m2;
        m_odd *= m2;
    }
    return e + s_even + s_odd;
}
fast_log2_accurate4(3.1415927f) = 1.651496887207031 : 17.01173782348633 nanoseconds

它不像我的代码那么准确并且需要更长的时间,因为每次迭代都更昂贵。


之前索引序列版本编译失败,因为我使用了double[8]代替std::array<double, 8>,我以为它们是同一个东西!经指出后,我修复了该问题,并且编译成功。

基准:

ln(256) = 5.545613288879395 : 3.985881805419922 nanoseconds
log(256) = 5.545177459716797 : 7.047939300537109 nanoseconds
fast_log2(3.1415927f) = 1.651496887207031 : 0.25787353515625 nanoseconds
log2f(3.1415927f) = 1.651496171951294 : 35.03541946411133 nanoseconds
fast_log2_double(pi) = 1.651496060131421 : 0.3331184387207031 nanoseconds
fast_log2_accurate(3.1415927f) = 1.651496171951294 : 9.366512298583984 nanoseconds
fast_log2_accurate3(3.1415927f) = 1.651496171951294 : 7.454872131347656 nanoseconds
fast_log2_accurate2(3.1415927f) = 1.651496171951294 : 14.07079696655273 nanoseconds
fast_log2_accurate4(3.1415927f) = 1.651496887207031 : 16.6351318359375 nanoseconds
fast_log2_accurate5(3.1415927f) = 1.651496171951294 : 7.868862152099609 nanoseconds

事实证明ln速度非常快,击败它的唯一方法是使用查找表,但它只给出 3 个正确的十进制数字,np.log(256) 给出 5.545177444479562。相比之下,我最快的函数可以给出 6 个正确的十进制数字,并且速度快十倍。我只需要把它乘以ln2 to get ln(x),而且这样会更准确。

我对解决方案进行了多次基准测试,并且fast_log2_accurate5是索引序列版本。它和旧式循环版本始终比我基于范围的 for 循环版本更快。有时 for 循环版本更快,有时索引序列版本更快。在这个级别上,测量值波动很大,并且我同时运行许多其他程序。

但看起来索引序列版本的性能比for循环版本稳定得多,所以我会接受它。


Update:

我重新审视了代码,我对索引序列版本做了一点改动,我只是添加了inline在...前面do_loop函数,这个小小的改变,使得代码在不到一纳秒的时间内运行,我可以添加更多术语,而不会减慢代码太多,它仍然会比log2同时获得非常准确的结果。

相比之下std::apply即使使用内联,版本也很慢:

constexpr std::array<double, 8> LOG2_POLY7A = {
    -3.23521156,
    7.08511059,
    -7.39616658,
    5.67353733,
    -2.91450247,
    0.95074541,
    -0.17811046,
    0.01459855,
};

template <std::size_t... I>
inline double do_loop(double m1, std::index_sequence<I...>) {
    double s = 0;
    double m = 1;
    ((s += std::get<I>(LOG2_POLY7A) * m, m *= m1), ...);
    return s;
}

inline float fast_log2_accurate5(float f) {
    uint32_t bits = std::bit_cast<uint32_t>(f);
    int e = ((bits >> 23) & 0xff) - 127;
    double m1 = 1 + (bits & 0x7fffff) * FU;
    double s = do_loop(m1, std::make_index_sequence<8>{});
    return e + s;
}

inline double do_loop1(double m1) {
    double s = 0;
    double m = 1;
    auto worker = [&](auto&...term) {
        ((s += term * m, m *= m1), ...);
        };
    std::apply(worker, LOG2_POLY7A);
    return s;
}

inline float fast_log2_accurate6(float f) {
    uint32_t bits = *reinterpret_cast<uint32_t*>(&f);
    int e = ((bits >> 23) & 0xff) - 127;
    double m1 = 1 + (bits & 0x7fffff) * FU;
    double s = do_loop1(m1);
    return e + s;
}
fast_log2_accurate5(3.1415927f) = 1.651496171951294 : 0.9766578674316406 nanoseconds
fast_log2_accurate6(3.1415927f) = 1.651496171951294 : 7.168102264404297 nanoseconds

您可以使用以下命令展开循环std::index_sequence, 如下:

#include <array>
#include <cstdint>
#include <utility>

constexpr std::size_t size_log2 = 8;

constexpr std::array<double, size_log2> LOG2_POLY7 = {
    -3.23521156,
    7.08511059,
    -7.39616658,
    5.67353733,
    -2.91450247,
    0.95074541,
    -0.17811046,
    0.01459855,
};
constexpr float FU = 1.0 / (1 << 23);

template <std::size_t... I>
double do_loop(double m1, std::index_sequence<I...>) {
    double s = 0;
    double m = 1;
    ((s += std::get<I>(LOG2_POLY7) * m, m *= m1),...);
    return s;
}

inline float fast_log2_accurate(float f) {
    uint32_t bits = *reinterpret_cast<uint32_t*>(&f);
    int e = ((bits >> 23) & 0xff) - 127;
    double m1 = 1 + (bits & 0x7fffff) * FU;
    double s = do_loop(m1, std::make_index_sequence<size_log2>{});
    return e + s;
}

测试代码godbolt.

请注意,我还移动了变量m里面的do_loop函数,因为只有那里才需要它。并且,根据评论中的建议,do_loop返回存储在变量中的结果s在问题中。 与您的初始版本相比,循环的展开是在编译时完成的,并且避免了,例如,比较p with LOG2_POLY7.end()在每次迭代时。与往常一样,应该对实际增益进行基准测试。

另外,正如评论中所述,可以通过使用来简化循环std::apply转换数组LOG2_POLY7进入可变参数;见下文或其他答案.

正如评论中所问,人们可以概括一下do_loop适用于泛型的函数std::array。然而,这意味着额外的例程。这是因为,在do_loop上面,type of the index_sequence由非类型模板参数确定std::size_t... I;这些参数,以及第二个参数的类型do_loop被推导出来。 现在,给定一个通用的std::array,你可以推断出大小N,但要使do_loop你需要转换这个的工作N到一系列索引0...N-1: 正是如此index_sequence is for.

因此,对于更通用的例程,可以将上面的代码替换为:

template <std::size_t N, std::size_t... I>
double do_loop_impl(double m1, const std::array<double, N>& data, std::index_sequence<I...>) {
    double s = 0;
    double m = 1;
    ((s += std::get<I>(data) * m, m *= m1),...);
    return s;
}

template <std::size_t N>
double do_loop(double m1, const std::array<double, N>& data) {
    return do_loop_impl(m1, data, std::make_index_sequence<N>{});
}

inline float fast_log2_accurate(float f) {
    uint32_t bits = *reinterpret_cast<uint32_t*>(&f);
    int e = ((bits >> 23) & 0xff) - 127;
    double m1 = 1 + (bits & 0x7fffff) * FU;
    double s = do_loop(m1, LOG2_POLY7);
    return e + s;
}

或者,可以使用std::apply,稍微概括一下其他答案:

template <std::size_t N>
double do_loop(double m1, const std::array<double, N>& data) {
    return std::apply([m1](auto... p) {
                      double s = 0;
                      double m = 1;
                      ((s += p * m, m *= m1), ...);
                      return s; }, data);
}

inline float fast_log2_accurate(float f) {
    uint32_t bits = *reinterpret_cast<uint32_t*>(&f);
    int e = ((bits >> 23) & 0xff) - 127;
    double m1 = 1 + (bits & 0x7fffff) * FU;
    double s = do_loop(m1, LOG2_POLY7);
    return e + s;
}

(记得include <tuple>如果你使用std::apply).

最后,如果您对速度感兴趣,您可以考虑 SIMD 矢量化,请参阅例如openMP SIMD 指令, the 矢量类库, the xsimd 库。这可能需要重新考虑循环。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 C++ 中高效应用多项式而无需循环? [关闭] 的相关文章

随机推荐

  • 收到异常“枚举时集合发生了变异”

    当我使用此代码时 我收到 Collection was mutated while being enumerated 异常 任何人都可以建议我如何摆脱这种情况 PaymentTerms currentElement for currentE
  • 使用 Log::Log4perl 制作自记录模块

    有没有办法使用日志 Log4perl制作一个智能自记录模块 即使在没有调用脚本且未初始化 Log4perl 的情况下 也可以将其操作记录到文件中 据我从文档中可以看出 使用 Log4perl 的唯一方法是在运行脚本中从配置初始化它 然后实现
  • 为什么浏览器在刷新同一页面时会发送两个请求?

    我创建了一个简单的 Node js 应用程序 它记录日志以控制台当前request url对于每个传入的 HTTP 请求 当我在 Mac OS X ML 上的 Chrome 中刷新页面时 我收到对同一页面的两个请求 Why 相比之下 当我使
  • 如何从静态 javascript 获取 ember / emberjs 中视图实例的引用?

    我在网络上 SOF 和 Google 看到了很多有关此问题的问题 但到目前为止还没有明确的答案 我有一个常见的 Ember 应用程序 带有各种视图和控制器 我的一个视图有一个实例方法 我想从静态上下文中调用它 因此在一个普通的 javasc
  • 为什么 6 个内置常量中有 2 个是可赋值的?

    在有关的文档中内置常量 不包括site常量 指出 注 姓名None False True and debug 无法重新分配 对它们的分配 即使作为属性名称 也会引发SyntaxError 因此它们可以被视为 真 常数 如果我没错的话 Tru
  • 免费源代码控制

    到目前为止 尽管我做了很多小型家庭项目 但我从未在自己的项目中使用过任何源代码管理 我现在即将部署我的第一个个人公共网站 并认为这是建立一些东西的好时机 我正在寻找的主要内容之一是版本控制 标签等 与 Visual Studio 2010
  • 如何制作一个接受尾随垃圾的 DateTimeFormatter?

    我正在改装一些旧的SimpleDateFormat使用新 Java 8 的代码DateTimeFormatter SimpleDateFormat 因此旧代码接受日期后面包含内容的字符串 例如 20130311nonsense 这DateT
  • 当目录名中有空格时如何使用copyfile?

    我正在尝试在 Windows 下执行简单的文件复制任务 但遇到一些问题 我的第一次尝试是使用 import shutils source C Documents and Settings Some directory My file txt
  • 在 Angular 中处理 forEach Ajax 调用的正确方法

    我需要使用 for 循环更新数组中每个对象的数据 一旦捕获所有数据 就运行一个函数 我不想在其中混合 jQuery 并以正确的 Angular 方式进行 这就是我正在做的事情 scope units u1 u2 u3 scope data
  • 使用 Elastic BeanStalk + Django 设置 ElastiCache Redis

    另一个堆栈溢出answer说您需要设置一个elasticache config文件来自动使用ElastiCache创建Redis服务器 但是 我可以在 AWS Elasticache 上创建一个 Redis 实例并将其端点添加到 Djang
  • 在 C++ 中插入和删除整数中的逗号

    这里非常菜鸟 所以最好假设我对任何答案一无所知 我一直在编写一个小应用程序 它运行良好 但可读性对我的数字来说是一场噩梦 本质上 我想做的就是在屏幕上显示的数字中添加逗号以使其更易于阅读 有没有一种快速且简单的方法可以做到这一点 我一直在使
  • 如何删除 Eclipse Mars Jboss Tools 工具栏项目

    我已经安装了 Eclipse Mars 并且还从 eclipse 市场安装了 Jboss Tools 我的问题是 安装 JBoss 工具后 我的菜单栏中似乎有一组服务器控件 这些控件是按照早期学习中心风格创建的 我已经尝试过 窗口 gt 透
  • Rijndael 256 加密:Java 和 .NET 不匹配

    我需要将 Rijandael 加密的 powershell 脚本转换为 Java 这是源powershell代码 Reflection Assembly LoadWithPartialName System Security Add Typ
  • 您遵循个人软件流程吗?您的组织/团队是否遵循团队软件流程? [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心以获得指导 了解更多信息 维基百科上的
  • 延迟回调直到脚本添加到文档中?

    如何让回调在脚本实际附加到文档之前不运行 function addScript filepath callback if filepath var fileref document createElement script fileref
  • Strapi - 限制用户仅获取与他相关的数据

    通常 登录用户会获取内容类型的所有条目 我创建了一个 片段 内容类型 id name content users lt lt gt gt snippets lt lt gt gt 表示 具有并属于许多 关系 我创建了一些测试用户并提出请求
  • 在 pyspark 列表中对不同数据帧列求和的正确方法是什么?

    我想对 Spark 数据框中的不同列求和 Code from pyspark sql import functions as F cols A p1 B p1 df spark createDataFrame 1 2 4 89 12 60
  • 如何使用 jQuery 动画更改背景图像?

    我想使用慢速动画更改背景图像 但它不起作用 body stop animate background url 1 jpg slow 语法有问题吗 您可以通过将图像不透明度淡化为 0 然后更改背景图像 最后再次淡化图像来获得类似的效果 这将需
  • 在 R 中强制字符向量编码从“未知”到“UTF-8”

    我有一个问题字符向量编码不一致 in R 我从中读取表格的文本文件已编码 通过Notepad in UTF 8 我尝试过UTF 8 without BOM 也 我想从这个文本文件中读取表格 然后将其转换data table set a ke
  • 如何在 C++ 中高效应用多项式而无需循环? [关闭]

    Closed 这个问题需要多问focused 目前不接受答案 我想获得一些复杂函数的准确近似值 pow exp log log2 比 C 标准库中 cmath 提供的更快 为此 我想利用浮点编码方式并使用位操作获取指数和尾数 然后进行多项式