今天给大家分享基于小波分析的squeezenet网络时间序列分类MATLAB实战。需要了解更多算法代码的,可以点击文章左下角的阅读全文,进行获取哦~需要了解智能算法、机器学习、深度学习和信号处理相关理论的可以后台私信哦,下一期分享的内容就是你想了解的内容~
本次实战使用连续小波变换(CWT)和squeezenet网络对人类心电图(ECG)信号进行分类。为加快训练速度,采用squeezenet网络进行迁移学习。对图像识别进行了预训练,以基于时频表示对ECG波形进行分类。本次使用的数据可从PhysioNet公开获得。
%%基于小波分析的squeezenet网络时间序列分类
clc
clear
%% 数据处理
unzip(fullfile(tempdir,"physionet_ECG_data-main.zip"),tempdir)
unzip(fullfile(tempdir,"physionet_ECG_data-main","ECGData.zip"), ...
fullfile(tempdir,"physionet_ECG_data-main"))
load(fullfile(tempdir,"physionet_ECG_data-main","ECGData.mat"))
parentDir = tempdir;
dataDir = "data";
helperCreateECGDirectories(ECGData,parentDir,dataDir)
helperPlotReps(ECGData)
%%
Fs = 128;
fb = cwtfilterbank(SignalLength=1000, ...
SamplingFrequency=Fs, ...
VoicesPerOctave=12);
sig = ECGData.Data(1,1:1000);
[cfs,frq] = wt(fb,sig);
t = (0:999)/Fs;
figure
pcolor(t,frq,abs(cfs))
set(gca,"yscale","log")
shading interp
axis tight
title("Scalogram")
xlabel("Time (s)")
ylabel("Frequency (Hz)")
%%
helperCreateRGBfromTF(ECGData,parentDir,dataDir)
allImages = imageDatastore(fullfile(parentDir,dataDir), ...
"IncludeSubfolders",true, ...
"LabelSource","foldernames");
[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,"randomized");
disp("Number of training images: "+num2str(numel(imgsTrain.Files)))
disp("Number of validation images: "+num2str(numel(imgsValidation.Files)))
%% 下载squeezenet
sqz = squeezenet;
lgraphSqz = layerGraph(sqz);
disp("Number of Layers: "+num2str(numel(lgraphSqz.Layers)))
lgraphSqz.Layers(1)
lgraphSqz.Layers(end-5:end)
tmpLayer = lgraphSqz.Layers(end-5);
newDropoutLayer = dropoutLayer(0.6,"Name","new_dropout");
%% 修改squeezenet最后5层架构,实现模型迁移
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newDropoutLayer);
numClasses = numel(categories(imgsTrain.Labels));
tmpLayer = lgraphSqz.Layers(end-4);
newLearnableLayer = convolution2dLayer(1,numClasses, ...
"Name","new_conv", ...
"WeightLearnRateFactor",10, ...
"BiasLearnRateFactor",10);
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newLearnableLayer);
tmpLayer = lgraphSqz.Layers(end);
newClassLayer = classificationLayer("Name","new_classoutput");
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newClassLayer);
lgraphSqz.Layers(end-5:end)
%% 数据增强
augimgsTrain = augmentedImageDatastore([227 227],imgsTrain);
augimgsValidation = augmentedImageDatastore([227 227],imgsValidation);
%% 设置网络参数
ilr = 3e-4;
miniBatchSize = 10;
maxEpochs = 15;
valFreq = floor(numel(augimgsTrain.Files)/miniBatchSize);
opts = trainingOptions("sgdm", ...
MiniBatchSize=miniBatchSize, ...
MaxEpochs=maxEpochs, ...
InitialLearnRate=ilr, ...
ValidationData=augimgsValidation, ...
ValidationFrequency=valFreq, ...
Verbose=1, ...
Plots="training-progress");
trainedSN = trainNetwork(augimgsTrain,lgraphSqz,opts);
trainedSN.Layers(end)
[YPred,probs] = classify(trainedSN,augimgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp("SqueezeNet Accuracy: "+num2str(100*accuracy)+"%")
部分知识来源于网络,如有侵权请联系作者删除~
今天的分享就到这里了,后续想了解智能算法、机器学习、深度学习和信号处理相关理论的可以后台私信哦~希望大家多多转发点赞加收藏,你们的支持就是我源源不断的创作动力!
作 者 | 华 夏
编 辑 | 华 夏
校 对 | 华 夏