1️⃣ GAN이란?
Generative Adversarial Network
두 개의 모델이 적대적으로 경쟁하면서 진짜 같은 데이터를 생성하는 구조
- Generator (G): 가짜 데이터를 만들어내는 생성자
- Discriminator (D): 진짜/가짜를 구분하는 판별자
2️⃣ GAN 학습 흐름
[랜덤 노이즈 z] ─▶ Generator ─▶ [가짜 이미지]
│
┌──────────────────┴──────────────────┐
▼ ▼
[진짜 이미지 (정답)] [가짜 이미지 (G가 만든)]
│ │
└──────────────▶ Discriminator ◀─────┘
│
[진짜일 확률 (0~1)]
3️⃣ Generator 구조
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(), # 비선형성
nn.Linear(256, 784), # 28x28
nn.Tanh() # 픽셀 정규화 (-1 ~ 1)
)
def forward(self, z):
out = self.model(z)
return out.view(-1, 1, 28, 28) # [batch, 1, 28, 28]
- 입력: 100차원 노이즈
- 출력: 28x28 흑백 이미지
4️⃣ Discriminator 구조
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(), # [1, 28, 28] → [784]
nn.Linear(784, 256),
nn.LeakyReLU(0.2), # 죽은 뉴런 방지
nn.Linear(256, 1),
nn.Sigmoid() # 확률 출력
)
def forward(self, x):
return self.model(x)
- 입력: 이미지
- 출력: 0~1 (진짜일 확률)
5️⃣ 손실 함수 (Criterion)
criterion = nn.BCELoss() # Binary Cross Entropy Loss
- Discriminator: 진짜는 1, 가짜는 0 되게 학습
- Generator: 가짜를 1처럼 보이게 학습 (D를 속이기)
6️⃣ 학습 루프
for epoch in range(epochs):
for real_imgs, _ in dataloader:
batch_size = real_imgs.size(0)
real_imgs = real_imgs.to(device)
# ---------------------
# 1. Discriminator 학습
# ---------------------
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# 진짜 손실
real_output = D(real_imgs)
d_loss_real = criterion(real_output, real_labels)
# 가짜 손실
z = torch.randn(batch_size, 100).to(device)
fake_imgs = G(z).detach() # G는 학습 안 함
fake_output = D(fake_imgs)
d_loss_fake = criterion(fake_output, fake_labels)
# 손실 합치기 + D 업데이트
d_loss = d_loss_real + d_loss_fake
D_optimizer.zero_grad()
d_loss.backward()
D_optimizer.step()
# ---------------------
# 2. Generator 학습
# ---------------------
z = torch.randn(batch_size, 100).to(device)
fake_imgs = G(z)
output = D(fake_imgs)
g_loss = criterion(output, real_labels) # G는 D를 속이려고 함
G_optimizer.zero_grad()
g_loss.backward()
G_optimizer.step()
7️⃣ 결과 시각화
z = torch.randn(16, 100).to(device)
fake_imgs = G(z).detach().cpu()
fig, axs = plt.subplots(4, 4, figsize=(4, 4))
for i in range(4):
for j in range(4):
axs[i, j].imshow(fake_imgs[i * 4 + j][0], cmap='gray')
axs[i, j].axis('off')
plt.tight_layout()
plt.show()
- Generator가 만든 가짜 이미지를 눈으로 확인하는 부분
✅ 손실 함수 정리
이름용도입력정답사용 예
| BCELoss | 이진 분류 | 확률 (0~1) | 0 or 1 | Discriminator |
| BCEWithLogitsLoss | 이진 분류 + 안정성 | 로짓 | 0 or 1 | 대체 사용 |
| CrossEntropyLoss | 다중 분류 | 로짓 | 정수 인덱스 | 숫자 분류 |
| MSELoss | 회귀 | 숫자 | 숫자 | 이미지 비교, 예측값 차이 |
| L1Loss | 회귀 (덜 민감함) | 숫자 | 숫자 | 픽셀 차이, 복원 |
🎁 한 줄 요약
GAN은 Generator와 Discriminator가 서로 속고 속이면서 발전하는 게임
손실 함수(criterion)는 얼마나 틀렸는지 계산하는 기준,
학습 루프는 그걸 바탕으로 모델을 점점 개선시키는 과정!
'인공지능' 카테고리의 다른 글
| MobileNet: 모바일을 위한 경량 CNN (0) | 2025.04.04 |
|---|---|
| Transformer란? (0) | 2025.04.04 |
| 전이 학습 (Transfer Learning) (0) | 2025.03.27 |
| CNN 아키텍처 발전 흐름 요약 (0) | 2025.03.27 |
| 딥러닝 학습 핵심 개념 정리 (0) | 2025.03.27 |