基于贝叶斯优化BILSTM的时序预测实战

文摘   2024-06-25 20:56   贵州  

 本期的分享内容是基于贝叶斯优化BILSTM的时序预测实战。主要从原理和代码实战展开。


一、原理

1、LSTM原理

    LSTM 由 RNN 神经网络演变而来,在传统 RNN 神 经网络输入层、隐藏层和输出层的结构基础上,改进了 隐藏层的结构。通过引入门机制来控制信息传递的路径,并通过门对经过网络的信息进行有选择的记忆或删除.如图 1 所示为一个典型的 LSTM 记忆块结构,每个记忆块都有三个“门控”结构,包括输入门、输出门和遗忘门.

图 1 LSTM 结构示意图

LSTM根据下列公式将输入数据进行更新:

2、BiLSTM 

    BiLSTM 是对 LSTM 的改进,其隐藏层由前向 LSTM 与后 向 LSTM 共 同 构成。其 结 构 如 图 2 所 示, BiLSTM 记忆网络有两个方向的传输层,前传层沿前向 训练时间序列,后传层沿后向训练时间序列,前传和后传层都与输出层连接。

图2 BiLSTM 结构示意图

3、贝叶斯优化BiLSTM模型

    贝 叶 斯 优 化 ( Bayesian optimization,BO) 选取模型超参数组合。贝叶斯优化基 于贝叶斯定理,运用概率代理模型拟合真实目标函数, 并根据拟合的结果选择接下来最有可能的点进行评估,利用历史信息合理减少评估次数,提高搜索效率。贝叶斯优化的框架主要包含两个核心部分: 概率 代理模型和采集函数。概率代理模型包含先验概率模 型和观测模型; 采集函数是根据后验概率分布构造的函数,选用合适的概率代理模型和采集函数才能获得更好的优化效果。贝叶斯优化BiLSTM模型的算法流程图。

二、代码实战


clc; clear; close all;
%% --------------- Prepair Data[opt,data] = PrepareData(opt,data);%% --------------- Find Best LSTM Parameters with Bayesian Optimization[opt,data] = OptimizeLSTM(opt,data);%% --------------- Evaluate Data[opt,data] = EvaluationData(opt,data);%% ---------------------------- Local Functions ---------------------------function data = loadData(opt)[chosenfile,chosendirectory] = uigetfile({'*.xlsx';'*.csv'},... 'Select Excel time series Data sets','data.xlsx');filePath = [chosendirectory chosenfile];if filePath ~= 0 data.DataFileName = chosenfile; data.CompleteData = readtable(filePath); if size(data.CompleteData,2)>1 warning('Input data should be an excel file with only one column!'); disp('Operation Failed... '); pause(.9); disp('Reloading data. '); pause(.9); data.x = []; data.isDataRead = false; return; end data.seriesdataHeder = data.CompleteData.Properties.VariableNames(1,:); data.seriesdata = table2array(data.CompleteData(:,:)); disp('Input data successfully read.'); data.isDataRead = true; data.seriesdata = PreInput(data.seriesdata);
figure('Name','InputData','NumberTitle','off'); plot(data.seriesdata); grid minor; title({['Mean = ' num2str(mean(data.seriesdata)) ', STD = ' num2str(std(data.seriesdata)) ];}); if strcmpi(opt.dataPreprocessMode,'None') data.x = data.seriesdata; elseif strcmpi(opt.dataPreprocessMode,'Data Normalization') data.x = DataNormalization(data.seriesdata); figure('Name','NormilizedInputData','NumberTitle','off'); plot(data.x); grid minor; title({['Mean = ' num2str(mean(data.x)) ', STD = ' num2str(std(data.x)) ];}); elseif strcmpi(opt.dataPreprocessMode,'Data Standardization') data.x = DataStandardization(data.seriesdata); figure('Name','NormilizedInputData','NumberTitle','off'); plot(data.x); grid minor; title({['Mean = ' num2str(mean(data.x)) ', STD = ' num2str(std(data.x)) ];}); end
else warning(['In order to train network, please load data.' ... 'Input data should be an excel file with only one column!']); disp('Operation Cancel.'); data.isDataRead = false;endendfunction data = PreInput(data)if iscell(data) for i=1:size(data,1) for j=1:size(data,2) if strcmpi(data{i,j},'#NULL!') tempVars(i,j) = NaN; %#ok else tempVars(i,j) = str2num(data{i,j}); %#ok end end end data = tempVars;endend
T = size(x,2);MaxDelay = max(Delays);Range = MaxDelay+1:T;X= [];for d = Delays X=[X; x(:,Range-d)];endY = x(:,Range);data.X = X;data.Y = Y;end% partitioning input datafunction data = dataPartitioning(opt,data)data.XTr = [];data.YTr = [];data.XTs = [];data.YTs = [];numTrSample = round(opt.trPercentage*size(data.X,2));data.XTr = data.X(:,1:numTrSample);data.YTr = data.Y(:,1:numTrSample);data.XTs = data.X(:,numTrSample+1:end);data.YTs = data.Y(:,numTrSample+1:end);disp(['Time Series data divided to ' num2str(opt.trPercentage*100) '% Train data and ' num2str((1-opt.trPercentage)*100) '% Test data']);end% Prepare input data for LSTM network.function data = LSTMInput(data)for i=1:size(data.XTr,2) XTr{i,1} = data.XTr(:,i); YTr(i,1) = data.YTr(:,i);endfor i=1:size(data.XTs,2) XTs{i,1} = data.XTs(:,i); YTs(i,1) = data.YTs(:,i);enddata.XTr = XTr;data.YTr = YTr;data.XTs = XTs;data.YTs = YTs;data.XVl = XTs;data.YVl = YTs;disp('Time Series data prepared as suitable LSTM Input data.');end% --------------- Evaluate Data ---------------% ---------------------------------------------function [opt,data] = EvaluationData(opt,data)if opt.isUseOptimizer OptimizedParams = evalin('base', 'OptimizedParams'); % find best Net [valBest,indxBest] = sort(str2double(extractAfter(strrep(fieldnames(OptimizedParams),'_','.'),'Error'))); data.BiLSTM.Net = OptimizedParams.(['ValidationError' strrep(num2str(valBest(1)),'.','_')]).Net; if opt.isSaveBestOptimizedValue fileName = ['BestNet ' num2str(valBest(1)) ' ' char(datetime('now','Format','yyyy.MM.dd HH.mm')) '.mat']; Net = data.BiLSTM.Net; save(fileName,'Net') endelseif ~opt.isUseOptimizer [chosenfile,chosendirectory] = uigetfile({'*.mat'},... 'Select Net File','BestNet.mat'); if chosenfile==0 error('Please Select saved Network File or set isUseOptimizer: true'); end filePath = [chosendirectory chosenfile]; Net = load(filePath); data.BiLSTM.Net = Net.Net;end
data.BiLSTM.TrainOutputs = deNorm(data.seriesdata,predict(data.BiLSTM.Net,data.XTr,'MiniBatchSize',opt.miniBatchSize),opt.dataPreprocessMode);data.BiLSTM.TrainTargets = deNorm(data.seriesdata,data.YTr,opt.dataPreprocessMode);data.BiLSTM.TestOutputs = deNorm(data.seriesdata,predict(data.BiLSTM.Net,data.XTs,'MiniBatchSize',opt.miniBatchSize),opt.dataPreprocessMode);data.BiLSTM.TestTargets = deNorm(data.seriesdata,data.YTs,opt.dataPreprocessMode);data.BiLSTM.AllDataTargets = [data.BiLSTM.TrainTargets data.BiLSTM.TestTargets];data.BiLSTM.AllDataOutputs = [data.BiLSTM.TrainOutputs data.BiLSTM.TestOutputs];data = PlotResults(data,'Tr',... data.BiLSTM.TrainOutputs, ... data.BiLSTM.TrainTargets);data = plotReg(data,'Tr',data.BiLSTM.TrainTargets,data.BiLSTM.TrainOutputs);data = PlotResults(data,'Ts',.... data.BiLSTM.TestOutputs, ... data.BiLSTM.TestTargets);data = plotReg(data,'Ts',data.BiLSTM.TestTargets,data.BiLSTM.TestOutputs);data = PlotResults(data,'All',... data.BiLSTM.AllDataOutputs, ... data.BiLSTM.AllDataTargets);data = plotReg(data,'All',data.BiLSTM.AllDataTargets,data.BiLSTM.AllDataOutputs);disp('Bi-LSTM network performance evaluated.');endfunction vars = deNorm(data,stdData,deNormMode)if iscell(stdData(1,1)) for i=1:size(stdData,1) tmp(i,:) = stdData{i,1}'; end stdData = tmp;endif strcmpi(deNormMode,'Data Normalization') for i=1:size(data,2) vars(:,i) = (stdData(:,i).*(max(data(:,i))-min(data(:,i)))) + min(data(:,i)); end vars = vars';
elseif strcmpi(deNormMode,'Data Standardization') for i=1:size(data,2) x.mu(1,i) = mean(data(:,i),'omitnan'); x.sig(1,i) = std (data(:,i),'omitnan'); vars(:,i) = ((stdData(:,i).* x.sig(1,i))+ x.mu(1,i)); end vars = vars';
else vars = stdData'; return;endend% plot the output of networks and real output on test and train datafunction data = PlotResults(data,firstTitle,Outputs,Targets)Errors = Targets - Outputs;MSE = mean(Errors.^2);RMSE = sqrt(MSE);NRMSE = RMSE/mean(Targets);ErrorMean = mean(Errors);ErrorStd = std(Errors);rankCorre = RankCorre(Targets,Outputs);if strcmpi(firstTitle,'tr') Disp1Name = 'OutputGraphEvaluation_TrainData'; Disp2Name = 'ErrorEvaluation_TrainData'; Disp3Name = 'ErrorHistogram_TrainData';elseif strcmpi(firstTitle,'ts') Disp1Name = 'OutputGraphEvaluation_TestData'; Disp2Name = 'ErrorEvaluation_TestData'; Disp3Name = 'ErrorHistogram_TestData';elseif strcmpi(firstTitle,'all') Disp1Name = 'OutputGraphEvaluation_ALLData'; Disp2Name = 'ErrorEvaluation_ALLData'; Disp3Name = 'ErrorHistogram_AllData';endfigure('Name',Disp1Name,'NumberTitle','off');plot(1:length(Targets),Targets,... 1:length(Outputs),Outputs);grid minorlegend('Targets','Outputs','Location','best') ;title(['Rank Correlation = ' num2str(rankCorre)]);figure('Name',Disp2Name,'NumberTitle','off');plot(Errors);grid minortitle({['MSE = ' num2str(MSE) ', RMSE = ' num2str(RMSE)... ' NRMSE = ' num2str(NRMSE)] ;});xlabel(['Error Per Sample']);figure('Name',Disp3Name,'NumberTitle','off');histogram(Errors);grid minortitle(['Error Mean = ' num2str(ErrorMean) ', Error StD = ' num2str(ErrorStd)]);xlabel(['Error Histogram']);if strcmpi(firstTitle,'tr') data.Err.MSETr = MSE; data.Err.STDTr = ErrorStd; data.Err.NRMSETr = NRMSE; data.Err.rankCorreTr = rankCorre;elseif strcmpi(firstTitle,'ts') data.Err.MSETs = MSE; data.Err.STDTs = ErrorStd; data.Err.NRMSETs = NRMSE; data.Err.rankCorreTs = rankCorre;elseif strcmpi(firstTitle,'all') data.Err.MSEAll = MSE; data.Err.STDAll = ErrorStd; data.Err.NRMSEAll = NRMSE; data.Err.rankCorreAll = rankCorre;endend% find rank correlation between network output and real datafunction [r]=RankCorre(x,y)x=x';y=y';% Find the data lengthN = length(x);% Get the ranks of xR = crank(x)';for i=1:size(y,2) % Get the ranks of y S = crank(y(:,i))'; % Calculate the correlation coefficient r(i) = 1-6*sum((R-S).^2)/N/(N^2-1); %#okendendfunction r=crank(x)u = unique(x);[~,z1] = sort(x);[~,z2] = sort(z1);r = (1:length(x))';r=r(z2);for i=1:length(u) s=find(u(i)==x); r(s,1) = mean(r(s));endend% plot the regression line of output and real valuefunction data = plotReg(data,Title,Targets,Outputs)if strcmpi(Title,'tr') DispName = 'RegressionGraphEvaluation_TrainData';elseif strcmpi(Title,'ts') DispName = 'RegressionGraphEvaluation_TestData';elseif strcmpi(Title,'all') DispName = 'RegressionGraphEvaluation_ALLData';endfigure('Name',DispName,'NumberTitle','off');x = Targets';y = Outputs';format longb1 = x\y;yCalc1 = b1*x;scatter(x,y,'MarkerEdgeColor',[0 0.4470 0.7410],'LineWidth',.7);hold('on');plot(x,yCalc1,'Color',[0.8500 0.3250 0.0980]);xlabel('Prediction');ylabel('Target');grid minor% xgrid = 'on';% disp.YGrid = 'on';X = [ones(length(x),1) x];b = X\y;yCalc2 = X*b;plot(x,yCalc2,'-.','MarkerSize',4,"LineWidth",.1,'Color',[0.9290 0.6940 0.1250])legend('Data','Fit','Y=T','Location','best');%Rsq2 = 1 - sum((y - yCalc1).^2)/sum((y - mean(y)).^2);if strcmpi(Title,'tr') data.Err.RSqur_Tr = Rsq2; title(['Train Data, R^2 = ' num2str(Rsq2)]);elseif strcmpi(Title,'ts') data.Err.RSqur_Ts = Rsq2; title(['Test Data, R^2 = ' num2str(Rsq2)]);elseif strcmpi(Title,'all') data.Err.RSqur_All = Rsq2; title(['All Data, R^2 = ' num2str(Rsq2)]);end


结果展示


    完整代码链接:https://mbd.pub/o/bread/ZJaXm51w


参考文献:
[1]汤义勤,邹宏亮,蒋旭,唐佳杰,赵洁,何育钦.基于VMD和贝叶斯优化LSTM的母线负荷预测方法[J].电网与清洁能源,2023,39(02):46-52+59.
[2]魏佳恒,郭惠勇.基于贝叶斯优化BiLSTM模型的输电塔损伤识别[J].振动与冲击,2023,42(01):238-248.DOI:10.13465/j.cnki.jvs.2023.01.028.




    如有侵权请联系作者删除~




作 者 | 华 夏
编 辑 | 华 夏
校 对 | 华 夏


人划线



matlab学习之家
分享学习matlab建模知识和matlab编程知识
 最新文章