神经网络拟合曲线及讨论

2023-05-16

神经网络拟合曲线及讨论

问题说明

神经网络能否拟合x^2 + y^2 = 100在第一象限的曲线?

设计思路

第一象限的曲线方程如下所示:
y = 100 − x 2 y = \sqrt{100-x^2} y=100x2
在[0, 10]中等距生成1000个点,划分训练集、开发集和测试集,构建神经网络训练。神经网络架构图如下,采用两层神经网络进行拟合,根据情况调节神经网络的深度以及隐藏层神经元个数。

模型图

代码实现

import random

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


def setup_seed(seed):
    # 固定所有的随机数种子,使结果能够复现
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def generate_data(num=1000):
    x = np.linspace(0, 10, num)
    y = [np.sqrt(100 - i ** 2) for i in x]
    return list(zip(x, y))


def distribute_dataset(data, rate):
    total_num = len(data)
    train_pos = int(total_num * rate[0])
    dev_pos = int(total_num * (rate[0] + rate[1]))
    train_data = data[: train_pos]
    dev_data = data[train_pos: dev_pos]
    test_data = data[dev_pos:]
    return train_data, dev_data, test_data
    

class MyDataset(Dataset):
    # 继承torch工具类的Dataset,进而构造DataLoader进行批量数据读取
    # 继承Dataset需要实现以下3个函数
    def __init__(self, dataset):
        super().__init__()
        data_x, data_y = zip(*dataset)
        self.data_x = data_x
        self.data_y = data_y

    def __getitem__(self, item):
        return self.data_x[item], self.data_y[item]

    def __len__(self):
        assert len(self.data_x) == len(self.data_y)
        return len(self.data_x)


class MyModel(nn.Module):
    # 继承nn.Module构建模型
    def __init__(self, hidden_size):
        super(MyModel, self).__init__()
        self.h1 = nn.Linear(1, hidden_size)
        self.h2 = nn.Linear(hidden_size, 1)

    def forward(self, inputs):
        out1 = self.h1(inputs)
        out2 = self.h2(F.relu(out1))
        return out2


if __name__ == '__main__':
    # 可调超参数
    data_num = 1000
    tdt_rate = [0.7, 0.2, 0.1]
    seed = 1
    hidden_size = 128
    batch_size = 4
    lr = 1e-3
    epoch = 50

    # 固定随机数
    setup_seed(seed)
    # 生成数据
    total_data = generate_data(data_num)
    random.shuffle(total_data)
    # 构建训练集、开发集、测试集
    train_data, dev_data, test_data = distribute_dataset(total_data, tdt_rate)
    train_ds = MyDataset(train_data)
    dev_ds = MyDataset(dev_data)
    test_ds = MyDataset(test_data)

    # 构建模型
    model = MyModel(hidden_size)
    loss = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # 训练集训练
    train_loader = DataLoader(train_ds, batch_size=batch_size)
    for e in range(epoch):
        total_l = []
        for x, y in train_loader:
            # 调整数据形式 x -> [batch_size, 1], y -> [batch_size]
            x = x.unsqueeze(-1).type(torch.float32)
            y = y.type(torch.float32)

            # 均方误差 (预测值 - 真实值)^2
            l = loss(model(x).squeeze(-1), y)

            # 邢将军法 |sign(预测值) * 预测值^2 - 真实值^2|
            # y_ = model(x).squeeze(-1)
            # sign = torch.sign(y_)
            # l = torch.abs(sign * y_ ** 2 - y ** 2).sum()

            # |预测值 - 真实值|
            # l = torch.abs(model(x).squeeze(-1) - y).sum()

            # |预测值^3 - 真实值^3|
            # l = torch.abs(model(x).squeeze(-1) ** 3 - y ** 3).sum()

            l.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_l.append(l.item())
        print(f"epoch {e+1}, avg_loss {sum(total_l) / len(total_l)}")

    # for e in range(epoch):
    #     # (x, y) -> (y, x) 增加靠近10的数据量,让模型再靠近10的地方拟合地更好
    #     total_l = []
    #     for x, y in train_loader:
    #         x, y = y, x
    #         # 调整数据形式 x -> [batch_size, 1], y -> [batch_size]
    #         x = x.unsqueeze(-1).type(torch.float32)
    #         y = y.type(torch.float32)
    #
    #         # 均方误差 (预测值 - 真实值)^2
    #         # l = loss(model(x).squeeze(-1), y)
    #
    #         # 邢将军法 |sign(预测值) * 预测值^2 - 真实值^2|
    #         y_ = model(x).squeeze(-1)
    #         sign = torch.sign(y_)
    #         l = torch.abs(sign * y_ ** 2 - y ** 2).sum()
    #
    #         # |预测值 - 真实值|
    #         # l = torch.abs(model(x).squeeze(-1) - y).sum()
    #
    #         # |预测值^3 - 真实值^3|
    #         # l = torch.abs(model(x).squeeze(-1) ** 3 - y ** 3).sum()
    #
    #         l.backward()
    #         optimizer.step()
    #         optimizer.zero_grad()
    #         total_l.append(l.item())
    #     print(f"reverse  epoch {e+1}, avg_loss {sum(total_l) / len(total_l)}")

    # 开发集调超参数
    dev_loader = DataLoader(dev_ds, batch_size=batch_size)
    with torch.no_grad():
        xs, ys, y_s = [], [], []
        for x, y in dev_loader:
            xs.extend(list(x.numpy()))
            ys.extend(list(y.numpy()))
            x = x.unsqueeze(-1).type(torch.float32)
            y = y.type(torch.float32)
            y_ = model(x).squeeze(-1)
            y_s.extend(list(y_.numpy()))
        # 排序
        total = [(a, b, c) for a, b, c in zip(xs, ys, y_s)]
        total.sort(key=lambda i: i[0])
        xs, ys, y_s = zip(*total)
        plt.plot(xs, ys, label='real')
        plt.plot(xs, y_s, label='pred')
        plt.legend()
        plt.title('dev')
        plt.show()

    # 测试集测试
    test_loader = DataLoader(test_ds, batch_size=batch_size)
    with torch.no_grad():
        xs, ys, y_s = [], [], []
        for x, y in test_loader:
            xs.extend(list(x.numpy()))
            ys.extend(list(y.numpy()))
            x = x.unsqueeze(-1).type(torch.float32)
            y = y.type(torch.float32)
            y_ = model(x).squeeze(-1)
            y_s.extend(list(y_.numpy()))
        # 排序
        total = [(a, b, c) for a, b, c in zip(xs, ys, y_s)]
        total.sort(key=lambda i: i[0])
        xs, ys, y_s = zip(*total)
        plt.plot(xs, ys, label='real')
        plt.plot(xs, y_s, label='pred')
        plt.legend()
        plt.title('test')
        plt.show()

神经网络的学习能力很强,获得不错的拟合效果。

均方误差dev

均方误差test

讨论

新的损失函数

上述代码中的loss选用的是均方误差——(预测值 - 真实值) ^2,邢将军提出了新的损失函数——|sign(预测值) * 预测值^2 - 真实值^2|,其中sign是符号函数,公式如下:
sign函数

损失函数描述预测值和真实值之间的误差,当预测值趋近于真实值时,损失函数应越来越小。预测值^2 - 真实值^2 满足损失函数的基本定义。

损失函数不能为负数,网络的优化目标是最小化损失函数,如果损失可为负数,网络将把损失推向负无穷,所以加上绝对值,成为**|预测值^2 - 真实值^2|**。

在此次拟合中,真实值的取值范围是[0, 10],永为正数。当预测值趋近于真实值的相反数时,损失依然在不断减小。如果初始梯度方向是向着真实值相反数的方向,最终就会导致出现关于x轴对称的拟合曲线出现。所以当预测值趋近于真实值的相反数时,要让损失变大。这里使用sign函数,当预测值为负值时,将产生更大的损失。最终的损失函数为**|sign(预测值) * 预测值^2 - 真实值^2|**。

代码实现如下:

y_ = model(x).squeeze(-1)
sign = torch.sign(y_)
l = torch.abs(sign * y_ ** 2 - y ** 2).sum()

效果图如下,除了尾部数据仍然有偏差,拟合效果大幅提升。

邢将军法dev

邢将军法test

基于邢将军的启发,我们可以设计更多损失函数,这些损失函数都表现不错。

  • | 预测值 - 真实值|
  • |sign(预测值) * 预测值^2 - 真实值^2|
  • | 预测值^3 - 真实值^3|

尾部偏差

sr同学认为预拟合的曲线在越靠近10的地方,|斜率|越来越大,最终极限为无穷大。在数据生成阶段,x是等距采样的,而y值的变化并非均匀。越靠近10,y值的变化越明显,特征越稀疏,不利于模型学习。

sr同学认为 x^2 + y^2 = 100 中,x与y是对称的,也就是说x与y可以互换。如果将生成的数据(x, y)对调成(y, x),则原来靠近0的地方数据更密集,现在靠近10的地方数据更密集。让模型再学习(y, x),尾部的偏差就能够拟合地更好。

代码很简单,将原训练代码的x与y互换即可。

for e in range(epoch):
    # (x, y) -> (y, x) 增加靠近10的数据量,让模型再靠近10的地方拟合地更好
    total_l = []
    for x, y in train_loader:
        # 对调
        x, y = y, x
        # 调整数据形式 x -> [batch_size, 1], y -> [batch_size]
        x = x.unsqueeze(-1).type(torch.float32)
        y = y.type(torch.float32)

        # 均方误差 (预测值 - 真实值)^2
        # l = loss(model(x).squeeze(-1), y)

        # 邢将军法 |sign(预测值) * 预测值^2 - 真实值^2|
        y_ = model(x).squeeze(-1)
        sign = torch.sign(y_)
        l = torch.abs(sign * y_ ** 2 - y ** 2).sum()

        # |预测值 - 真实值|
        # l = torch.abs(model(x).squeeze(-1) - y).sum()

        # |预测值^3 - 真实值^3|
        # l = torch.abs(model(x).squeeze(-1) ** 3 - y ** 3).sum()

        l.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_l.append(l.item())
        print(f"reverse  epoch {e+1}, avg_loss {sum(total_l) / len(total_l)}")

效果图如下,尾部数据也得到有效的拟合。

sr法dev

sr法test

总结

  1. 经验上,对于回归问题,我们倾向于使用均方误差作为损失函数;对于分类问题,我们倾向于使用交叉熵作为损失函数。然而,我们可以根据实际问题,设计出更好的损失函数,让模型收敛更快,效果更好。
  2. 神经网络的学习能力很强,但需要大量数据多次训练。如果模型在某些地方表现的不够好,增加这方面的数据,模型就能自动学习到特征,表现得更好。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

神经网络拟合曲线及讨论 的相关文章

  • IDEA 创建Servlet项目

    1 打开IDEA xff0c 点击Create New project创建一个一个新项目 2 点击Java Enterprise xff0c 然后选择Web Application xff0c 点击Next 3 设置项目名 xff0c 项目
  • 数据库接口类和接口实现类

    数据库接口类 xff08 BasicDAO java xff09 xff1a 实现对数据库的直接增删查改的interface接口 span class token keyword import span java span class to
  • Linux安装Anaconda

    Anaconda是一个开源的Python发行版本 xff0c 其包含了conda Python等180多个科学包及其依赖项 一 安装Anaconda 1 下载Anaconda安装包 xff08 我的位置是hadoop的家目录 xff0c 即
  • Windows 安装Maven3.6.1

    Win10 安装Maven3 6 1 xff0c 并为IntelliJ IDEA配置本地maven 一 安装Maven二 配置Maven本地仓库三 为IntelliJ IDEA配置本地maven 一 安装Maven 1 前提安装好jdk 2
  • 使用gorm创建casbin数据库报错

    1 报错 span class token operator span github span class token punctuation span com span class token operator span casbin s
  • Java 操作HBase

    Java 操作HBase 思路 1 建立连接 2 针对表的操作 xff08 创建表 删除表 判断表是否存在 使用 禁用表 列出表 xff09 3 针对数据的操作 xff08 添加 删除 修改 查看 xff09 4 关闭连接 HBase常用的
  • strtok和strtok_s函数使用说明

    看了很多高赞CSDN文章和百度百科 xff0c 越看越晕 xff0c 浪费好多时间 xff0c 特此记录 先介绍strtok xff0c 后边给个strtok s的例子 注意 xff1a 这两个函数必要连续调用多次才能实现分割和输出功能 x
  • chmod修改权限的用法

    一 chmod作用 xff1a 修改文件 目录的权限 二 语法 xff1a chmod 对谁操作 操作符 赋予的权限 文件名 三 操作对象 xff1a u 用户user xff0c 表现文件或目录的所有者 g 用户组group xff0c
  • 一篇文章入门Stm32CubeMX在freertos系统下进行uart串口通讯

    相信大部分人早期入门STM32系列单片机都是从各种例程入手的 xff0c STM32单片机繁多的寄存器已经不允许我们像学51系列单片机一样直接操作寄存器了 xff08 如果你记忆力好 xff0c 或者愿意花很多时间翻芯片手册查看对应寄存器的
  • Note: Python_Matplotlib绘制平滑曲线和散点图

    给出横坐标纵坐标点 xff0c 即可连线绘图 import matplotlib 调用绘图工具包 给出x y点坐标 x y 61 1 2 3 4 5 6 5 9 3 4 7 5 绘图 matplotlib pyplot plot x y 这
  • Word论文中设置正文中的引用参考文献 按住Ctrl键+单击鼠标右键 实现跳转到论文参考文献的对应位置

    Word论文中设置正文中的引用参考文献 按住Ctrl键 43 单击鼠标右键 实现跳转到论文参考文献的对应位置 首先要确保文中参考文献排版是插入的编号 xff0c 而不是自己手敲的 1 2 在正文要引用参考文献的位置 点击上方菜单栏的 插入
  • NVM 安装node.js后没有npm

    我们在使用NVM管理工具安装一个新的node后 xff0c 发现没有npm可以使用 参考文档 是因为在使用NVM安装node的时候不会默认安装npm xff0c 所以需要我们自己下载后放到nvm对应的node目录下面 npm下载地址 xff
  • idea项目设置鼠标右键点击文件夹通过IDEA打开

    每次打开idea项目是每次都要打开idea再手动选择项目 xff0c 直接设置成右键打开会很方便 效果图 xff1a 1 首先 win 43 R 输入regedit 打开注册表 2 打开注册表后找到如下路径 xff1a 计算机 HKEY L
  • Java利用Stream统计List中每个元素的个数

    1 传统HashMap 新建HashMap然后for循环List去统计每个元素出现的次数的方法实现 public static Map lt String Integer gt frequencyOfListElements List it
  • Git操作一直要求输入用户名和密码

    通过如下命令配置 xff1a git config global credential helper store git config global user email git config global user name 配置好后再去
  • Linux下的简单线程池

    问题描述 xff1a 在我们的日常生活中 xff0c 如果我们现在要浏览一个网页或者频繁的打开一个执行时间较短的任务 xff0c 如果每次调用都创建一个线程 xff0c 使用结束后就立即释放 xff0c 那么这样的开销对于操作系统来说有点太
  • 对于MYSQL中左对齐右对齐的实现

    在查询表的时候 xff0c 因为是表格的形式 会想要让其左对齐和右对齐的形式 能够看起来舒服一点 书上写的是ltrim rtrim方法 但是具体实现起来并不是很理想 左对齐很快 一开始表格的显示形式就是左对齐 或者用ltrim 右对齐的话
  • 记录罗技键盘从win切换mac的经历

    罗技蓝牙键盘ALT和WIN键 OPT和CMD键 如何对调 今天一直正常使用的罗技K380蓝牙键盘 不知道怎么抽风了 opt键和cmd键位置对调了 也就是windows环境下alt键和win键对调了 在使用复制粘贴快捷键的时候 特别不方便 而
  • IDEA中@author模板的设置

    在设置中查找Editor中的File and Code Templates 具体如下图所示 Created by IntelliJ IDEA 64 Author USER 64 create DATE TIME
  • 解决Win10搜索框没有反应

    刚发现电脑搜索突然不好使 xff0c 这个办法一下就解决了 在状态栏左下角的搜索框搜索OneNote没有任何反应 xff0c 对 xff0c 就是这个地方 最后在另一篇博客上找到了答案 xff0c 那篇博客也是在知乎找到的答案 xff0c

随机推荐

  • mac终端走代理

    mac终端走代理 mac即使打开了代理可以正常上网 xff0c 但终端默认不走代理 xff0c 需要手动配置终端代理 mac终端走代理的方法 span class token operator span 方法一 xff1a xff08 推荐
  • 从数据集CLEVR来看视觉推理的发展

    一 视觉推理的发展 视觉推理 Visual Reasoning 概念的兴起是在Li Fei Fei组提出的 CLEVR 数据集后 xff0c 被大家广泛认识并且越来越多的人开始研究 xff0c 大家提出的各种模型都是为了让机器或者是神经网络
  • Xmanager安装与使用攻略

    文章目录 前言一 工具二 步骤 前言 Xmanager 是一款可以在自己的办公电脑Windows机器下 xff0c 用于远程连接控制服务器Linux UNIX的管理工具 本经验介绍如何在windows上安装xmanager 一 工具 XMa
  • tensorflow的tensor张量如何转化为numpy数组?

    比方说 xff0c a是一个已经定义好的tensor张量 那么直接 xff1a a numpy 即可
  • vnc登录不上解决办法

    备忘 xff01 环境介绍 xff1a ubantu16 04 xff0c 安装了anaconda xff0c QT xff0c cmake xff0c 乱七八糟一堆东西 环境变量也改了很多 xff0c 不知为何会影响vnc桌面的启动 解决
  • js删除数组中的指定对象

    文章目录 实现效果封装工具函数完整demo 实现效果 封装工具函数 思路就是 xff0c 遍历取到每个对象和对应下标 xff0c 通过自定义的函数判断该对象是否删除 xff0c span class token comment 删除数组中指
  • 用OpenCV储存视频时遇到的问题

    用 MJPG 格式储存 34 avi 34 格式时报错 cv2 error OpenCV 3 4 1 io opencv modules videoio src container avi cpp 737 error 215 pos lt
  • debian 10执行提示service: command not found(找不到service命令)解决方法

    debian 10 用 root 执行提示 bash service command not found xff08 找不到 service 命令 xff09 解决方法 问题 想要执行 service xff0c 发现找不到命令 xff0c
  • IAR平台进行编译时常见错误:

    1 IAR编辑时出现如下错误 xff1a Near constant start address 43 size must be less than 错误原因是 xff1a 代码对应的Device 芯片选型错误 解决方法 xff1a 将Ge
  • PHPexcel报出错误‘break‘ not in the ‘loop‘ or ‘switch‘ context

    今天本地改代码改完做测试发现现在的文件中打开是 break 39 not in the 39 loop 39 or 39 switch 39 context 这样的 xff1b 当时一脸懵逼 xff0c 这是一个老项目最近也没动啊怎么回事
  • Linux Centos7 xfsdump文件系统的备份和恢复

    xfs提供了 xfsdump 和 xfsrestore工具 xff0c 协助备份xfs文件系统中的数据 xfsdump按 inode顺序备份一个xfs文件系统 CentOS7默认文件系统是xfs xff0c CentOS6默认文件系统是ex
  • TIM基本定时器——定时

    1 定时器功能 xff1a 定时 输出比较 输入捕获 互补输出 分类 xff1a 基本定时器 xff08 定时 xff09 通用定时器 xff08 定时 输出比较 输入捕获 xff09 高级定时器 xff08 定时 输出比较 输入捕获 互补
  • ubuntu 释放空间的7种简单方法

    从我们的理想中 xff0c 我们无意间暴露了自己的缺陷 让 罗斯唐 Linux系统空间不足 xff1f 您可以通过以下几种方式清理系统 xff0c 释放Ubuntu和其他基于Ubuntu的Linux发行版上的空间 随着时间的流逝 xff0c
  • OpenCV-Python画虚线

    问题背景 使用OpenCV Python处理图像时 xff0c 有函数cv line 函数可以快速画出直线 xff0c 本以为使用该函数修改参数可以快速画出虚线等特殊直线 xff0c 查阅OpenCV文档可以看到 xff0c cv line
  • IP地址分类及其范围

    IP地址分类 xff08 A类 B类 C类 D类 E类 xff09 IP地址由四段组成 xff0c 每个字段是一个字节 xff0c 8位 xff0c 最大值是255 xff0c IP地址由两部分组成 xff0c 即网络地址和主机地址 网络地
  • 在windows下conda换源时,使用conda install 安装库时,出现:Anaconda An HTTP error...的解决方法

    在windows下conda换源时 xff0c 使用conda install 安装库时 xff0c 出现 xff1a Anaconda An HTTP error occurred when trying to retrieve this
  • 华为1288H V5服务器做RAID_超详细图文教程

    目录 一 服务器介绍 二 服务器RAID 三 开始配置RAID 四 多RAID 如有需要 五 检查RAID信息 六 IBMC远程管理 结束 一 服务器介绍 型号 xff1a 华为1288H V5 服务器1U xff1b 图解 xff1a 这
  • VsCode下载,使用国内镜像秒下载

    还在因为vscode官方下载慢而头疼嘛 xff0c 按这个步骤来直接起飞兄弟萌 首先进入vscode官方网站然后选择对应版本下载然后进入浏览器下载页面复制下载链接粘贴到地址栏 将地址中的 stable前换成vscode cdn azure
  • Ubuntu常见问题 | 解压中文乱码问题

    文章目录 环境复现BUG原因解决 环境 Ubuntu16 04LTS 复现 右击 zip压缩文件 左击提取到此处 BUG 解压出来的文件名 文件夹名只要有中文 xff0c 中文就会变成乱码 原因 Windows下的中文编码规则与Linux下
  • 神经网络拟合曲线及讨论

    神经网络拟合曲线及讨论 问题说明 神经网络能否拟合x 2 43 y 2 61 100在第一象限的曲线 xff1f 设计思路 第一象限的曲线方程如下所示 xff1a y 61 100