C++/Python机器学习—感知机(二分类)

2023-10-31

 一、Python

 

import numpy as np
import matplotlib.pyplot as plt

# 定义预测函数
def predict(x, w, b):
    # 计算特征向量和权重向量的点积
    dot_product = np.dot(x, w)
    # 计算预测值
    y_pred = dot_product + b
    # 如果预测值大于0,则返回1,否则返回-1
    return 1 if y_pred > 0 else -1

# 定义训练函数
def train(X, y, w, b, learning_rate, epochs):
    # 获取数据集大小(n)和特征维度(m)
    n, m = X.shape 
    for i in range(epochs): # 迭代训练
        for j in range(n): # 遍历数据集
            y_pred = predict(X[j], w, b) # 预测结果
            if y_pred != y[j]: # 判断是否需要更新权重和偏置
                # 更新权重和偏置
                w += learning_rate * y[j] * X[j]
                b += learning_rate * y[j]
    # 返回更新后的权重和偏置
    return w, b

# 定义生成数据函数
def generate_data(num_examples):
    # 初始化特征矩阵和标签向量
    X = np.zeros((num_examples, 2))
    y = np.zeros(num_examples)
    for i in range(num_examples): # 生成数据
        # 从标准正态分布中生成两个随机数作为特征
        feature = np.random.normal(0.0, 1.0, size=2)
        X[i] = feature
        # 计算特征向量和权重向量的点积
        dot_product = feature[0] * 2 + feature[1] * 3
        # 如果点积大于0,则标签为1,否则为-1
        label = 1 if dot_product > 0 else -1
        y[i] = label # 添加标签
    # 返回特征矩阵和标签向量
    return X, y

# 生成数据
num_examples = 1000
X, y = generate_data(num_examples)

# 初始化权重和偏置
w = np.zeros(2)
b = 0

learning_rate = 0.1
epochs = 1000
# 训练模型
w, b = train(X, y, w, b, learning_rate, epochs)
# 输出训练后的权重和偏置
print("w:", w)
print("b:", b)

# 可视化数据
# 绘制散点图
plt.scatter(X[:,0], X[:,1], c=y)
plt.title("Generated data")
plt.xlabel("x1")
plt.ylabel("x2")

# 绘制决策边界
# 获取x1和x2的最小值和最大值
x1_min, x1_max = X[:,0].min() - 1, X[:,0].max() + 1
x2_min, x2_max = X[:,1].min() - 1, X[:,1].max() + 1
# 生成网格点
xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, 0.1),
                       np.arange(x2_min, x2_max, 0.1))
# 预测网格点的标签
Z = np.array([predict(np.array([x1, x2]), w, b) for x1, x2 in np.c_[xx1.ravel(), xx2.ravel()]])
Z = Z.reshape(xx1.shape)
# 绘制等高线图
plt.contourf(xx1, xx2, Z, alpha=0.4)
plt.show()

二、C++

#include <iostream>
#include <vector>
#include "math.h"
#include <random>

using namespace std;

// 定义预测函数
int predict(vector<double>& x, vector<double>& w, double b) {
    double dot_product = 0;
    for (int i = 0; i < x.size(); i++) {
        dot_product += x[i] * w[i];
    }
    double y_pred = dot_product + b;
    return y_pred > 0 ? 1 : -1;
}

// 定义训练函数
void train(vector<vector<double>>& X, vector<int>& y, vector<double>& w, double& b, double learning_rate, int epochs) {
    int n = X.size(); // 获取数据集大小
    int m = X[0].size(); // 获取特征维度
    for (int i = 0; i < epochs; i++) { // 迭代训练
        for (int j = 0; j < n; j++) { // 遍历数据集
            int y_pred = predict(X[j], w, b); // 预测结果
            if (y_pred != y[j]) { // 判断是否需要更新权重和偏置
                for (int k = 0; k < m; k++) {
                    w[k] += learning_rate * y[j] * X[j][k];
                }
                b += learning_rate * y[j];
            }
        }
    }
}

// 定义生成数据函数
void generate_data(int num_examples, vector<vector<double>>& X, vector<int>& y) {
    default_random_engine generator; // 定义随机数生成器
    normal_distribution<double> distribution(0.0, 1.0); // 定义正态分布
    for (int i = 0; i < num_examples; i++) { // 生成数据
        vector<double> feature;
        feature.push_back(distribution(generator));
        feature.push_back(distribution(generator));
        X.push_back(feature);
        double dot_product = feature[0] * 2 + feature[1] * 3; // 计算点积
        int label = dot_product > 0 ? 1 : -1; // 根据点积确定标签
        y.push_back(label); // 添加标签
    }
}

int main() {
    // 生成数据
    vector<vector<double>> X;
    vector<int> y;
    int num_examples = 1000;
    generate_data(num_examples, X, y);
    
    // 初始化权重和偏置
    vector<double> w = {0, 0};
    double b = 0;
    
    double learning_rate = 0.1;
    int epochs = 1000;
    // 训练模型
    train(X, y, w, b, learning_rate, epochs);
    // 输出训练结果
    cout << "w: " << w[0] << ", " << w[1] << ", b: " << b << std::endl;
    
    // 随机生成测试数据
    default_random_engine generator;
    normal_distribution<double> distribution(0.0, 1.0);
    vector<double> x_test = {distribution(generator), distribution(generator)};
    // 预测结果
    int y_pred = predict(x_test, w, b);
    // 输出预测结果
    cout << "Input values: " << x_test[0] << ", " << x_test[1] << std::endl;
    cout << "Predicted label: " << y_pred << std::endl;

    return 0;
}

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

C++/Python机器学习—感知机(二分类) 的相关文章

  • pyspark 将 twitter json 流式传输到 DF

    我正在从事集成工作spark streaming with twitter using pythonAPI 我看到的大多数示例或代码片段和博客是他们从Twitter JSON文件进行最终处理 但根据我的用例 我需要所有字段twitter J
  • Numpy - 根据表示一维的坐标向量的条件替换数组中的值

    我有一个data多维数组 最后一个是距离 另一方面 我有距离向量r 例如 Data np ones 20 30 100 r np linspace 10 50 100 最后 我还有一个临界距离值列表 称为r0 使得 r0 shape Dat
  • Cython 和类的构造函数

    我对 Cython 使用默认构造函数有疑问 我的 C 类 Node 如下 Node h class Node public Node std cerr lt lt calling no arg constructor lt lt std e
  • 加快网络抓取速度

    我正在使用一个非常简单的网络抓取工具抓取 23770 个网页scrapy 我对 scrapy 甚至 python 都很陌生 但设法编写了一个可以完成这项工作的蜘蛛 然而 它确实很慢 爬行 23770 个页面大约需要 28 小时 我看过scr
  • 使用特定颜色和抖动在箱形图上绘制数据点

    我有一个plotly graph objects Box图 我显示了箱形 图中的所有点 我需要根据数据的属性为标记着色 如下所示 我还想抖动这些点 下面未显示 Using Box我可以绘制点并抖动它们 但我不认为我可以给它们着色 fig a
  • 如何断言 Unittest 上的可迭代对象不为空?

    向服务提交查询后 我会收到一本字典或一个列表 我想确保它不为空 我使用Python 2 7 我很惊讶没有任何assertEmpty方法为unittest TestCase类实例 现有的替代方案看起来并不正确 self assertTrue
  • 将自定义元数据添加到 jpeg 文件

    我正在开发一个图像处理项目 C 我需要在处理完成后将自定义元数据写入 jpeg 文件 我怎样才能做到这一点 有没有可用的图书馆可以做到这一点 如果您正在谈论 EXIF 元数据 您可能需要查看exiv2 http www exiv2 org
  • 根据列 value_counts 过滤数据框(pandas)

    我是第一次尝试熊猫 我有一个包含两列的数据框 user id and string 每个 user id 可能有多个字符串 因此会多次出现在数据帧中 我想从中导出另一个数据框 一个只有那些user ids列出至少有 2 个或更多string
  • 为什么 Pickle 协议 4 中的 Pickle 文件是协议 3 中的两倍,而速度却没有任何提升?

    我正在测试 Python 3 4 我注意到 pickle 模块有一个新协议 因此 我对 2 个协议进行了基准测试 def test1 pickle3 open pickle3 wb for i in range 1000000 pickle
  • clang 实例化后静态成员初始化

    这样的代码可以用 GCC 编译 但 clang 3 5 失败 include
  • 在本地网络上运行 Bokeh 服务器

    我有一个简单的 Bokeh 应用程序 名为app py如下 contents of app py from bokeh client import push session from bokeh embed import server do
  • Discord.net 无法在 Linux 上运行

    我正在尝试让在 Linux VPS 上运行的 Discord net 中编码的不和谐机器人 我通过单声道运行 但我不断收到此错误 Unhandled Exception System Exception Connection lost at
  • 实现 XGboost 自定义目标函数

    我正在尝试使用 XGboost 实现自定义目标函数 在 R 中 但我也使用 python 所以有关 python 的任何反馈也很好 我创建了一个返回梯度和粗麻布的函数 它工作正常 但是当我尝试运行 xgb train 时它不起作用 然后 我
  • C++ 复制初始化和直接初始化,奇怪的情况

    在继续阅读本文之前 请阅读在 C 中 复制初始化和直接初始化之间有区别吗 https stackoverflow com questions 1051379 is there a difference in c between copy i
  • 如何使我的表单标题栏遵循 Windows 深色主题?

    我已经下载了Windows 10更新包括黑暗主题 文件资源管理器等都是深色主题 但是当我创建自己的 C 表单应用程序时 标题栏是亮白色的 如何使我自己的桌面应用程序遵循我在 Windows 中设置的深色主题 你需要调用DwmSetWindo
  • 更改 Tk 标签小部件中单个单词的颜色

    我想更改 Tkinter 标签小部件中单个单词的字体颜色 我知道可以使用文本小部件来实现与我想要完成的类似的事情 例如使单词 YELLOW 显示为黄色 self text tag config tag yel fg clr yellow s
  • mysql-connector-c++ - “get_driver_instance”不是“sql::mysql”的成员

    我是 C 的初学者 我认为学习的唯一方法就是接触一些代码 我正在尝试构建一个连接到 mysql 数据库的程序 我在 Linux 上使用 g 没有想法 我运行 make 这是我的错误 hello cpp 38 error get driver
  • cv2.VideoWriter:请求一个元组作为 Size 参数,然后拒绝它

    我正在使用 OpenCV 4 0 和 Python 3 7 创建延时视频 构造 VideoWriter 对象时 文档表示 Size 参数应该是一个元组 当我给它一个元组时 它拒绝它 当我尝试用其他东西替换它时 它不会接受它 因为它说参数不是
  • 防止索引超出范围错误

    我想编写对某些条件的检查 而不必使用 try catch 并且我想避免出现 Index Out of Range 错误的可能性 if array Element 0 Object Length gt 0 array Element 1 Ob
  • 使用按位运算符相乘

    我想知道如何使用按位运算符将一系列二进制位相乘 但是 我有兴趣这样做来查找二进制值的十进制小数值 这是我正在尝试做的一个例子 假设 1010010 我想使用每个单独的位 以便将其计算为 1 2 1 0 2 2 1 2 3 0 2 4 虽然我

随机推荐

  • mysql索引之B+树

    1 概述 提到B 树就不得不提及二叉树 平衡二叉树和B树这三种数据结构了 B 树就是从他们三个演化来的 众所周知B 树是一种常见的数据结构 被广泛应用于数据库和文件系统等领域 B 树的设计目标是保持树的平衡性 以提供稳定的性能 并且适用于大
  • 链表插入详解

    单链表速成 增与删 众所周知 链表是数据结构的重中之重 但有许多朋友对此并不感冒 甚至想骂 本文主要介绍小编对于链表的喜与悲 乐于忧 先上图 添加结点 单链表结点类型声明 typedef int ElemType 假设ElemType为自定
  • python捕获异常时,打印异常的类型、报错文件、与报错所在的行

    捕获异常 异常的完整代码是 try raise Exception wa except print 报错 else print 没有报错 finally print 程序关闭 得到结果 报错 程序关闭 一般程序里的 try与except是一
  • 如何优化一个肽预测模型

    要优化一个肽预测模型 首先需要考虑的是输入数据的质量 确保输入的数据是完整的 正确的 而不是噪声数据 此外 还需要考虑模型的训练方式 比如是否使用正则化和提前停止来确保不会过拟合训练数据 最后 应该尝试在模型中使用不同的参数来改善模型的性能
  • 02 Java基本数据结构之队列实现

    系列文章目录 01 Java基本数据结构之栈实现 02 Java基本数据结构之队列实现 03 Java基本数据结构之优先级队列 04 Java基本数据结构之链表 如有错误 还请指出 文章目录 系列文章目录 前言 一 队列 简述 二 栈 数组
  • 【Hyperledger Fabric】学习笔记1—— 区块链介绍

    目录 1 区块链介绍 1 1 区块链技术起源 1 1 1 区块链技术 1 1 2 区块链技术发展 1 2 区块链核心技术 1 2 1 定义 1 2 2 区块链技术原理 1 2 3 区块链工作过程 1 3 区块链开发平台 1 3 1 公有链平
  • GIT使用教程(十五步吃透,全网最详细)

    一 安装GIT 到官网下载GIT https git scm com downloads 二 创建仓库 在要创建仓库的文件夹空白地方点击右键 在弹出的菜单中点击 GIT Bash Here 然后初始化仓库 git init 成功后该文件夹中
  • MySQL配置和设置问题小结

    问题1 root Tony ts tian bin mysqladmin uroot password kaka123 mysqladmin connect to server at localhost failed ERROR 1045
  • [4G+5G专题-144]: 测试-频谱分析仪工作原理与测试结果分析

    作者主页 文火冰糖的硅基工坊 文火冰糖 王文兵 的博客 文火冰糖的硅基工坊 CSDN博客 本文网址 https blog csdn net HiWangWenBing article details 123222945 目录 前言 第1章
  • RSA/数字证书/签名原理详解

    文中首先解释了加密解密的一些基础知识和概念 然后通过一个加密通信过程的例子说明了加密算法的作用 以及数字证书的出现所起的作用 接着对数字证书做一个详细的解释 并讨论一下windows中数字证书的管理 最后演示使用makecert生成数字证书
  • 优惠卷测试案例

    提示 过期优惠卷 不同等级的用户 叠加使用 退款 支付失败 取消支付 退款中 订单信息 网络问题 退货 兼容性 优惠券是否可以正常使用 外观是否与UI保持一致 部分商品是否能正常使用 购买商品的时候会不会提示使用优惠券 优惠券是否能分享 分
  • git锁住如何解决GitLab: Your account has been blocked.

    今天用gitpush和pull的时候出现了一个问题 报了一个错误 GitLab Your account has been blocked 然后我怀疑是账号错误 然后发现账号密码对 后来发现是两个git账号同时占用了一个目录 强制删除目录下
  • 数据结构【堆】的认识及建立

    目录 一 堆 1 什么是堆 2 堆的存储方式 二 堆的建立与存储 三 堆的应用 1 堆排序 2 对顶堆 一 堆 1 什么是堆 堆 Heap 是一种特殊的完全二叉树结构 其中最大堆 Max Heap 或最小堆 Min Heap 的每个节点的键
  • Maven-Failed to parse POMs

    Maven Failed to parse POMs 错误描述信息 产生错误的原因 解决办法 依赖关系 错误描述信息 ERROR Failed to parse POMs hudson remoting ProxyException hud
  • mmdetection学习&训练测试自己的数据集

    一 本机使用环境 商汤科技和香港中文大学联合开源的深度学习目标检测工具箱mmdetection源码地址 Ubuntu16 04 Cuda9 0 cudnn7 5 Python3 6 GCC 7 2 Anaconda3 二 环境配置 官方配置
  • 无法连接 SQL Server 不可用或不存在 无法连接, SQL Server 不存在或拒绝网络访问..请问这是怎么回事?...

    远程连接sql server 2000服务器的解决方案 一 看ping 服务器IP能否ping通 这个实际上是看和远程sql server 2000服务器的物理连接是否存在 如果不行 请检查网络 查看配置 当然得确保远程sql server
  • CUDA 6.0在 VS 2010下的安装和配置

    CUDA 6 0在 VS 2010下的安装和配置 安装前准备 CUDA 6 0 安装包 下载地址 https developer nvidia com cuda downloads VS 2010 安装 这个直接下个免费的就行 Visual
  • 信息打点-公众号服务&Github监控&供应链&网盘泄漏&证书图标邮箱资产

    文章目录 微信公众号 获取 三方服务 Github监控 开发 配置 源码 网盘资源搜索 全局文件机密 敏感目录文件 目录扫描 爬虫 网络空间进阶 证书 图标 邮箱 实战案例四则 技术分享打击方位 微信公众号 获取 三方服务 1 获取微信公众
  • Linux C利用Socket套接字进行服务器与多个客户端进行通讯

    http blog csdn net returningprodigal article details 51916754 服务器端 html view plain copy print include
  • C++/Python机器学习—感知机(二分类)

    一 Python import numpy as np import matplotlib pyplot as plt 定义预测函数 def predict x w b 计算特征向量和权重向量的点积 dot product np dot x