본문 바로가기
인공지능

GAN 전체 정리

by 3in9u 2025. 3. 31.

1️⃣ GAN이란?

Generative Adversarial Network
두 개의 모델이 적대적으로 경쟁하면서 진짜 같은 데이터를 생성하는 구조

  • Generator (G): 가짜 데이터를 만들어내는 생성자
  • Discriminator (D): 진짜/가짜를 구분하는 판별자

2️⃣ GAN 학습 흐름

[랜덤 노이즈 z] ─▶ Generator ─▶ [가짜 이미지]
                                 │
             ┌──────────────────┴──────────────────┐
             ▼                                     ▼
     [진짜 이미지 (정답)]                     [가짜 이미지 (G가 만든)]
             │                                     │
             └──────────────▶ Discriminator ◀─────┘
                                │
                      [진짜일 확률 (0~1)]

3️⃣ Generator 구조

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),                      # 비선형성
            nn.Linear(256, 784),           # 28x28
            nn.Tanh()                      # 픽셀 정규화 (-1 ~ 1)
        )

    def forward(self, z):
        out = self.model(z)
        return out.view(-1, 1, 28, 28)      # [batch, 1, 28, 28]
  • 입력: 100차원 노이즈
  • 출력: 28x28 흑백 이미지

4️⃣ Discriminator 구조

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),                  # [1, 28, 28] → [784]
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),             # 죽은 뉴런 방지
            nn.Linear(256, 1),
            nn.Sigmoid()                   # 확률 출력
        )

    def forward(self, x):
        return self.model(x)
  • 입력: 이미지
  • 출력: 0~1 (진짜일 확률)

5️⃣ 손실 함수 (Criterion)

criterion = nn.BCELoss()  # Binary Cross Entropy Loss
  • Discriminator: 진짜는 1, 가짜는 0 되게 학습
  • Generator: 가짜를 1처럼 보이게 학습 (D를 속이기)

6️⃣ 학습 루프

for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)

        # ---------------------
        # 1. Discriminator 학습
        # ---------------------
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # 진짜 손실
        real_output = D(real_imgs)
        d_loss_real = criterion(real_output, real_labels)

        # 가짜 손실
        z = torch.randn(batch_size, 100).to(device)
        fake_imgs = G(z).detach()  # G는 학습 안 함
        fake_output = D(fake_imgs)
        d_loss_fake = criterion(fake_output, fake_labels)

        # 손실 합치기 + D 업데이트
        d_loss = d_loss_real + d_loss_fake
        D_optimizer.zero_grad()
        d_loss.backward()
        D_optimizer.step()

        # ---------------------
        # 2. Generator 학습
        # ---------------------
        z = torch.randn(batch_size, 100).to(device)
        fake_imgs = G(z)
        output = D(fake_imgs)
        g_loss = criterion(output, real_labels)  # G는 D를 속이려고 함

        G_optimizer.zero_grad()
        g_loss.backward()
        G_optimizer.step()

7️⃣ 결과 시각화

z = torch.randn(16, 100).to(device)
fake_imgs = G(z).detach().cpu()

fig, axs = plt.subplots(4, 4, figsize=(4, 4))
for i in range(4):
    for j in range(4):
        axs[i, j].imshow(fake_imgs[i * 4 + j][0], cmap='gray')
        axs[i, j].axis('off')
plt.tight_layout()
plt.show()
  • Generator가 만든 가짜 이미지를 눈으로 확인하는 부분

✅ 손실 함수 정리

이름용도입력정답사용 예
BCELoss 이진 분류 확률 (0~1) 0 or 1 Discriminator
BCEWithLogitsLoss 이진 분류 + 안정성 로짓 0 or 1 대체 사용
CrossEntropyLoss 다중 분류 로짓 정수 인덱스 숫자 분류
MSELoss 회귀 숫자 숫자 이미지 비교, 예측값 차이
L1Loss 회귀 (덜 민감함) 숫자 숫자 픽셀 차이, 복원

🎁 한 줄 요약

GAN은 Generator와 Discriminator가 서로 속고 속이면서 발전하는 게임
손실 함수(criterion)는 얼마나 틀렸는지 계산하는 기준,
학습 루프는 그걸 바탕으로 모델을 점점 개선시키는 과정!

'인공지능' 카테고리의 다른 글

MobileNet: 모바일을 위한 경량 CNN  (0) 2025.04.04
Transformer란?  (0) 2025.04.04
전이 학습 (Transfer Learning)  (0) 2025.03.27
CNN 아키텍처 발전 흐름 요약  (0) 2025.03.27
딥러닝 학습 핵심 개념 정리  (0) 2025.03.27