BigData
예시로 보는 *PyTorch 기반의 "DDPM(Denoising Diffusion Probabilistic Model)"* - 이미지생성모델
IT오이시이
2025. 8. 29. 12:03
728x90
예시로 보는 PyTorch 기반의 "DDPM(Denoising Diffusion Probabilistic Model)"
- 핵심 원리(노이즈 추가 및 제거)를 이해하기 쉽고, 실제로 학습 및 샘플 생성까지 코드
PyTorch 기반의 DDPM(Denoising Diffusion Probabilistic Model)은 이미지 생성 분야에서 가장 강력하고 안정적인 모델 중 하나로 평가받고 있어요. 아래에 DDPM의 핵심 개념과 PyTorch 구현 예시 입니다.
🧠 DDPM 핵심 개념 요약
| 작동 단계 | 설명 |
| Forward Process (q) | 원본 이미지에 점진적으로 Gaussian 노이즈를 추가해 완전히 파괴된 이미지 (x_T)로 변환 |
| Reverse Process (p) | 파괴된 이미지 (x_T)에서 점진적으로 노이즈를 제거해 원본 이미지 (x_0)를 복원 |
| 목표 함수 | Variational Inference 기반의 ELBO 최적화 → 실제로는 MSE 기반의 간단한 loss로 근사 |
| 특징 | - Markov Chain 기반 - 시간 단계 (t)에 따라 노이즈 스케줄링 - 학습 안정성 높고, 고품질 이미지 생성 가능 |
🎯 알고리즘 활용
- 이미지 크기: Diffusion 모델은 고해상도 이미지에 강하지만, 학습 시간과 메모리 요구가 큽니다.
- 노이즈 단계 수: 일반적으로 1000단계가 사용되며, 줄이면 속도는 빨라지지만 품질이 저하될 수 있어요.
- 조건부 생성: 텍스트, 클래스 레이블 등을 조건으로 넣어 조건부 DDPM도 구현 가능!
🛠 PyTorch 기반 구현
1. 기본 구성 요소
- UNet 모델: 이미지의 다양한 해상도에서 특징을 추출하고 복원하는 구조
- Beta Schedule: 시간 단계별 노이즈 강도 조절 (linear, cosine 등)
- Gaussian Diffusion 클래스: forward/reverse 과정 정의
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
model = Unet(
dim=64,
dim_mults=(1, 2, 4, 8)
)
diffusion = GaussianDiffusion(
model,
image_size=128,
timesteps=1000, # 노이즈 단계 수
loss_type='l1' # 또는 'l2'
)
2. 학습 루프 예시
import torch
for epoch in range(num_epochs):
for images in dataloader:
images = images.to(device)
loss = diffusion(images)
loss.backward()
optimizer.step()
optimizer.zero_grad()
3. 샘플링 (이미지 생성)
sampled_images = diffusion.sample(batch_size=16)
📦 참고 구현 리포지토리
- DDPM-PyTorch GitHub 구현에서는 MNIST 기반으로 DDPM을 학습하고 샘플링하는 전체 파이프라인을 제공합니다.
- Lucidrains의 denoising-diffusion-pytorch는 HuggingFace 스타일의 모듈화된 구현으로, 다양한 실험에 적합해요.
주요 알고리즘의 유형
- mnist/cifar10 등 단순 이미지 생성: 구조가 단순해서 이론 설명과 실습에 모두 적합.
- 가장 널리 사용되는 구현 패턴: UNet 단일 구조, 시간 인코딩, forward/reverse 루프로 모듈화.
- 실제 실습/시각화 가능: 학습 후 샘플 이미지 생성이 바로 가능해 블로그 컨텐츠에 시각적 재미 추가.
대표 "최소 구현형 Diffusion 모델" 오픈소스
- minDiffusion (superminddpm.py): 200줄 내외, self-contained로 테스트/학습/생성까지 구현.
- simple-diffusion: PyTorch 기반 UNet 구조, Oxford Flowers 데이터셋 예시.
- PyTorch Diffusion Model Tutorial: 다양한 구조 예시 및 학습 코드 참고 가능.
예제 코드 스케치
아래는 가장 기본적인 DDPM 핵심부 샘플 코드 예시입니다.
실습 상세 코드는 깃허브 minDiffusion 등을 참고해서 일부 발췌합니다.
import torch
import torch.nn as nn
class SimpleUnet(nn.Module):
# 최소 구조의 Unet
def __init__(self): super().__init__()
# ...레이어 정의 생략(깃허브 참고)...
def noise_schedule(T=1000, beta_start=1e-4, beta_end=0.02):
return torch.linspace(beta_start, beta_end, T)
T = 1000
betas = noise_schedule(T)
alphas = 1. - betas
alpha_bars = torch.cumprod(alphas, dim=0)
def q_sample(x0, t, noise):
return torch.sqrt(alpha_bars[t]) * x0 + torch.sqrt(1 - alpha_bars[t]) * noise
# 학습 루프 개요
model = SimpleUnet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
for x0 in dataloader:
t = torch.randint(0, T, (x0.shape,))
noise = torch.randn_like(x0)
xt = q_sample(x0, t, noise)
loss = ((model(xt, t) - noise) ** 2).mean()
loss.backward()
optimizer.step()
이 코드는 구조와 컨셉 설명에 최적화되어 있고, 완성 코드는 해당 깃허브 저장소를 참고해서 전체 노트북 혹은 py 파일로 링크 걸면 됩니다.
728x90
반응형