基于长短期记忆神经网络LSTM的预测模型(matlab实现)

🌼 希望是附丽于存在的,有存在,便有希望,有希望,便是光明。 ——鲁迅

1.普通循环神经网络

循环神经网络(Recurrent Neural Networks)简称RNN,是一种能够处理时间序列数据的神经网络模型,可以自然的拟合时间和数据之间的关系。循环神经网络作为一种早期深度学习算法中的一种,在各领域有广泛的应用范围,如预测,语音识别等。

相比于传统的前馈性神经网络,循环神经网络的最大的特点是增加了“记忆”的优势,通过信号的双向传播和隐藏层的循环结构,循环神经网络能够综合预测前的多个信息,分析输入信号前后的相关性,能有效提高时间序列方向预测的准确性。

却有一定的缺陷:

但是循环神经网络可能在训练的过程中出现指数的衰减或增长,损失函数经常遇见梯度消失或梯度爆炸的问题。对于这种传统的循环神经网络,随着训练数据集量的增大或数据前后时间间隔过大,那么新输入神经网络的数据会答题前期训练确定的隐藏层数目,有效的信息就会被“忘记”。

2.长短期记忆神经网络

2.1长短期记忆神经网络概述

长短期记忆神经网络(LSTM)是循环神经网络中的一种。

由于普通的神经网络出现如上的问题,长短期记忆神经网络来源于对循环神经网络的改进和优化,有效解决了循环神经网络中出现的梯度消失或梯度爆炸等问题,可用于处理更复杂的时间序列问题。

长短期记忆神经网络通过独特设计的遗忘门、输入们和输出门,在设计的开始就默认可以学习到长期的信息,以此实现输入神经网络的数据长记忆性。

2.2 长短期记忆神经网络原理

LSTM运行的关键状态就是细胞状态,传送带从上方贯穿运行,只有少量的线性交互,信息在交互时有很少的流失,LSTM通过精心设计的“门”结构来增加或减少到细胞状态的信息,“门”是一种选择性让信息通过的方法, 这个方法包括sigmoid神经网络层和一个按位的乘法操作。

其中sigmoid层输出值的范围位[0,1],用来表示每个部分有多少信息通过,0表示不允许任何通过,1表示允许任何通过。

LSTM原理图如下:

2.3 长短期记忆神经网络结构

LSTM模型同样由输入层、隐藏层和输出层三部分构成。其结构如下图:

在LSTM结构中,隐藏层的神经元内部设置了输入们、遗忘门和输出门,LSTM模型的核心是输入门、遗忘门和输出们以及单元状态。
  • 第一个是遗忘门,也就是决定从细胞状态中丢弃什么信息,遗忘门是一个sigmoid函数,是以当前时刻的输入和上一时刻的输出,读取0到1之间的一个数值作用于上一时刻输出的细胞状态。

  • 第二个是输入门,也就是决定有多少新信息保存到细胞状态,输入们是sigmoid层函数决定了更新哪部分的信息,而tanh层会生成一个新的候选值向量加入到细胞状态中,这两层结合起来更新细胞状态的值。

  • 第三个是输出门,是根据输入值和当前单元状态共同决定输出,输出门采用sigmoid函数决定了输出哪部分的信息,单元状态通过tanh函数处理后才输出,然后与sigmoid函数处理的输出相乘,最后输出结果。

3.matlab代码实现

%%  创建模型
layers = [
    sequenceInputLayer(30)               % 建立输入层
    
    lstmLayer(4, 'OutputMode', 'last')  % LSTM层
    reluLayer                           % Relu激活层
    
    fullyConnectedLayer(1)              % 全连接层
    regressionLayer];                   % 回归层

%%  参数设置
options = trainingOptions('adam', ...      % Adam 梯度下降算法
    'MiniBatchSize', 30, ...               % 批大小
    'MaxEpochs', 2000, ...                 % 最大迭代次数
    'InitialLearnRate', 1e-2, ...          % 初始学习率为
    'LearnRateSchedule', 'piecewise', ...  % 学习率下降
    'LearnRateDropFactor', 0.5, ...        % 学习率下降因子
    'LearnRateDropPeriod', 800, ...        % 经过 800 次训练后 学习率为 0.01 * 0.5
    'Shuffle', 'every-epoch', ...          % 每次训练打乱数据集
    'Plots', 'training-progress', ...      % 画出曲线
    'Verbose', false);
%%  训练模型
net = trainNetwork(p_train, t_train, layers, options);

%%  查看网络结构
analyzeNetwork(net)
此处为部分代码,计算相关评价指标,见: http://t.csdn.cn/X6S46

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年4月5日
下一篇 2023年4月5日

相关推荐