深度卷积对抗生成网络matlab实战

文摘   科技   2024-06-27 21:44   贵州  

    今天给大家分享深度卷积对抗生成网络的matlab实战,主要从算法原理和代码实战展开。需要了解更多算法代码的,可以点击文章左下角的阅读全文,进行获取哦~需要了解智能算法、机器学习、深度学习和信号处理相关理论的可以后台私信哦,下一期分享的内容就是你想了解的内容~


这里给参加数学建模比赛的同学推荐一些数学建模比赛书籍,亲测好用。

上述书籍中,数学建模算法与应用》主要介绍常见的数学模型在理解数学模型原理的基础上,我们可以借助《MATLAB智能算法30个案例分析》《MATLAB机器学习30个案例分析》中的代码实战进行学习,进而不断积累知识,方能在数学建模比赛中取得良好成绩!


一、原理

    深度卷积对抗生成网络 (DCGAN)将GAN与CNN相结合,奠定几乎所有GAN的基本网络架构。DCGAN极大地提升了原始GAN训练的稳定性以及生成结果质量。

     DCGAN网络设计中采用了当时对CNN比较流行的改进方案:

1、将空间池化层用卷积层替代,这种替代只需要将卷积的步长stride设置为大于1的数值。改进的意义是下采样过程不再是固定的抛弃某些位置的像素值,而是可以让网络自己去学习下采样方式。

2、将全连接层去除

3、采用BN层,BN的全称是Batch Normalization,是一种用于常用于卷积层后面的归一化方法,起到帮助网络的收敛等作用。作者实验中发现对所有的层都使用BN会造成采样的震荡(我也不理解什么是采样的震荡,我猜是生成图像趋于同样的模式或者生成图像质量忽高忽低)和网络不稳定。

4、在生成器中除输出层使用Tanh(Sigmoid)激活函数,其余层全部使用ReLu激活函数。

5、在判别器所有层都使用LeakyReLU激活函数,防止梯度稀消失。

      下面是DCGAN的生成器网络架构图。

二、代码实战

clear all; close all; clc;%% Deep Convolutional Generative Adversarial Network%% Load Dataload('mnistAll.mat')trainX = preprocess(mnist.train_images); trainY = mnist.train_labels;testX = preprocess(mnist.test_images); testY = mnist.test_labels;%% Settingssettings.latentDim = 100;settings.batch_size = 32; settings.image_size = [28,28,1]; settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;settings.beta2 = 0.999; settings.maxepochs = 50;%% GeneratorparamsGen.FCW1 = dlarray(initializeGaussian([128*7*7,...    settings.latentDim]));paramsGen.FCb1 = dlarray(zeros(128*7*7,1,'single'));paramsGen.TCW1 = dlarray(initializeGaussian([3,3,128,128]));paramsGen.TCb1 = dlarray(zeros(128,1,'single'));paramsGen.BNo1 = dlarray(zeros(128,1,'single'));paramsGen.BNs1 = dlarray(ones(128,1,'single'));paramsGen.TCW2 = dlarray(initializeGaussian([3,3,64,128]));paramsGen.TCb2 = dlarray(zeros(64,1,'single'));paramsGen.BNo2 = dlarray(zeros(64,1,'single'));paramsGen.BNs2 = dlarray(ones(64,1,'single'));paramsGen.CNW1 = dlarray(initializeGaussian([3,3,64,1]));paramsGen.CNb1 = dlarray(zeros(1,1,'single'));stGen.BN1 = []; stGen.BN2 = [];
%% DiscriminatorparamsDis.CNW1 = dlarray(initializeGaussian([3,3,1,32]));paramsDis.CNb1 = dlarray(zeros(32,1,'single'));paramsDis.CNW2 = dlarray(initializeGaussian([3,3,32,64]));paramsDis.CNb2 = dlarray(zeros(64,1,'single'));paramsDis.BNo1 = dlarray(zeros(64,1,'single'));paramsDis.BNs1 = dlarray(ones(64,1,'single'));paramsDis.CNW3 = dlarray(initializeGaussian([3,3,64,128]));paramsDis.CNb3 = dlarray(zeros(128,1,'single'));paramsDis.BNo2 = dlarray(zeros(128,1,'single'));paramsDis.BNs2 = dlarray(ones(128,1,'single'));paramsDis.CNW4 = dlarray(initializeGaussian([3,3,128,256]));paramsDis.CNb4 = dlarray(zeros(256,1,'single'));paramsDis.BNo3 = dlarray(zeros(256,1,'single'));paramsDis.BNs3 = dlarray(ones(256,1,'single'));paramsDis.FCW1 = dlarray(initializeGaussian([1,256*4*4]));paramsDis.FCb1 = dlarray(zeros(1,1,'single'));stDis.BN1 = []; stDis.BN2 = []; stDis.BN3 = [];
% average Gradient and average Gradient squared holdersavgG.Dis = []; avgGS.Dis = []; avgG.Gen = []; avgGS.Gen = [];%% TrainnumIterations = floor(size(trainX,4)/settings.batch_size);out = false; epoch = 0; global_iter = 0;

%% modelGradientsfunction [GradGen,GradDis,stGen,stDis]=modelGradients(x,z,paramsGen,... paramsDis,stGen,stDis)[fake_images,stGen] = Generator(z,paramsGen,stGen);d_output_real = Discriminator(x,paramsDis,stDis);[d_output_fake,stDis] = Discriminator(fake_images,paramsDis,stDis);
% Loss due to true or notd_loss = -mean(.9*log(d_output_real+eps)+log(1-d_output_fake+eps));g_loss = -mean(log(d_output_fake+eps));
% For each network, calculate the gradients with respect to the loss.GradGen = dlgradient(g_loss,paramsGen,'RetainData',true);GradDis = dlgradient(d_loss,paramsDis);end%% progressplotfunction progressplot(paramsGen,stGen,settings)r = 5; c = 5;noise = gpdl(randn([settings.latentDim,r*c]),'CB');gen_imgs = Generator(noise,paramsGen,stGen);gen_imgs = reshape(gen_imgs,28,28,[]);
fig = gcf;if ~isempty(fig.Children) delete(fig.Children)end
I = imtile(gatext(gen_imgs));I = rescale(I);imagesc(I)title("Generated Images")colormap gray
drawnow;end

实验结果:

图1 DCGAN生成的手写数字

图2 真实的手写数字
     由图1、2可知,训练好的DCGAN生成的手写数字与真实的手写数字差别很小。

(完整代码点击文章左下角“阅读原文”获取!!!)



    部分知识来源于网络,如有侵权请联系作者删除~


    今天的分享就到这里了,后续想了解智能算法、机器学习、深度学习和信号处理相关理论的可以后台私信哦~希望大家多多转发点赞加收藏,你们的支持就是我源源不断的创作动力!


作 者 | 华 夏

编 辑 | 华 夏

校 对 | 华 夏


matlab学习之家
分享学习matlab建模知识和matlab编程知识
 最新文章