如何在python中实现小批量梯度下降?

2024-03-14

我刚刚开始学习深度学习。当谈到梯度下降时,我发现自己陷入了困境。我知道如何实现批量梯度下降。我知道它是如何工作的以及小批量和随机梯度下降在理论上是如何工作的。但实在无法理解如何用代码实现。

import numpy as np
X = np.array([ [0,0,1],[0,1,1],[1,0,1],[1,1,1] ])
y = np.array([[0,1,1,0]]).T
alpha,hidden_dim = (0.5,4)
synapse_0 = 2*np.random.random((3,hidden_dim)) - 1
synapse_1 = 2*np.random.random((hidden_dim,1)) - 1
for j in xrange(60000):
    layer_1 = 1/(1+np.exp(-(np.dot(X,synapse_0))))
    layer_2 = 1/(1+np.exp(-(np.dot(layer_1,synapse_1))))
    layer_2_delta = (layer_2 - y)*(layer_2*(1-layer_2))
    layer_1_delta = layer_2_delta.dot(synapse_1.T) * (layer_1 * (1-layer_1))
    synapse_1 -= (alpha * layer_1.T.dot(layer_2_delta))
    synapse_0 -= (alpha * X.T.dot(layer_1_delta))

这是 ANDREW TRASK 博客中的示例代码。它很小而且很容易理解。该代码实现了批量梯度下降,但我想在此示例中实现小批量和随机梯度下降。我怎么能这样做呢?为了分别实现小批量和随机梯度下降,我必须在这段代码中添加/修改什么?你的帮助会对我有很大帮助。提前致谢。(我知道这个示例代码有几个例子,而我需要大数据集来分割成小批量。但我想知道如何实现它)


该函数返回给定输入和目标的小批量:

def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert inputs.shape[0] == targets.shape[0]
    if shuffle:
        indices = np.arange(inputs.shape[0])
        np.random.shuffle(indices)
    for start_idx in range(0, inputs.shape[0] - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield inputs[excerpt], targets[excerpt]

这告诉您如何使用它进行训练:

for n in xrange(n_epochs):
    for batch in iterate_minibatches(X, Y, batch_size, shuffle=True):
        x_batch, y_batch = batch
        l_train, acc_train = f_train(x_batch, y_batch)

    l_val, acc_val = f_val(Xt, Yt)
    logging.info('epoch ' + str(n) + ' ,train_loss ' + str(l_train) + ' ,acc ' + str(acc_train) + ' ,val_loss ' + str(l_val) + ' ,acc ' + str(acc_val))

显然,您需要根据您正在使用的优化库(例如 Lasagne、Keras)自行定义 f_train、f_val 和其他函数。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在python中实现小批量梯度下降? 的相关文章

随机推荐

  • 如何获取或生成 Google Cloud Run 服务的部署 URL

    如何在 CI 环境中以编程方式获取已部署服务的 URL 成功部署后确实会记录 URL 但如果我想以编程方式提取并使用 URL 作为部署后需求的一部分 例如 该怎么办 发布验收测试的 URL 只需使用该标志 format value stat
  • 如何监控Linux上进程的线程数?

    我想监视 Linux 上特定进程使用的线程数 有没有一种简单的方法可以在不影响流程性能的情况下获取此信息 try ps huH p
  • 如何在cx_Oracle和python 2.7中处理unicode数据?

    我在用 Python 2 7 cx Oracle 6 0 2 我在我的代码中做了类似的事情 import cx Oracle connection string s s s 192 168 8 168 1521 xe connection
  • 适用于 Azure Service Fabric 无状态 Web API 应用程序的 Swagger

    我正在开发 Web API 服务并作为微服务托管在 Azure Service Fabric 上 我需要为 API 定义实现 Swagger 并且我可以看到 SwaggerConfig Register 方法在应用程序启动时未调用 所以我无
  • mysql_close 和 pg_close 是否是必需的? [复制]

    这个问题在这里已经有答案了 可能的重复 使用 mysql close https stackoverflow com questions 2065282 using mysql close 是否需要 mysql close 和 pg clo
  • Java HashMap Get 基准测试(JMH 与循环)

    我的最终目标是使用标准 Java 集合作为基线 为多个 Java 原始集合库创建一套全面的基准测试 过去我曾使用循环方法来编写此类微基准 我将要进行基准测试的函数放入循环中并迭代 100 万次以上 以便 jit 有机会预热 我计算循环的总时
  • 如何在Python中使用AutoReg预测时间序列

    我正在尝试仅使用自动回归算法来构建老式模型 我发现它有一个实现statsmodel包裹 我已阅读文档 据我了解 它应该像 ARIMA 一样工作 所以 这是我的代码 import statsmodels api as sm model sm
  • 使用 AND 和 OR 的 C# 谓词生成器

    我有以下课程 public class testClass public string name get set public int id get set public int age get set 和以下代码 var list new
  • 如何在 MySQL 中返回数据透视表输出?

    如果我有一个看起来像这样的 MySQL 表 company name action pagecount Company A PRINT 3 Company A PRINT 2 Company A PRINT 3 Company B EMAI
  • AttributeError:模块“jaxlib.xla_extension”没有属性“PmapFunction”

    有人可以帮我修复在 check not jax transformed f 中的 usr local lib python3 7 dist packages haiku src transform py in check not jax t
  • Ruby Mechanize:点击链接

    在 Mechanize on Ruby 中 我必须为我访问的每个新页面分配一个新变量 例如 page2 page1 link with text gt Continue click page3 page2 link with text gt
  • Cucumber 在一段时间后逐步停止执行

    我的一个测试会等到事件发生Then步 如果测试工作正常 则没有问题 但如果测试失败 即没有触发任何事件 那么它就会挂起 我怎样才能设置超时Cucumber I know JUnit有一个超时参数 您可以在 Test annotation h
  • 使用 Spark SQL 跳过/获取

    如何使用 Spark SQL 实现跳过 获取查询 典型的服务器端网格分页 我在网上搜索过 只能找到非常基本的示例 例如 https databricks training s3 amazonaws com data exploration
  • 使用键盘快捷键聚焦于文本字段

    我有一个 macOS Monterrey 应用程序 其中包含TextField在工具栏上 我用它来搜索我的应用程序上的文本 现在 我正在尝试添加键盘快捷键以专注于TextField 我尝试了下面的代码 添加带有快捷方式的按钮作为测试这是否可
  • 在sqlite不同数据库中触发

    我有 2 个不同的数据库 A 和 B 我需要创建一个触发器 当我在数据库 A 的表 T1 中插入任何条目时 数据库 B 的表 T2 的条目将得到已删除 请给我推荐一个方法 这不可能 在SQLite中 触发器内部的DML只能修改同一数据库的表
  • 将字符串提取函数包装在 ifelse 语句中

    下面的问题是一个延伸这个问题 https stackoverflow com questions 74135095 adding a column to the data that looks for a list of words and
  • 在现实世界应用中使用语义网络技术的示例[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 您正在开发使用 RDF OWL SPARQL 技术的 可能是商业的 产品吗 如果是这样 您能描述一下您的产品吗 O Reilly 的
  • 写入/编辑 CSV 文件(不要重写整个文件!)

    我需要替换直接在 CSV 文件上操作的客户端的某些功能 该文件用作系统的配置文件 搜索到的大多数案例都是关于从 CSV 读取到其他格式的 其他将整个 CSV 放入内存 附加专用行和更改 然后将它们写回新文件 或覆盖现有文件 我想更聪明地完成
  • Jetpack Compose 应用程序范围内的条件 TopAppBar 最佳实践

    我有一个 Android Jetpack Compose 应用程序 它使用BottomNavigation and TopAppBar可组合项 从通过打开的选项卡BottomNavigation用户可以更深入地导航到导航图 问题 The T
  • 如何在python中实现小批量梯度下降?

    我刚刚开始学习深度学习 当谈到梯度下降时 我发现自己陷入了困境 我知道如何实现批量梯度下降 我知道它是如何工作的以及小批量和随机梯度下降在理论上是如何工作的 但实在无法理解如何用代码实现 import numpy as np X np ar