本期的分享内容是基于贝叶斯优化BILSTM的时序预测实战。主要从原理和代码实战展开。
一、原理
1、LSTM原理
LSTM 由 RNN 神经网络演变而来,在传统 RNN 神 经网络输入层、隐藏层和输出层的结构基础上,改进了 隐藏层的结构。通过引入门机制来控制信息传递的路径,并通过门对经过网络的信息进行有选择的记忆或删除.如图 1 所示为一个典型的 LSTM 记忆块结构,每个记忆块都有三个“门控”结构,包括输入门、输出门和遗忘门.
图 1 LSTM 结构示意图
LSTM根据下列公式将输入数据进行更新:
2、BiLSTM
BiLSTM 是对 LSTM 的改进,其隐藏层由前向 LSTM 与后 向 LSTM 共 同 构成。其 结 构 如 图 2 所 示, BiLSTM 记忆网络有两个方向的传输层,前传层沿前向 训练时间序列,后传层沿后向训练时间序列,前传和后传层都与输出层连接。
图2 BiLSTM 结构示意图
3、贝叶斯优化BiLSTM模型
贝 叶 斯 优 化 ( Bayesian optimization,BO) 选取模型超参数组合。贝叶斯优化基 于贝叶斯定理,运用概率代理模型拟合真实目标函数, 并根据拟合的结果选择接下来最有可能的点进行评估,利用历史信息合理减少评估次数,提高搜索效率。贝叶斯优化的框架主要包含两个核心部分: 概率 代理模型和采集函数。概率代理模型包含先验概率模 型和观测模型; 采集函数是根据后验概率分布构造的函数,选用合适的概率代理模型和采集函数才能获得更好的优化效果。贝叶斯优化BiLSTM模型的算法流程图。
二、代码实战
clc; clear; close all;
%% --------------- Prepair Data
[opt,data] = PrepareData(opt,data);
%% --------------- Find Best LSTM Parameters with Bayesian Optimization
[opt,data] = OptimizeLSTM(opt,data);
%% --------------- Evaluate Data
[opt,data] = EvaluationData(opt,data);
%% ---------------------------- Local Functions ---------------------------
function data = loadData(opt)
[chosenfile,chosendirectory] = uigetfile({'*.xlsx';'*.csv'},...
'Select Excel time series Data sets','data.xlsx');
filePath = [chosendirectory chosenfile];
if filePath ~= 0
data.DataFileName = chosenfile;
data.CompleteData = readtable(filePath);
if size(data.CompleteData,2)>1
warning('Input data should be an excel file with only one column!');
disp('Operation Failed... '); pause(.9);
disp('Reloading data. '); pause(.9);
data.x = [];
data.isDataRead = false;
return;
end
data.seriesdataHeder = data.CompleteData.Properties.VariableNames(1,:);
data.seriesdata = table2array(data.CompleteData(:,:));
disp('Input data successfully read.');
data.isDataRead = true;
data.seriesdata = PreInput(data.seriesdata);
figure('Name','InputData','NumberTitle','off');
plot(data.seriesdata); grid minor;
title({['Mean = ' num2str(mean(data.seriesdata)) ', STD = ' num2str(std(data.seriesdata)) ];});
if strcmpi(opt.dataPreprocessMode,'None')
data.x = data.seriesdata;
elseif strcmpi(opt.dataPreprocessMode,'Data Normalization')
data.x = DataNormalization(data.seriesdata);
figure('Name','NormilizedInputData','NumberTitle','off');
plot(data.x); grid minor;
title({['Mean = ' num2str(mean(data.x)) ', STD = ' num2str(std(data.x)) ];});
elseif strcmpi(opt.dataPreprocessMode,'Data Standardization')
data.x = DataStandardization(data.seriesdata);
figure('Name','NormilizedInputData','NumberTitle','off');
plot(data.x); grid minor;
title({['Mean = ' num2str(mean(data.x)) ', STD = ' num2str(std(data.x)) ];});
end
else
warning(['In order to train network, please load data.' ...
'Input data should be an excel file with only one column!']);
disp('Operation Cancel.');
data.isDataRead = false;
end
end
function data = PreInput(data)
if iscell(data)
for i=1:size(data,1)
for j=1:size(data,2)
if strcmpi(data{i,j},'#NULL!')
tempVars(i,j) = NaN; %#ok
else
tempVars(i,j) = str2num(data{i,j}); %#ok
end
end
end
data = tempVars;
end
end
T = size(x,2);
MaxDelay = max(Delays);
Range = MaxDelay+1:T;
X= [];
for d = Delays
X=[X; x(:,Range-d)];
end
Y = x(:,Range);
data.X = X;
data.Y = Y;
end
% partitioning input data
function data = dataPartitioning(opt,data)
data.XTr = [];
data.YTr = [];
data.XTs = [];
data.YTs = [];
numTrSample = round(opt.trPercentage*size(data.X,2));
data.XTr = data.X(:,1:numTrSample);
data.YTr = data.Y(:,1:numTrSample);
data.XTs = data.X(:,numTrSample+1:end);
data.YTs = data.Y(:,numTrSample+1:end);
disp(['Time Series data divided to ' num2str(opt.trPercentage*100) '% Train data and ' num2str((1-opt.trPercentage)*100) '% Test data']);
end
% Prepare input data for LSTM network.
function data = LSTMInput(data)
for i=1:size(data.XTr,2)
XTr{i,1} = data.XTr(:,i);
YTr(i,1) = data.YTr(:,i);
end
for i=1:size(data.XTs,2)
XTs{i,1} = data.XTs(:,i);
YTs(i,1) = data.YTs(:,i);
end
data.XTr = XTr;
data.YTr = YTr;
data.XTs = XTs;
data.YTs = YTs;
data.XVl = XTs;
data.YVl = YTs;
disp('Time Series data prepared as suitable LSTM Input data.');
end
% --------------- Evaluate Data ---------------
% ---------------------------------------------
function [opt,data] = EvaluationData(opt,data)
if opt.isUseOptimizer
OptimizedParams = evalin('base', 'OptimizedParams');
% find best Net
[valBest,indxBest] = sort(str2double(extractAfter(strrep(fieldnames(OptimizedParams),'_','.'),'Error')));
data.BiLSTM.Net = OptimizedParams.(['ValidationError' strrep(num2str(valBest(1)),'.','_')]).Net;
if opt.isSaveBestOptimizedValue
fileName = ['BestNet ' num2str(valBest(1)) ' ' char(datetime('now','Format','yyyy.MM.dd HH.mm')) '.mat'];
Net = data.BiLSTM.Net;
save(fileName,'Net')
end
elseif ~opt.isUseOptimizer
[chosenfile,chosendirectory] = uigetfile({'*.mat'},...
'Select Net File','BestNet.mat');
if chosenfile==0
error('Please Select saved Network File or set isUseOptimizer: true');
end
filePath = [chosendirectory chosenfile];
Net = load(filePath);
data.BiLSTM.Net = Net.Net;
end
data.BiLSTM.TrainOutputs = deNorm(data.seriesdata,predict(data.BiLSTM.Net,data.XTr,'MiniBatchSize',opt.miniBatchSize),opt.dataPreprocessMode);
data.BiLSTM.TrainTargets = deNorm(data.seriesdata,data.YTr,opt.dataPreprocessMode);
data.BiLSTM.TestOutputs = deNorm(data.seriesdata,predict(data.BiLSTM.Net,data.XTs,'MiniBatchSize',opt.miniBatchSize),opt.dataPreprocessMode);
data.BiLSTM.TestTargets = deNorm(data.seriesdata,data.YTs,opt.dataPreprocessMode);
data.BiLSTM.AllDataTargets = [data.BiLSTM.TrainTargets data.BiLSTM.TestTargets];
data.BiLSTM.AllDataOutputs = [data.BiLSTM.TrainOutputs data.BiLSTM.TestOutputs];
data = PlotResults(data,'Tr',...
data.BiLSTM.TrainOutputs, ...
data.BiLSTM.TrainTargets);
data = plotReg(data,'Tr',data.BiLSTM.TrainTargets,data.BiLSTM.TrainOutputs);
data = PlotResults(data,'Ts',....
data.BiLSTM.TestOutputs, ...
data.BiLSTM.TestTargets);
data = plotReg(data,'Ts',data.BiLSTM.TestTargets,data.BiLSTM.TestOutputs);
data = PlotResults(data,'All',...
data.BiLSTM.AllDataOutputs, ...
data.BiLSTM.AllDataTargets);
data = plotReg(data,'All',data.BiLSTM.AllDataTargets,data.BiLSTM.AllDataOutputs);
disp('Bi-LSTM network performance evaluated.');
end
function vars = deNorm(data,stdData,deNormMode)
if iscell(stdData(1,1))
for i=1:size(stdData,1)
tmp(i,:) = stdData{i,1}';
end
stdData = tmp;
end
if strcmpi(deNormMode,'Data Normalization')
for i=1:size(data,2)
vars(:,i) = (stdData(:,i).*(max(data(:,i))-min(data(:,i)))) + min(data(:,i));
end
vars = vars';
elseif strcmpi(deNormMode,'Data Standardization')
for i=1:size(data,2)
x.mu(1,i) = mean(data(:,i),'omitnan');
x.sig(1,i) = std (data(:,i),'omitnan');
vars(:,i) = ((stdData(:,i).* x.sig(1,i))+ x.mu(1,i));
end
vars = vars';
else
vars = stdData';
return;
end
end
% plot the output of networks and real output on test and train data
function data = PlotResults(data,firstTitle,Outputs,Targets)
Errors = Targets - Outputs;
MSE = mean(Errors.^2);
RMSE = sqrt(MSE);
NRMSE = RMSE/mean(Targets);
ErrorMean = mean(Errors);
ErrorStd = std(Errors);
rankCorre = RankCorre(Targets,Outputs);
if strcmpi(firstTitle,'tr')
Disp1Name = 'OutputGraphEvaluation_TrainData';
Disp2Name = 'ErrorEvaluation_TrainData';
Disp3Name = 'ErrorHistogram_TrainData';
elseif strcmpi(firstTitle,'ts')
Disp1Name = 'OutputGraphEvaluation_TestData';
Disp2Name = 'ErrorEvaluation_TestData';
Disp3Name = 'ErrorHistogram_TestData';
elseif strcmpi(firstTitle,'all')
Disp1Name = 'OutputGraphEvaluation_ALLData';
Disp2Name = 'ErrorEvaluation_ALLData';
Disp3Name = 'ErrorHistogram_AllData';
end
figure('Name',Disp1Name,'NumberTitle','off');
plot(1:length(Targets),Targets,...
1:length(Outputs),Outputs);grid minor
legend('Targets','Outputs','Location','best') ;
title(['Rank Correlation = ' num2str(rankCorre)]);
figure('Name',Disp2Name,'NumberTitle','off');
plot(Errors);grid minor
title({['MSE = ' num2str(MSE) ', RMSE = ' num2str(RMSE)...
' NRMSE = ' num2str(NRMSE)] ;});
xlabel(['Error Per Sample']);
figure('Name',Disp3Name,'NumberTitle','off');
histogram(Errors);grid minor
title(['Error Mean = ' num2str(ErrorMean) ', Error StD = ' num2str(ErrorStd)]);
xlabel(['Error Histogram']);
if strcmpi(firstTitle,'tr')
data.Err.MSETr = MSE;
data.Err.STDTr = ErrorStd;
data.Err.NRMSETr = NRMSE;
data.Err.rankCorreTr = rankCorre;
elseif strcmpi(firstTitle,'ts')
data.Err.MSETs = MSE;
data.Err.STDTs = ErrorStd;
data.Err.NRMSETs = NRMSE;
data.Err.rankCorreTs = rankCorre;
elseif strcmpi(firstTitle,'all')
data.Err.MSEAll = MSE;
data.Err.STDAll = ErrorStd;
data.Err.NRMSEAll = NRMSE;
data.Err.rankCorreAll = rankCorre;
end
end
% find rank correlation between network output and real data
function [r]=RankCorre(x,y)
x=x';
y=y';
% Find the data length
N = length(x);
% Get the ranks of x
R = crank(x)';
for i=1:size(y,2)
% Get the ranks of y
S = crank(y(:,i))';
% Calculate the correlation coefficient
r(i) = 1-6*sum((R-S).^2)/N/(N^2-1); %#ok
end
end
function r=crank(x)
u = unique(x);
[~,z1] = sort(x);
[~,z2] = sort(z1);
r = (1:length(x))';
r=r(z2);
for i=1:length(u)
s=find(u(i)==x);
r(s,1) = mean(r(s));
end
end
% plot the regression line of output and real value
function data = plotReg(data,Title,Targets,Outputs)
if strcmpi(Title,'tr')
DispName = 'RegressionGraphEvaluation_TrainData';
elseif strcmpi(Title,'ts')
DispName = 'RegressionGraphEvaluation_TestData';
elseif strcmpi(Title,'all')
DispName = 'RegressionGraphEvaluation_ALLData';
end
figure('Name',DispName,'NumberTitle','off');
x = Targets';
y = Outputs';
format long
b1 = x\y;
yCalc1 = b1*x;
scatter(x,y,'MarkerEdgeColor',[0 0.4470 0.7410],'LineWidth',.7);
hold('on');
plot(x,yCalc1,'Color',[0.8500 0.3250 0.0980]);
xlabel('Prediction');
ylabel('Target');
grid minor
% xgrid = 'on';
% disp.YGrid = 'on';
X = [ones(length(x),1) x];
b = X\y;
yCalc2 = X*b;
plot(x,yCalc2,'-.','MarkerSize',4,"LineWidth",.1,'Color',[0.9290 0.6940 0.1250])
legend('Data','Fit','Y=T','Location','best');
%
Rsq2 = 1 - sum((y - yCalc1).^2)/sum((y - mean(y)).^2);
if strcmpi(Title,'tr')
data.Err.RSqur_Tr = Rsq2;
title(['Train Data, R^2 = ' num2str(Rsq2)]);
elseif strcmpi(Title,'ts')
data.Err.RSqur_Ts = Rsq2;
title(['Test Data, R^2 = ' num2str(Rsq2)]);
elseif strcmpi(Title,'all')
data.Err.RSqur_All = Rsq2;
title(['All Data, R^2 = ' num2str(Rsq2)]);
end
结果展示
完整代码链接:https://mbd.pub/o/bread/ZJaXm51w
[1]汤义勤,邹宏亮,蒋旭,唐佳杰,赵洁,何育钦.基于VMD和贝叶斯优化LSTM的母线负荷预测方法[J].电网与清洁能源,2023,39(02):46-52+59.
人划线