TensorFlow 相当于 PyTorch 的 Transforms.Normalize()

2024-05-15

我正在尝试推断最初在 PyTorch 中构建的 TFLite 模型。我一直在遵循PyTorch 实现 https://github.com/leoxiaobin/deep-high-resolution-net.pytorch/blob/1ee551d619641268c2ebd80134101db6e962f45f/demo/inference.py#L93并且必须沿着 RGB 通道预处理图像。我找到了最接近的 TensorFlow 等价物transforms.Normalize() to be tf.image.per_image_standardization() (文档 https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization)。虽然这是一场很不错的搭配,tf.image.per_image_standardization()这是通过跨渠道获取均值和标准差并将其应用于它们来实现的。这是他们的完整实现here https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/python/ops/image_ops_impl.py

def per_image_standardization(image):
  """Linearly scales `image` to have zero mean and unit norm.
  This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average
  of all values in image, and
  `adjusted_stddev = max(stddev, 1.0/sqrt(image.NumElements()))`.
  `stddev` is the standard deviation of all values in `image`. It is capped
  away from zero to protect against division by 0 when handling uniform images.
  Args:
    image: 3-D tensor of shape `[height, width, channels]`.
  Returns:
    The standardized image with same shape as `image`.
  Raises:
    ValueError: if the shape of 'image' is incompatible with this function.
  """
  image = ops.convert_to_tensor(image, name='image')
  _Check3DImage(image, require_static=False)
  num_pixels = math_ops.reduce_prod(array_ops.shape(image))

  image = math_ops.cast(image, dtype=dtypes.float32)
  image_mean = math_ops.reduce_mean(image)

  variance = (math_ops.reduce_mean(math_ops.square(image)) -
              math_ops.square(image_mean))
  variance = gen_nn_ops.relu(variance)
  stddev = math_ops.sqrt(variance)

  # Apply a minimum normalization that protects us against uniform images.
  min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, dtypes.float32))
  pixel_value_scale = math_ops.maximum(stddev, min_stddev)
  pixel_value_offset = image_mean

  image = math_ops.subtract(image, pixel_value_offset)
  image = math_ops.div(image, pixel_value_scale)
  return image

而 PyTorch 的transforms.Normalize()允许我们提及要应用于每个通道的平均值和标准差,如下所示。

# transformation
    pose_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

在 TensorFlow 2.x 中获得此功能的方法是什么?

Edit:我创建了一个快速的错误,似乎通过定义一个函数来解决这个问题:

def normalize_image(image, mean, std):
    for channel in range(3):
        image[:,:,channel] = (image[:,:,channel] - mean[channel])/std[channel]
    
    return image

我不确定这有多有效,但似乎可以完成工作。在输入到模型之前,我仍然必须将输出转换为张量。


您提到的解决方法似乎没问题。但使用for...loop计算标准化为each RGB通道为单幅图像当您处理数据管道中的大型数据集时可能会有点问题(generator or tf.data)。但无论如何都没关系。这是您的方法的演示,稍后我们将提供两种可能适合您的替代方案。

from PIL import Image 
from matplotlib.pyplot import imshow, subplot, title, hist

# load image (RGB)
img = Image.open('/content/9.jpg')

def normalize_image(image, mean, std):
    for channel in range(3):
        image[:,:,channel] = (image[:,:,channel] - mean[channel]) / std[channel]
    return image

OP_approach = normalize_image(np.array(img) / 255.0, 
                            mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])

现在,让我们观察一下变换属性。

plt.figure(figsize=(25,10))
subplot(121); imshow(OP_approach); title(f'Normalized Image \n min-px: \
    {OP_approach.min()} \n max-pix: {OP_approach.max()}')
subplot(122); hist(OP_approach.ravel(), bins=50, density=True); \ 
                                    title('Histogram - pixel distribution')

归一化后最小和最大像素的范围是(-2.1179039301310043, 2.6399999999999997) 分别。

Option 2

我们可以使用tf。 keras...标准化 https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing/Normalization预处理层做同样的事情。它需要两个重要的论点,它们是mean and, variance(的平方std).

from tensorflow.keras.experimental.preprocessing import Normalization

input_data = np.array(img)/255
layer = Normalization(mean=[0.485, 0.456, 0.406], 
                      variance=[np.square(0.299), 
                                np.square(0.224), 
                                np.square(0.225)])

plt.figure(figsize=(25,10))
subplot(121); imshow(layer(input_data).numpy()); title(f'Normalized Image \n min-px: \
   {layer(input_data).numpy().min()} \n max-pix: {layer(input_data).numpy().max()}')
subplot(122); hist(layer(input_data).numpy().ravel(), bins=50, density=True);\
   title('Histogram - pixel distribution')

归一化后最小和最大像素的范围是(-2.0357144, 2.64) 分别。

Option 3

这更像是减去平均值mean并除以平均值std.

norm_img = ((tf.cast(np.array(img), tf.float32) / 255.0) - 0.449) / 0.226

plt.figure(figsize=(25,10))
subplot(121); imshow(norm_img.numpy()); title(f'Normalized Image \n min-px: \
{norm_img.numpy().min()} \n max-pix: {norm_img.numpy().max()}')
subplot(122); hist(norm_img.numpy().ravel(), bins=50, density=True); \
title('Histogram - pixel distribution')

归一化后最小和最大像素的范围是(-1.9867257, 2.4380531) 分别。最后,如果我们比较pytorch方式,这些方法之间没有太大区别。

import torchvision.transforms as transforms

transform_norm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
])
norm_pt = transform_norm(img)

plt.figure(figsize=(25,10))
subplot(121); imshow(np.array(norm_pt).transpose(1, 2, 0));\
  title(f'Normalized Image \n min-px: \
  {np.array(norm_pt).min()} \n max-pix: {np.array(norm_pt).max()}')
subplot(122); hist(np.array(norm_pt).ravel(), bins=50, density=True); \
  title('Histogram - pixel distribution')

归一化后最小和最大像素的范围是(-2.117904, 2.64) 分别。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow 相当于 PyTorch 的 Transforms.Normalize() 的相关文章

随机推荐

  • 如何使用 AJAX/jQuery 显示打印内容?

    所以我试图理解整个 AJAX jQuery 的事情 现在 当我单独运行这个 PHP 脚本时 我必须等待并观察轮子旋转 直到循环完成然后加载 while row mysql fetch array res postcode to storm
  • 使用 boost::thread 特定的 ptr<>::get() 是否会很慢?有什么解决方法吗?

    我目前正在使用 Valgrind 的 Callgrind 分析一个存在性能问题的应用程序 在查看分析数据时 似乎有 25 的处理时间花费在boost detail get tss data在主要目的是物理模拟和可视化的应用程序中 get t
  • 应对失败的“未来”

    给出以下两种方法 def f Future Int Future 10 def g Future Int Future 5 我想把它们写成 scala gt import scala concurrent Future import sca
  • 如何在Fluentui(Office ui Fabric)中创建“危险”按钮?

    如何在Microsoft Fluentui库中创建 危险 红色 按钮 就像 bootstrap 等其他 ui 框架中的那样 有
  • Android版本App更新代码

    我读到如果我们想更新Google Play中的应用程序 版本代码应该高于以前的apk文件 我有一个版本代码为 20 且版本名称为 1 0 的应用程序 那么要更新app 应该如何增加版本号呢 应该增加10吗 或者仅仅 1 就足够了 即版本代码
  • 如何防止控件在 TableLayoutPanel 内调整大小时视觉上滞后?

    我有一个基于多个嵌套的中等复杂度的布局TableLayoutPanels 调整窗体大小会导致更深嵌套表内的控件在视觉上滞后于调整大小 首先 这使得它们看起来像是在调整表单大小时四处移动 但更糟糕的是 当它们滞后到足以离开分配的表格单元格时
  • 使用 C# 获取 ec2-instance 标签

    我不是开发人员 所以也许答案是有不同的解决方案 但我无法真正从 python 或其他东西翻译它 我尝试使用 AWS NET SDK 查找实例 然后获取实例的标签 我已经能够确定实例是否已启动并正在运行 我还了解了如何创建和删除标签 不在下面
  • 如何正确自定义 Django LoginView

    我试图弄清楚如何根据用户当天是否第一次登录来自定义 django LoginView 我当前已设置 LoginView 使其默认为 settings py 文件中的 LOGIN REDIRECT URL book author 这工作完美无
  • 为 Windows 98 编译 Qt

    我需要支持 Windows 98 Qt 文档声称这是可能的 但没有说明 Qt 4 6 的分布式二进制文件不能在 Win98 上运行 而且我采样的大多数 Qt 应用程序也不能在 Win98 上运行 对于几个确实在 98 上运行的应用程序 我询
  • 带路径压缩算法的加权 Quick-Union

    有一种 带路径压缩的加权快速联合 算法 代码 public class WeightedQU private int id private int iz public WeightedQU int N id new int N iz new
  • 动态创建和下载Doc文件

    因此 我尝试动态创建 doc 文件并让用户在单击按钮时下载该文件 这些是我找到的用于下载文件的标头 header Content Description File Transfer header Content Type applicati
  • paymentId 和 TRANSACTIONID 之间的区别

    我正在从 REST 转向经典 API 而且我对两者都是新手 作为一名开发人员 我想记录付款的唯一标识符 以便将网站中的销售与 Paypal 付款 ID 相关联 例如我想要退款时 REST API 曾经给我付款 ID https stacko
  • 如何将列中的天数添加到 DB2 中的当前日期?

    我正在编写此 SQL 来动态计算一定的天数 如下所示 但我不知道如何让它工作 因为我不断收到错误 select Current Date Dynamic numbr of days calculation here from TableNa
  • 使用一次递归调用实现递归

    给定一个函数如下 f n f n 1 f n 3 f n 4 f 0 1 f 1 2 f 2 3 f 3 4 我知道使用递归来实现它 并在一个函数内进行三个递归调用 但我想在函数内仅使用一次递归调用来完成此操作 怎样才能做到呢 要实现使用
  • DataGridView 使用 Structure 和 LINQ 来排序 txt 文件

    当我的程序出现问题时 我能够将所有数据拉入网格并进入正确的列 行 但是 我相信我的 LINQ 查询是错误的 它没有使第三列正确划分并插入正确的数据 我的结果 https gyazo com 0f307a10dff4c015a361708ec
  • 使用 Mock 对 Laravel 5 Mail 进行单元测试

    有没有办法在 Laravel 5 中测试 Mail 尝试了我在互联网上看到的唯一合法的模拟示例 但它似乎只适用于 Laravel 4 下面的当前代码 mock Mockery mock Swift Mailer this gt app ma
  • 反转js对象中的键值

    我不知道如何改变 first de second ab de third de to de first second third ab second 我想将唯一值与包含键的列表相关联 我尝试过的 但我认为我离它还很远 const data
  • 将具有值的产品属性添加到 Woocommerce 中的产品

    我正在使用此代码添加自定义属性 attributes array array name gt Size options gt array S L XL XXL position gt 1 visible gt 1 variation gt
  • 编码杂志[关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • TensorFlow 相当于 PyTorch 的 Transforms.Normalize()

    我正在尝试推断最初在 PyTorch 中构建的 TFLite 模型 我一直在遵循PyTorch 实现 https github com leoxiaobin deep high resolution net pytorch blob 1ee