一文弄懂生成式对抗网络

文摘   科技   2024-07-19 07:30   江苏  
点击蓝字
 
关注我们










01


引言



在人工智能的广阔领域中,有一项非凡的创新能够模仿人类的创造力并生成逼真的数据样本,因而脱颖而出:生成对抗网络(GANs)。GANs 诞生于2014 年发表的开创性论文《生成对抗网络》(Generative Adversarial Networks)中,它彻底改变了数据生成领域。


在本文中,我们将开启一段通往 GANs 世界的奇妙之旅,探索它们的数学基础、实际应用以及对人工智能的深远影响。






02


什么是生成式对抗网络?


GANs 的核心是一个引人入胜的概念:两个神经网络之间的对决,就像一场猫捉老鼠的游戏中,这两个神经网络都在努力战胜对方。第一个主角是 "生成器",它是一位生成大师,负责创建与真实数据无异的合成数据样本。第二位主角是 "判别器",他是一名警觉的侦探,受过区分真假数据样本的训练。他们组合在一起跳起了对抗性的舞蹈,各自将对方逼到能力的极限。


GAN 的精髓可以用一个简单而优雅的数学框架来概括。让我们深入了解一下其数学表述:

  • 生成器:
    • 生成器旨在学习基础数据分布,并生成与真实数据非常相似的合成数据样本。
    • 它由一个神经网络表示,将随机噪声作为输入,并输出合成数据样本。

    • 从数学上讲,生成器力求使生成的数据分布与真实数据分布之间的偏差最小。

  • 判别器:
    • 判别器作为对手,负责区分真实和生成的虚假数据样本。
    • 它可以用神经网络(通常是二元分类器)来表示,经过训练后可将输入样本分为真样本和假样本。

    • 从数学上讲,判别器的目标是最大限度地提高其正确分类真假样本的能力,从而有效地学习近似数据分布。




03


 训练过程


GANs 的训练过程就像一场充满欺骗和发现的迷人舞蹈:

  • 生成器以一个天真的学徒身份开始了它的旅程,生成粗糙的真实数据仿制品。
  • 与此同时,判别器会仔细检查每个样本,敏锐地察觉到哪怕是最细微的欺骗迹象。
  • 随着战斗的进行,生成器学会了制作越来越令人信服的赝品,而判别器则演变成了一个可怕的对手。
  • 最终,会达到一种微妙的平衡,即生成器生成的合成样本与真实数据几乎没有区别,而判别器则越来越难以将它们区分开来。






04


  GANs的应用


GAN 的多功能性是无止境的,它横跨不同领域的无数应用:

  • 图像生成:GAN 可以生成逼真的人脸、风景和物体图像,从而激发数字艺术和设计的创造力。

  • 图像到图像的转换:GAN 可以将图像从一个领域转换到另一个领域,实现风格转换、色彩化和语义分割。

  • 文本生成:GAN 可以生成自然语言文本,为讲故事、对话生成和内容创作开辟了新的领域。

  • 异常检测:GAN 可通过学习正常数据分布并识别偏离正常数据分布的情况来检测数据中的异常情况,可应用于欺诈检测和网络安全领域。

  • 药物发现:GANs可以设计出具有所需特性的新型分子结构,从而加速药物发现和分子设计。





05


 实际举例


使用 GAN 进行异常检测包括训练 GAN 来学习正常数据分布,然后使用生成器生成合成数据样本。异常被识别为严重偏离所学分布的数据样本。下面是一个使用 GANs 进行异常检测的简化 Python 代码示例:
import numpy as npimport matplotlib.pyplot as pltfrom tensorflow.keras.layers import Input, Densefrom tensorflow.keras.models import Modelfrom tensorflow.keras.optimizers import Adam
# Generate synthetic data (normal and anomalous)def generate_data(n_normal, n_anomalous): normal_data = np.random.normal(loc=0, scale=1, size=(n_normal, 2)) anomalous_data = np.random.normal(loc=3, scale=1, size=(n_anomalous, 2)) return normal_data, anomalous_data
# Define and compile the GAN modeldef build_gan(latent_dim=2): # Generator generator_input = Input(shape=(latent_dim,)) generator_output = Dense(2, activation='linear')(generator_input) generator = Model(generator_input, generator_output) # Discriminator discriminator_input = Input(shape=(2,)) discriminator_output = Dense(1, activation='sigmoid')(discriminator_input) discriminator = Model(discriminator_input, discriminator_output) discriminator.compile(optimizer=Adam(lr=0.001), loss='binary_crossentropy') # Combined model (GAN) gan_output = discriminator(generator_output) gan = Model(generator_input, gan_output) gan.compile(optimizer=Adam(lr=0.001), loss='binary_crossentropy') return generator, discriminator, gan
# Train the GANdef train_gan(generator, discriminator, gan, normal_data, epochs=1000, batch_size=64): for epoch in range(epochs): # Train discriminator real_data = normal_data[np.random.randint(0, len(normal_data), size=batch_size)] real_labels = np.ones((batch_size, 1)) fake_data = generator.predict(np.random.normal(0, 1, (batch_size, 2))) fake_labels = np.zeros((batch_size, 1)) discriminator_loss_real = discriminator.train_on_batch(real_data, real_labels) discriminator_loss_fake = discriminator.train_on_batch(fake_data, fake_labels) discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake) # Train generator (via GAN) # Here we train the entire GAN to train the generator noise = np.random.normal(0, 1, (batch_size, 2)) gan_loss = gan.train_on_batch(noise, np.ones((batch_size, 1))) # Print progress if epoch % 100 == 0: print(f'Epoch {epoch}/{epochs} | Discriminator Loss: {discriminator_loss} | GAN Loss: {gan_loss}')
# Generate synthetic datanormal_data, anomalous_data = generate_data(n_normal=1000, n_anomalous=50)
# Build and compile GAN modelgenerator, discriminator, gan = build_gan()
# Train GAN modeltrain_gan(generator, discriminator, gan, normal_data)
# Generate synthetic data using the trained generatorsynthetic_data = generator.predict(np.random.normal(0, 1, (1000, 2)))
# Plot dataplt.scatter(normal_data[:, 0], normal_data[:, 1], label='Normal Data', color='blue', alpha=0.5)plt.scatter(anomalous_data[:, 0], anomalous_data[:, 1], label='Anomalous Data', color='red', alpha=0.5)plt.scatter(synthetic_data[:, 0], synthetic_data[:, 1], label='Synthetic Data', color='green', alpha=0.5)plt.legend()plt.xlabel('Feature 1')plt.ylabel('Feature 2')plt.title('Anomaly Detection using GAN')plt.show()

在此代码中:

  • 我们首先生成合成数据,其中正常数据样本取自标准正态分布,异常数据样本取自不同的分布。

  • 我们使用 Keras 定义了带有生成器和判别器的 GAN 架构。

  • 我们使用正常数据训练 GAN 模型,其中判别器学习如何区分真实数据和合成数据样本,生成器学习如何生成与正常数据相似的合成数据。

  • 训练结束后,我们使用生成器生成合成数据样本。

  • 最后,我们在散点图上将正常数据、异常数据和合成数据可视化,以观察使用 GANs 进行异常检测的效果。

结果如下:

综上所述,我们成功地训练了生成器网络,使其在接收随机噪声作为输入时,能够生成与原始数据分布一致的数据点。我们网络的这一功能可用于合成数据生成器。

此外,判别器还可用于异常检测任务,因为它有能力检查给定数据点是否属于原始分布。




06


 结论


生成对抗网络是人类创造力和人工智能力量的见证。随着我们不断揭开生成对抗网络的神秘面纱,我们即将迎来AIGC的新时代,在这个时代里,机器将成为人类想象力的共同创造者。





点击上方小卡片关注我




添加个人微信,进专属粉丝群!


AI算法之道
一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
 最新文章