📚 [Archive] CS & AI 스터디/[PR] 논문 리뷰 📝

🎨 Generative Adversarial Nets 논문 구현

히주 2025. 1. 6. 22:05

 

  안녕하세요! 최근 생성형 AI의 시작인 위대한 Generative Adversarial Nets 논문 리뷰를 했었는데요! 아래 링크에 GAN 논문 리뷰를 링크해 놓았으니 한번씩 읽어보시는 것을 추천합니다! GAN의 개념에 대해서 간단히 설명하자면 생성자(Generator)와 판별자(Discriminator)라는 두 신경망이 서로 경쟁하며 학습하는 모델입니다. 이 때, 생성자(G)는 랜덤한 노이즈에서 진짜 같은 데이터를 생성하려고 합니다.반면에 판별자(D)는 입력 데이터가 진짜(실제 데이터)인지 가짜(생성자가 만든 데이터)인지 구분하려고 합니다. 이번 시간에는 MNIST 데이터셋을 사용해 흑백 손글씨 이미지를 생성하는 GAN을 구현하려고 합니다!

 

 

 

생성형 AI의 시작 : Generative Adversarial Nets 논문 리뷰

생성형 AI의 시작을 본격적으로 알린 모델은 2014년 발표된 논문인 "Generative Adversarial Nets (GANs)"입니다. 이 논문은 두 개의 신경망(생성자와 판별자)이 경쟁적으로 학습하는 방식으로 이전까지의

yiheeju.tistory.com

 


 

🛠 실습 환경 안내

  • OS: Windows 11
  • WSL: Ubuntu 22.04
  • 패키지 관리: Miniconda (미니콘다) 설치 후 가상환경 구성
  • IDE: Visual Studio Code (VSCode)

WSL 환경에서 Miniconda를 사용해 가상 환경을 구성하고 해당 환경에서 GAN 논문을 직접 구현 및 실습하였습니다.

본 페이지에서는 모델 구현에 중점을 둘 예정이라 실습 환경 설정에 관해서는 저의 블로그 항목에서 '[VSCODE] AI 개발환경 구축 🖥️' 에 들어가셔서 차근차근히 준비하시면 됩니다!

 

아래 링크를 참고하세요!

 

실패 없이 Window 환경에서 WSL 2와 VScode 통합하기

안녕하세요! 😊 환경 구축이란 게 쉽지 않고 작은 실수만으로 제대로 설정하기 어려울 수 있습니다. 저도 초반에는 많이 맸어요.🥹 그러나 이 가이드에 나열된 과정을 차근차근 따라오시면 비

yiheeju.tistory.com

 

🔧 구현 환경 준비

실습을 진행할 파일을 생성하고 열었다면 터미널 창에 아래 코드를 입력해주시면 됩니다.

1) 가상환경 생성

conda create -n gan_99 python=3.10 --y

다음과 같이 나온다면 성공!

 

2) 가상환경 활성화

conda activate gan

(base)에서 (gan)이라고 나온다면 성공!

 

3) 필요한 라이브러리 설치

PyTorch 설치 링크

 

Start Locally

Start Locally

pytorch.org

 

conda install pytorch torchvision torchaudio cpuonly -c pytorch # 주의!
# 위의 링크에서 알맞은 파이토치 버전으로 다운받아서 실행하세요.

conda install matplotlib

 

 

참고‼️
(파일명).ipynb로 파일을 만드신 다음 코드를 쳤을 때 가상환경을 선택하라 하면 우리가 생성한 가상환경을 선택해주세요. 다음으로  ipykernel을 설치하라는 팝업이 뜬다면 설치를 누르고 기다리시면 코드 실행이 됩니다. 

📦 꼭! 실습 항목 및 가상환경 설정을 참고하여 실습 준비해주세요!

 

📂 데이터 준비

우리는 MNIST 데이터셋을 사용해 흑백 손글씨 이미지를 생성하는 GAN을 구현할 예정입니다.

import torch
from torchvision import datasets, transforms

data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=data_transform),
    batch_size=64,
    shuffle=True
)

 

 

🖼️받아온 데이터 시각화하기

import matplotlib.pyplot as plt


# 배치에서 하나의 이미지 가져오기
data_iter = iter(dataloader)
images, labels = next(data_iter)

# 첫 번째 이미지 시각화 (Numpy 변환 없이 PyTorch 텐서 사용)
image = images[0].squeeze(0)  # (1, 28, 28) -> (28, 28)

plt.imshow(image, cmap='gray')  # 텐서를 바로 시각화
plt.title(f'Label: {labels[0].item()}')
plt.axis('off')
plt.show()

 

🏗️ 생성자(Generator) 설계

생성자는 랜덤 노이즈를 받아 이미지를 생성하는 역할을 합니다.

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),  # 입력: 100차원 노이즈
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28 * 28),  # MNIST 이미지는 28x28
            nn.Tanh()  # 픽셀값을 -1~1로 스케일링
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)  # MNIST 형태로 변환
        return img

 

🏢 판별자(Discriminator) 설계

판별자는 이미지를 입력으로 받아 진짜 또는 가짜를 판별합니다.

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 진짜(1) 또는 가짜(0) 확률을 반환
        )

    def forward(self, img):
        img = img.view(img.size(0), -1)  # 28x28 이미지를 1D로 변환
        return self.model(img)

 

⚙️손실 함수 및 최적화 기법

  • 손실 함수는 Binary Cross Entropy (BCE)를 사용합니다.
  • Adam 옵티마이저를 사용해 두 네트워크를 최적화합니다.
import torch.optim as optim

adversarial_loss = nn.BCELoss()

generator = Generator()
discriminator = Discriminator()

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

 

🚀 학습 과정

GAN은 두 단계로 학습합니다.

  1. 판별자(Discriminator) 학습: 진짜 이미지를 진짜로 가짜 이미지를 가짜로 판별하는 능력 향상.
  2. 생성자(Generator) 학습: 판별자가 가짜 이미지를 진짜로 오판하도록 가짜 이미지를 개선.
num_epochs = 50
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        
        # 진짜와 가짜 라벨 정의
        real = torch.ones(imgs.size(0), 1)
        fake = torch.zeros(imgs.size(0), 1)

        # ----- 판별자 학습 -----
        optimizer_D.zero_grad()
        
        real_imgs = imgs
        output = discriminator(real_imgs)
        real_loss = adversarial_loss(output, real)

        z = torch.randn(imgs.size(0), 100)
        gen_imgs = generator(z)
        output = discriminator(gen_imgs.detach())
        fake_loss = adversarial_loss(output, fake)

        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # ----- 생성자 학습 -----
        optimizer_G.zero_grad()
        
        z = torch.randn(imgs.size(0), 100)
        gen_imgs = generator(z)
        output = discriminator(gen_imgs)
        g_loss = adversarial_loss(output, real)

        g_loss.backward()
        optimizer_G.step()

    print(f"[Epoch {epoch}/{num_epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

시간 이슈로 50에포크로 짧게 학습을 시켰습니다.
현재 판별자(D)와 생성자(G) 손실 패턴을 보면 균형을 이루고 있어서 GAN이 과적합이나 붕괴 현상 없이 정상적으로 학습되고 있는 것으로 보입니다. 다만 에포크를 짧게 설정해 생성자의 성능이 충분히 발전되지 못한 것 같습니다...

 

 

🖼️ 결과물 시각화

import torchvision

z = torch.randn(64, 100)
gen_imgs = generator(z).detach().cpu()
grid = torchvision.utils.make_grid(gen_imgs, nrow=8, normalize=True)
plt.imshow(grid.permute(1, 2, 0))
plt.show()