二、代码实战
clc
clear
%%下载训练集和测试集
s = load("HumanActivityTrain.mat");
XTrain = s.XTrain;
TTrain = s.YTrain;
numObservations = numel(XTrain);
classes = categories(TTrain{1});
numClasses = numel(classes);
numFeatures = size(s.XTrain{1},1);
s1 = load("HumanActivityTest.mat");
XTest = s1.XTest;
TTest = s1.YTest;
%% 搭建TCN网络
numFilters = 64;
filterSize = 5;
dropoutFactor = 0.005;
numBlocks = 4;
layer = sequenceInputLayer(numFeatures,Normalization="rescale-symmetric",Name="input");
lgraph = layerGraph(layer);
outputName = layer.Name;
for i = 1:numBlocks
dilationFactor = 2^(i-1);
layers = [
convolution1dLayer(filterSize,numFilters,DilationFactor=dilationFactor,Padding="causal",Name="conv1_"+i)
layerNormalizationLayer
spatialDropoutLayer(dropoutFactor)
convolution1dLayer(filterSize,numFilters,DilationFactor=dilationFactor,Padding="causal")
layerNormalizationLayer
reluLayer
spatialDropoutLayer(dropoutFactor)
additionLayer(2,Name="add_"+i)];
% Add and connect layers.
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,outputName,"conv1_"+i);
% Skip connection.
if i == 1
% Include convolution in first skip connection.
layer = convolution1dLayer(1,numFilters,Name="convSkip");
lgraph = addLayers(lgraph,layer);
lgraph = connectLayers(lgraph,outputName,"convSkip");
lgraph = connectLayers(lgraph,"convSkip","add_" + i + "/in2");
else
lgraph = connectLayers(lgraph,outputName,"add_" + i + "/in2");
end
% Update layer output name.
outputName = "add_" + i;
end
layers = [
fullyConnectedLayer(numClasses,Name="fc")
softmaxLayer
classificationLayer];
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,outputName,"fc");
%% 画网络结构
figure
plot(lgraph)
title("Temporal Convolutional Network")
%% 设置网络参数
options = trainingOptions("adam", ...
MaxEpochs=60, ...
miniBatchSize=1, ...
Plots="training-progress", ...
Verbose=0);
%% 训练网络
net = trainNetwork(XTrain,TTrain,lgraph,options);
%% 测试网络
YPred = classify(net,XTest);
figure
confusionchart(TTest{1},YPred{1})
accuracy = mean(YPred{1} == TTest{1})
仿真结果:
部分知识来源于网络,如有侵权请联系作者删除~
今天的分享就到这里了,后续想了解智能算法、机器学习、深度学习和信号处理相关理论的可以后台私信哦~希望大家多多转发点赞加收藏,你们的支持就是我源源不断的创作动力!
作 者 | 华 夏
编 辑 | 华 夏
校 对 | 华 夏