1.GAN概述
自 2014年GAN(生成对抗网络) 诞生以来,它已经成为AI生成领域的核心技术之一。无论是生成图像、合成音乐,还是生成文本,我们几乎都能看到GAN的身影。今天就带你简单实现一个GAN,轻松入门!
GAN结合了计算图和博弈论的思想:两个模型互相竞争,一个负责“造假”,另一个负责“辨别真假”。只要模型足够强大,经过多轮训练,它们就能生成以假乱真的数据!
GAN工作机制
生成器 (G):类似“造假者”,它生成看似真实的数据。
鉴别器 (D):像一个“侦探”,它的任务是判断输入的数据是真是假。
生成器G并不会直接看到原始数据,它只能依赖鉴别器D的反馈不断改进。训练过程中,G像盲人摸象一样摸索如何欺骗D,而D则不断变强以甄别G的伪造品。
2.实现代码:GAN的最简版本
下面是一段基础的GAN实现代码,使用了 PyTorch:
#!/usr/bin/env python
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
matplotlib_is_available = True
try:
from matplotlib import pyplot as plt
except ImportError:
print("Will skip plotting; matplotlib is not available.")
matplotlib_is_available = False
# 数据分布参数
data_mean = 4
data_stddev = 1.25
# ### Uncomment only one of these to define what data is actually sent to the Discriminator
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
#(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)
#(name, preprocess, d_input_func) = ("Data and diffs", lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2)
(name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4)
print("Using data [%s]" % (name))
# 生成目标数据的分布采样器(正态分布)
def get_distribution_sampler(mu, sigma):
return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n))) # Gaussian
# 生成器输入的采样器(均匀分布)
def get_generator_input_sampler():
return lambda m, n: torch.rand(m, n) # Uniform-dist data into generator, _NOT_ Gaussian
# 定义生成器网络
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size, f):
super(Generator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
self.f = f
def forward(self, x):
x = self.map1(x)
x = self.f(x)
x = self.map2(x)
x = self.f(x)
x = self.map3(x)
return x
# 定义鉴别器网络
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size, f):
super(Discriminator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
self.f = f
def forward(self, x):
x = self.f(self.map1(x))
x = self.f(self.map2(x))
return self.f(self.map3(x))
def extract(v):
return v.data.storage().tolist()
def stats(d):
return [np.mean(d), np.std(d)]
def get_moments(d):
# Return the first 4 moments of the data provided
mean = torch.mean(d)
diffs = d - mean
var = torch.mean(torch.pow(diffs, 2.0))
std = torch.pow(var, 0.5)
zscores = diffs / std
skews = torch.mean(torch.pow(zscores, 3.0))
kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0 # excess kurtosis, should be 0 for Gaussian
final = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,)))
return final
def decorate_with_diffs(data, exponent, remove_raw_data=False):
mean = torch.mean(data.data, 1, keepdim=True)
mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
diffs = torch.pow(data - Variable(mean_broadcast), exponent)
if remove_raw_data:
return torch.cat([diffs], 1)
else:
return torch.cat([data, diffs], 1)
# 模型训练函数
def train():
# 网络参数设置
g_input_size = 1 # Random noise dimension coming into generator, per output vector
g_hidden_size = 5 # Generator complexity
g_output_size = 1 # Size of generated output vector
d_input_size = 500 # Minibatch size - cardinality of distributions
d_hidden_size = 10 # Discriminator complexity
d_output_size = 1 # Single dimension for 'real' vs. 'fake' classification
minibatch_size = d_input_size
d_learning_rate = 1e-3
g_learning_rate = 1e-3
sgd_momentum = 0.9
num_epochs = 5000
print_interval = 100
d_steps = 20
g_steps = 20
dfe, dre, ge = 0, 0, 0
d_real_data, d_fake_data, g_fake_data = None, None, None
discriminator_activation_function = torch.sigmoid
generator_activation_function = torch.tanh
# 数据采样器
d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size,
hidden_size=g_hidden_size,
output_size=g_output_size,
f=generator_activation_function)
D = Discriminator(input_size=d_input_func(d_input_size),
hidden_size=d_hidden_size,
output_size=d_output_size,
f=discriminator_activation_function)
criterion = nn.BCELoss() # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)
for epoch in range(num_epochs):
for d_index in range(d_steps):
# 训练D网络
D.zero_grad()
# 训练D网络区分真实数据
d_real_data = Variable(d_sampler(d_input_size))
d_real_decision = D(preprocess(d_real_data))
d_real_error = criterion(d_real_decision, Variable(torch.ones([1]))) # ones = true
d_real_error.backward() # compute/store gradients, but don't change params
# 训练D网络区分生成数据
d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels
d_fake_decision = D(preprocess(d_fake_data.t()))
d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1]))) # zeros = fake
d_fake_error.backward()
d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward()
dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]
for g_index in range(g_steps):
# 2. # 训练G网络
G.zero_grad()
gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
g_fake_data = G(gen_input)
dg_fake_decision = D(preprocess(g_fake_data.t()))
g_error = criterion(dg_fake_decision, Variable(torch.ones([1]))) # Train G to pretend it's genuine
g_error.backward()
g_optimizer.step() # Only optimizes G's parameters
ge = extract(g_error)[0]
if epoch % print_interval == 0:
print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s), Fake Dist (%s) " %
(epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))
# 可视化生成结果
if matplotlib_is_available:
print("Plotting the generated distribution...")
values = extract(g_fake_data)
print(" Values: %s" % (str(values)))
plt.hist(values, bins=50)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('Histogram of Generated Distribution')
plt.grid(True)
plt.show()
train()
代码输出结果
总结
从编程的角度来看,GAN的实现并不复杂:
随机生成噪声:使用
numpy
的正态分布生成噪声作为G的输入;构建两个网络:G用于生成数据,D用于判断数据真假;
训练步骤:
先训练D判断输入数据是否真实;
再训练G欺骗D,让D认为G生成的假数据是真实的。
由于生成的数据是随机的,所以每次运行的结果都会有所不同。
想要了解更多内容,可在小程序搜索🔍AI Pulse,获取更多最新内容。