生成对抗网络(Generative Adversarial Networks,简称GANs)是一种深度学习模型,由Ian Goodfellow和其同事于2014年提出。GANs由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器的目标是生成尽可能逼真的数据,而判别器的目标是区分真实数据和生成器产生的假数据。这两个网络相互竞争,生成器试图欺骗判别器,而判别器则努力提高自己的识别能力。
在Python中,可以使用深度学习框架如TensorFlow或PyTorch来实现GANs。以下是一个简单的GANs实现示例,使用PyTorch框架:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义生成器网络
class Generator(nn.Module):
def __init__(self, input_dim, output_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 128),
nn.LeakyReLU(0.01),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.01),
nn.Linear(256, output_dim),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self, input_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 256),
nn.LeakyReLU(0.01),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.LeakyReLU(0.01),
nn.Dropout(0.3),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 超参数设置
input_dim = 100 # 输入噪声的维度
output_dim = 28 * 28 # 生成图像的维度(MNIST数据集)
batch_size = 64
learning_rate = 0.0002
epochs = 100
# 数据加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
train_dataset = datasets.MNIST(root='./mnist_data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器
generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)
# 损失函数和优化器
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate)
# 训练GANs
for epoch in range(epochs):
for i, (images, _) in enumerate(train_loader):
# 真实标签和假标签
real_labels = torch.ones(images.size(0), 1)
fake_labels = torch.zeros(images.size(0), 1)
# 训练判别器
optimizer_d.zero_grad()
outputs = discriminator(images.view(images.size(0), -1))
loss_d_real = criterion(outputs, real_labels)
loss_d_real.backward()
z = torch.randn(images.size(0), input_dim)
fake_images = generator(z)
outputs = discriminator(fake_images.detach().view(images.size(0), -1))
loss_d_fake = criterion(outputs, fake_labels)
loss_d_fake.backward()
loss_d = loss_d_real + loss_d_fake
optimizer_d.step()
# 训练生成器
optimizer_g.zero_grad()
outputs = discriminator(fake_images.view(images.size(0), -1))
loss_g = criterion(outputs, real_labels)
loss_g.backward()
optimizer_g.step()
if (i+1) % 100 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}')
这个例子中,我们使用了MNIST数据集,生成器将100维的噪声向量转换为28x28的图像,判别器则尝试区分真实图像和生成图像。训练过程中,生成器和判别器交替更新,最终生成器能够生成非常逼真的手写数字图像。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。