인공지능
GAN (Generative Adversarial Network)
by 3in9u
2025. 4. 4.
✅ 기본 개념
- 두 개의 신경망(G, D)이 서로 경쟁하며 학습
- 목표: 진짜 같은 가짜 데이터를 생성하는 것
- 이름 구성:
- Generative: 데이터를 생성하는 G
- Adversarial: 경쟁하는 구조
🧠 GAN Architecture
[노이즈 z] ─▶ 🎨 Generator(G) ─▶ 🖼️ Fake Image ─▶
│ ↓
[Real Image] ─▶ 🛡️ Discriminator(D)
│
└──────────────▶ Real or Fake
- Generator(G): 노이즈 벡터 z로부터 이미지를 생성
- Discriminator(D): 입력이 실제인지 가짜인지 판별
- G는 D를 속이려고 학습하고, D는 G를 잡아내기 위해 학습
→ 게임 이론 기반의 최적화 구조
⚖️ GAN Loss Function
네트워크목표수식 개념
| 네트워크 |
목표 |
수식 개념 |
| D |
진짜는 1, 가짜는 0 |
log(D(x)) + log(1 - D(G(z))) 최대화 |
| G |
가짜를 진짜처럼 |
log(D(G(z))) 최소화 (또는 -log(D(G(z))) 최대화) |
🧪 Vanilla GAN
- 가장 기본적인 GAN 구조
- Fully Connected (MLP)로만 구성
- 낮은 해상도 (보통 28x28 또는 64x64 수준)
🎨 cGAN (Conditional GAN)
✅ 개념
- 조건 y (예: 클래스 레이블)를 함께 입력받아 조건 제어 생성
- 입력: (z, y), 출력도 y에 조건됨
✅ 예시
- "3이라는 숫자 이미지를 생성"
- "여자 얼굴 생성"
- "노란 배경에 빨간 사과 생성"
z + y ─▶ G(z|y) ─▶ fake image
x + y ─▶ D(x|y) ─▶ real or fake
🧱 DCGAN (Deep Convolutional GAN)
✅ 특징
- Vanilla GAN + CNN = DCGAN
- ConvTranspose2D: 업샘플링
- BatchNorm, LeakyReLU: 안정적인 학습
- 이미지 품질 대폭 향상
✅ 구조 요약
- G: Dense → Reshape → ConvTranspose2D (업샘플링)
- D: Conv2D (다운샘플링) → Flatten → Dense
🧬 DCGAN의 벡터 연산 특성
- Latent vector(z)의 의미적 조작이 가능
- 예:
- z_남자 + z_안경 - z_여자 ≈ z_안경 쓴 남자
- 개념 공간에서 의미 있는 연산 가능 (semantic arithmetic)
❗ GAN의 한계 및 문제점
| 문제점 |
설명 |
| 🔳 고해상도 생성 어려움 |
초기 GAN은 64x64 이미지 정도 |
| ⚖️ 학습 불안정 |
G와 D의 경쟁 밸런스 깨지면 collapse |
| 📉 확률 밀도 추정 어려움 |
생성 분포의 명확한 수치화 어려움 |
| 📏 평가 지표 부족 |
이미지 품질 평가가 주관적일 수 있음 |
| ❓ 노이즈 ↔ 이미지 매핑 불분명 |
어떤 z가 어떤 이미지인지 설명 어려움 |
특히 모드 붕괴(mode collapse)는 GAN 학습의 대표적 문제입니다. G가 특정한 결과만 반복적으로 생성할 수 있어 다양성이 사라짐.
✨ 요약: GAN 계열 모델 비교
모델특징주요 개선점
| 모델 |
특징 |
주요 개선점 |
| Vanilla GAN |
기본 구조 |
MLP 기반, 해상도 낮음 |
| DCGAN |
CNN 기반 |
고해상도 이미지 가능 |
| cGAN |
조건 입력 추가 |
클래스/속성 기반 생성 |
| WGAN |
Wasserstein 거리 사용 |
학습 안정성 개선 |
| StyleGAN |
스타일 기반 조절 |
고품질 얼굴 생성 |
| CycleGAN |
도메인 간 이미지 변환 |
예: 말 ↔ 얼룩말 |
📌 실전 예시 코드 (PyTorch 기반 DCGAN G 구조)
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(100, 512, 4, 1, 0),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 1, 4, 2, 1),
nn.Tanh()
)
def forward(self, z):
return self.net(z)