Google Colab:为什么 CPU 比 TPU 快?

2024-05-01

我正在使用 Google colabTPU训练一个简单的Keras模型。删除分布式strategy并在CPUTPU。这怎么可能?

import timeit
import os
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam

# Load Iris dataset
x = load_iris().data
y = load_iris().target

# Split data to train and validation set
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.30, shuffle=False)

# Convert train data type to use TPU 
x_train = x_train.astype('float32')
x_val = x_val.astype('float32')

# Specify a distributed strategy to use TPU
resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)

# Use the strategy to create and compile a Keras model
with strategy.scope():
  model = Sequential()
  model.add(Dense(32, input_shape=(4,), activation=tf.nn.relu, name="relu"))
  model.add(Dense(3, activation=tf.nn.softmax, name="softmax"))
  model.compile(optimizer=Adam(learning_rate=0.1), loss='logcosh')

start = timeit.default_timer()

# Fit the Keras model on the dataset
model.fit(x_train, y_train, batch_size=20, epochs=20, validation_data=[x_val, y_val], verbose=0, steps_per_epoch=2)

print('\nTime: ', timeit.default_timer() - start)

谢谢你的问题。

我认为这里发生的事情是一个开销问题——因为 TPU 运行在一个单独的虚拟机上(可通过grpc://$COLAB_TPU_ADDR),每次调用在 TPU 上运行模型都会产生一定量的开销,因为客户端(本例中为 Colab 笔记本)将图形发送到 TPU,然后编译并运行。与运行所需的时间相比,此开销很小。 ResNet50 适用于一个时期,但与运行示例中的简单模型相比要大一些。

为了在 TPU 上获得最佳效果,我们建议使用tf.data.数据集 https://www.tensorflow.org/api_docs/python/tf/data/Dataset。我更新了您的 TensorFlow 2.2 示例:

%tensorflow_version 2.x
import timeit
import os
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam

# Load Iris dataset
x = load_iris().data
y = load_iris().target

# Split data to train and validation set
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.30, shuffle=False)

# Convert train data type to use TPU 
x_train = x_train.astype('float32')
x_val = x_val.astype('float32')

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(20)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(20)

# Use the strategy to create and compile a Keras model
with strategy.scope():
  model = Sequential()
  model.add(Dense(32, input_shape=(4,), activation=tf.nn.relu, name="relu"))
  model.add(Dense(3, activation=tf.nn.softmax, name="softmax"))
  model.compile(optimizer=Adam(learning_rate=0.1), loss='logcosh')

start = timeit.default_timer()

# Fit the Keras model on the dataset
model.fit(train_dataset, epochs=20, validation_data=val_dataset)

print('\nTime: ', timeit.default_timer() - start)

运行大约需要 30 秒,而在 CPU 上运行大约需要 1.3 秒。通过重复数据集并运行一个长周期而不是几个小周期,我们可以大大减少这里的开销。我用这个替换了数据集设置:

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).repeat(20).batch(20)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(20)

并更换了fit用这个调用:

model.fit(train_dataset, validation_data=val_dataset)

这使我的运行时间减少到大约 6 秒。这仍然比 CPU 慢,但是对于这样一个可以轻松在本地运行的小型模型来说,这并不奇怪。一般来说,您会发现在较大模型中使用 TPU 会带来更多好处。我建议仔细看一下TensorFlow 的官方 TPU 指南 https://www.tensorflow.org/guide/tpu,它为 MNIST 数据集提供了一个更大的图像分类模型。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Google Colab:为什么 CPU 比 TPU 快? 的相关文章

随机推荐

  • 高级自定义字段 – 具有多个输入的自定义字段类型

    我正在尝试为 ACF 创建一个新的字段类型 其中包含多个输入或存储值数组 原因是我希望为一组输入字段提供一些交互性和自定义布局 我按照这个教程http www advancedcustomfields com resources tutor
  • 在 jQuery 中从 asp.net runat 服务器获取 ID

    我正在尝试使用 ASP NET 在 jQuery 中制作一些东西 但身份证来自runat server 与 HTML 中使用的 id 不同 我曾经用它来从这种情况中获取ID val 但在这种情况下 它不起作用 我不知道为什么 JavaScr
  • Neo4j 2.0 唯一约束错误“节点已存在”,当它不存在时

    我在 Neo4j 唯一约束方面遇到了一些麻烦 其中 CREATE cypher 语句由于节点已经存在而无法执行 问题是 它不 存在 此外 昨天使用这些确切数据的精确流程也有效 我的neo4j版本是ubuntu 12 04 3上的commun
  • Javascript 'this' 覆盖 Z 组合器和所有其他递归函数

    背景 我有一个由a实现的递归函数Z 组合器如图所示here https stackoverflow com questions 17645356 anonymous recursion any way to replace javascri
  • WScript.Shell.Exec - 从 stdout 读取输出

    我的 VBScript 不显示我执行的任何命令的结果 我知道命令被执行 但我想捕获结果 我已经测试了多种方法来执行此操作 例如以下方法 Const WshFinished 1 Const WshFailed 2 strCommand pin
  • 测试抽象模型 - django 2.2.4 / sqlite3 2.6.0

    我正在尝试使用 django 2 2 4 sqlite3 2 6 0 python 3 6 8 测试一些简单的抽象混合 目前 我在使用架构编辑器从测试数据库中删除模型时遇到问题 我有以下测试用例 from django test impor
  • Web 服务版本控制策略的优缺点

    更新20100224 我真的不需要某些供应商网站上的一些蹩脚定义 我正在寻找的是实际实施以及实际实施这些东西的人们在日常 IT 业务周期中面临的挑战 更多内容如下 尚未制定 采用任何退休策略 显然需要制定一项策略 我对您如何制定该战略并将其
  • 矩形相当于文本的文本锚点表示属性吗?

    是否有一个与文本的文本锚点表示属性等效的矩形 我希望能够从左侧 右侧或根据情况定位矩形 我知道这可以通过一些简单的计算来完成 但我只是想知道是否已经存在内置的东西 文本锚点演示属性上的链接 https developer mozilla o
  • Shutil.rmtree() 引发异常 WindowsError:访问被拒绝:

    尝试使用 python 脚本自动删除文件 我得到 Traceback most recent call last Python script 5 line 8 in
  • 减少 CSS 网格中的行间距

    我想知道如何减少行间距 我尝试过将边距和填充设置为 0 但似乎没有什么效果 左侧为桌面视图 右侧为移动视图 content margin 0 padding 0 width 100 display grid grid gap 5px gri
  • 优化spark sql中分区数据写入S3

    我在每个 Spark 作业运行中从 HDFS 读取大约 700 GB 的数据 我的工作读取这些数据 过滤大约 60 的数据 将其分区如下 val toBePublishedSignals hiveCtx sql some query toB
  • 使用 PLINQ 扩展时是否会传输线程标识?

    我正在使用 AsParallel ForAll 在 ASP NET 请求上下文中并行枚举集合 枚举方法依赖于System Threading Thread CurrentPrincipal 我是否可以依赖用于将 System Threadi
  • 使用 crypto.getRandomValues() 生成 0 到 1 的随机数

    看起来 Math random 会生成 0 1 范围内的 64 位浮点数 而新的 crypto getRandomValues API 仅返回整数 使用此 API 生成 0 1 中的数字的理想方法是什么 这似乎有效 但似乎不太理想 ints
  • 如何避免获取 .repo/manifest.xml?

    如何避免获取 repo manifest xml 故意修改的 我不想在回购同步期间对其进行修改 我已经做了一个repo init 这一步就完成了 我对manifest xml做了一个小修改 删除了一些同步不需要的项目 当我们进行存储库同步时
  • 模数在 Javascript 中不起作用

    我试图理解为什么模运算不能按预期工作 我需要验证 IBAN 该算法包括进行取模 根据维基百科 在此输入链接描述 https en wikipedia org wiki International Bank Account Number Va
  • Hibernate加载惰性代理,但我只需要PK

    我有这些实体 Entity public class Room ManyToOne optional true fetch FetchType LAZY private Player player1 Entity public class
  • 如何使用openJDK11运行Eclipse?

    怎样必须eclipse ini看起来像是让 Eclipse Photon 2018 09 或 2018 12 在 openJDK11 上运行 我已经安装了 openJDK 11 0 1 和 Eclipse 2018 09 我有一个包含 XM
  • Java Swing并发显示JTextArea

    我需要执行 显示从 Arraylist 到 JTextArea 的一系列事件 但是 每个事件的执行时间不同 以下是我的目标的一个简单示例 public void start ActionEvent e SwingUtilities invo
  • 读/写结构到文件 - c

    我正在用 C 语言创建一个学生数据库 我需要做的最后一件事是能够读取我创建的数据库并将其写入文件 所以我已经有了一个充满指向学生结构的指针的数组 我需要将其写入文件 一旦我写完它 我也需要能够将它读回到我的数组中 我真的不知道该怎么做 这是
  • Google Colab:为什么 CPU 比 TPU 快?

    我正在使用 Google colabTPU训练一个简单的Keras模型 删除分布式strategy并在CPU比TPU 这怎么可能 import timeit import os import tensorflow as tf from sk