使用torch以及tensorflow训练一个最简单网络的基本步骤

2023-10-31

torch: 

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

x = torch.Tensor.unsqueeze(torch.Tensor.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2 * torch.rand()


class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__() #继承init
        #定义每层的形式
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)

    def forward(self, x):
        #正向传播输入值,神经网络分析出输出值
        x = F.relu(self.hidden(x)) #激励函数(隐藏层的线性值)
        x = self.predict(x) #输出值
        return x

net = Net(n_feature=1, n_hidden=10, n_output=1)

print(net)

optimizer = torch.optim.SGD(net.parameters(), lr=0.2) #传入net的所有参数,学习率
loss_func = torch.nn.MSELoss() #预测值和真实值的误差计算公式(均方差)

for t in range(100):
    prediction = net(x)

    loss = loss_func(prediction, y)

 

 

tensorflow:

#--coding:utf-8--

import tensorflow as tf
from numpy.random import RandomState

batch_size = 8

#定义神经网络参数
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))

x = tf.placeholder(tf.float32, shape=(None, 2), name='x-input')
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')

#前向传播过程
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)

#定义损失函数和反向传播算法
y = tf.sigmoid(y)
cross_entropy = -tf.reduce_mean(
    y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)) + (1 - y) * tf.log(tf.clip_by_value(1 - y, 1e-10, 1.0))
)

train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    print(sess.run(w1))
    print(sess.run(w2))

    #开始训练
    STEPS = 5000
    for i in range(STEPS):
        start = (i * batch_size) % dataset_size
        end = min(start + batch_size, dataset_size)

        sess.run(train_step,
                 feed_dict={x: X[start:end], y_: Y[start:end]})

        if i % 1000 == 0:
            total_cross_entropy = sess.run(
                cross_entropy, feed_dict={x: X, y_: Y})
            print("After %d training steps, cross entropy on all data is %g" % (i, total_cross_entropy))

    print(sess.run(w1))
    print(sess.run(w2))

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用torch以及tensorflow训练一个最简单网络的基本步骤 的相关文章

随机推荐

  • 三元运算符 使用

    三元运算符 三元表达式判断闰年 var b 2012 var year b 4 0 b 100 0 闰年 平年 console log year 判断奇数偶数 var a prompt 输入你要判断的数 var a 3 var res a
  • 使用docker部署springboot项目并连接上mysql数据库

    使用docker部署springboot项目并连接上mysql数据库 预览 http 8 142 6 23 screen 项目开源地址 前端vue https gitee com gaohan888 echarts learning tre
  • 旧版vue-cli脚手架Webpack3项目如何升级Webpack4

    vue cli脚手架出到了4 3 1版本 目前主推通过create命令来新建项目 与过去的vue cli2的init命令不同的是 create命令脚手架建完的项目webpack为4 而init采用的模板中引用的webpack版本还是3 单独
  • 程序人生-Hello’s P2P

    第一章 概述 1 1 Hello 简介 1 1 1 P2P Program to Process 从程序到进程 P2P指Hello c从源程序到进程的过程 Hello c经过预处理器的编译预处理 得到预编译文件Hello i Hello i
  • Java多对象的内存情况分析

    这种情况指的是在一个类中创建了多个对象 最先创建的对象直接指向类 后面创建的对象则指向第一个创建的对象 那么针对这种情况就会出现如下情况 1 照旧生成栈内存和堆内存 但是堆内存只会生成一个包含类中所有属性和方法的内存地址 2 因为后面创建的
  • 前端面试题集锦(6)

    目录 1 常见的兼容问题有哪些 1 1 获取标签节点 1 2 获取卷去的高度 1 3 获取样式 1 4 事件侦听器 1 5 事件解绑 1 6 事件对象的获取 1 7 阻止默认行为 1 8 阻止事件冒泡 1 9 获取精准的目标元素 1 10
  • Eclipse的Team菜单中没有SVN选项的解决方法

    Eclipse开发项目时想使用SVN来管理 但是发现Team gt Share Project菜单中没有SVN选项 只有一个GIT选项 如下图 解决方法 1 菜单栏Help gt Eclipse Marketplace 2 打开如下对话框
  • SQL 常用&高级 教程

    用SELECT INTO 或INSERT INTO复制表结构 数据 MySQL 数据库不支持 SELECT INTO 语句 但支持 INSERT INTO SELECT MySQL可以使用以下语句来 1 拷贝表结构及数据 CREATE TA
  • 【学习笔记】R数据科学(R for Data Science)—第3章 使用dplyr进行数据转换

    dplyr包是tidyverse中的一个核心R包 dplyr的5个核心函数 按值筛选观测 filter 对行进行重新排序 arrange 按名称选取变量 select 使用现有变量的函数创建新变量 mutate 将多个值总结为一个摘要统计量
  • 设置文本阴影和溢出效果

    一 文本阴影效果 方法一 显示字体时 根据要求 为文字阴影添加颜色以增强网页的吸引力 这时就需要用到CSS3样式中的text shadow属性 text shadow 阴影水平偏移值 可正负 阴影垂直偏移值 可正负 阴影模糊值 阴影颜色 后
  • 为什么 i&1 可以判断奇偶

    记录一下看到过几次但是总会遗忘的知识点 是位运算 在计算机里是只认识二进制的 我们人类用的一般是十进制 而二进制有个特点就是每一位上要么是0要么是1 还有一个特点是如果哪个位置是1 那一位的值就是2n 这个符号表示次方 n就是这个1所处的位
  • 吴恩达机器学习python代码练习三(多类别分类)

    import numpy as np import pandas as pd import matplotlib pyplot as plt import scipy io as sio from scipy optimize import
  • 使用ddt实现unittest的参数化测试

    0 前言 本文介绍如何使用ddt库来完成unitest的参数化设置 ddt的github地址 ddt的官方文档 1 为什么需要参数化 我们在写单测中 需要考虑到各种场景 通过输入各种场景的值执行目的的方法 来判断输出是否是我们所期待的值 如
  • Android 中WebView的使用详解

    博主前些天发现了一个巨牛的人工智能学习网站 通俗易懂 风趣幽默 忍不住也分享一下给大家 点击跳转到网站 前言 通过WebView控件可以实现加载网页的效果 加载URL 网络或者本地assets文件夹下的html文件 加载html代码 Nat
  • 测开学习技能清单

    一 代码语言 打好语言基础 python java 底层语言主要掌握java 更高级的语法可以选择python去学习 领域预演 DSL shell SQL Docker shell 是指一种应用程序 这个应用程序提供了一个界面 用户通过这个
  • 在Lumia 950 XL上运行Windows 10 ARM64,是种什么体验?

    本文于2019年02月01日首发于IT之家 地址 点击这里 2019年1月 据IT之家报道 微软Lumia 950 XL刷Windows 10 ARM64项目取得了巨大进展 显卡驱动已经成功运行 随后 适用于Lumia 950 XL的WiF
  • MYSQL 数据存在 (多条件同时满足)则更新,不存在则添加

    需求 提交数据时 数据不存在则添加 数据存在则更新 此处判断数据是否存在需要满足2个条件 cid date 如果两者同时满足的情况下 才更新数据 否则添加数据 表结构 使用的方法是 on duplicate key update INSER
  • 蓝桥杯 全球变暖 bfs学习

    全球变暖 你有一张某海域NxN像素的照片 表示海洋 表示陆地 如下所示 其中 上下左右 四个方向上连在一起的一片陆地组成一座岛屿 例如上图就有2座岛屿 由于全球变暖导致了海面上升 科学家预测未来几十年 岛屿边缘一个像素的范围会被海水淹没 具
  • 芜湖今年小升初计算机考试,刚刚!芜湖幼升小、小升初网上报名时间定了!附报名流程和具体安排...

    就在今天 芜湖发布了 关于做好2021年芜湖市义务教育网上报名审核工作的通知 其中明确幼升小和小升初的网上报名时间 家长们赶紧来看看 这则重要通知还说了哪些关于报名的重要信息吧 公办义务教育学校网上报名工作安排 民办义务教育学校网上报名工作
  • 使用torch以及tensorflow训练一个最简单网络的基本步骤

    torch import torch import torch nn functional as F import matplotlib pyplot as plt x torch Tensor unsqueeze torch Tensor