麻雀算法SSA优化LSTM超参数

2023-11-08

前言

  1. LSTM 航空乘客预测单步预测的两种情况。 简单运用LSTM 模型进行预测分析。
  2. 加入注意力机制的LSTM 对航空乘客预测采用了目前市面上比较流行的注意力机制,将两者进行结合预测。
  3. 多层 LSTM 对航空乘客预测 简单运用多层的LSTM 模型进行预测分析。
  4. 双向LSTM 对航空乘客预测双向LSTM网络对其进行预测。
  5. MLP多层感知器 对航空乘客预测简化版 使用MLP 对航空乘客预测
  6. CNN + LSTM 航空乘客预测采用的CNN + LSTM网络对其进行预测。
  7. ConvLSTM 航空乘客预测采用ConvLSTM 航空乘客预测
  8. LSTM的输入格式和输出个数说明 中对单步和多步的输入输出格式进行了解释
  9. LSTM 单变量多步预测航空乘客简单版
  10. LSTM 单变量多步预测航空乘客复杂版
  11. LSTM 多变量单步预测空气质量(1—》1) 用LSTM 前一个数据点的多变量预测下一个时间点的空气质量
  12. LSTM 多变量单步预测空气质量(3 —》1) 用LSTM 前三个数据点的多变量预测下一个时间点的空气质量

本文主要是采用麻雀算法SSA优化LSTM超参数

程序

麻雀搜索算法是2020提出的一种新的优化算法,在此不对具体原理进行分析,针对代码实操.

SSA

麻雀算法代码简介

class SSA():
    def __init__(self, func, n_dim=None, pop_size=20, max_iter=50, lb=-512, ub=512, verbose=False):
        self.func = func
        self.n_dim = n_dim  # dimension of particles, which is the number of variables of func
        self.pop = pop_size  # number of particles
        P_percent = 0.2  # # 生产者的人口规模占总人口规模的20%
        D_percent = 0.1  # 预警者的人口规模占总人口规模的10%
        self.pNum = round(self.pop * P_percent)  # 生产者的人口规模占总人口规模的20%
        self.warn = round(self.pop * D_percent)  # 预警者的人口规模占总人口规模的10%

        self.max_iter = max_iter  # max iter
        self.verbose = verbose  # print the result of each iter or not

        self.lb, self.ub = np.array(lb) * np.ones(self.n_dim), np.array(ub) * np.ones(self.n_dim)
        assert self.n_dim == len(self.lb) == len(self.ub), 'dim == len(lb) == len(ub) is not True'
        assert np.all(self.ub > self.lb), 'upper-bound must be greater than lower-bound'

        self.X = np.random.uniform(low=self.lb, high=self.ub, size=(self.pop, self.n_dim))

        self.Y = [self.func(self.X[i]) for i in range(len(self.X))]  # y = f(x) for all particles
        self.pbest_x = self.X.copy()  # personal best location of every particle in history
        self.pbest_y = [np.inf for i in range(self.pop)]  # best image of every particle in history
        self.gbest_x = self.pbest_x.mean(axis=0).reshape(1, -1)  # global best location for all particles
        self.gbest_y = np.inf  # global best y for all particles
        self.gbest_y_hist = []  # gbest_y of every iteration
        self.update_pbest()
        self.update_gbest()
        #
        # record verbose values
        self.record_mode = False
        self.record_value = {'X': [], 'V': [], 'Y': []}
        self.best_x, self.best_y = self.gbest_x, self.gbest_y  # history reasons, will be deprecated
        self.idx_max = 0
        self.x_max = self.X[self.idx_max, :]
        self.y_max = self.Y[self.idx_max]

    def cal_y(self, start, end):
        # calculate y for every x in X
        for i in range(start, end):
            self.Y[i] = self.func(self.X[i])
        # return self.Y

    def update_pbest(self):
        '''
        personal best
        '''
        for i in range(len(self.Y)):
            if self.pbest_y[i] > self.Y[i]:
                self.pbest_x[i] = self.X[i]
                self.pbest_y[i] = self.Y[i]

    def update_gbest(self):
        idx_min = self.pbest_y.index(min(self.pbest_y))
        if self.gbest_y > self.pbest_y[idx_min]:
            self.gbest_x = self.X[idx_min, :].copy()
            self.gbest_y = self.pbest_y[idx_min]

    def find_worst(self):
        self.idx_max = self.Y.index(max(self.Y))
        self.x_max = self.X[self.idx_max, :]
        self.y_max = self.Y[self.idx_max]

    def update_finder(self):
        r2 = np.random.rand(1)  # 预警值
        self.idx = sorted(enumerate(self.Y), key=lambda x: x[1])
        self.idx = [self.idx[i][0] for i in range(len(self.idx))]
        # 这一部位为发现者(探索者)的位置更新
        if r2 < 0.8:  # 预警值较小,说明没有捕食者出现
            for i in range(self.pNum):
                r1 = np.random.rand(1)
                self.X[self.idx[i], :] = self.X[self.idx[i], :] * np.exp(-(i) / (r1 * self.max_iter))  # 对自变量做一个随机变换
                self.X = np.clip(self.X, self.lb, self.ub)  # 对超过边界的变量进行去除
                # X[idx[i], :] = Bounds(X[idx[i], :], lb, ub)  # 对超过边界的变量进行去除
                # fit[sortIndex[0, i], 0] = func(X[sortIndex[0, i], :])  # 算新的适应度值
        elif r2 >= 0.8:  # 预警值较大,说明有捕食者出现威胁到了种群的安全,需要去其它地方觅食
            for i in range(self.pNum):
                Q = np.random.rand(1)  # 也可以替换成  np.random.normal(loc=0, scale=1.0, size=1)
                self.X[self.idx[i], :] = self.X[self.idx[i], :] + Q * np.ones(
                    (1, self.n_dim))  # Q是服从正态分布的随机数。L表示一个1×d的矩阵
                self.X = np.clip(self.X, self.lb, self.ub)  # 对超过边界的变量进行去除
                # X[idx[i], :] = Bounds(X[sortIndex[0, i], :], lb, ub)
                # fit[sortIndex[0, i], 0] = func(X[sortIndex[0, i], :])
        self.cal_y(0, self.pNum)

    def update_follower(self):
        #  这一部位为加入者(追随者)的位置更新
        for ii in range(self.pop - self.pNum):
            i = ii + self.pNum
            A = np.floor(np.random.rand(1, self.n_dim) * 2) * 2 - 1
            best_idx = self.Y[0:self.pNum].index(min(self.Y[0:self.pNum]))
            bestXX = self.X[best_idx, :]
            if i > self.pop / 2:
                Q = np.random.rand(1)
                self.X[self.idx[i], :] = Q * np.exp((self.x_max - self.X[self.idx[i], :]) / np.square(i))
            else:
                self.X[self.idx[i], :] = bestXX + np.dot(np.abs(self.X[self.idx[i], :] - bestXX),
                                                         1 / (A.T * np.dot(A, A.T))) * np.ones((1, self.n_dim))
        self.X = np.clip(self.X, self.lb, self.ub)  # 对超过边界的变量进行去除
        # X[self.idx[i],:] = Bounds(X[self.idx[i],lb,ub)
        # fit[self.idx[i],0] = func(X[self.idx[i], :])
        self.cal_y(self.pNum, self.pop)

    def detect(self):
        arrc = np.arange(self.pop)
        c = np.random.permutation(arrc)  # 随机排列序列
        b = [self.idx[i] for i in c[0: self.warn]]
        e = 10e-10
        for j in range(len(b)):
            if self.Y[b[j]] > self.gbest_y:
                self.X[b[j], :] = self.gbest_y + np.random.rand(1, self.n_dim) * np.abs(self.X[b[j], :] - self.gbest_y)
            else:
                self.X[b[j], :] = self.X[b[j], :] + (2 * np.random.rand(1) - 1) * np.abs(
                    self.X[b[j], :] - self.x_max) / (self.func(self.X[b[j]]) - self.y_max + e)
            # X[sortIndex[0, b[j]], :] = Bounds(X[sortIndex[0, b[j]], :], lb, ub)
            # fit[sortIndex[0, b[j]], 0] = func(X[sortIndex[0, b[j]]])
            self.X = np.clip(self.X, self.lb, self.ub)  # 对超过边界的变量进行去除
            self.Y[b[j]] = self.func(self.X[b[j]])

    def run(self, max_iter=None):
        self.max_iter = max_iter or self.max_iter
        for iter_num in range(self.max_iter):
            self.update_finder()  # 更新发现者位置
            self.find_worst()  # 取出最大的适应度值和最差适应度的X
            self.update_follower()  # 更新跟随着位置
            self.update_pbest()
            self.update_gbest()
            self.detect()
            self.update_pbest()
            self.update_gbest()
            self.gbest_y_hist.append(self.gbest_y)
        return self.best_x, self.best_y

LSTM

def build_model(neurons1, neurons2, dropout):
    X_train, y_train, X_test, y_test = process_data()
    # X_train, y_train = create_dataset(X_train, y_train, steps)
    # X_test, y_test = create_dataset(X_test, y_test, steps)
    nb_features = X_train.shape[2]
    input1 = X_train.shape[1]
    model1 = Sequential()
    model1.add(LSTM(
        input_shape=(input1, nb_features),
        units=neurons1,
        return_sequences=True))
    model1.add(Dropout(dropout))

    model1.add(LSTM(
        units=neurons2,
        return_sequences=False))
    model1.add(Dropout(dropout))

    model1.add(Dense(units=1))
    model1.add(Activation("linear"))
    model1.compile(loss='mse', optimizer='Adam', metrics='mae')
    return model1, X_train, y_train, X_test, y_test

优化超参数

if __name__ == '__main__':
    '''
    神经网络第一层神经元个数
    神经网络第二层神经元个数
    dropout比率
    batch_size
    '''
    neurons1 = 64
    neurons2 = 64
    dropout = 0.01
    batch_size = 32
    model, X_train, y_train, X_test, y_test = build_model(neurons1, neurons2, dropout)
    history1 = model.fit(X_train, y_train, epochs=150, batch_size=batch_size, validation_split=0.2, verbose=1,
                         callbacks=[EarlyStopping(monitor='val_loss', patience=9, restore_best_weights=True)])
    # 测试集预测
    y_score = model.predict(X_test)
    # 反归一化
    y_score = scaler.inverse_transform(y_score.reshape(-1, 1))
    y_test = scaler.inverse_transform(y_test.reshape(-1, 1))

    print("==========evaluation==============\n")
    from sklearn.metrics import mean_squared_error
    from sklearn.metrics import mean_absolute_error #平方绝对误差
    import math

    MAE = mean_absolute_error(y_test, y_score)
    print('MAE: %.4f ' % MAE)
    RMSE = math.sqrt(mean_squared_error(y_test, y_score))
    print('RMSE: %.4f ' % (RMSE))
    

总结

  1. SSA在一定范围内可以优化LSTM 的超参数,对算力要求有点大
  2. SSA优化算法有一定的局限性,如何利用其优势至关重要
  3. LSTM的超参数可以部分优化,能够节约时间和节省算力资源

备注:
需要源代码和数据集,或者想要沟通交流,请私聊,谢谢.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

麻雀算法SSA优化LSTM超参数 的相关文章

随机推荐

  • matlab读取sheet1_matlab读取excel数据的方法步骤详解

    在Excel中录入好数据以后经常需要被matlab读取 具体该如何读取呢 下面是由学习啦小编分享的matlab读取excel数据的方法 以供大家阅读和学习 matlab读取excel数据的方法 matlab读取Excel数据步骤1 如果数据
  • Vue/JS自定义指令:实现元素滑动、移动端适配以及边界处理

    核心属性 Element clientWidth 元素可视宽度 Element clientHeight 元素可视高度 MouseEvent clientX 鼠标相对于浏览器左上顶点的水平坐标 MouseEvent clientY 鼠标相对
  • 微服务如何治理

    微服务远程调用可能有如下问题 注册中心宕机 服务提供者B有节点宕机 服务消费者A和注册中心之间的网络不通 服务提供者B和注册中心之间的网络不通 服务消费者A和服务提供者B之间的网络不通 服务提供者B有些节点性能变慢 服务提供者B短时间内出现
  • 每日学习07:Comparable接口的CompareTo的用法

    接口 Comparable 此接口强行对实现它的每个类的对象进行整体排序 这种排序被称为类的自然排序 类的 compareTo 方法被称为它的自然比较方法 字符串 数组列表 数组 所有可以 排序 的类都实现了java lang Compar
  • ThinkPHP5.1开发企业微信支付到零钱

    Wxpay php
  • npm启动vue应用开发服务器过程分析

    关于 npm run serve 命令启动vue应用开发环境的过程分析 1 npm run 命令执行时 npm run 命令执行时 会把 node modules bin目录添加到执行环境的PATH变量中 全局的没有安装的包 在node m
  • 本地IDEA中使用SonarQube扫描代码

    文章目录 背景 步骤 安装插件 配置 使用 背景 为了提高效率 在走代码CICD流程里的Sonarqube之前 先在本地提前进行一次扫描和修复 步骤 安装插件 2种方式 在IDE的插件管理中心安装名为 SonarQube Community
  • 爬虫高级应用(15. 基于Charles抓包软件抓取手机APP数据)

    目录 写在前面 配置安装Charles 安装Charles 下载相关证书 电脑证书 手机证书 设置代理 实操案例 抓取手机APP爱吾游戏宝盒数据 写在前面 移动App多使用异步的方式从服务端获取数据 抓取数据之前 要先分析移动App用于获取
  • 线性代数 --- 线性代数基本定理下(四个基本子空间他们两两正交,且互为正交补)

    正交子空间 前面我们已经知道了 两个向量的内积为0是勾股定理的另一种表现形式 现在我们来研究一下两个子空间之间的正交 虽然 我很不喜欢一上来就先给个定义 但我这里还是要给 sorry 现有两个子空间V和W 如果V中的任何一个向量v和W中的任
  • deepsort算法原理以及代码解析

    概述 前边我们讲了sort算法的原理 并且指出了它的不足 IDsw过大 为了解决该问题 17年时候sort算法的团队又提出了DeepSort算法 Deepsort在原来Sort算法的基础上 改进了以下内容 使用级联匹配算法 针对每一个检测器
  • .NET通用开发框架

    在开源中国社区 简单整理了下比较好的 NET通用开发框架 一个好的通用框架大概包括 开源 扩展性好 灵活性好 复用性好 维护性好 易测试 易发布 易部署 快速业务搭建 或业务集成 通用性强 参考资料多 持续技术支持 社区疑难问题建设 NET
  • 顺序表的基本操作(初始化、插入、删除、查询、扩容、打印、清空等)

    顺序表的基本操作 顺序表是用一段物理地址连续的存储单元依次存储数据元素的线性结构 一般情况下采用数组存储 在数组上完成数据的增删查改等基本操作 初始化 初始化结构体 开辟空间 void SeqListInit SeqList ps size
  • TypeScript 中的 any、unknown、never 和 void

    any any 表示 任意类型 它是任意类型的父类 任意类型的值都可以赋予给 any 类型 编译不会报错 let anything any 前端西瓜哥 let flag boolean true anything flag anything
  • 数据库实体关系模型 --- ER Model

    数据库实体关系图 The Entity Relationship Model ER Model ER模型的作用 ER模型的基本组成 E R 图 ER图的基本组成 不同的键 Key 超码 superkey 候选码 candidate key
  • 微服务多模块:Springboot+Security+Redis+Gateway+OpenFeign+Nacos+JWT (附源码)仅需一招,520彻底拿捏你

    可能有些人会觉得这篇似曾相识 没错 这篇是由原文章进行二次开发的 前阵子有些事情 但最近看到评论区说原文章最后实现的是单模块的验证 由于过去太久也懒得验证 所以重新写了一个完整的可以跑得动的一个 OK 回到正题 以下是真正对应的微服务多模块
  • python习题及答案/4.16

    文章目录 1 从键盘输入两个数 求它们的和并输出 2 从键盘输入三个数到a b c中 按公式值输出 3 输出 Python语言简单易学 4 使用函数求特殊a串数列和 5 使用函数求素数和 6 使用函数统计指定数字的个数 1 从键盘输入两个数
  • 以太坊学习笔记(三)——搭建以太坊私链

    以太坊私链的搭建可以直接通过下载程序进行安装 也可以通过编译源码安装 本文介绍通过编译源码进行安装 编译源码 1 准备环境 我们下载的是go语言的源码 首先需要正确的安装go语言环境 如何正确安装go语言环境 大家可以去网上找教程 2 下载
  • AndroidO audio系统之AudioPolicyService分析(三)

    1 AudioPolicyService基础 AudioPolicy在Android系统中主要负责Audio 策略 相关的问题 它和AudioFlinger一起组成了Android Audio系统的两个服务 一个负责管理audio的 路由
  • QStringList 常用方法

    QStringList类 常用方法 定义一个字符串链表 QStringList weekList 往链表中添加元素 weekList lt lt 星期一 lt lt 星期二 lt lt 星期三 lt lt 星期四 weekList lt l
  • 麻雀算法SSA优化LSTM超参数

    前言 LSTM 航空乘客预测单步预测的两种情况 简单运用LSTM 模型进行预测分析 加入注意力机制的LSTM 对航空乘客预测采用了目前市面上比较流行的注意力机制 将两者进行结合预测 多层 LSTM 对航空乘客预测 简单运用多层的LSTM 模