Entrenando una GAN con PyTorch
Construiremos una Generative Adversarial Network desde cero con PyTorch: diseñaremos el Generador y el Discriminador, implementaremos el training loop adversarial, visualizaremos la evolución del entrenamiento y escalaremos a una DCGAN con convoluciones. Cada línea de código está explicada.
Requisitos previos
- Python 3.9+ y PyTorch 2.x instalados
- Conceptos básicos de redes neuronales: forward pass, loss, backpropagation
- Haber leído la teoría de GANs (juego minimax, generador, discriminador)
- Familiaridad con convoluciones (para la parte de DCGAN)
- Opcional: GPU con CUDA (acelera el entrenamiento significativamente)
¿Qué vamos a construir?
Vamos a implementar una Generative Adversarial Network (GAN) completa desde cero con PyTorch. Una GAN es un sistema de dos redes neuronales que compiten entre sí: un Generador (G) que crea imágenes falsas a partir de ruido aleatorio, y un Discriminador (D) que intenta distinguir las imágenes reales de las falsas. Este juego adversarial fue propuesto por Ian Goodfellow et al. (2014) y revolucionó la IA generativa.
1.1 El juego en una frase
El Generador quiere engañar al Discriminador produciendo imágenes lo más realistas posible. El Discriminador quiere no ser engañado, clasificando correctamente cada imagen como real o generada. Cuando este juego se equilibra, el Generador produce imágenes indistinguibles de las reales.
1.2 Lo que construiremos
1.3 Nuestro plan
- Setup — Instalación, imports, configuración de hiperparámetros y device.
- Dataset — Cargar MNIST, normalizar y crear DataLoaders.
- Generador — Red MLP que transforma ruido z en una imagen 28×28.
- Discriminador — Red MLP que clasifica imágenes como reales o falsas.
- Training loop — Alternancia D/G, Binary Cross-Entropy, optimizadores.
- Visualización — Grids épocas, curvas de loss, latent space interpolation.
- DCGAN — Upgrade completo con convoluciones.
- Debugging — Problemas reales y cómo resolverlos.
- Referencias — Papers, repos, documentación y siguientes pasos.
Setup: instalación, imports y configuración
2.1 Instalación
# CPU (funciona en cualquier máquina)
pip install torch torchvision matplotlib
# GPU CUDA 12.1
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
2.2 Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
import os
import time
# Reproducibilidad
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch {torch.__version__} | Device: {device}")
make_grid — utilidad de torchvision para crear mosaicos de imágenes. Perfecto para visualizar las generaciones.2.3 Hiperparámetros
Centralizamos todos los hiperparámetros al inicio del script. Esto facilita la experimentación: cambiar un valor aquí afecta a todo el pipeline.
# ── Hiperparámetros ──────────────────────────────────────
LATENT_DIM = 100 # Dimensión del vector de ruido z
IMG_SIZE = 28 # MNIST: 28x28
IMG_CHANNELS = 1 # Escala de grises
IMG_PIXELS = IMG_SIZE * IMG_SIZE * IMG_CHANNELS # 784
BATCH_SIZE = 128
EPOCHS = 50
LR_G = 2e-4 # Learning rate del Generador
LR_D = 2e-4 # Learning rate del Discriminador
BETA1 = 0.5 # Adam beta1 (estándar en GANs)
BETA2 = 0.999 # Adam beta2
# Directorio para guardar resultados
os.makedirs("gan_results", exist_ok=True)
LATENT_DIM = 100 — el tamaño del vector de ruido que alimenta al Generador. 100 es el valor estándar desde el paper original. Valores típicos: 64-256.LR = 2e-4 y BETA1 = 0.5 — estos valores vienen directamente del paper de DCGAN (Radford et al., 2016) y son el estándar de facto para entrenar GANs con Adam.
El optimizador Adam con beta1=0.5 (en vez del default 0.9) fue
identificado como clave en el paper de DCGAN. La razón:
beta1=0.9lleva demasiado momentum, lo que puede causar oscilaciones en el entrenamiento adversarial.beta1=0.5reduce el momentum, estabilizando las actualizaciones del Generador y Discriminador.- Alternativa: SGD funciona pero converge mucho más lento. RMSprop es otra opción popular (usado en WGAN).
Dataset: preparar MNIST
Cargamos el dataset MNIST y lo normalizamos al rango [-1, 1].
Este rango es clave: la última capa del Generador usará tanh,
que produce valores en [-1, 1], así que las imágenes reales deben estar
en el mismo rango para que la comparación tenga sentido.
# ── Transformaciones ─────────────────────────────────────
# ToTensor() convierte [0, 255] → [0, 1]
# Normalize((0.5,), (0.5,)) mapea [0, 1] → [-1, 1]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# ── Descargar y cargar MNIST ─────────────────────────────
train_dataset = datasets.MNIST(
root='./data',
train=True,
transform=transform,
download=True
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True, # Descartar último batch incompleto
num_workers=2,
pin_memory=True if device.type == 'cuda' else False,
)
print(f"Dataset: {len(train_dataset):,} imágenes")
print(f"Batches por epoch: {len(train_loader)}")
print(f"Imagen shape: {train_dataset[0][0].shape}")
print(f"Rango: [{train_dataset[0][0].min():.1f}, {train_dataset[0][0].max():.1f}]")
Normalize((0.5,), (0.5,)) aplica (x - 0.5) / 0.5 que mapea [0,1] → [-1,1]. ¿Por qué no [0,1]? Porque tanh produce [-1,1] y queremos que G y el dataset vivan en el mismo espacio.drop_last=True — descarta el último batch si tiene menos de 128 imágenes. Esto evita problemas de shape en el training loop.pin_memory=True con GPU acelera la transferencia CPU→GPU al "fijar" la memoria del DataLoader.3.1 Visualizar muestras reales
Siempre es buena práctica visualizar los datos antes de entrenar. Creamos una función reutilizable para mostrar grids de imágenes:
def show_images(images, nrow=8, title=""):
"""Muestra un grid de imágenes (tensor normalizado a [-1,1])."""
# Desnormalizar: [-1, 1] → [0, 1]
images = (images + 1) / 2
images = images.clamp(0, 1)
grid = make_grid(images, nrow=nrow, padding=2)
plt.figure(figsize=(10, 10))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy(), cmap='gray')
plt.title(title, fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.show()
# Visualizar un batch de imágenes reales
real_batch = next(iter(train_loader))[0][:64]
show_images(real_batch, title="MNIST — Imágenes reales")
matplotlib las muestre correctamente.make_grid crea un mosaico de imágenes en un solo tensor. nrow=8 = 8 imágenes por fila.grid.permute(1,2,0) — PyTorch usa (C, H, W) pero matplotlib espera (H, W, C).tanh del Generador produce
valores en [-1, 1] de forma natural, con gradientes saludables alrededor de 0.
Si usáramos [0, 1] con sigmoid, los gradientes se saturan en los
extremos (0 y 1), dificultando el aprendizaje. Siempre que uses tanh
en G, normaliza tus datos a [-1, 1].
Sí. Cambia solo la línea del dataset y ajusta IMG_SIZE y IMG_CHANNELS:
- Fashion-MNIST:
datasets.FashionMNIST(...)— misma estructura que MNIST pero con ropa (más difícil). - CIFAR-10:
datasets.CIFAR10(...)— 32×32, 3 canales RGB. Necesitarás ajustarIMG_SIZE=32yIMG_CHANNELS=3. - CelebA:
datasets.CelebA(...)— rostros 218×178. Requiere resize y más capacidad en G y D. - Custom:
datasets.ImageFolder('path/to/images')— cualquier directorio con imágenes organizadas en subcarpetas.
Para datasets más complejos, recomendamos la DCGAN del paso 8 desde el inicio.
El Generador: de ruido a imágenes
El Generador toma un vector de ruido z ∈ ℝ¹⁰⁰ muestreado de una
distribución normal estándar y lo transforma en una imagen de 28×28 píxeles.
Es, esencialmente, una red que aprende a mapear puntos del espacio latente
a imágenes plausibles.
class Generator(nn.Module):
"""
Generador MLP: z (100,) → imagen (1, 28, 28).
Arquitectura: Linear → LeakyReLU → BN, repetido, → Tanh.
"""
def __init__(self, latent_dim=LATENT_DIM, img_pixels=IMG_PIXELS):
super().__init__()
self.net = nn.Sequential(
# Bloque 1: 100 → 256
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm1d(256),
# Bloque 2: 256 → 512
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm1d(512),
# Bloque 3: 512 → 1024
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm1d(1024),
# Capa de salida: 1024 → 784, activación Tanh
nn.Linear(1024, img_pixels),
nn.Tanh(), # Output en [-1, 1]
)
def forward(self, z):
"""z: (batch, latent_dim) → img: (batch, 1, 28, 28)"""
flat = self.net(z) # (B, 784)
img = flat.view(-1, IMG_CHANNELS, IMG_SIZE, IMG_SIZE) # (B, 1, 28, 28)
return img
# Crear generador
G = Generator().to(device)
# Verificar shape
z_test = torch.randn(4, LATENT_DIM, device=device)
fake_test = G(z_test)
print(f"Generador - Input: {z_test.shape} → Output: {fake_test.shape}")
print(f"Parámetros G: {sum(p.numel() for p in G.parameters()):,}")
Linear → LeakyReLU → BatchNorm1d. La BatchNorm estabiliza el entrenamiento — es uno de los trucos clave de DCGAN, y funciona también en MLPs.LeakyReLU(0.2) en vez de ReLU. En GANs, LeakyReLU permite que los gradientes fluyan aunque la neurona esté en la zona negativa. El slope 0.2 es estándar.nn.Tanh() — la activación final del Generador produce valores en [-1, 1], que es el rango de nuestras imágenes normalizadas.flat.view(-1, 1, 28, 28) — reshape del vector plano de 784 valores a una imagen con forma (C, H, W). El -1 infiere el batch_size automáticamente.En las GANs, el gradiente necesita fluir a través de toda la red. Con ReLU estándar, las neuronas inactivas (output = 0) producen gradiente 0, creando "neuronas muertas" que nunca se recuperan. Esto es especialmente problemático en la fase inicial del entrenamiento cuando el Generador produce basura.
LeakyReLU(0.2) permite un gradiente pequeño (0.2× el input) en la
zona negativa, manteniendo el flujo de gradientes. Alternativas:
- ELU: suave en la zona negativa, pero más costoso.
- GELU: usado en Transformers, pero raro en GANs.
- SELU: autoregulante, pero requiere inicialización especial.
Para imágenes más complejas, puedes añadir conexiones residuales al Generador. Cada bloque se convierte en:
output = block(x) + projection(x)
El projection es un nn.Linear que ajusta las dimensiones.
Esto permite al Generador tener más capas sin que el gradiente se degrade.
Lo veremos en más detalle en el paso 8 con la DCGAN.
El Discriminador: real vs fake
El Discriminador es un clasificador binario: recibe una imagen y produce una probabilidad de que sea real (cercana a 1) o generada (cercana a 0). Es lo opuesto al Generador: aquí partimos de 784 píxeles y comprimimos hasta un escalar.
class Discriminator(nn.Module):
"""
Discriminador MLP: imagen (1, 28, 28) → probabilidad [0, 1].
Arquitectura: Flatten → Linear → LeakyReLU → Dropout, repetido, → Sigmoid.
"""
def __init__(self, img_pixels=IMG_PIXELS):
super().__init__()
self.net = nn.Sequential(
# Flatten: (B, 1, 28, 28) → (B, 784)
nn.Flatten(),
# Bloque 1: 784 → 512
nn.Linear(img_pixels, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
# Bloque 2: 512 → 256
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
# Bloque 3: 256 → 128
nn.Linear(256, 128),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
# Salida: 128 → 1 (probabilidad)
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, img):
"""img: (batch, 1, 28, 28) → p: (batch, 1)"""
return self.net(img)
# Crear discriminador
D = Discriminator().to(device)
# Verificar shape
pred_test = D(fake_test)
print(f"Discriminador - Input: {fake_test.shape} → Output: {pred_test.shape}")
print(f"Parámetros D: {sum(p.numel() for p in D.parameters()):,}")
nn.Flatten() convierte (B, 1, 28, 28) → (B, 784). El MLP necesita inputs 1D.Dropout(0.3) — regularización crucial en el Discriminador. Sin Dropout, D aprende demasiado rápido y el Generador no puede seguirle el ritmo. 0.3 es un buen punto de inicio.nn.Sigmoid() — la salida es una probabilidad en [0, 1]. 1 = "creo que es real", 0 = "creo que es fake".5.1 Inicialización de pesos
La inicialización de los pesos puede marcar la diferencia entre una GAN que converge y una que colapsa. Usaremos la inicialización recomendada en DCGAN:
def weights_init(m):
"""Inicialización de pesos según DCGAN paper."""
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# Aplicar a G y D
G.apply(weights_init)
D.apply(weights_init)
print("✓ Pesos inicializados (N(0, 0.02))")
model.apply(fn) aplica recursivamente la función a todos los sub-módulos del modelo.| Red | Capa | Parámetros |
|---|---|---|
| Generador | Linear(100, 256) + BN | 26,368 |
| Linear(256, 512) + BN | 132,096 | |
| Linear(512, 1024) + BN | 526,336 | |
| Linear(1024, 784) | 803,600 | |
| Total G | ~1.07M | |
| Discriminador | Linear(784, 512) | 401,920 |
| Linear(512, 256) | 131,328 | |
| Linear(256, 128) | 32,896 | |
| Linear(128, 1) | 129 | |
| Total D | ~533K | |
El Generador es ~2x más grande que el Discriminador. Esto es normal: G tiene que aprender a generar (tarea difícil), D solo tiene que clasificar (tarea más sencilla). Si D fuera demasiado potente, el G nunca aprendería.
Training loop: el juego adversarial
Este es el corazón de toda GAN: el algoritmo de entrenamiento alternante. En cada iteración del batch, primero entrenamos el Discriminador para que mejore su capacidad de distinguir real de fake, y luego entrenamos el Generador para que mejore su capacidad de engañar al Discriminador.
6.1 Loss y optimizadores
# ── Loss: Binary Cross-Entropy ───────────────────────────
criterion = nn.BCELoss()
# ── Optimizadores separados para G y D ───────────────────
optimizer_G = optim.Adam(G.parameters(), lr=LR_G, betas=(BETA1, BETA2))
optimizer_D = optim.Adam(D.parameters(), lr=LR_D, betas=(BETA1, BETA2))
# ── Labels fijos ─────────────────────────────────────────
real_label = 1.0
fake_label = 0.0
# ── Vector z fijo para visualización ─────────────────────
# Lo usaremos para ver cómo evolucionan las generaciones
fixed_noise = torch.randn(64, LATENT_DIM, device=device)
BCELoss — Binary Cross-Entropy, la loss original de GANs. Mide cuánto se equivoca D al clasificar real/fake.fixed_noise — siempre generamos imágenes a partir del mismo ruido para poder comparar la evolución del Generador entre épocas.6.2 El loop de entrenamiento
# ── Logging ──────────────────────────────────────────────
G_losses, D_losses = [], []
D_real_acc, D_fake_acc = [], []
print("Iniciando entrenamiento...")
start_time = time.time()
for epoch in range(EPOCHS):
g_loss_epoch, d_loss_epoch = 0, 0
d_real_epoch, d_fake_epoch = 0, 0
for i, (real_imgs, _) in enumerate(train_loader):
batch_size = real_imgs.size(0)
real_imgs = real_imgs.to(device)
# ═══════════════════════════════════════════════════
# FASE 1: Entrenar el Discriminador
# Objetivo: maximizar log(D(x)) + log(1 - D(G(z)))
# ═══════════════════════════════════════════════════
D.zero_grad()
# ── 1a. Imágenes REALES → D debe decir "real" (1) ──
labels_real = torch.full((batch_size, 1), real_label,
device=device)
output_real = D(real_imgs)
loss_D_real = criterion(output_real, labels_real)
# ── 1b. Imágenes FAKE → D debe decir "fake" (0) ────
z = torch.randn(batch_size, LATENT_DIM, device=device)
fake_imgs = G(z)
labels_fake = torch.full((batch_size, 1), fake_label,
device=device)
output_fake = D(fake_imgs.detach()) # ¡DETACH! No queremos gradientes en G
loss_D_fake = criterion(output_fake, labels_fake)
# ── 1c. Loss total de D y backward ──────────────────
loss_D = (loss_D_real + loss_D_fake) / 2
loss_D.backward()
optimizer_D.step()
# ═══════════════════════════════════════════════════
# FASE 2: Entrenar el Generador
# Objetivo: maximizar log(D(G(z))) → "engañar a D"
# ═══════════════════════════════════════════════════
G.zero_grad()
# ── 2a. Generar fakes y pedir a D que diga "real" ──
z = torch.randn(batch_size, LATENT_DIM, device=device)
fake_imgs = G(z)
labels_real_for_G = torch.full((batch_size, 1), real_label,
device=device)
output_G = D(fake_imgs) # SIN detach: queremos gradientes en G
loss_G = criterion(output_G, labels_real_for_G)
loss_G.backward()
optimizer_G.step()
# ── Acumular métricas ────────────────────────────────
g_loss_epoch += loss_G.item()
d_loss_epoch += loss_D.item()
d_real_epoch += output_real.mean().item()
d_fake_epoch += output_fake.mean().item()
# ── Promedios del epoch ──────────────────────────────────
n = len(train_loader)
G_losses.append(g_loss_epoch / n)
D_losses.append(d_loss_epoch / n)
D_real_acc.append(d_real_epoch / n)
D_fake_acc.append(d_fake_epoch / n)
elapsed = time.time() - start_time
if (epoch + 1) % 5 == 0 or epoch == 0:
print(f"Epoch [{epoch+1:3d}/{EPOCHS}] | "
f"D_loss: {D_losses[-1]:.4f} | G_loss: {G_losses[-1]:.4f} | "
f"D(x): {D_real_acc[-1]:.3f} | D(G(z)): {D_fake_acc[-1]:.3f} | "
f"Time: {elapsed:.0f}s")
# ── Guardar grid de imágenes cada 10 epochs ─────────────
if (epoch + 1) % 10 == 0 or epoch == 0:
with torch.no_grad():
sample = G(fixed_noise)
sample = (sample + 1) / 2 # [-1,1] → [0,1]
grid = make_grid(sample, nrow=8, padding=2)
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy(), cmap='gray')
plt.title(f'Epoch {epoch+1}')
plt.axis('off')
plt.savefig(f'gan_results/epoch_{epoch+1:03d}.png',
bbox_inches='tight', dpi=100)
plt.close()
print(f"\n✓ Entrenamiento completo en {time.time()-start_time:.0f}s")
D.zero_grad() — resetea los gradientes del Discriminador. Fundamental: si no lo haces, los gradientes se acumulan entre iteraciones.fake_imgs.detach() — esto es CRUCIAL. Cuando entrenamos D, no queremos que los gradientes fluyan hacia G a través de las imágenes fake. .detach() corta el grafo computacional en ese punto..detach(): ahora sí queremos que los gradientes fluyan de la loss de G, a través de D, hasta G. Esto es lo que permite a G aprender a engañar a D.log(1 - D(G(z))) (que satura cuando D es bueno), maximizamos log(D(G(z))) usando labels = 1. Esto produce gradientes más fuertes para G al inicio.•
D(x) baja gradualmente desde ~0.9 hacia ~0.5-0.7 (D ya no está seguro de las reales).
•
D(G(z)) sube gradualmente desde ~0.1 hacia ~0.4-0.5 (D confunde las fakes con reales).
•
D_loss se estabiliza alrededor de ln(2) ≈ 0.693 (equilibrio de Nash).
•
G_loss baja progresivamente (G mejora generando).
Si D_loss → 0 rápidamente, D está dominando y G no puede aprender. Si G_loss → 0, G está dominando (posible mode collapse).
El .detach() es uno de los conceptos más confusos de las GANs en PyTorch. Veamos exactamente qué pasa:
Cuando entrenamos D (con detach):
output_fake = D(fake_imgs.detach())
- El grafo computacional es:
fake_imgs_COPY → D → loss_D - Los gradientes fluyen:
loss_D → D - G no se actualiza.
Cuando entrenamos G (sin detach):
output_G = D(fake_imgs)
- El grafo completo es:
z → G → fake_imgs → D → loss_G - Los gradientes fluyen:
loss_G → D → G - Pero solo se actualizan los parámetros de G (porque usamos
optimizer_G.step()). - D está "congelado" en esta fase (sus gradientes se calculan pero no se aplican).
Si olvidaras el .detach() al entrenar D, los gradientes fluirían hasta G
durante la fase de D, actualizando G en la dirección equivocada.
BCEWithLogitsLoss combina sigmoid + BCE en una sola operación,
con mejor estabilidad numérica (evita el log(0) que puede ocurrir con
BCELoss + Sigmoid separados):
- Quita
nn.Sigmoid()de la última capa de D. - Cambia
criterion = nn.BCEWithLogitsLoss(). - El output de D serán logits (no probabilidades), pero la loss se calcula correctamente.
Esta es la práctica recomendada para producción. Con MNIST la diferencia es mínima, pero con datasets más difíciles puede prevenir NaN.
Visualización y monitorización
En una GAN, mirar solo la loss no basta. Necesitas vigilar las curvas de loss de ambas redes, la evolución visual de las generaciones y las métricas de confianza del Discriminador. Si algo va mal, estas visualizaciones te lo dirán antes que los números.
7.1 Curvas de loss
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# ── Panel 1: Losses ──
axes[0].plot(G_losses, label='Generator', color='#6c5ce7', linewidth=2)
axes[0].plot(D_losses, label='Discriminator', color='#fd79a8', linewidth=2)
axes[0].axhline(y=np.log(2), color='rgba(253,203,110,.5)',
linestyle='--', label=f'Equilibrio (ln2≈{np.log(2):.3f})')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('GAN Losses')
axes[0].legend()
axes[0].grid(alpha=0.15)
# ── Panel 2: Confianza de D ──
axes[1].plot(D_real_acc, label='D(x) — reales', color='#00b894', linewidth=2)
axes[1].plot(D_fake_acc, label='D(G(z)) — fakes', color='#e17055', linewidth=2)
axes[1].axhline(y=0.5, color='rgba(255,255,255,.2)', linestyle='--', label='Equilibrio (0.5)')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Probabilidad media')
axes[1].set_title('Confianza del Discriminador')
axes[1].legend()
axes[1].grid(alpha=0.15)
axes[1].set_ylim(0, 1)
plt.tight_layout()
plt.savefig('gan_results/training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
• Caso ideal: D_loss se estabiliza en ~0.693 (ln2), G_loss baja gradualmente. D(x) y D(G(z)) convergen hacia 0.5.
• D domina: D_loss → 0, D(x) → 1, D(G(z)) → 0. El G no puede aprender. Solución: reducir la capacidad de D, entrenar G más veces por step de D, o label smoothing.
• Mode collapse: G_loss oscila bruscamente, las imágenes generadas son todas iguales. Solución: paso 9.
7.2 Evolución visual de las generaciones
def show_evolution(G, fixed_noise, epochs_to_show=[1, 10, 25, 50]):
"""Genera imágenes con el mismo ruido guardado en diferentes checkpoints."""
# Esta función requiere haber guardado modelos en cada epoch
# Aquí mostramos cómo generar un grid al final del entrenamiento
G.eval()
with torch.no_grad():
fake = G(fixed_noise)
fake = (fake + 1) / 2 # [-1,1] → [0,1]
G.train()
grid = make_grid(fake, nrow=8, padding=2)
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(grid.permute(1, 2, 0).cpu().numpy(), cmap='gray')
ax.set_title('Imágenes generadas (epoch final)', fontsize=14)
ax.axis('off')
plt.tight_layout()
plt.savefig('gan_results/final_generation.png', dpi=150, bbox_inches='tight')
plt.show()
show_evolution(G, fixed_noise)
7.3 Interpolación en el espacio latente
Una forma potente de evaluar si G ha aprendido una representación suave y continua es interpolar entre dos puntos del espacio latente. Si la transición entre las imágenes generadas es gradual, G ha aprendido una representación rica.
def interpolate_latent(G, z1, z2, n_steps=10):
"""Interpola linealmente entre z1 y z2 en el espacio latente."""
G.eval()
alphas = torch.linspace(0, 1, n_steps, device=device)
interpolations = []
with torch.no_grad():
for alpha in alphas:
z = (1 - alpha) * z1 + alpha * z2
img = G(z.unsqueeze(0))
interpolations.append(img)
images = torch.cat(interpolations, dim=0)
images = (images + 1) / 2
G.train()
return images
# Dos puntos aleatorios del espacio latente
z1 = torch.randn(LATENT_DIM, device=device)
z2 = torch.randn(LATENT_DIM, device=device)
interp = interpolate_latent(G, z1, z2, n_steps=12)
grid = make_grid(interp, nrow=12, padding=2)
plt.figure(figsize=(15, 2))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy(), cmap='gray')
plt.title('Interpolación en espacio latente: z₁ → z₂', fontsize=12)
plt.axis('off')
plt.tight_layout()
plt.savefig('gan_results/interpolation.png', dpi=150, bbox_inches='tight')
plt.show()
z = (1-α)·z₁ + α·z₂. Para α=0 → z₁, para α=1 → z₂, y valores intermedios interpolan suavemente.7.4 Widget: explorador del espacio latente
Para evaluar formalmente la calidad de una GAN, se usan dos métricas:
- FID (Fréchet Inception Distance): Compara la distribución de features de imágenes reales y generadas usando un modelo Inception preentrenado. Menor FID = mejor calidad. Heusel et al., 2017
- IS (Inception Score): Mide diversidad y calidad usando la distribución de clases predichas por Inception. Mayor IS = mejor. Salimans et al., 2016
Para nuestro proyecto MNIST, la inspección visual es suficiente. Para CIFAR-10+,
usa torchmetrics.image.FrechetInceptionDistance o el paquete
pytorch-fid.
De GAN vanilla a DCGAN
La GAN vanilla con MLPs funciona para MNIST, pero tiene limitaciones serias: las capas fully-connected no respetan la estructura espacial de las imágenes. La DCGAN (Radford et al., 2016) resuelve esto con una arquitectura basada en convoluciones que se convirtió en el estándar para GANs de imágenes.
8.1 Las reglas de DCGAN
El paper de DCGAN estableció 5 reglas arquitectónicas que se convirtieron en el estándar para GANs convolucionales:
| # | Regla | Razón |
|---|---|---|
| 1 | Reemplazar pooling con stride convolutions (D) y fractional-strided convolutions (G) | La red aprende su propio downsampling/upsampling |
| 2 | BatchNorm en G y D (excepto la salida de G y la entrada de D) | Estabiliza el entrenamiento, previene mode collapse |
| 3 | Eliminar capas fully-connected (excepto para z → primer bloque) | Las convs respetan la estructura espacial |
| 4 | ReLU en G (excepto la salida = Tanh), LeakyReLU en D | Gradientes saludables en ambas redes |
| 5 | Adam con lr=0.0002, β₁=0.5 | Momentum bajo para estabilidad adversarial |
8.2 Generador DCGAN
class DCGenerator(nn.Module):
"""
DCGAN Generator: z (100,) → imagen (1, 28, 28).
Usa ConvTranspose2d para upsampling progresivo.
"""
def __init__(self, latent_dim=LATENT_DIM, feature_maps=128):
super().__init__()
self.net = nn.Sequential(
# z: (B, 100) → (B, 256, 7, 7)
nn.ConvTranspose2d(latent_dim, feature_maps * 2, 7, 1, 0, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(True),
# (B, 256, 7, 7) → (B, 128, 14, 14)
nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps),
nn.ReLU(True),
# (B, 128, 14, 14) → (B, 1, 28, 28)
nn.ConvTranspose2d(feature_maps, IMG_CHANNELS, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, z):
# z: (B, 100) → (B, 100, 1, 1) para ConvTranspose2d
z = z.view(-1, LATENT_DIM, 1, 1)
return self.net(z)
# Crear y verificar
G_dc = DCGenerator().to(device)
G_dc.apply(weights_init)
z_test = torch.randn(4, LATENT_DIM, device=device)
out = G_dc(z_test)
print(f"DCGAN Generator: {z_test.shape} → {out.shape}")
print(f"Parámetros: {sum(p.numel() for p in G_dc.parameters()):,}")
ConvTranspose2d(100, 256, 7, 1, 0) — "deconvolución" que proyecta z (1×1) a un feature map de 7×7. Kernel=7, stride=1, padding=0.ConvTranspose2d(256, 128, 4, 2, 1) — stride=2 duplica la resolución: 7×7 → 14×14. Kernel=4, padding=1 para output exacto.z.view(-1, 100, 1, 1) — reshape a formato 4D para ConvTranspose2d, que espera (B, C, H, W).8.3 Discriminador DCGAN
class DCDiscriminator(nn.Module):
"""
DCGAN Discriminator: imagen (1, 28, 28) → probabilidad [0, 1].
Usa Conv2d con stride para downsampling progresivo.
"""
def __init__(self, feature_maps=128):
super().__init__()
self.net = nn.Sequential(
# (B, 1, 28, 28) → (B, 128, 14, 14) — sin BN en la primera capa
nn.Conv2d(IMG_CHANNELS, feature_maps, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# (B, 128, 14, 14) → (B, 256, 7, 7)
nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.LeakyReLU(0.2, inplace=True),
# (B, 256, 7, 7) → (B, 1, 1, 1)
nn.Conv2d(feature_maps * 2, 1, 7, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, img):
return self.net(img).view(-1, 1) # (B, 1)
# Crear y verificar
D_dc = DCDiscriminator().to(device)
D_dc.apply(weights_init)
pred = D_dc(out.detach())
print(f"DCGAN Discriminator: {out.shape} → {pred.shape}")
print(f"Parámetros: {sum(p.numel() for p in D_dc.parameters()):,}")
stride=2 reduce 28×28 → 14×14.stride=2 reduce 14×14 → 7×7.Conv2d(256, 1, 7) — el kernel de 7×7 colapsa el feature map 7×7 a un escalar 1×1. Es el equivalente a un "global average pooling + linear" pero en una sola operación.8.4 Entrenar la DCGAN
El training loop es idéntico al de la GAN vanilla. Solo sustituimos G y D por las versiones convolucionales:
# ── Usar los modelos DCGAN ───────────────────────────────
G = G_dc
D = D_dc
optimizer_G = optim.Adam(G.parameters(), lr=LR_G, betas=(BETA1, BETA2))
optimizer_D = optim.Adam(D.parameters(), lr=LR_D, betas=(BETA1, BETA2))
# ── Reutilizar el mismo training loop del paso 6 ────────
# (copia el loop anterior, o refactorízalo en una función)
# El resultado con DCGAN típicamente:
# - Converge más rápido (~20 epochs vs ~50)
# - Produce dígitos más nítidos
# - Captura mejor los detalles espaciales (trazos, curvas)
8.5 Comparativa: MLP vs DCGAN
| Aspecto | GAN Vanilla (MLP) | DCGAN (Conv) |
|---|---|---|
| Parámetros totales | ~1.6M | ~623K |
| Estructura espacial | No la respeta (flatten + linear) | Sí (convoluciones locales) |
| Calidad en MNIST | Buena (dígitos reconocibles) | Mejor (más nítida, menos ruido) |
| Escalabilidad | No funciona bien en 64×64+ | Escala a 64×64, 128×128, 256×256 |
| Velocidad de convergencia | ~50 epochs para MNIST | ~20-30 epochs |
| Uso recomendado | Aprendizaje, datos tabulares | Imágenes, producción |
- WGAN / WGAN-GP: Reemplaza BCE con Wasserstein distance. Entrenamiento más estable, sin mode collapse. Arjovsky et al., 2017
- Progressive GAN: Crece progresivamente de 4×4 a 1024×1024. Karras et al., 2018
- StyleGAN / StyleGAN2: Mapping network + AdaIN para control de estilo. El estado del arte en generación de caras. Karras et al., 2019
- BigGAN: GANs a gran escala con class conditioning. Brock et al., 2019
- StyleGAN3: Elimina artefactos de aliasing. Karras et al., 2021
Problemas comunes y soluciones
Entrenar GANs es notoriamente difícil. A diferencia de un clasificador donde la loss baja y listo, en GANs la loss no es una métrica fiable de calidad. Aquí cubrimos los problemas más frecuentes y cómo diagnosticarlos y resolverlos.
9.1 Mode collapse
El mode collapse ocurre cuando el Generador colapsa a generar siempre la misma imagen (o un subconjunto pequeño de imágenes), ignorando la diversidad del dataset real. Es el problema #1 de las GANs.
Diagnóstico:
- Las imágenes generadas son todas (casi) iguales.
- G_loss oscila en vez de bajar gradualmente.
- D_loss baja a ~0 (D reconoce fácilmente el truco).
Soluciones:
- WGAN-GP: Reemplaza BCE por Wasserstein distance con gradient penalty. Es la solución más fiable.
- Minibatch discrimination: Añade features al D que detectan si el batch carece de diversidad.
- Unrolled GAN: El generador optimiza mirando varios pasos adelante del D.
- Feature matching: En vez de maximizar
D(G(z)), minimizar la distancia de features intermedias de D entre reales y fakes.
9.2 Training inestable / oscilaciones
| Síntoma | Causa probable | Solución |
|---|---|---|
| Loss de D → 0 rápidamente | D demasiado fuerte para G | Reducir capacidad de D, label smoothing, más pasos de G por paso de D |
| Loss de G → 0 rápidamente | G encontró un exploit en D | Aumentar capacidad de D, entrenar D más pasos, spectral normalization |
| Losses oscilan sin converger | Learning rate demasiado alto | Reducir LR a 1e-4 o 5e-5, usar LR scheduling |
| NaN en la loss | log(0) en BCELoss | Usar BCEWithLogitsLoss, gradient clipping, label smoothing |
| Imágenes borrosas | G no tiene suficiente capacidad | Aumentar hidden sizes, usar DCGAN, más capas |
| Checkerboard artifacts | ConvTranspose2d con stride desalineado | Usar Upsample + Conv2d en vez de ConvTranspose2d |
9.3 Técnicas de estabilización
9.4 Implementar label smoothing
La técnica más fácil de implementar. Solo cambiamos los labels:
# ── Con label smoothing ──────────────────────────────────
# En vez de:
# labels_real = torch.ones(batch_size, 1) → 1.0
# labels_fake = torch.zeros(batch_size, 1) → 0.0
# Usamos:
labels_real = torch.FloatTensor(batch_size, 1).uniform_(0.9, 1.0).to(device)
labels_fake = torch.FloatTensor(batch_size, 1).uniform_(0.0, 0.1).to(device)
# Esto hace que D no pueda ser "100% seguro", lo que:
# 1. Previene que D domine a G
# 2. Regulariza D de forma implícita
# 3. Reduce el riesgo de gradientes saturados
- Instance noise: Añadir ruido gaussiano a las imágenes que recibe D (tanto reales como fakes). Se reduce gradualmente durante el training (annealing). Esto suaviza la distribución del D y facilita el aprendizaje de G. Sønderby et al., 2017
- Feature matching: En vez de
loss_G = BCE(D(G(z)), 1), minimizar||f(x_real) - f(G(z))||²dondefes una capa intermedia de D. Produce gradientes más informativos. - Historical averaging: Añadir un término de penalización que penaliza cambios bruscos en los parámetros:
||θ - θ_avg||². Estabiliza el entrenamiento.
Referencias y próximos pasos
Hemos construido una GAN completa desde cero — primero con MLPs, luego con convoluciones (DCGAN). Aquí recopilamos las referencias fundamentales del campo y los próximos pasos para seguir aprendiendo.
10.1 Resumen de lo aprendido
| Paso | Concepto clave | Código PyTorch |
|---|---|---|
| Generador | z ~ N(0,1) → Linear/ConvT → Tanh → imagen | nn.Sequential(Linear → LeakyReLU → BN → ... → Tanh) |
| Discriminador | imagen → Linear/Conv → Sigmoid → probabilidad | nn.Sequential(Flatten → Linear → LeakyReLU → Dropout → ... → Sigmoid) |
| Loss | Binary Cross-Entropy adversarial | nn.BCELoss() — D minimiza, G maximiza D(G(z)) |
| Training | Alternancia D/G con detach() | D(fake.detach()) para D, D(fake) para G |
| DCGAN | ConvTranspose2d (G) + Conv2d (D) | 5 reglas: stride convs, BN, no FC, ReLU/LReLU, Adam(β₁=0.5) |
| Debugging | Mode collapse, balance D/G, label smoothing | Monitorizar D(x) y D(G(z)), visual inspection |
10.2 Papers fundamentales
| Paper | Contribución | Año |
|---|---|---|
| Goodfellow et al. — Generative Adversarial Networks | El paper original: formulación minimax, prueba teórica de convergencia | 2014 |
| Radford et al. — Unsupervised Representation Learning with DCGAN | Arquitectura convolucional estándar, trucos de entrenamiento, aritmética latente | 2016 |
| Arjovsky et al. — Wasserstein GAN | Wasserstein distance como loss, weight clipping, entrenamiento estable | 2017 |
| Gulrajani et al. — Improved Training of WGANs (WGAN-GP) | Gradient penalty reemplaza weight clipping, convergencia mejorada | 2017 |
| Miyato et al. — Spectral Normalization for GANs | Normalización espectral de D para estabilidad | 2018 |
| Karras et al. — Progressive Growing of GANs | Crecer resolución progresivamente de 4×4 a 1024×1024 | 2018 |
| Karras et al. — A Style-Based Generator Architecture (StyleGAN) | Mapping network, AdaIN, control de estilo por capas | 2019 |
| Brock et al. — Large Scale GAN Training (BigGAN) | GANs a gran escala, class conditioning, truncation trick | 2019 |
| Salimans et al. — Improved Techniques for Training GANs | Feature matching, minibatch discrimination, Inception Score | 2016 |
| Heusel et al. — GANs Trained by Two Time-Scale Update Rule | FID metric, TTUR para convergencia | 2017 |
10.3 Documentación y repositorios
- PyTorch DCGAN Tutorial (oficial) — DCGAN con CelebA, muy bien documentado
- PyTorch-GAN (GitHub) — Colección de +30 implementaciones de GANs en PyTorch
- StyleGAN3 (NVIDIA) — Implementación oficial de StyleGAN3
- Papers With Code — GANs — Todos los papers de GANs indexados con código
- The GAN Zoo — Lista de más de 500 variantes de GANs con papers
- PyTorch nn.ConvTranspose2d — Documentación oficial de la "deconvolución"
- Deconvolution and Checkerboard Artifacts (Distill) — Artículo visual sobre los artifacts de ConvTranspose2d