Theano 中的 numpy.matmul

2024-03-16

TL;DR
我想复制的功能numpy.matmul in theano。最好的方法是什么?

过短;不明白
看着theano.tensor.dot and theano.tensor.tensordot,我没有看到一种简单的方法来进行简单的批量矩阵乘法。即,将 N 维张量的最后两个维度视为矩阵,并将它们相乘。我需要诉诸一些愚蠢的用法吗theano.tensor.batched_dot?或者*颤抖*自己循环播放而不广播!?


当前的拉取请求不支持广播,所以我现在想出了这个。我可能会清理它,添加更多功能,并提交我自己的 PR 作为临时解决方案。在那之前,我希望这对某人有帮助! 我包含了测试以显示它复制 numpy.matmul,前提是输入符合我更严格的(临时)断言。

此外,.scan 停止迭代序列argmin(*sequencelengths)迭代。因此,我相信不匹配的数组形状不会引发任何异常。

import theano as th
import theano.tensor as tt
import numpy as np


def matmul(a: tt.TensorType, b: tt.TensorType, _left=False):
    """Replicates the functionality of numpy.matmul, except that
    the two tensors must have the same number of dimensions, and their ndim must exceed 1."""

    # TODO ensure that broadcastability is maintained if both a and b are broadcastable on a dim.

    assert a.ndim == b.ndim  # TODO support broadcasting for differing ndims.
    ndim = a.ndim
    assert ndim >= 2

    # If we should left multiply, just swap references.
    if _left:
        tmp = a
        a = b
        b = tmp

    # If a and b are 2 dimensional, compute their matrix product.
    if ndim == 2:
        return tt.dot(a, b)
    # If they are larger...
    else:
        # If a is broadcastable but b is not.
        if a.broadcastable[0] and not b.broadcastable[0]:
            # Scan b, but hold a steady.
            # Because b will be passed in as a, we need to left multiply to maintain
            #  matrix orientation.
            output, _ = th.scan(matmul, sequences=[b], non_sequences=[a[0], 1])
        # If b is broadcastable but a is not.
        elif b.broadcastable[0] and not a.broadcastable[0]:
            # Scan a, but hold b steady.
            output, _ = th.scan(matmul, sequences=[a], non_sequences=[b[0]])
        # If neither dimension is broadcastable or they both are.
        else:
            # Scan through the sequences, assuming the shape for this dimension is equal.
            output, _ = th.scan(matmul, sequences=[a, b])
        return output


def matmul_test() -> bool:
    vlist = []
    flist = []
    ndlist = []
    for i in range(2, 30):
        dims = int(np.random.random() * 4 + 2)

        # Create a tuple of tensors with potentially different broadcastability.
        vs = tuple(
            tt.TensorVariable(
                tt.TensorType('float64',
                              tuple((p < .3) for p in np.random.ranf(dims-2))
                              # Make full matrices
                              + (False, False)
                )
            )
            for _ in range(2)
        )
        vs = tuple(tt.swapaxes(v, -2, -1) if j % 2 == 0 else v for j, v in enumerate(vs))

        f = th.function([*vs], [matmul(*vs)])

        # Create the default shape for the test ndarrays
        defshape = tuple(int(np.random.random() * 5 + 1) for _ in range(dims))
        # Create a test array matching the broadcastability of each v, for each v.
        nds = tuple(
            np.random.ranf(
                tuple(s if not v.broadcastable[j] else 1 for j, s in enumerate(defshape))
            )
            for v in vs
        )
        nds = tuple(np.swapaxes(nd, -2, -1) if j % 2 == 0 else nd for j, nd in enumerate(nds))

        ndlist.append(nds)
        vlist.append(vs)
        flist.append(f)

    for i in range(len(ndlist)):
        assert np.allclose(flist[i](*ndlist[i]), np.matmul(*ndlist[i]))

    return True


if __name__ == "__main__":
    print("matmul_test -> " + str(matmul_test()))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Theano 中的 numpy.matmul 的相关文章

随机推荐

  • git reset --soft 的实际用途?

    我使用 git 已经有一个多月了 事实上 我昨天才第一次使用重置 但软重置对我来说仍然没有多大意义 我知道我可以使用软重置来编辑提交 而无需更改索引或工作目录 就像我所做的那样git commit amend 这两个命令真的一样吗 rese
  • C++20 范围和排序

    我正在处理 C 20 的最后 4 个大版本 试图学习新的主要功能 尝试来自网络的一些与范围相关的代码 我写了 std vector ints 6 5 2 8 auto even int i return 0 i 2 ranges auto
  • 在 MongoDB 中的对象中插入数组

    我是 MongoDB 的新手 我想像这样插入 mongodb 数据 但我不知道如何做 image cab tags NNP 0 NN 1 image castle tags NNP 2 NN 1 我的代码是 BasicDBObject ob
  • 登录管理后,Django 开发服务器停止

    我已经在 python 3 7 中安装了 django 3 0 并启动了一个基本的 django 项目 我创建了一个超级用户并使用运行开发服务器python manage py runserver 当我去localhost 8000 adm
  • 如何使用 R 将日期时间格式转换为“ddmmyyyy”?

    我的约会dataframe看起来像这样 Date Values 1JAN2018 80 23DEC2019 21 3 我怎样才能将其格式化为ddmmyyyy日期以便我可以使用ggplot创建时间序列图 我做了什么 Date lt as Da
  • 处理父小部件中的点击事件

    在我的应用程序树中 我有两个小部件 GestureDetector onTap gt print Outer child IconButton icon Icon Icons add onPressed gt print Inner 他们都
  • 将 OnClickListener 关闭然后再打开

    我在用户单击按钮后将 OnClickListener 设置为关闭 confirm setOnClickListener null 这使得该按钮不可单击 但我希望在用户单击另一个按钮后它可以单击 我怎样才能做到这一点 Just set con
  • 枚举所有正在运行的数据库

    我正在编写一个小型数据库管理程序 如果您提供数据库 它可以正常工作 但如果您不知道安装了哪个数据库 则效果不佳 如何枚举所有正在运行的数据库 例如程序的输出 Port xy MS SQL Server 2005 Port ab Postgr
  • 没有子元素的 Javascript 元素 html [关闭]

    Closed 这个问题需要调试细节 help minimal reproducible example 目前不接受答案 在我的 javascript 代码中 我需要获取元素的定义 但没有其内容 既不是文本也不是子元素 例如 为了 div c
  • 如何将十进制基数 (10) 转换为负二进制基数 (-2)?

    我想编写一个程序将十进制转换为负二进制 我不知道如何从十进制转换为负二进制 我不知道如何找到规则以及它是如何运作的 例子 7 base10 gt 11011 base 2 我只知道是这样7 2 0 1 2 1 1 2 2 0 2 3 1 2
  • 如果进程附加了 CLR 调试器,.NET 代码运行速度是否会变慢?

    正如标题所说 我正在运行一个很长的程序 并且它附加了 CLR 调试器 因此我可以捕获和检查异常 我获得的性能是否与不使用调试器运行它相当 或者我是否付出了严重的 2 10 倍或更多 代价 最重要的是 工具 选项 调试 常规 抑制模块加载的
  • Java 中 if/else 与 switch 语句的相对性能差异是什么?

    担心我的 Web 应用程序的性能 我想知道 if else 或 switch 语句中哪一个在性能方面更好 我完全同意过早优化是应该避免的观点 但 Java VM 确实有可用于 switch 的特殊字节码 See WM Spec http d
  • 如何在flutter中使用injectable和get_it的共享首选项?

    我在flutter中使用injectable和get it包 我有一个共同的偏好类别 LazySingleton class SharedPref final String token token SharedPreferences pre
  • 使用 TextMode Number 回发后,TextBox 失去值

    遇到奇怪的问题 我有一个简单的页面TextBox
  • 如何使用 ScalaPB 序列化/反序列化使用“oneof”的 protobuf 消息?

    我正在使用 ScalaPB 编译 Scala 案例类来序列化我的 protobuf 消息 我有一个 proto包含以下消息的文件 message WrapperMessage oneof msg Login login 1 Register
  • AutoMapper 地图子属性也定义了地图

    我有以下域对象 public class DomainClass public int Id get set public string A get set public string B get set 我有以下两个要映射到的对象 pub
  • send() 函数返回的字节数多于 C++ 所需的字节数

    我正在做一个套接字程序 在我的服务器与设备连接后 我试图向他发送一条消息 但 send 函数返回的字节数大于数组中存储的字节数 并且消息没有被发送 这是我的代码 StartSendingMessages int retorno CStrin
  • 是否有任何 jQuery 版本符合 Promise/A 规范?

    在阅读了几篇文章之后 我开始知道 jQuery 中存在 Promise 实现 但我不确定 jQuery 的任何版本是否兼容 Promise A 2015 更新 jQuery 3 0 与 Promises A 兼容 看这个问题在 GitHub
  • SocketCluster 中间件握手与承诺

    我正在构建一个同时服务 http 和 ws 的应用程序 用户首先通过 HTTP 登录 Laravel 服务器 这会返回一个 JWT 用于允许通过 WS 登录 Ihv 添加了一个 MIDDLEWARE HANDSHAKE 来获取令牌并向 La
  • Theano 中的 numpy.matmul

    TL DR我想复制的功能numpy matmul in theano 最好的方法是什么 过短 不明白看着theano tensor dot and theano tensor tensordot 我没有看到一种简单的方法来进行简单的批量矩阵