生成模型 | 代码简单实现生成对抗网络GAN

文摘   2024-11-04 14:35   印度尼西亚  
👆点击上方名片关注哟👆

1.GAN概述

2014年GAN(生成对抗网络) 诞生以来,它已经成为AI生成领域的核心技术之一。无论是生成图像、合成音乐,还是生成文本,我们几乎都能看到GAN的身影。今天就带你简单实现一个GAN,轻松入门!

GAN结合了计算图博弈论的思想:两个模型互相竞争,一个负责“造假”,另一个负责“辨别真假”。只要模型足够强大,经过多轮训练,它们就能生成以假乱真的数据!

GAN工作机制

  • 生成器 (G):类似“造假者”,它生成看似真实的数据。

  • 鉴别器 (D):像一个“侦探”,它的任务是判断输入的数据是真是假。

生成器G并不会直接看到原始数据,它只能依赖鉴别器D的反馈不断改进。训练过程中,G像盲人摸象一样摸索如何欺骗D,而D则不断变强以甄别G的伪造品。

2.实现代码:GAN的最简版本

下面是一段基础的GAN实现代码,使用了 PyTorch:

#!/usr/bin/env python import numpy as npimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.autograd import Variable matplotlib_is_available = Truetry:  from matplotlib import pyplot as pltexcept ImportError:  print("Will skip plotting; matplotlib is not available.")  matplotlib_is_available = False # 数据分布参数data_mean = 4data_stddev = 1.25 # ### Uncomment only one of these to define what data is actually sent to the Discriminator#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)#(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)#(name, preprocess, d_input_func) = ("Data and diffs", lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2)(name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4) print("Using data [%s]" % (name)) # 生成目标数据的分布采样器(正态分布)def get_distribution_sampler(mu, sigma):    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian  # 生成器输入的采样器(均匀分布)def get_generator_input_sampler():    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian # 定义生成器网络 class Generator(nn.Module):    def __init__(self, input_size, hidden_size, output_size, f):        super(Generator, self).__init__()        self.map1 = nn.Linear(input_size, hidden_size)        self.map2 = nn.Linear(hidden_size, hidden_size)        self.map3 = nn.Linear(hidden_size, output_size)        self.f = f     def forward(self, x):        x = self.map1(x)        x = self.f(x)        x = self.map2(x)        x = self.f(x)        x = self.map3(x)        return x
# 定义鉴别器网络 class Discriminator(nn.Module): def __init__(self, input_size, hidden_size, output_size, f): super(Discriminator, self).__init__() self.map1 = nn.Linear(input_size, hidden_size) self.map2 = nn.Linear(hidden_size, hidden_size) self.map3 = nn.Linear(hidden_size, output_size) self.f = f def forward(self, x): x = self.f(self.map1(x)) x = self.f(self.map2(x)) return self.f(self.map3(x)) def extract(v): return v.data.storage().tolist() def stats(d): return [np.mean(d), np.std(d)] def get_moments(d): # Return the first 4 moments of the data provided mean = torch.mean(d) diffs = d - mean var = torch.mean(torch.pow(diffs, 2.0)) std = torch.pow(var, 0.5) zscores = diffs / std skews = torch.mean(torch.pow(zscores, 3.0)) kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0 # excess kurtosis, should be 0 for Gaussian final = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,))) return final def decorate_with_diffs(data, exponent, remove_raw_data=False): mean = torch.mean(data.data, 1, keepdim=True) mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0]) diffs = torch.pow(data - Variable(mean_broadcast), exponent) if remove_raw_data: return torch.cat([diffs], 1) else: return torch.cat([data, diffs], 1)# 模型训练函数 def train(): # 网络参数设置 g_input_size = 1 # Random noise dimension coming into generator, per output vector g_hidden_size = 5 # Generator complexity g_output_size = 1 # Size of generated output vector d_input_size = 500 # Minibatch size - cardinality of distributions d_hidden_size = 10 # Discriminator complexity d_output_size = 1 # Single dimension for 'real' vs. 'fake' classification minibatch_size = d_input_size d_learning_rate = 1e-3 g_learning_rate = 1e-3 sgd_momentum = 0.9 num_epochs = 5000 print_interval = 100 d_steps = 20 g_steps = 20 dfe, dre, ge = 0, 0, 0 d_real_data, d_fake_data, g_fake_data = None, None, None discriminator_activation_function = torch.sigmoid generator_activation_function = torch.tanh     # 数据采样器 d_sampler = get_distribution_sampler(data_mean, data_stddev) gi_sampler = get_generator_input_sampler() G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size, f=generator_activation_function) D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size, f=discriminator_activation_function) criterion = nn.BCELoss() # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum) g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum) for epoch in range(num_epochs): for d_index in range(d_steps): # 训练D网络 D.zero_grad() # 训练D网络区分真实数据 d_real_data = Variable(d_sampler(d_input_size)) d_real_decision = D(preprocess(d_real_data)) d_real_error = criterion(d_real_decision, Variable(torch.ones([1]))) # ones = true d_real_error.backward() # compute/store gradients, but don't change params # 训练D网络区分生成数据 d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size)) d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels d_fake_decision = D(preprocess(d_fake_data.t())) d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1]))) # zeros = fake d_fake_error.backward() d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward() dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0] for g_index in range(g_steps): # 2. # 训练G网络 G.zero_grad() gen_input = Variable(gi_sampler(minibatch_size, g_input_size)) g_fake_data = G(gen_input) dg_fake_decision = D(preprocess(g_fake_data.t())) g_error = criterion(dg_fake_decision, Variable(torch.ones([1]))) # Train G to pretend it's genuine g_error.backward() g_optimizer.step() # Only optimizes G's parameters ge = extract(g_error)[0] if epoch % print_interval == 0: print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s), Fake Dist (%s) " % (epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data)))) # 可视化生成结果 if matplotlib_is_available: print("Plotting the generated distribution...") values = extract(g_fake_data) print(" Values: %s" % (str(values))) plt.hist(values, bins=50) plt.xlabel('Value') plt.ylabel('Count') plt.title('Histogram of Generated Distribution') plt.grid(True) plt.show() train()

代码输出结果

总结

从编程的角度来看,GAN的实现并不复杂:

  1. 随机生成噪声:使用 numpy 的正态分布生成噪声作为G的输入;

  2. 构建两个网络:G用于生成数据,D用于判断数据真假;

  3. 训练步骤

  • 先训练D判断输入数据是否真实;

  • 再训练G欺骗D,让D认为G生成的假数据是真实的。

由于生成的数据是随机的,所以每次运行的结果都会有所不同。



想要了解更多内容,可在小程序搜索🔍AI Pulse,获取更多最新内容。

AI Pulse
"AI Pulse - AI脉动",探索AI技术前沿,深入解析算法精髓,分享行业应用案例,洞察智能科技未来。欢迎关注,与我们共赴AI学习之旅。
 最新文章