안녕하세요! 최근 생성형 AI의 시작인 위대한 Generative Adversarial Nets 논문 리뷰를 했었는데요! 아래 링크에 GAN 논문 리뷰를 링크해 놓았으니 한번씩 읽어보시는 것을 추천합니다! GAN의 개념에 대해서 간단히 설명하자면 생성자(Generator)와 판별자(Discriminator)라는 두 신경망이 서로 경쟁하며 학습하는 모델입니다. 이 때, 생성자(G)는 랜덤한 노이즈에서 진짜 같은 데이터를 생성하려고 합니다.반면에 판별자(D)는 입력 데이터가 진짜(실제 데이터)인지 가짜(생성자가 만든 데이터)인지 구분하려고 합니다. 이번 시간에는 MNIST 데이터셋을 사용해 흑백 손글씨 이미지를 생성하는 GAN을 구현하려고 합니다!
판별자(Discriminator) 학습: 진짜 이미지를 진짜로 가짜 이미지를 가짜로 판별하는 능력 향상.
생성자(Generator) 학습: 판별자가 가짜 이미지를 진짜로 오판하도록 가짜 이미지를 개선.
num_epochs = 50
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
# 진짜와 가짜 라벨 정의
real = torch.ones(imgs.size(0), 1)
fake = torch.zeros(imgs.size(0), 1)
# ----- 판별자 학습 -----
optimizer_D.zero_grad()
real_imgs = imgs
output = discriminator(real_imgs)
real_loss = adversarial_loss(output, real)
z = torch.randn(imgs.size(0), 100)
gen_imgs = generator(z)
output = discriminator(gen_imgs.detach())
fake_loss = adversarial_loss(output, fake)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_D.step()
# ----- 생성자 학습 -----
optimizer_G.zero_grad()
z = torch.randn(imgs.size(0), 100)
gen_imgs = generator(z)
output = discriminator(gen_imgs)
g_loss = adversarial_loss(output, real)
g_loss.backward()
optimizer_G.step()
print(f"[Epoch {epoch}/{num_epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
시간 이슈로 50에포크로 짧게 학습을 시켰습니다. 현재 판별자(D)와 생성자(G) 손실 패턴을 보면 균형을 이루고 있어서 GAN이 과적합이나 붕괴 현상 없이 정상적으로 학습되고 있는 것으로 보입니다. 다만 에포크를 짧게 설정해 생성자의 성능이 충분히 발전되지 못한 것 같습니다...