Transfer Learning paso a paso con PyTorch
Aprenderás a reutilizar modelos preentrenados en ImageNet para resolver tu propio problema de clasificación de imágenes. Desde explorar y descargar backbones hasta feature extraction, fine-tuning, evaluación e inferencia.
Requisitos previos
- Python 3.9+ con PyTorch y torchvision instalados
- Haber completado el tutorial de PyTorch básico
- Conocer qué es una CNN (consulta la teoría de convoluciones)
- Conocer la teoría de clasificación y transfer learning
- Opcional: GPU con CUDA (acelera mucho, pero funciona en CPU)
Explorar modelos preentrenados
El Transfer Learning consiste en reutilizar un modelo que ya fue entrenado en un dataset grande (típicamente ImageNet, con 1.2 millones de imágenes y 1,000 clases) y adaptarlo a tu problema específico. En lugar de aprender desde cero, aprovechas las features jerárquicas que el modelo ya ha extraído: bordes, texturas, formas, y patrones complejos.
PyTorch ofrece modelos preentrenados a través de torchvision.models.
Adicionalmente, la librería timm (PyTorch Image Models) ofrece
más de 1,200 modelos con pesos preentrenados. Veamos qué hay disponible:
import torchvision.models as models
# Ver todos los modelos disponibles con pesos preentrenados
available = models.list_models(module=models)
print(f"Modelos en torchvision: {len(available)}")
print(available[:15]) # Primeros 15
# Con timm (instalar: pip install timm)
import timm
timm_models = timm.list_models(pretrained=True)
print(f"\nModelos en timm: {len(timm_models)}")
# Buscar modelos de una familia específica
resnet_models = timm.list_models('resnet*', pretrained=True)
efficientnet_models = timm.list_models('efficientnet*', pretrained=True)
print(f"ResNets disponibles: {len(resnet_models)}")
print(f"EfficientNets disponibles: {len(efficientnet_models)}")
Galería de modelos populares
Haz clic en cualquier modelo para ver cómo se carga en PyTorch. Cada tarjeta muestra los parámetros, accuracy top-1 en ImageNet y características clave:
# 👆 Haz clic en un modelo de la galería para ver cómo cargarlo
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
- Pocos datos (<1,000 imágenes): ResNet-18 o MobileNetV2 (menos parámetros = menos overfitting)
- Datos moderados (1K-10K): ResNet-50 o EfficientNet-B0 (buen balance)
- Muchos datos (>10K) y GPU: EfficientNet-B3/B4 o ConvNeXt (máximo accuracy)
- Despliegue en móvil: MobileNetV2 o EfficientNet-B0 (pocos FLOPs)
torchvision.models es la opción oficial de PyTorch. Ventajas:
- Mantenido por el equipo de PyTorch
- API estable y bien documentada
- Pesos multi-versión con
Weightsenum (v2 incluye mejores transforms) - Sin dependencias extra
timm (PyTorch Image Models) es la librería de Ross Wightman. Ventajas:
- +1,200 modelos con pesos preentrenados
- Modelos de última generación (ConvNeXt V2, EfficientNet V2, MetaFormer, etc.)
- API unificada:
timm.create_model('nombre', pretrained=True) - Pesos entrenados con mejores recetas de entrenamiento
Recomendación: usa torchvision si el modelo que quieres está disponible.
Si necesitas algo más específico o moderno, usa timm.
Error 1: API obsoleta de pesos
Antes de torchvision 0.13, se usaba pretrained=True. Ahora se usa weights=:
# ❌ Obsoleto (funciona pero da warning)
model = models.resnet50(pretrained=True)
# ✅ Correcto (torchvision >= 0.13)
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
Error 2: No verificar la resolución de entrada
Cada modelo espera una resolución específica. ResNet usa 224×224, pero EfficientNet-B4 usa 380×380. Los pesos incluyen las transforms correctas:
weights = models.ResNet50_Weights.IMAGENET1K_V2
preprocess = weights.transforms() # Incluye resize, crop, normalización
Descargar y cargar un backbone
Vamos a trabajar con ResNet-50 como ejemplo principal. Al cargar un modelo con pesos preentrenados, PyTorch descarga automáticamente los pesos (~98 MB) y los almacena en caché:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet50_Weights
# Cargar ResNet-50 con pesos de ImageNet (V2 = mejores pesos)
weights = ResNet50_Weights.IMAGENET1K_V2
backbone = models.resnet50(weights=weights)
print(f"Modelo cargado: ResNet-50")
print(f"Parámetros totales: {sum(p.numel() for p in backbone.parameters()):,}")
print(f"Output del backbone: {backbone.fc.in_features} features")
print(f"Clases originales: {backbone.fc.out_features}")
# Las transforms recomendadas para este modelo
preprocess = weights.transforms()
print(f"\nTransforms recomendadas:\n{preprocess}")
ResNet50_Weights.IMAGENET1K_V2 — la versión V2 usa mejores recetas de entrenamiento y da +1.5% accuracy vs V1.~/.cache/torch/hub/checkpoints/.backbone.fc.in_features — dimensión del vector de features antes de la capa final (2048 en ResNet-50).weights.transforms() — transforms que usaron los autores al entrenar. Incluye resize a 232, center crop a 224 y normalización con media/std de ImageNet.Verificar que funciona: predicción con ImageNet
Antes de modificar nada, comprobemos que el modelo funciona con una imagen de prueba. Esto te servirá como sanity check:
from PIL import Image
from torchvision.transforms import functional as F
# Descargar una imagen de ejemplo (o usa una tuya)
import urllib.request
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/YellowLabradorLooking_new.jpg/1200px-YellowLabradorLooking_new.jpg"
urllib.request.urlretrieve(url, "test_dog.jpg")
img = Image.open("test_dog.jpg")
# Preprocesar con las transforms del modelo
batch = preprocess(img).unsqueeze(0) # (3,224,224) → (1,3,224,224)
# Predecir
backbone.eval()
with torch.no_grad():
logits = backbone(batch)
probs = torch.softmax(logits, dim=1)
top5 = torch.topk(probs, 5)
# Mostrar top-5 predicciones
categories = weights.meta["categories"]
for i in range(5):
idx = top5.indices[0][i].item()
prob = top5.values[0][i].item() * 100
print(f" {categories[idx]:30s} → {prob:.1f}%")
Alternativa: cargar con timm
import timm
# timm tiene una API unificada para todos los modelos
model_timm = timm.create_model('resnet50', pretrained=True)
# Para ver la configuración del modelo
data_config = timm.data.resolve_model_data_config(model_timm)
print(data_config)
# {'input_size': (3, 224, 224), 'interpolation': 'bicubic',
# 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), ...}
# También puedes crear el modelo ya sin la cabeza de clasificación
backbone_timm = timm.create_model('resnet50', pretrained=True, num_classes=0)
# Ahora backbone_timm produce directamente un vector de 2048 features
Los pesos se descargan en estas ubicaciones por defecto:
- torchvision:
~/.cache/torch/hub/checkpoints/ - timm:
~/.cache/huggingface/hub/(o~/.cache/torch/hub/checkpoints/)
Puedes cambiar la ruta con la variable de entorno TORCH_HOME:
export TORCH_HOME=/ruta/personalizada
# O en Python:
import os
os.environ['TORCH_HOME'] = '/ruta/personalizada'
Para entornos offline (servidores sin internet), descarga los pesos previamente y cárgalos manualmente:
state_dict = torch.load("resnet50_v2.pth", weights_only=True)
model = models.resnet50()
model.load_state_dict(state_dict)
Entender la arquitectura: backbone + head
Un modelo de clasificación CNN se divide en dos partes fundamentales: el backbone (extractor de features) y el head (clasificador). La clave del transfer learning es reutilizar el backbone y reemplazar el head.
El diagrama muestra la estrategia de feature extraction: se congela todo el backbone (🔒) y solo se entrena el head (🔓). En el fine-tuning, también se descongelan las últimas capas del backbone.
Inspeccionar la estructura del modelo
# Ver las capas principales de ResNet-50
for name, module in backbone.named_children():
params = sum(p.numel() for p in module.parameters())
print(f"{name:12s} → {params:>10,} params ({type(module).__name__})")
Reemplazar el head
El head original clasifica en 1,000 clases de ImageNet. Necesitamos reemplazarlo por uno que clasifique en nuestras N clases. En este tutorial usaremos un problema de clasificación binaria (perros vs gatos) como ejemplo:
NUM_CLASSES = 2 # perros vs gatos (o tu número de clases)
# El head original
print(f"Head original: {backbone.fc}")
# Opción A: head simple (una sola capa lineal)
backbone.fc = nn.Linear(backbone.fc.in_features, NUM_CLASSES)
# Opción B: head con más capacidad (recomendado para few-shot)
backbone.fc = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, NUM_CLASSES),
)
print(f"Head nuevo: {backbone.fc}")
print(f"Parámetros del head: {sum(p.numel() for p in backbone.fc.parameters()):,}")
Head simple (nn.Linear(2048, N)):
- Menos parámetros → menos riesgo de overfitting
- Ideal cuando tienes muchos datos o cuando las features del backbone ya son muy buenas para tu dominio
- Entrenamiento más rápido
Head multicapa (con capas ocultas + Dropout):
- Permite aprender combinaciones no lineales de features
- Útil cuando el dominio es diferente a ImageNet (médicas, satélite, microscopía)
- El Dropout (0.3-0.5) ayuda a regularizar
Regla práctica: empieza con head simple. Si el accuracy se estanca, prueba multicapa.
Tabla comparativa de heads por modelo
| Modelo | Atributo del head | Features in | Cómo reemplazar |
|---|---|---|---|
| ResNet-* | model.fc |
512 / 2048 | model.fc = nn.Linear(N_in, N_cls) |
| VGG-* | model.classifier[6] |
4096 | model.classifier[6] = nn.Linear(4096, N_cls) |
| EfficientNet-* | model.classifier[1] |
1280 | model.classifier[1] = nn.Linear(1280, N_cls) |
| MobileNetV2 | model.classifier[1] |
1280 | model.classifier[1] = nn.Linear(1280, N_cls) |
| DenseNet-* | model.classifier |
1024 | model.classifier = nn.Linear(1024, N_cls) |
| ConvNeXt-* | model.classifier[2] |
768 / 1024 | model.classifier[2] = nn.Linear(N_in, N_cls) |
print(model) o
model.named_children() para encontrar la capa correcta.
Preparar tu dataset personalizado
Para transfer learning necesitas tus propias imágenes etiquetadas.
La forma más sencilla es organizar las imágenes en carpetas donde
cada carpeta es una clase. PyTorch las carga automáticamente
con ImageFolder.
dataset/
├── train/
│ ├── gatos/ # Clase 0
│ │ ├── cat_001.jpg
│ │ ├── cat_002.jpg
│ │ └── ...
│ └── perros/ # Clase 1
│ ├── dog_001.jpg
│ ├── dog_002.jpg
│ └── ...
├── val/
│ ├── gatos/
│ └── perros/
└── test/
├── gatos/
└── perros/
Transforms y Data Augmentation
Las transforms son cruciales en transfer learning. Para entrenamiento aplicamos data augmentation (para regularizar y que el modelo generalice mejor). Para validación y test, solo redimensionamos y normalizamos.
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
# ── Transforms para ENTRENAMIENTO (con augmentation) ─────
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2,
saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ── Transforms para VALIDACIÓN / TEST (sin augmentation) ─
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ── Crear datasets ────────────────────────────────────────
train_dataset = datasets.ImageFolder("dataset/train", transform=train_transforms)
val_dataset = datasets.ImageFolder("dataset/val", transform=val_transforms)
test_dataset = datasets.ImageFolder("dataset/test", transform=val_transforms)
print(f"Clases encontradas: {train_dataset.classes}")
print(f"Mapping: {train_dataset.class_to_idx}")
print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")
# ── Crear DataLoaders ─────────────────────────────────────
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
shuffle=False, num_workers=4, pin_memory=True)
RandomResizedCrop(224) — recorta aleatoriamente entre el 80-100% de la imagen y redimensiona a 224×224.RandomHorizontalFlip — voltea la imagen horizontalmente con 50% de probabilidad. No uses flip vertical si no tiene sentido para tu dominio (ej: texto).ColorJitter — variaciones aleatorias de brillo, contraste, saturación y hue. Simula diferentes condiciones de iluminación.Normalize — obligatorio. Usa las mismas media y std que se usaron para preentrenar en ImageNet.Resize(256) + CenterCrop(224) es el pipeline estándar de ImageNet. Sin augmentation.num_workers=4 — procesos paralelos para cargar imágenes. Evita que la CPU sea el cuello de botella. pin_memory=True acelera la transferencia CPU→GPU.# Alternativa rápida: usar un dataset de torchvision
from torchvision.datasets import OxfordIIITPet
# Descarga automáticamente ~800 MB
train_dataset = OxfordIIITPet(root="./data", split="trainval",
target_types="category",
transform=train_transforms, download=True)
test_dataset = OxfordIIITPet(root="./data", split="test",
target_types="category",
transform=val_transforms, download=True)
NUM_CLASSES = 37 # 37 razas de perros y gatos
print(f"Clases: {len(train_dataset.classes)} razas de mascotas")
Principio: la augmentation debe generar variaciones que podrían ocurrir en el mundo real.
- Siempre útiles: flip horizontal, pequeños recortes, variación de color
- Con cuidado: rotaciones grandes (solo si tiene sentido), flip vertical (no para texto o caras)
- Avanzado:
transforms.RandAugment(),transforms.TrivialAugmentWide(), MixUp, CutMix
Si tienes muy pocos datos (<500 imágenes):
- Augmentation agresiva (rotaciones, elastic deformations)
- Considera
transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET) - MixUp y CutMix pueden ayudar mucho
Errores comunes:
- Augmentation en validación/test (nunca hacerlo, sesga las métricas)
- Olvidar la normalización con la media/std de ImageNet
- Usar augmentation demasiado agresiva que destruye la información relevante
Error: FileNotFoundError: Found no valid file for the classes
→ Las imágenes no están en subdirectorios. Revisa que la estructura sea train/clase/imagen.jpg.
Error: RuntimeError: stack expects each tensor to be equal size
→ Las imágenes tienen tamaños diferentes y no estás usando Resize/CenterCrop.
Error: num_workers > 0 crashes en Windows
→ En Windows necesitas if __name__ == '__main__': antes de invocar el DataLoader.
Alternativamente, usa num_workers=0.
Error: Imágenes corruptas o con canales extras (RGBA, grayscale)
→ Añade transforms.Lambda(lambda x: x.convert('RGB')) al principio de las transforms,
o filtra las imágenes antes de entrenar.
Feature extraction: backbone congelado
La primera estrategia de transfer learning es feature extraction: congelamos todos los pesos del backbone y solo entrenamos el head nuevo. El backbone actúa como un extractor de features fijo.
¿Por qué funciona? Porque las features aprendidas en ImageNet (bordes, texturas, formas) son universales y útiles para casi cualquier problema de visión.
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet50_Weights
# 1. Cargar modelo preentrenado
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# 2. CONGELAR todo el backbone
for param in model.parameters():
param.requires_grad = False
# 3. Reemplazar el head (se crea con requires_grad=True por defecto)
NUM_CLASSES = 2
model.fc = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, NUM_CLASSES),
)
# Verificar qué se entrena y qué no
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params
print(f"Parámetros totales: {total_params:>12,}")
print(f"Parámetros congelados: {frozen_params:>12,} ({frozen_params/total_params*100:.1f}%)")
print(f"Parámetros entrenables: {trainable_params:>12,} ({trainable_params/total_params*100:.1f}%)")
# Mover a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
param.requires_grad = False — congela cada parámetro. PyTorch no calculará gradientes para ellos → no se actualizarán durante el entrenamiento.requires_grad=True por defecto. Estos son los únicos parámetros que se entrenarán.Configurar el optimizador (solo head)
Es importante que el optimizador solo reciba los parámetros entrenables. Si le pasas todos, funciona igualmente (los congelados no tienen gradientes), pero es más limpio y eficiente filtrar:
import torch.optim as optim
# Solo los parámetros que requieren gradiente
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3, # Learning rate relativamente alto para el head
weight_decay=1e-4, # Regularización L2
)
# Loss function
criterion = nn.CrossEntropyLoss()
# Scheduler: reducir LR cuando se estanca la mejora
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
print(f"Optimizador: Adam (lr=1e-3)")
print(f"Parámetros en el optimizer: {sum(p.numel() for group in optimizer.param_groups for p in group['params']):,}")
filter(lambda p: p.requires_grad, ...) — filtra solo los parámetros del head. Más limpio que pasar todos.lr=1e-3 — para feature extraction usamos un learning rate relativamente alto (1e-3). El head empieza con pesos aleatorios y necesita aprender rápido.weight_decay=1e-4 — regularización L2. Evita que los pesos del head crezcan demasiado.ReduceLROnPlateau — si la val_loss no mejora en 3 epochs, reduce el lr a la mitad. Muy útil para convergencia fina.La decisión depende de dos factores: cuántos datos tienes y cuánto se parece tu dominio a ImageNet:
| Dominio similar a ImageNet | Dominio diferente | |
|---|---|---|
| Pocos datos (<1K) | Feature extraction ✅ | Feature extraction + augmentation agresiva |
| Datos moderados (1K-10K) | Feature extraction o fine-tuning parcial | Fine-tuning de las últimas capas |
| Muchos datos (>10K) | Fine-tuning completo | Fine-tuning completo (quizá desde scratch) |
Regla práctica: empieza siempre con feature extraction. Si el accuracy no es suficiente, pasa a fine-tuning.
Entrenar el head
Ahora escribimos el training loop con validación en cada epoch. El loop monitoriza tanto la loss como el accuracy, y guarda el mejor modelo según la validation loss (early stopping manual):
import time
from pathlib import Path
def train_one_epoch(model, loader, criterion, optimizer, device):
"""Entrena un epoch completo y devuelve loss y accuracy."""
model.train()
running_loss, correct, total = 0.0, 0, 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
correct += outputs.argmax(1).eq(labels).sum().item()
total += labels.size(0)
return running_loss / total, correct / total
@torch.no_grad()
def validate(model, loader, criterion, device):
"""Evalúa en el conjunto de validación."""
model.eval()
running_loss, correct, total = 0.0, 0, 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
correct += outputs.argmax(1).eq(labels).sum().item()
total += labels.size(0)
return running_loss / total, correct / total
# ── Entrenamiento ─────────────────────────────────────────
EPOCHS = 15
best_val_loss = float('inf')
patience_counter = 0
PATIENCE = 5 # early stopping: parar si no mejora en 5 epochs
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
save_dir = Path("checkpoints")
save_dir.mkdir(exist_ok=True)
print("Epoch Train Loss Train Acc Val Loss Val Acc LR Tiempo")
print("─" * 75)
for epoch in range(EPOCHS):
t0 = time.time()
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, device
)
val_loss, val_acc = validate(
model, val_loader, criterion, device
)
# Learning rate scheduler
scheduler.step(val_loss)
current_lr = optimizer.param_groups[0]['lr']
elapsed = time.time() - t0
# Guardar historial
history['train_loss'].append(train_loss)
history['val_loss'].append(val_loss)
history['train_acc'].append(train_acc)
history['val_acc'].append(val_acc)
# Early stopping + guardar mejor modelo
marker = ""
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
torch.save(model.state_dict(), save_dir / "best_model.pth")
marker = " ★ saved"
else:
patience_counter += 1
if patience_counter >= PATIENCE:
print(f"\n⏹ Early stopping en epoch {epoch+1} (sin mejora en {PATIENCE} epochs)")
break
print(f"{epoch+1:3d}/{EPOCHS} {train_loss:.4f} {train_acc:.4f} "
f"{val_loss:.4f} {val_acc:.4f} {current_lr:.1e} {elapsed:.1f}s{marker}")
model.train() — activa Dropout y BatchNorm en modo entrenamiento.@torch.no_grad() — decorador que desactiva autograd. Más eficiente que with torch.no_grad(): para funciones completas.PATIENCE = 5 — si la val_loss no mejora en 5 epochs consecutivos, paramos. Evita overentrenar.Con solo el head entrenado, ya alcanzamos ~94% de accuracy en clasificación binaria. Impresionante, considerando que entrenamos menos del 5% de los parámetros del modelo.
Si el accuracy no es suficiente con backbone congelado, antes de pasar a fine-tuning prueba:
- Head más grande: Añade más capas ocultas (512→256→N) con ReLU y Dropout
- Más augmentation: RandAugment, AutoAugment, más variaciones de color
- Otro backbone: Prueba EfficientNet o ConvNeXt, que suelen dar mejores features
- Diferente learning rate: Prueba 1e-4 a 5e-3
- Más epochs: Aumenta la paciencia del early stopping
Si nada de esto funciona, es hora del fine-tuning (paso siguiente).
Problema: Loss no baja
- Verifica que las transforms incluyen la normalización de ImageNet
- Comprueba que las etiquetas son correctas (
class_to_idx) - Asegúrate de que el head está descongelado (
requires_grad=True)
Problema: Val loss sube desde el principio (overfitting inmediato)
- Reduce el learning rate
- Aumenta el Dropout en el head (0.5 en vez de 0.3)
- Añade más data augmentation
- Reduce la complejidad del head
Problema: CUDA out of memory
- Reduce el batch_size (16, 8, o incluso 4)
- Verifica que el backbone está congelado (no se almacenan activaciones para backprop)
- Usa un modelo más pequeño (ResNet-18, MobileNetV2)
Fine-tuning del backbone
El fine-tuning consiste en descongelar parte (o todo) el backbone y reentrenarlo con un learning rate muy bajo. Esto permite que las features se adapten a tu dominio específico sin destruir lo aprendido en ImageNet.
La clave es usar learning rates diferenciados: un lr bajo para el backbone (que ya tiene pesos buenos) y un lr más alto para el head (que necesita aprender más).
Estrategia recomendada: descongelado progresivo
La técnica más robusta es el progressive unfreezing: primero entrenas solo el head (paso anterior), luego descongelas las últimas capas del backbone y reduces el learning rate. Esto evita destruir las features de las capas tempranas:
# ── Fase 2: Fine-tuning ───────────────────────────────────
# 1. Cargar el mejor modelo de feature extraction
model.load_state_dict(torch.load("checkpoints/best_model.pth", weights_only=True))
# 2. Descongelar las últimas capas del backbone
# Estrategia: descongelar layer4 (y opcionalmente layer3)
for name, param in model.named_parameters():
param.requires_grad = False # Congelar todo primero
# Descongelar layer4 + head
for name, param in model.named_parameters():
if "layer4" in name or "fc" in name:
param.requires_grad = True
# Verificar
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Entrenables ahora: {trainable:,} / {total:,} ({trainable/total*100:.1f}%)")
# 3. Optimizador con LEARNING RATES DIFERENCIADOS
# - Backbone (layer4): lr bajo para no destruir features
# - Head: lr más alto para que siga aprendiendo
optimizer_ft = optim.Adam([
{"params": [p for n, p in model.named_parameters()
if "layer4" in n and p.requires_grad],
"lr": 1e-5}, # ← LR bajo para el backbone
{"params": model.fc.parameters(),
"lr": 5e-4}, # ← LR más alto para el head
], weight_decay=1e-4)
# 4. Nuevo scheduler
scheduler_ft = optim.lr_scheduler.CosineAnnealingLR(
optimizer_ft, T_max=10, eta_min=1e-6
)
print("\nGrupos del optimizador:")
for i, group in enumerate(optimizer_ft.param_groups):
n_params = sum(p.numel() for p in group['params'])
print(f" Grupo {i}: lr={group['lr']:.1e}, params={n_params:,}")
layer4 (las capas más profundas del backbone) y el fc (head).CosineAnnealingLR — reduce el learning rate siguiendo una curva coseno. Más suave que ReduceLROnPlateau.Entrenar la fase de fine-tuning
# ── Fase 2: Training loop de fine-tuning ──────────────────
FT_EPOCHS = 10
best_val_loss_ft = float('inf')
patience_counter_ft = 0
print("\n📌 Fine-tuning (layer4 + head)")
print("Epoch Train Loss Train Acc Val Loss Val Acc Tiempo")
print("─" * 65)
for epoch in range(FT_EPOCHS):
t0 = time.time()
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer_ft, device
)
val_loss, val_acc = validate(model, val_loader, criterion, device)
scheduler_ft.step()
elapsed = time.time() - t0
marker = ""
if val_loss < best_val_loss_ft:
best_val_loss_ft = val_loss
patience_counter_ft = 0
torch.save(model.state_dict(), "checkpoints/best_model_finetuned.pth")
marker = " ★ saved"
else:
patience_counter_ft += 1
if patience_counter_ft >= 5:
print(f"\n⏹ Early stopping en epoch {epoch+1}")
break
print(f"{epoch+1:3d}/{FT_EPOCHS} {train_loss:.4f} {train_acc:.4f} "
f"{val_loss:.4f} {val_acc:.4f} {elapsed:.1f}s{marker}")
Con fine-tuning, subimos de 94% → 97.6% accuracy. La adaptación
de las features de layer4 al dominio específico ha supuesto una
mejora significativa.
- Si el lr del backbone es demasiado alto, destruirás las features preentrenadas (catastrophic forgetting)
- Si descongelas demasiadas capas con pocos datos, overfitting
- Monitoriza siempre la val_loss: si sube mientras train_loss baja → overfitting
1. Progressive unfreezing (gradual):
Descongela bloque por bloque en epochs sucesivos:
# Epoch 1-5: solo head
# Epoch 6-10: head + layer4
# Epoch 11-15: head + layer4 + layer3
layers_to_unfreeze = ['layer4', 'layer3', 'layer2']
for i, layer_name in enumerate(layers_to_unfreeze):
# Descongelar la siguiente capa cada 5 epochs
for name, param in model.named_parameters():
if layer_name in name:
param.requires_grad = True
2. Discriminative fine-tuning (ULMFiT-style):
Cada capa recibe un lr diferente, decreciendo exponencialmente hacia las capas tempranas:
base_lr = 1e-3
param_groups = []
for i, (name, layer) in enumerate(reversed(list(model.named_children()))):
lr = base_lr * (0.3 ** i) # Cada capa anterior: lr × 0.3
param_groups.append({'params': layer.parameters(), 'lr': lr})
3. LORA / LoRA (Low-Rank Adaptation):
En lugar de fine-tunear todos los pesos, inserta matrices de bajo rango.
Usado principalmente en LLMs pero aplicable a CNNs. Implementación: peft (Hugging Face).
Evaluar el modelo
La evaluación final se hace en el test set (que el modelo nunca ha visto, ni siquiera para decidir early stopping). Además del accuracy, calcularemos métricas detalladas: precision, recall, F1-score y la matriz de confusión.
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
# 1. Cargar el mejor modelo fine-tuned
model.load_state_dict(
torch.load("checkpoints/best_model_finetuned.pth", weights_only=True)
)
model.eval()
# 2. Recoger todas las predicciones del test set
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
outputs = model(images)
probs = torch.softmax(outputs, dim=1)
all_preds.extend(outputs.argmax(1).cpu().numpy())
all_labels.extend(labels.numpy())
all_probs.extend(probs.cpu().numpy())
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)
# 3. Accuracy global
test_acc = (all_preds == all_labels).mean()
print(f"🎯 Test Accuracy: {test_acc:.4f} ({test_acc*100:.1f}%)")
print(f" Correctas: {(all_preds == all_labels).sum()} / {len(all_labels)}")
# 4. Classification report (precision, recall, F1)
class_names = test_dataset.classes # ['gatos', 'perros']
print(f"\n📊 Classification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))
# 5. Matriz de confusión
cm = confusion_matrix(all_labels, all_preds)
print(f"Confusion Matrix:")
print(cm)
torch.softmax(outputs, dim=1) — convierte logits en probabilidades. Útil para análisis de confianza.classification_report de sklearn — calcula precision, recall y F1 por clase. Esencial para datasets desbalanceados.Análisis de errores: imágenes más confusas
Una buena práctica es inspeccionar las imágenes donde el modelo se equivoca o tiene menos confianza. Esto revela patrones de error (imágenes borrosas, ángulos raros, razas ambiguas...):
# Encontrar las predicciones incorrectas
errors_idx = np.where(all_preds != all_labels)[0]
print(f"Errores totales: {len(errors_idx)} / {len(all_labels)}")
# Encontrar las predicciones con menor confianza
confidence = all_probs.max(axis=1)
least_confident_idx = np.argsort(confidence)[:10] # 10 menos confiables
print(f"\n🔍 Top-10 predicciones menos confiables:")
for idx in least_confident_idx:
pred_class = class_names[all_preds[idx]]
real_class = class_names[all_labels[idx]]
conf = confidence[idx] * 100
status = "✅" if all_preds[idx] == all_labels[idx] else "❌"
print(f" {status} Imagen {idx}: pred={pred_class} ({conf:.1f}%) | real={real_class}")
# Métricas adicionales
from sklearn.metrics import roc_auc_score
if NUM_CLASSES == 2:
auc = roc_auc_score(all_labels, all_probs[:, 1])
print(f"\n📈 AUC-ROC: {auc:.4f}")
Comparativa: features extraction vs fine-tuning
| Estrategia | Test Accuracy | Params entrenados | Tiempo/epoch | GPU RAM |
|---|---|---|---|---|
| Feature extraction (head) | 94.0% | 1.05M (4.3%) | ~12s | ~2 GB |
| Fine-tuning (layer4 + head) | 97.6% | 16.0M (65.2%) | ~18s | ~5 GB |
| Desde cero (sin pretrained) | ~85% | 24.6M (100%) | ~20s | ~6 GB |
1. Test-Time Augmentation (TTA):
Aplica varias augmentations a la misma imagen y promedia las predicciones:
# TTA: predecir con la imagen original + flip horizontal
preds_original = model(img)
preds_flipped = model(torch.flip(img, dims=[3]))
final_pred = (preds_original + preds_flipped) / 2
2. Calibración de probabilidades:
Las probabilidades de softmax no siempre están bien calibradas.
Usa sklearn.calibration.CalibratedClassifierCV o temperature scaling.
3. Visualizar con GradCAM:
GradCAM muestra qué parte de la imagen activa la predicción.
Puedes usar la librería pytorch-grad-cam:
pip install pytorch-grad-cam
from pytorch_grad_cam import GradCAM
cam = GradCAM(model=model, target_layers=[model.layer4[-1]])
Inferencia y exportación
Una vez entrenado y evaluado el modelo, queremos usarlo para clasificar imágenes nuevas. Vamos a ver cómo hacer inferencia, guardar el modelo completo y exportarlo a ONNX para despliegue en producción.
Inferencia sobre una imagen nueva
from PIL import Image
def predict_image(model, image_path, transform, class_names, device):
"""Clasifica una imagen y devuelve clase y confianza."""
img = Image.open(image_path).convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device) # (1, 3, 224, 224)
model.eval()
with torch.no_grad():
logits = model(img_tensor)
probs = torch.softmax(logits, dim=1)
confidence, pred_idx = probs.max(dim=1)
pred_class = class_names[pred_idx.item()]
confidence = confidence.item() * 100
return pred_class, confidence, probs[0].cpu().numpy()
# Usar el modelo
class_names = ['gatos', 'perros']
pred, conf, probs = predict_image(
model, "nueva_imagen.jpg", val_transforms, class_names, device
)
print(f"Predicción: {pred} (confianza: {conf:.1f}%)")
print(f"Probabilidades: {dict(zip(class_names, [f'{p:.3f}' for p in probs]))}")
Inferencia en batch
from pathlib import Path
def predict_batch(model, image_dir, transform, class_names, device):
"""Clasifica todas las imágenes de un directorio."""
model.eval()
results = []
image_paths = list(Path(image_dir).glob("*.jpg")) + \
list(Path(image_dir).glob("*.png"))
for path in image_paths:
img = Image.open(path).convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(img_tensor)
probs = torch.softmax(logits, dim=1)
confidence, pred_idx = probs.max(dim=1)
results.append({
"file": path.name,
"class": class_names[pred_idx.item()],
"confidence": confidence.item() * 100,
})
return results
# Ejemplo
results = predict_batch(model, "nuevas_imagenes/", val_transforms, class_names, device)
for r in results:
print(f" {r['file']:25s} → {r['class']} ({r['confidence']:.1f}%)")
Guardar el modelo completo
# Checkpoint completo: para poder reentrenar o compartir
checkpoint = {
"model_state_dict": model.state_dict(),
"class_names": class_names,
"num_classes": NUM_CLASSES,
"backbone": "resnet50",
"input_size": 224,
"normalize_mean": [0.485, 0.456, 0.406],
"normalize_std": [0.229, 0.224, 0.225],
"test_accuracy": test_acc,
"training_info": {
"feature_extraction_epochs": EPOCHS,
"finetuning_epochs": FT_EPOCHS,
"strategy": "progressive_unfreezing",
}
}
torch.save(checkpoint, "model_final.pth")
print("✅ Checkpoint completo guardado")
# Para cargar después:
ckpt = torch.load("model_final.pth", weights_only=False)
model_loaded = models.resnet50(weights=None) # Sin pesos preentrenados
model_loaded.fc = nn.Sequential(
nn.Linear(2048, 512), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(512, ckpt["num_classes"]),
)
model_loaded.load_state_dict(ckpt["model_state_dict"])
model_loaded.eval()
print(f"✅ Modelo cargado: {ckpt['backbone']}, acc={ckpt['test_accuracy']:.3f}")
Exportar a ONNX (producción)
Para desplegar el modelo en producción (servidor web, app móvil, edge device), exportarlo a ONNX es la opción más portable. ONNX es compatible con ONNX Runtime, TensorRT, OpenVINO, CoreML y más:
# Exportar a ONNX
model.eval()
dummy_input = torch.randn(1, 3, 224, 224).to(device)
torch.onnx.export(
model,
dummy_input,
"model_transfer.onnx",
input_names=["image"],
output_names=["logits"],
dynamic_axes={
"image": {0: "batch_size"},
"logits": {0: "batch_size"},
},
opset_version=17,
)
print("✅ Modelo exportado a ONNX")
# Verificar con ONNX Runtime
# pip install onnxruntime
import onnxruntime as ort
session = ort.InferenceSession("model_transfer.onnx")
ort_input = {"image": dummy_input.cpu().numpy()}
ort_output = session.run(None, ort_input)
print(f"Output shape: {ort_output[0].shape}") # (1, 2)
- TorchScript:
model_scripted = torch.jit.script(model)— para C++ o servidores PyTorch - TensorRT:
torch_tensorrt.compile(model)— máxima velocidad en GPU NVIDIA - CoreML:
ct.convert(traced_model)— para iOS/macOS - TFLite: convierte ONNX → TFLite para Android
1. torch.compile() (PyTorch 2.x):
# Compila y optimiza el modelo (hasta 2× más rápido)
model_compiled = torch.compile(model, mode="reduce-overhead")
2. Half precision (FP16):
# Convierte a FP16 (la mitad de memoria, ~2× más rápido en GPU)
model_fp16 = model.half()
with torch.no_grad():
output = model_fp16(input_tensor.half())
3. Quantización (INT8):
# Quantización dinámica (CPU): 2-4× más rápido, ~1% menos accuracy
model_quantized = torch.quantization.quantize_dynamic(
model.cpu(), {nn.Linear}, dtype=torch.qint8
)
4. Batching: Procesa múltiples imágenes a la vez. Es más eficiente que una a una, especialmente en GPU.
Script completo y referencias
Aquí tienes el script completo que integra todo lo que hemos visto: carga del backbone, preparación de datos, feature extraction, fine-tuning, evaluación e inferencia. Cópialo y ejecútalo directamente.
📄 Script: transfer_learning.py
"""
Transfer Learning paso a paso con PyTorch.
Clasificación binaria (gatos vs perros) con ResNet-50 preentrenado.
"""
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.models import ResNet50_Weights
# ── Config ──────────────────────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 2
BATCH_SIZE = 32
FE_EPOCHS = 15 # Feature extraction epochs
FT_EPOCHS = 10 # Fine-tuning epochs
FE_LR = 1e-3 # LR para feature extraction
FT_LR_BACKBONE = 1e-5 # LR para backbone en fine-tuning
FT_LR_HEAD = 5e-4 # LR para head en fine-tuning
PATIENCE = 5
DATA_DIR = Path("dataset")
SAVE_DIR = Path("checkpoints")
SAVE_DIR.mkdir(exist_ok=True)
print(f"Device: {DEVICE}")
# ── 1. Datos ────────────────────────────────────────────
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
train_ds = datasets.ImageFolder(DATA_DIR / "train", train_transforms)
val_ds = datasets.ImageFolder(DATA_DIR / "val", val_transforms)
test_ds = datasets.ImageFolder(DATA_DIR / "test", val_transforms)
train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_ds, BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
class_names = train_ds.classes
print(f"Clases: {class_names} | Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}")
# ── 2. Modelo ───────────────────────────────────────────
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
for p in model.parameters():
p.requires_grad = False
model.fc = nn.Sequential(
nn.Linear(2048, 512), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(512, NUM_CLASSES),
)
model = model.to(DEVICE)
# ── Helpers ─────────────────────────────────────────────
def train_one_epoch(model, loader, criterion, optimizer):
model.train()
loss_sum, correct, total = 0.0, 0, 0
for x, y in loader:
x, y = x.to(DEVICE), y.to(DEVICE)
out = model(x); loss = criterion(out, y)
optimizer.zero_grad(); loss.backward(); optimizer.step()
loss_sum += loss.item() * x.size(0)
correct += out.argmax(1).eq(y).sum().item(); total += y.size(0)
return loss_sum / total, correct / total
@torch.no_grad()
def evaluate(model, loader, criterion):
model.eval()
loss_sum, correct, total = 0.0, 0, 0
for x, y in loader:
x, y = x.to(DEVICE), y.to(DEVICE)
out = model(x); loss = criterion(out, y)
loss_sum += loss.item() * x.size(0)
correct += out.argmax(1).eq(y).sum().item(); total += y.size(0)
return loss_sum / total, correct / total
def run_training(model, optimizer, scheduler, epochs, tag, save_name):
criterion = nn.CrossEntropyLoss()
best_vl = float('inf'); wait = 0
print(f"\n{'='*60}\n📌 {tag}\n{'='*60}")
for ep in range(epochs):
t0 = time.time()
tl, ta = train_one_epoch(model, train_loader, criterion, optimizer)
vl, va = evaluate(model, val_loader, criterion)
scheduler.step(vl) if hasattr(scheduler, 'step') else None
mk = ""
if vl < best_vl:
best_vl = vl; wait = 0
torch.save(model.state_dict(), SAVE_DIR / save_name); mk = " ★"
else:
wait += 1
if wait >= PATIENCE:
print(f" ⏹ Early stopping ep {ep+1}"); break
print(f" {ep+1:2d}/{epochs} tl={tl:.4f} ta={ta:.4f} vl={vl:.4f} va={va:.4f} {time.time()-t0:.1f}s{mk}")
model.load_state_dict(torch.load(SAVE_DIR / save_name, weights_only=True))
return best_vl
# ── 3. Feature extraction ──────────────────────────────
opt1 = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
lr=FE_LR, weight_decay=1e-4)
sch1 = optim.lr_scheduler.ReduceLROnPlateau(opt1, patience=3, factor=0.5)
run_training(model, opt1, sch1, FE_EPOCHS, "Feature Extraction (head only)", "best_fe.pth")
# ── 4. Fine-tuning ─────────────────────────────────────
for n, p in model.named_parameters():
p.requires_grad = "layer4" in n or "fc" in n
opt2 = optim.Adam([
{"params": [p for n, p in model.named_parameters() if "layer4" in n and p.requires_grad], "lr": FT_LR_BACKBONE},
{"params": model.fc.parameters(), "lr": FT_LR_HEAD},
], weight_decay=1e-4)
sch2 = optim.lr_scheduler.CosineAnnealingLR(opt2, T_max=FT_EPOCHS, eta_min=1e-6)
run_training(model, opt2, sch2, FT_EPOCHS, "Fine-tuning (layer4 + head)", "best_ft.pth")
# ── 5. Evaluación final ────────────────────────────────
criterion = nn.CrossEntropyLoss()
test_loss, test_acc = evaluate(model, test_loader, criterion)
print(f"\n🎯 Test Accuracy: {test_acc:.4f} ({test_acc*100:.1f}%)")
# ── 6. Guardar modelo final ────────────────────────────
torch.save({
"model_state_dict": model.state_dict(),
"class_names": class_names, "num_classes": NUM_CLASSES,
"backbone": "resnet50", "test_accuracy": test_acc,
}, SAVE_DIR / "model_final.pth")
print("💾 Modelo final guardado")
Resumen del flujo completo
Referencias y recursos
- Paper Deep Residual Learning for Image Recognition — He et al., 2015. El paper fundacional de ResNet.
- Paper How transferable are features in deep neural networks? — Yosinski et al., 2014. Estudio seminal sobre transferibilidad de features por capa.
- Paper EfficientNet: Rethinking Model Scaling for CNNs — Tan & Le, 2019. Escalado compuesto de profundidad, anchura y resolución.
- Paper Universal Language Model Fine-tuning for Text Classification (ULMFiT) — Howard & Ruder, 2018. Técnicas de fine-tuning progresivo y discriminative LR.
- Paper MobileNetV2: Inverted Residuals and Linear Bottlenecks — Sandler et al., 2018.
- Paper Densely Connected Convolutional Networks (DenseNet) — Huang et al., 2017.
- Docs torchvision.models — PyTorch Documentation — Lista completa de modelos preentrenados con pesos y transforms.
- Docs Transfer Learning for Computer Vision Tutorial — PyTorch — Tutorial oficial de PyTorch sobre transfer learning.
- Repo timm (PyTorch Image Models) — Ross Wightman. +1,200 modelos con pesos preentrenados y training recipes.
- Repo pytorch-grad-cam — Implementación de GradCAM, GradCAM++, ScoreCAM y más para visualizar decisiones del modelo.
- Blog CS231n: Transfer Learning — Notas del curso de Stanford sobre transfer learning en CNNs.
- Blog Transfer Learning — Sebastian Ruder — Overview completo de métodos de transfer learning (NLP y CV).
- Docs ONNX Runtime Documentation — Para despliegue optimizado de modelos exportados a ONNX.