【机器学习】Numpy手写机器学习算法,3万行代码!

2023-10-27

一、前言

NumPy 作为 Python 生态中最受欢迎的科学计算包,很多读者已经非常熟悉它了。

它为 Python 提供高效率的多维数组计算,并提供了一系列高等数学函数,我们可以快速搭建模型的整个计算流程。毫不负责任地说,NumPy 就是现代深度学习框架的「爸爸」。

尽管目前使用 NumPy 写模型已经不是主流,但这种方式依然不失为是理解底层架构和深度学习原理的好方法。最近,来自普林斯顿的一位博士后将 NumPy 实现的所有机器学习模型全部开源,超过 3 万行代码、30 多个模型,并提供了相应的论文和一些实现的测试效果。

粗略估计,该项目大约有 30 个主要机器学习模型,此外还有 15 个用于预处理和计算的小工具,全部.py 文件数量有 62 个之多。平均每个模型的代码行数在 500 行以上,在神经网络模型的 layer.py 文件中,代码行数接近 4000。

这,应该是目前用 NumPy 手写机器学习模型的「最高境界」吧。

项目地址为:

https://github.com/ddbourgin/numpy-ml

二、作者简介

通过项目的代码目录,我们能发现,作者基本上把主流模型都实现了一遍,这个工作量简直惊为天人。

作者 David Bourgin 是一位大神,于 2018 年获得加州大学伯克利分校计算认知科学博士学位,随后在普林斯顿大学从事博士后研究。

尽管毕业不久,David 在顶级期刊与计算机会议上都发表了一些优秀论文。在 ICML 2019 中,其关于认知模型先验的研究就被接收为少有的 Oral 论文。

在这里插入图片描述

David Bourgin 就是用 NumPy 手写 ML 模型、手推反向传播的大神。这么多的工作量,当然还是需要很多参考资源的,David 会理解这些资源或实现,并以一种更易读的方式写出来。

他表示,从 autograd repo 学到了很多,但二者的不同之处在于,他显式地进行了所有梯度计算,以突出概念/数学的清晰性。当然,这么做的缺点也很明显,在每次需要微分一个新函数时,你都要写出它的公式……

估计 David Bourgin 在写完这个项目后,机器学习基础已经极其牢固了。

三、项目总体介绍

这个项目最大的特点是作者把机器学习模型都用 NumPy 手写了一遍,包括更显式的梯度计算和反向传播过程。可以说它就是一个机器学习框架了,只不过代码可读性会强很多。

David Bourgin 表示他一直在慢慢写或收集不同模型与模块的纯 NumPy 实现,它们跑起来可能没那么快,但是模型的具体过程一定足够直观。每当我们想了解模型 API 背后的实现,却又不想看复杂的框架代码,那么它可以作为快速的参考。

如下所示为项目文件,不同的文件夹即不同种类的代码集:

在这里插入图片描述
在这里插入图片描述
在每一个代码集下,作者都会提供不同实现的参考资料,例如模型的效果示例图、参考论文和参考链接等。

当然如此庞大的代码总会存在一些 Bug,作者也非常希望我们能一起完善这些实现。如果我们以前用纯 NumPy 实现过某些好玩的模型,那也可以直接提交 PR 请求。因为实现基本上都只依赖于 NumPy,那么环境配置就简单很多了,大家差不多都能跑得动。

四、手写 NumPy 全家福

作者在 GitHub 中提供了模型/模块的实现列表,列表结构基本就是代码文件的结构了。整体上,模型主要分为两部分,即传统机器学习模型与主流的深度学习模型。

其中浅层模型既有隐马尔可夫模型和提升方法这样的复杂模型,也包含了线性回归或最近邻等经典方法。而深度模型则主要从各种模块、层级、损失函数、最优化器等角度搭建代码架构,从而能快速构建各种神经网络。

除了模型外,整个项目还有一些辅助模块,包括一堆预处理相关的组件和有用的小工具。

该 repo 的模型或代码结构如下所示:

  1. 高斯混合模型
    EM 训练
  2. 隐马尔可夫模型
    维特比解码
    似然计算
    通过 Baum-Welch/forward-backward 算法进行 MLE 参数估计
  3. 隐狄利克雷分配模型(主题模型)
    用变分 EM 进行 MLE 参数估计的标准模型
    用 MCMC 进行 MAP 参数估计的平滑模型
  4. 神经网络
    4.1 层/层级运算
    Add
    Flatten
    Multiply
    Softmax
    全连接/Dense
    稀疏进化连接
    LSTM
    Elman 风格的 RNN
    最大+平均池化
    点积注意力
    受限玻尔兹曼机 (w. CD-n training)
    2D 转置卷积 (w. padding 和 stride)
    2D 卷积 (w. padding、dilation 和 stride)
    1D 卷积 (w. padding、dilation、stride 和 causality)
    4.2 模块
    双向 LSTM
    ResNet 风格的残差块(恒等变换和卷积)
    WaveNet 风格的残差块(带有扩张因果卷积)
    Transformer 风格的多头缩放点积注意力
    4.3 正则化项
    Dropout
    归一化
    批归一化(时间上和空间上)
    层归一化(时间上和空间上)
    4.4 优化器
    SGD w/ 动量
    AdaGrad
    RMSProp
    Adam
    4.5 学习率调度器
    常数
    指数
    Noam/Transformer
    Dlib 调度器
    4.6 权重初始化器
    Glorot/Xavier uniform 和 normal
    He/Kaiming uniform 和 normal
    标准和截断正态分布初始化
    4.7 损失
    交叉熵
    平方差
    Bernoulli VAE 损失
    带有梯度惩罚的 Wasserstein 损失
    4.8 激活函数
    ReLU
    Tanh
    Affine
    Sigmoid
    Leaky ReLU
    4.9 模型
    Bernoulli 变分自编码器
    带有梯度惩罚的 Wasserstein GAN
    4.10 神经网络工具
    col2im (MATLAB 端口)
    im2col (MATLAB 端口)
    conv1D
    conv2D
    deconv2D
    minibatch
  5. 基于树的模型
    决策树 (CART)
    [Bagging] 随机森林
    [Boosting] 梯度提升决策树
  6. 线性模型
    岭回归
    Logistic 回归
    最小二乘法
    贝叶斯线性回归 w/共轭先验
    7.n 元序列模型
    最大似然得分
    Additive/Lidstone 平滑
    简单 Good-Turing 平滑
  7. 强化学习模型
    使用交叉熵方法的智能体
    首次访问 on-policy 蒙特卡罗智能体
    加权增量重要采样蒙特卡罗智能体
    Expected SARSA 智能体
    TD-0 Q-learning 智能体
    Dyna-Q / Dyna-Q+ 优先扫描
  8. 非参数模型
    Nadaraya-Watson 核回归
    k 最近邻分类与回归
  9. 预处理
    离散傅立叶变换 (1D 信号)
    双线性插值 (2D 信号)
    最近邻插值 (1D 和 2D 信号)
    自相关 (1D 信号)
    信号窗口
    文本分词
    特征哈希
    特征标准化
    One-hot 编码/解码
    Huffman 编码/解码
    词频逆文档频率编码
  10. 工具
    相似度核
    距离度量
    优先级队列
    Ball tree 数据结构

五、项目示例

由于代码量庞大,这里整理了一些示例。

例如,实现点积注意力机制:

在这里插入图片描述

class DotProductAttention(LayerBase):
    def __init__(self, scale=True, dropout_p=0, init="glorot_uniform", optimizer=None):
        super().__init__(optimizer)
        self.init = init
        self.scale = scale
        self.dropout_p = dropout_p
        self.optimizer = self.optimizer
        self._init_params()

    def _fwd(self, Q, K, V):
        scale = 1 / np.sqrt(Q.shape[-1]) if self.scale else 1
        scores = Q @ K.swapaxes(-2, -1) * scale  # attention scores
        weights = self.softmax.forward(scores)  # attention weights
        Y = weights @ V
        return Y, weights

    def _bwd(self, dy, q, k, v, weights):
        d_k = k.shape[-1]
        scale = 1 / np.sqrt(d_k) if self.scale else 1

        dV = weights.swapaxes(-2, -1) @ dy
        dWeights = dy @ v.swapaxes(-2, -1)
        dScores = self.softmax.backward(dWeights)
        dQ = dScores @ k * scale
        dK = dScores.swapaxes(-2, -1) @ q * scale
        return dQ, dK, dV

在以上代码中,Q、K、V 三个向量输入到「_fwd」函数中,用于计算每个向量的注意力分数,并通过 softmax 的方式得到权重。而「_bwd」函数则计算 V、注意力权重、注意力分数、Q 和 K 的梯度,用于更新网络权重。

在一些实现中,作者也进行了测试,并给出了测试结果。如图为隐狄利克雷(Latent Dirichlet allocation,LDA)实现进行文本聚类的结果。左图为词语在特定主题中的分布热力图。右图则为文档在特定主题中的分布热力图。

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【机器学习】Numpy手写机器学习算法,3万行代码! 的相关文章

  • 没有任何元数据的 zip 文件

    我想找到一种简单的方法来压缩一堆文件 而无需任何文件元数据 例如时间戳 这zip命令似乎总是保留元数据 我没有找到禁用元数据的方法 我希望解决方案是一个命令或最多一个 python 脚本 谢谢 正如一些帖子已经指出的那样 zip 标头中的大
  • 行未从树视图复制

    该行未在树视图中复制 我在按行并复制并粘贴到未粘贴的任何地方后制作了弹出复制 The code popup tk Menu tree opportunity tearoff 0 def row copy item tree opportun
  • 如何确定非阻塞套接字是否真正连接?

    这个问题不仅限于Python 这是一个一般的套接字问题 我有一个非阻塞套接字 想要连接到一台可访问的机器 在另一端 该端口不存在 为什么 select 仍然成功 我预计会超时 sock send 因管道损坏而失败 select 之后如何确定
  • 从字符串到类型的词法转换

    最近 我尝试用Python存储和读取文件中的信息 遇到了一个小问题 我想从文本文件中读取类型信息 从 string 到 int 或 float 的类型转换非常有效 但从 string 到 type 的类型转换似乎是另一个问题 当然 我尝试了
  • 如何从 PyCharm 项目中获取我的“exe”[重复]

    这个问题在这里已经有答案了 通过 PyCharm 在 Python 上编写一些项目 我想从中获取一个exe文件 我尝试过 另存为 gt XXX exe 但是 当我尝试执行它时出现错误 此类操作系统不支持该文件 附注 我有win7 x64 它
  • 可以在 TensorFlow 中使用排名相关作为成本函数吗?

    我正在处理偶尔充满异常值的极其嘈杂的数据 因此我主要依靠相关性来衡量我的神经网络的准确性 是否可以明确使用诸如等级相关性 斯皮尔曼相关系数 之类的东西作为我的成本函数 到目前为止 我主要依赖 MSE 作为相关性的代理 我现在面临三个主要障碍
  • 优化 Keras 以使用所有可用的 CPU 资源

    好吧 我真的不知道我在说什么 所以请耐心听我说 我正在使用 Theano 后端运行 Keras 以在 MNIST 图像上运行基本的神经网络 目前只是一个教程 过去 我一直使用我的旧 HP 笔记本电脑 因为我有 Windows 和 Ubunt
  • Pandas重置索引未生效[重复]

    这个问题在这里已经有答案了 我不确定我在哪里误入歧途 但我似乎无法重置数据帧上的索引 当我跑步时test head 我得到以下输出 正如您所看到的 数据帧是一个切片 因此索引超出范围 我想做的是重置该数据帧的索引 所以我跑test rese
  • 带图像的简单 GUI [关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 我试图在简单的 GUI 上显示一些卡
  • 具有多个元素的数组的真值是二义性错误吗? Python

    from numpy import from pylab import from math import def TentMap a x if x gt 0 and x lt 0 5 return 2 a x elif x gt 0 5 a
  • 获取列表中倒数第二个元素[重复]

    这个问题在这里已经有答案了 我可以通过以下方式获取列表的倒数第二个元素 gt gt gt lst a b c d e f gt gt gt print lst len lst 2 e 有没有比使用更好的方法print lst len lst
  • Python:计算数据帧列中所有行中特定字符的实例数

    我有一个包含列 toaddress ccaddress body 的数据框 df 我想迭代数据帧的索引 以获取 toaddress 和 ccaddress 字段中电子邮件地址的最小 最大和平均数量 这是通过计算这两列中每个字段中的 和 的实
  • 在可编辑的QSqlQueryModel中实现setEditStrategy

    这是后续这个问题 https stackoverflow com questions 49752388 editable qtableview of complex sql query 在那里 我们创建了 QSqlQueryModel 的可
  • 将 Pandas 列中的列表拆分为单独的列

    这是我在 pandas 数据框中的 特征 列 Feature Cricket 82379 Kabaddi 255 Reality 4751 Cricket 15640 Wildlife 730 LiveTV 13 Football 4129
  • conda-env list / conda info --envs 如何查找环境?

    我一直在尝试 anaconda miniconda 因为我的用户使用随 miniconda 安装的结构生物学程序 并且作者都没有 A 考虑到可能存在其他 miniconda 应用程序 B 他们的程序将在多用户环境中使用 因此 使用 Arch
  • 如何有效地从 loadmat 函数生成的嵌套 numpy 数组中提取值?

    python中是否有更有效的方法从嵌套的python列表中提取数据 例如A array array 12000000 dtype object 我一直在使用A 0 0 0 0 当你有很多像 A 这样的数据时 这似乎不是一个有效的方法 我也用
  • 如何通过字符串匹配加速 pandas 行过滤?

    我经常需要过滤 pandas 数据框df by df df col name string value 并且我想加快行选择操作 有没有快速的方法可以做到这一点 例如 In 1 df mul df 3000 2000 3 reset inde
  • 如何在sphinx中启用数学?

    我在用sphinx http sphinx pocoo org index html与pngmath http sphinx pocoo org ext math html module sphinx ext pngmath扩展来记录我的代
  • 异步和协程与任务队列

    我一直在阅读有关 python 3 中的 asyncio 模块的内容 以及更广泛地了解 python 中的协程的内容 但我不明白是什么让 asyncio 成为如此出色的工具 我的感觉是 你可以用协程做的所有事情 通过使用基于多处理模块 例如
  • 在 Python 模块中使用 InstaLoader

    我正在尝试使用 Instaloader 下载与主题标签相关的照片以进行图像分析 我在GitHub存储库中找到了一个全面的方法 如何在终端中执行它 但是 我需要将脚本集成到Python笔记本中 这是脚本 instaloader no vide

随机推荐

  • C语言实现格林威治时间转北京时间+根据日期计算星期几

    C语言实现格林威治时间转北京时间 根据日期计算星期几 北京时间 GMT时间 8小时 main c Created on 2021年12月16日 Author hello include
  • 达梦实现高可用性的实现(failover功能/负载均衡/虚拟ip透明切换)

    达梦实现高可用性的实现 failover功能 负载均衡 虚拟ip透明切换 一 failover功能 基于守护进程和监视器两个内在工具实现 守护进程 监视器 数据守护和读写分离集群 共享存储集群 二 负载均衡 基于jdbc接口和客户端实现读写
  • 都是 HBase 上的 SQL 引擎,Kylin 和 Phoenix 有什么不同?

    大数据时代 数据的价值越来越被重视 企业从海量大数据中挖掘所需要的信息 用来驱动业务决策以获得更大的商业价值 与此同时 出现了越来越多的大数据技术帮助企业进行大数据分析 例如 Apache Hadoop Hive Spark Presto
  • log4j日志级别

    log4j的8个日志级别 ALL 最低等级的 用于打开所有日志记录 TRACE designates finer grained informational events than the DEBUG Since 1 2 12 很低的日志级
  • 传统手机ODM厂商加快布局TWS代工 未来5年可期

    年初 闻泰科技 600745 SH 在投资交流平台上也表示 公司之前通讯业务主要以手机 平板为主 现已新增笔电 IoT模块 CPE 工业网关 TWS耳机等新产品线 随着华为 小米 OPPO和vivo等手机大厂加码TWS耳机产品线 闻泰 华勤
  • 【VIM】同时在多行的某字符前批量添加内容

    DRAFT SAVED ARCHIVED DELETED DRAFT draft SAVED saved ARCHIVED archived DELETED deleted s v w 1 L 1 2 参考 https superuser
  • PPTP 相关命令

    1 ifconfig grep ppp 查看连接的用户网络情况 2 last grep still grep ppp 查看连接的用户名 3 UnixBench跑分测试 wget https byte unixbench googlecode
  • 第十四讲几何布局

    分两部分 第一部分就是通过双击第一人称的firstCharacter蓝图 选择枪和人 在细节浏览器中隐藏 第二部分是盖房子 几个长方体的尺寸设置 其中 锁定光源到摄像头这个选项没了
  • NodeJS中http模块开发

    NodeJS中http模块开发 web服务器初体验 request对象 method的处理 url的处理 GET请求处理 POST请求处理 headers的处理 response对象 返回响应结果 返回状态码 响应头文件 http模块发送网
  • JAVA集合之——Comparable和Comparator

    JAVA集合之 Comparable和Comparator 从TreeSet可以清晰的看到Comparable和Comparator的区别 这里再集中整理一下 Comparable 是一个对象本身就已经支持自比较所需要实现的接口 如 Str
  • 思科刀片服务器统一计算系统,全面解析:思科UCS统一计算刀片服务器

    IT168 评论 思科统一计算系统是下一代数据中心平台 在一个紧密结合的系统中整合了计算 网络 存储接入与虚拟化功能 旨在降低总体拥有成本 TCO 同时提高业务灵活性 该系统包含一个低延时无丢包万兆以太网统一网络阵列 以及多台企业级x86架
  • Java 常见笔面试题

    2013年年底的时候 我看到了网上流传的一个叫做 Java面试题大全 的东西 认真的阅读了以后发现里面的很多题目是重复且没有价值的题目 还有不少的参考答案也是错误的 于是我花了半个月时间对这个所谓的 Java面试大全 进行了全面的修订并重新
  • USB驱动移植及mdev热插拔的实现

    基于之前移植的的内核 把驱动分别进行移植 这篇主要进行USB驱动移植 并阐明与热插拔相关的mdev 在2 6 30内核中 USB驱动已经比较完善了 移植是只要简单对配置单进行修改即可 添加的内容如下 Device Drivers gt SC
  • STM32入门篇2之外部中断

    外部中断 STM32入门统一版完整链接 更新中 中断 在主程序运行过程中 出现了特定的中断触发条件 中断源 使得CPU暂停当前正在运行的程序 转而去处理中断程序 处理完成后又返回原来被暂停的位置继续运行 中断优先级 当有多个中断源同时申请中
  • mysql分库分表

    一 概述 分库分表的顺序应该是先垂直分 后水平分 单个库太大 如果是因为表多而数据多 应使用垂直切分 根据业务切分成不同的库 如果是因为单张表的数据量太大 需要用水平切分 即把表的数据按某种规则切分成多张表 甚至多个库上的多张表 二 分库
  • 错误调试-debugger

    在浏览器中调试 在编写更复杂的代码前 让我们先来聊聊调试吧 调试 是指在一个脚本中找出并修复错误的过程 所有的现代浏览器和大多数其他环境都支持调试工具 开发者工具中的一个令调试更加容易的特殊用户界面 它也可以让我们一步步地跟踪代码以查看当前
  • 【数据结构】二叉树

    一 树的基本概念 1 1 树的概念 树是一种非线性的数据结构 它是由n n gt 0 个有限结点组成一个具有层次关系的集合 把它叫做树是因为它看起来像一棵倒挂的树 也就是说它是根朝上 而叶朝下的 有一个特殊的结点 称为根结点 根节点没有前驱
  • vue prop属性使用方法

    Prop作用是在子组件中接收父组件的值 参考
  • 晚上下班之后可以做什么副业,业余时间需要利用起来

    对大多数普通人来说 他们晚上有很多空闲时间 但他们总是在手机上玩游戏 刷视频 白白度过一夜 事实上 近年来 很多朋友都想利用晚上的时间做一些副业 因为目前的工资已经不能满足自己的需求 再加上生活各方面的压力 他们像山一样压着自己 然而 晚上
  • 【机器学习】Numpy手写机器学习算法,3万行代码!

    目录 Numpy手写机器学习算法 一 前言 二 作者简介 三 项目总体介绍 四 手写 NumPy 全家福 五 项目示例 一 前言 NumPy 作为 Python 生态中最受欢迎的科学计算包 很多读者已经非常熟悉它了 它为 Python 提供