JavaScript玩转机器学习:​​​​​​​训练模型

2023-05-16

JavaScript玩转机器学习:训练模型

 

本指南假定您已经阅读了模型和图层指南。

在TensorFlow.js中,有两种方法来训练机器学习模型:

  1. 通过LayersModel.fit()或使用Layers API LayersModel.fitDataset()
  2. 将Core API与结合使用Optimizer.minimize()

首先,我们将研究Layers API,这是用于构建和训练模型的高级API。然后,我们将展示如何使用Core API训练相同的模型。

 

介绍

机器学习模型是一种具有可学习参数的函数,该函数将输入映射到所需输出。通过在数据上训练模型可以获得最佳参数。

培训涉及以下几个步骤:

  • 获取一批数据到模型。
  • 要求模型做出预测。
  • 将预测与“真实”值进行比较。
  • 确定每个参数的更改量,以便模型将来可以对该批次进行更好的预测。

训练有素的模型将提供从输入到所需输出的准确映射。

 

模型参数

让我们使用Layers API定义一个简单的2层模型:

const model = tf.sequential({
 layers: [
   tf.layers.dense({inputShape: [784], units: 32, activation: 'relu'}),
   tf.layers.dense({units: 10, activation: 'softmax'}),
 ]
});
 

在内部,模型具有可通过数据训练学习的参数(通常称为权重)。让我们打印与此模型关联的权重的名称及其形状:

model.weights.forEach(w => {
 console.log(w.name, w.shape);
});
 

我们得到以下输出:

> dense_Dense1/kernel [784, 32]
> dense_Dense1/bias [32]
> dense_Dense2/kernel [32, 10]
> dense_Dense2/bias [10]
 

总共有4个权重,每个密集层2个。这是可以预期的,因为致密层表示一个函数,该函数可通过等式将输入张量映射x到输出张量yy = Ax + b其中A(内核)和b(偏差)是致密层的参数。

注意:默认情况下,密集层包含偏差,但是可以通过{useBias: false}在创建密集层时在选项中指定来排除它。

model.summary() 如果您想获得模型概述并查看参数总数,这是一种有用的方法:

层(类型)输出形状参数#
density_Dense1(密集)[null,32]25120
density_Dense2(密集)[null,10]330
总参数:25450 可训练参数:
25450
非可训练参数:0

模型中的每个权重都由一个Variable对象后端。在TensorFlow.js中,a Variable是浮点数Tensor,其中一种assign()用于更新其值的附加方法。Layers API使用最佳实践自动初始化权重。为了演示起见,我们可以通过调用assign()基础变量来覆盖权重:

model.weights.forEach(w => {
  const newVals = tf.randomNormal(w.shape);
  // w.val is an instance of tf.Variable
  w.val.assign(newVals);
});
 

优化器,损失和指标

在进行任何培训之前,您需要确定三件事:

  1. 优化器。在给定当前模型预测的情况下,优化器的工作是决定更改模型中每个参数的数量。使用Layers API时,您可以提供现有优化程序的字符串标识符(例如'sgd''adam'),也可以提供Optimizer该类的实例。
  2. 损失函数。模型将尝试最小化的目标。其目标是给出模型预测的“错误程度”的一个数字。损耗是对每一批数据计算的,因此模型可以更新其权重。使用Layers API时,您可以提供现有损失函数的字符串标识符(例如'categoricalCrossentropy'),也可以提供任何采用预测值和真实值并返回损失的函数。请参阅我们的API文档中的可用损失列表。
  3. 指标列表。与损失类似,指标会计算一个数字,总结了模型的运行情况。通常在每个时期结束时对整个数据计算指标。至少,我们要监视我们的损失随着时间的推移而下降。但是,我们经常需要更人性化的指标,例如准确性。使用Layers API时,您可以提供现有指标的字符串标识符(例如'accuracy'),也可以提供采用预测值和真实值并返回分数的任何函数。请参阅我们的API文档中的可用指标列表。

确定后,LayersModel通过调用model.compile()提供的选项来编译a :

model.compile({
  optimizer: 'sgd',
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});
 

在编译过程中,模型将进行一些验证,以确保您选择的选项彼此兼容。

 

训练

有两种方法来训练a LayersModel

  • 使用model.fit()并提供数据作为一个大张量。
  • model.fitDataset()通过Dataset对象使用和提供数据。

 

model.fit()

如果您的数据集适合主内存,并且可以作为单个张量使用,则可以通过调用fit()方法来训练模型:

// Generate dummy data.
const data = tf.randomNormal([100, 784]);
const labels = tf.randomUniform([100, 10]);

function onBatchEnd(batch, logs) {
  console.log('Accuracy', logs.acc);
}

// Train for 5 epochs with batch size of 32.
model.fit(data, labels, {
   epochs: 5,
   batchSize: 32,
   callbacks: {onBatchEnd}
 }).then(info => {
   console.log('Final accuracy', info.history.acc);
 });
 

在幕后,model.fit()可以为我们做很多事情:

  • 将数据分为训练和验证集,并使用验证集衡量训练期间的进度。
  • 仅在拆分后才对数据进行随机排序。为了安全起见,您应该在将数据传递给之前对其进行预混洗fit()
  • 将大数据张量拆分为大小较小的张量 batchSize.
  • optimizer.minimize()在计算有关数据批次的模型损失时调用。
  • 它可以在每个时期或批次的开始和结束时通知您。在我们的情况下,我们会在每个批次的末尾使用该callbacks.onBatchEnd选项通知我们。其他选项包括:onTrainBeginonTrainEndonEpochBeginonEpochEndonBatchBegin
  • 它屈服于主线​​程,以确保可以及时处理JS事件循环中排队的任务。

欲了解更多信息,请参阅文件的fit()。请注意,如果您选择使用Core API,则必须自己实现此逻辑。

 

model.fitDataset()

如果您的数据不能完全容纳在内存中或正在流式传输中,则可以通过调用fitDataset()Dataset对象的来训练模型。这是相同的训练代码,但具有包装生成器函数的数据集:

function* data() {
 for (let i = 0; i < 100; i++) {
   // Generate one sample at a time.
   yield tf.randomNormal([784]);
 }
}

function* labels() {
 for (let i = 0; i < 100; i++) {
   // Generate one sample at a time.
   yield tf.randomUniform([10]);
 }
}

const xs = tf.data.generator(data);
const ys = tf.data.generator(labels);
// We zip the data and labels together, shuffle and batch 32 samples at a time.
const ds = tf.data.zip({xs, ys}).shuffle(100 /* bufferSize */).batch(32);

// Train the model for 5 epochs.
model.fitDataset(ds, {epochs: 5}).then(info => {
 console.log('Accuracy', info.history.acc);
});
 

有关数据集的更多信息,请参阅文件的model.fitDataset()

 

预测新数据

训练完模型后,您可以调用model.predict()对看不见的数据进行预测:

// Predict 3 random samples.
const prediction = model.predict(tf.randomNormal([3, 784]));
prediction.print();
 

注意:正如我们在“ 模型和层”指南中提到的那样,LayersModel期望输入的最外部尺寸为批处理大小。在上面的示例中,批次大小为3。

 

核心API

之前,我们提到了在TensorFlow.js中训练机器学习模型的两种方法。

一般的经验法则是尝试首先使用Layers API,因为它是根据广为采用的Keras API建模的。Layers API还提供了各种现成的解决方案,例如权重初始化,模型序列化,监控培训,可移植性和安全性检查。

您可能需要在任何时候使用Core API:

  • 您需要最大的灵活性或控制力。
  • 而且您不需要序列化,也可以实现自己的序列化逻辑。

有关此API的更多信息,请阅读“ 模型和层”指南中的“核心API”部分。

与上述使用Core API编写的模型相同,如下所示:

// The weights and biases for the two dense layers.
const w1 = tf.variable(tf.randomNormal([784, 32]));
const b1 = tf.variable(tf.randomNormal([32]));
const w2 = tf.variable(tf.randomNormal([32, 10]));
const b2 = tf.variable(tf.randomNormal([10]));

function model(x) {
  return x.matMul(w1).add(b1).relu().matMul(w2).add(b2);
}
 

除了Layers API外,Data API还可以与Core API无缝协作。让我们重用我们先前在model.fitDataset()部分中定义的数据集,它为我们进行混洗和批处理:

const xs = tf.data.generator(data);
const ys = tf.data.generator(labels);
// Zip the data and labels together, shuffle and batch 32 samples at a time.
const ds = tf.data.zip({xs, ys}).shuffle(100 /* bufferSize */).batch(32);
 

让我们训练模型:

const optimizer = tf.train.sgd(0.1 /* learningRate */);
// Train for 5 epochs.
for (let epoch = 0; epoch < 5; epoch++) {
  await ds.forEachAsync(({xs, ys}) => {
    optimizer.minimize(() => {
      const predYs = model(xs);
      const loss = tf.losses.softmaxCrossEntropy(ys, predYs);
      loss.data().then(l => console.log('Loss', l));
      return loss;
    });
  });
  console.log('Epoch', epoch);
}
 

上面的代码是使用Core API训练模型时的标准配方:

  • 循环历元数。
  • 在每个时期内,循环遍历您的批次数据。使用时Datasetdataset.forEachAsync() 是一种方便的循环批处理的方法。
  • 对于每个批次,call都会通过计算相对于我们先前定义的四个变量的梯度optimizer.minimize(f)来执行f并最小化其输出。
  • f计算损失。它使用模型的预测和真实值调用预定义的损失函数之一。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

JavaScript玩转机器学习:​​​​​​​训练模型 的相关文章

  • SQLServer--动态SQL拆分字符串,并将结果存进临时表

    存储过程的代码可参考如下 xff1a USE NewUserTest GO Object StoredProcedure dbo Splite Script Date 04 03 2018 10 23 52 SET ANSI NULLS O
  • 树莓派 ffmpeg 录制 USB 摄像头+话筒 视频+音频 mp4

    record video amp audio in the same file ffmpeg y f alsa ac 1 i hw 1 acodec pcm s16le f v4l2 framerate 25 video size 640x
  • C语言数组带下标赋值

    好记性不如烂笔头 c语言数组带下标赋值 xff0c 初始化的时候数组元素的值不受顺序影响 xff0c 在有些时候方便扩展一幕了然 span class token keyword int span array span class toke
  • Linux 2.4 Packet Filtering HOWTO

    Linux 2 4 Packet Filtering HOWTO 简体中文版 Rusty Russell mailing list netfilter 64 lists samba org Revision 1 3 Date 2002 06
  • lammps案例:Cu三点弯曲模拟

    大家好 xff0c 我是小马老师 本文分享一个Cu弯曲的案例 本案例参考三点弯曲实验 xff0c 三点弯曲试验是将试样放在弯曲装置上 xff0c 在试样上加载进行弯曲试验 xff0c 直到达到规定的弯曲程度或发生断裂 模拟原理是在z方向固定
  • PHP 中的数组函数

    文章目录 array change key casearray chunkarray columnarray combinearray count valuesarray diffarray diff keyarray diff assoc
  • Linux下运行bash脚本显示/usr/bin/env bash\r没有那个文件或目录

    Linux下运行bash脚本显示 usr bin env bash r 没有那个文件或目录 错误原因 这主要是因为bash后面多了 r这个字符的原因 在linux终端下 xff0c 输出 r会什么都不显示 xff0c 只是把光标移到行首 于
  • ROS软路由设置

    ROS软路由设置 不要怀疑软路由的性能 xff0c 也不用担心所谓的耗电多少 所谓的软路由耗电大 xff0c 只不过是商家搞的噱头而已 软路由完全不需要显示器 键盘鼠标 甚至 xff0c 可以在BIOS 里设置系统启动完即关闭硬盘 至于主板
  • QT +go 开发 GUI程序(一)

    如果你是一个墨守成规的coding xff0c 请移步其他内容 xff0c 这部分内容可能不适合你 如果你希望到外面看看 xff0c 感受新鲜的技术以及自由自在的氛围 xff0c 请继续 当然你也要付出一定的精力去学习如何科学上网 xff0
  • Android 7.0系统权限问题

    Android 7 0系统在运行应用的时候 对权限做了诸多限制 normal dangerous signature signatureOrSystem 取决于保护级别 xff0c 在确定是否授予权限时 xff0c 系统可能采取不同的操作
  • 跨平台,开源,免费的单片机IDE开发环境搭建-SDCC+eclipse

    关于如何使用sdcc编译器 xff0c 参见 单片机开发 xff0c 推荐开源跨平台的SDCC编译器 xff0c 其中较为详细叙述了使用方法和执行效率 1 xff0c IDE基本环境 SDCC在eclipse有一个插件 xff0c 版本1
  • LINUX中添加用户时为用户设置了全名(FULL NAME)

    more etc passwd 每行第4个冒号后面的字母就是full name
  • 系统开发系列 之Java中打印日志的几种方式

    在Java 中实现记录日志的方式有很多种 xff1a 最简单的方式 xff0c 就是system println out error 这样直接在控制台打印消息了 Java util logging 在JDK 1 4 版本之后 xff0c 提
  • 如何获取Android设备唯一识别码

    来自 xff1a http syawlaus com remindme E5 A6 82 E4 BD 95 E8 8E B7 E5 8F 96android E8 AE BE E5 A4 87 E5 94 AF E4 B8 80 E8 AF
  • 对时间操作

    TextBox1 Text 61 DateTime Today ToString 34 yyyy年M月d日 34 点第一个BUTTON时 TextBox1 Text 61 DateTime Today AddDays 1 ToString
  • docker访问samba服务器做持久化

    需求 xff1a 在window上 xff0c 我们经常使用 192 168 24x 1xx xxx 这样的路径访问网络共享文件服务器 xff0c 测试人员将访该文件服务器做持久化给到类似rancher这样的k8s管理平台上的docker操
  • 数据库迁移思路梳理

    1 分析系统 xff1a 进一步分析系统的功能和需求确认 业务需求分析 应用分析 评估工作量 2 制定方案 xff1a 确定迁移的重点和难点 xff0c 制定迁移方案 2 1确定数据库结构 xff1a 明确数据表 表中字段和各字段的数据类型
  • Maven 国内镜像仓库

    镜像仓库目标 当我们未定义任何远程仓库时 xff0c 使用 Maven 更新依赖时 xff0c 其会去默认远程仓库中拉取 xff0c 默认远程仓库 是国外地址 xff0c 所以在国内访问特别慢 xff0c 想提升访问速度 xff0c 需要将
  • 命名难,难于上青天

    Photo by Jorik Kleen on Unsplash Quora 问答社区的一个开发者投票统计 xff0c 程序员最大的难题是 xff1a 如何命名 xff08 例如 xff1a 给变量 xff0c 类 xff0c 函数等等 x
  • 企业发放的奖金根据利润提成

    案例 xff1a 利润I低于或等于10万元时 xff0c 奖金可提10 xff05 xff1b 利润高于10万元 xff0c 低于20万元 xff08 10000 lt I 200000 xff09 时 xff0c 其中10万元按10 xf

随机推荐