人工智能(pytorch)搭建模型23-pytorch搭建生成對抗網絡(GAN):手寫數字生成的項目應用

慈雲數據 8個月前 (03-12) 技術支持 91 0

大家好,我是微學AI,今天給大家介紹一下人工智能(pytorch)搭建模型23-pytorch搭建生成對抗網絡(GAN):手寫數字生成的項目應用。生成對抗網絡(GAN)是一種強大的生成模型,在手寫數字生成方面具有廣泛的應用前景。通過生成逼真的手寫數字圖像,GAN可以用于數據增強、圖像修複、風格遷移等任務,提高模型的性能和泛化能力。生成對抗網絡在手寫數字生成領域具有廣泛的應用前景。主要應用場景包括數據增強、圖像修複、風格遷移和跨領域生成。數據增強可以通過生成逼真的手寫數字圖像,爲訓練數據集提供更多的樣本,提高模型的泛化能力。

一、項目背景

随着深度學習技術的不斷發展,生成模型在計算機視覺自然語言處理等領域取得了顯著的成果。生成對抗網絡(GAN)作爲一種新興的生成模型,近年來備受關注。在手寫數字生成方面,GAN可以生成逼真的手寫數字圖像,爲數據增強、圖像修複等任務提供有力支持。

二、生成對抗網絡原理

生成對抗網絡(GAN)由Goodfellow等人于2014年提出,它由兩個神經網絡——生成器(Generator)和判别器(Discriminator)——組成。生成器的目标是生成逼真的假樣本,而判别器的目标是區分真實樣本和生成器生成的假樣本。在訓練過程中,生成器和判别器相互競争,不斷調整參數,以達到納什均衡。

GAN的目标是最小化以下價值函數:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] Gmin​Dmax​V(D,G)=Ex∼pdata​(x)​[logD(x)]+Ez∼pz​(z)​[log(1−D(G(z)))]

其中, G G G表示生成器, D D D表示判别器, x x x表示真實樣本, z z z表示生成器的輸入噪聲, p data p_{\text{data}} pdata​表示真實數據分布, p z p_z pz​表示噪聲分布。

在這裏插入圖片描述

三、生成對抗網絡應用場景

生成對抗網絡(GAN)在手寫數字生成領域的應用具有廣泛的前景。以下是幾個主要的應用場景:

1.數據增強:通過生成逼真的手寫數字圖像,GAN可以爲訓練數據集提供更多的樣本,提高模型的泛化能力。

2. 圖像修複:GAN可以用于修複損壞或缺失的手寫數字圖像,提高圖像的質量和可讀性

3. 風格遷移:GAN可以将一種手寫風格轉換爲另一種風格,爲個性化手寫數字生成提供可能。

4. 跨領域生成:GAN可以實現不同手寫數字數據集之間的轉換,爲多任務學習提供支持。

四、生成對抗網絡實現手寫數字生成

下面我将利用pytorch深度學習框架構建生成對抗網絡的生成器模型Generator、判别器模型Discriminator。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
# 超參數設置
batch_size = 128
learning_rate = 0.0002
num_epochs = 80
# 數據預處理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
# 下載并加載訓練數據
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
# 定義生成器模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )
    def forward(self, x):
        return self.model(x).view(x.size(0), 1, 28, 28)
# 定義判别器模型
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)
# 初始化模型
generator = Generator()
discriminator = Discriminator()
# 損失函數和優化器
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=learning_rate)
optimizerD = optim.Adam(discriminator.parameters(), lr=learning_rate)
# 訓練模型
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        # 确保标簽的大小與當前批次的數據大小一緻
        real_labels = torch.ones(images.size(0), 1)
        fake_labels = torch.zeros(images.size(0), 1)
        # 訓練判别器
        optimizerD.zero_grad()
        real_outputs = discriminator(images)
        d_loss_real = criterion(real_outputs, real_labels)
        z = torch.randn(images.size(0), 100)
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizerD.step()
        # 訓練生成器
        optimizerG.zero_grad()
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizerG.step()
        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')
    # 保存生成器生成的圖片
    save_image(fake_images.data[:25], './fake_images/fake_images-{}.png'.format(epoch+1), nrow=5, normalize=True)
# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

最後我們打開fake_images/文件夾,可以看到生成手寫圖片的過程:

在這裏插入圖片描述

五、總結

本項目利用生成對抗網絡(GAN)實現了手寫數字的生成。通過訓練生成器和判别器,我們成功生成了逼真的手寫數字圖像。這些生成的圖像可以應用于數據增強、圖像修複、風格遷移等領域,爲手寫數字識别等相關任務提供有力支持。

微信掃一掃加客服

微信掃一掃加客服

點擊啓動AI問答
Draggable Icon