Segmentación de imágenes con CNNs
De la clasificación a la predicción por píxel: segmentación semántica, por instancias y panóptica con FCN, U-Net, DeepLab, Mask R-CNN y más. Funciones de pérdida, métricas y código en PyTorch / TensorFlow.
🎯 ¿Qué es la segmentación de imágenes?
La segmentación de imágenes es la tarea de asignar una etiqueta a cada píxel de una imagen. Mientras que la clasificación responde «¿qué hay en esta imagen?» y la detección responde «¿dónde está cada objeto?» con un rectángulo, la segmentación responde «¿a qué clase pertenece cada píxel?».
El resultado de una segmentación es un mapa de segmentación (o segmentation mask): una imagen del mismo tamaño que la original donde cada píxel tiene un valor entero que representa su clase. Esto permite delimitar los objetos con precisión a nivel de píxel.
Ejemplo cotidiano: cuando tu móvil aplica el efecto «retrato» desenfocando el fondo, está usando un modelo de segmentación para separar la persona (primer plano) del fondo — píxel a píxel.
📊 Clasificación vs. Detección vs. Segmentación
Para entender bien qué aporta la segmentación, comparémosla con las otras dos tareas fundamentales de visión por computador:
🧪 Comparador de tareas de visión
| Propiedad | Clasificación | Detección | Segmentación |
|---|---|---|---|
| Pregunta | ¿Qué hay? | ¿Dónde está cada objeto? | ¿A qué clase pertenece cada píxel? |
| Salida | Una etiqueta (clase) | Bounding boxes + clases | Máscara píxel a píxel |
| Granularidad | Imagen completa | Rectángulos | Píxel individual |
| Forma del objeto | No se conoce | Aproximada (rectángulo) | Exacta (contorno real) |
| Ejemplo de red | ResNet, VGG | YOLO, Faster R-CNN | U-Net, DeepLab |
🧩 Tipos de segmentación
Existen tres tipos principales de segmentación, cada uno con un nivel diferente de detalle:
🧪 Explorador de tipos de segmentación
Segmentación semántica
Asigna una clase a cada píxel, pero no distingue entre instancias del mismo objeto. Si hay tres coches en la imagen, todos sus píxeles reciben la etiqueta «coche», pero no se sabe que son tres coches distintos.
Donde C es el número de clases. Cada píxel de la salida toma un valor entero que indica su clase.
Segmentación por instancias
Detecta y segmenta cada instancia individual de un objeto. Los tres coches del ejemplo anterior recibirían cada uno una máscara distinta: coche_1, coche_2, coche_3. Sin embargo, las regiones de fondo no se clasifican.
Segmentación panóptica
Combina ambas: todos los píxeles reciben una clase (como la semántica) y, además, las instancias de objetos «contables» (things: personas, coches) se distinguen individualmente. Las regiones como cielo, carretera o hierba (stuff) solo reciben la clase semántica sin distinguir instancias.
¿Cuándo usar cada tipo?
- Semántica: conducción autónoma (carretera vs. acera vs. edificios), imágenes médicas (tumor vs. tejido sano).
- Instancias: contar objetos, tracking de personas individuales.
- Panóptica: comprensión completa de la escena (ej. Cityscapes, COCO Panoptic).
🌍 Aplicaciones de la segmentación
La segmentación es una de las tareas más demandadas en visión por computador. Algunas aplicaciones clave:
Imagen médica
Segmentación de tumores en CT/MRI, retina en fundoscopia, células en microscopía. U-Net fue diseñada originalmente para este dominio.
Conducción autónoma
Segmentación de carretera, peatones, semáforos, señales. Crítico para la percepción del entorno en vehículos autónomos.
Teledetección
Clasificación de uso del suelo, detección de deforestación, mapeo de inundaciones desde imágenes satelitales.
Fotografía móvil
Modo retrato (desenfoque de fondo), eliminación de fondo en videollamadas, edición selectiva de fotos.
🔬 Técnicas clásicas de segmentación
Antes de las redes neuronales, la segmentación se basaba en técnicas de procesamiento de imagen que explotan diferencias de color, intensidad o textura. Aunque hoy las redes profundas dominan, estos métodos siguen siendo útiles como preprocesamiento, para datasets pequeños, o cuando la interpretabilidad es clave.
Umbralización (thresholding)
La técnica más simple: se elige un umbral T y cada píxel se clasifica según su intensidad:
Funciona bien cuando hay buen contraste entre el objeto y el fondo. Sin embargo, elegir T manualmente es poco robusto.
Método de Otsu
El método de Otsu (1979) automatiza la elección del umbral. Busca el valor de T que maximiza la varianza entre clases (fondo vs. primer plano):
Donde \omega_0, \omega_1 son las proporciones de píxeles en cada clase y \mu_0, \mu_1 sus medias de intensidad. Se calcula para todos los posibles T \in [0, 255] y se elige el que maximiza \sigma_B^2.
🧪 Simulador de umbralización
Detección de bordes
Los bordes son transiciones abruptas de intensidad que a menudo delimitan objetos. Los operadores clásicos calculan el gradiente de la imagen:
- Sobel: kernels de 3×3 que aproximan las derivadas parciales.
- Canny: pipeline más sofisticado (suavizado Gaussiano → gradiente → supresión de no-máximos → doble umbral + histéresis).
- Laplaciano de Gausianas (LoG): detecta cruces por cero de la segunda derivada.
Limitación: la detección de bordes produce contornos, no regiones etiquetadas. Para obtener una segmentación completa, los contornos deben cerrarse y rellenarse — un problema difícil en imágenes complejas.
Watershed (línea divisoria de aguas)
Inspirado en la topografía: se interpreta la imagen como un relieve donde la intensidad es la altitud. Se «inundan» los valles y las líneas divisorias donde se encuentran aguas de distintas cuencas definen los contornos de los objetos.
Tip: el watershed por sí solo tiende a sobre-segmentar. Se suele usar con marcadores definidos manualmente o con otro método (ej. umbrales + operaciones morfológicas) para reducir el problema.
Superpíxeles
Los superpíxeles agrupan píxeles vecinos con color y textura similar en regiones compactas. No son una segmentación semántica, pero reducen la complejidad de la imagen de millones de píxeles a cientos o miles de regiones homogéneas.
El algoritmo más popular es SLIC (Simple Linear Iterative Clustering), que aplica k-means en el espacio (L, a, b, x, y) combinando color (CIELAB) y posición espacial.
📋 Comparación de técnicas clásicas
| Técnica | Ventajas | Limitaciones |
|---|---|---|
| Umbralización | Muy rápida, simple | Solo funciona con buen contraste; binaria |
| Otsu | Automática, óptima para bi-modal | Asume distribución bimodal de intensidades |
| Detección de bordes | Buena para contornos nítidos | Contornos abiertos; sensible a ruido |
| Watershed | Sigue contornos naturales | Sobre-segmentación; necesita marcadores |
| SLIC (superpíxeles) | Reduce complejidad; preserva bordes | No es segmentación semántica; preproceso |
🧠 ¿Por qué Deep Learning para segmentación?
Las técnicas clásicas funcionan bien en escenarios controlados, pero tienen limitaciones fundamentales en imágenes naturales complejas:
✅ Ventajas del deep learning
- Aprende features automáticamente (no hay que diseñar descriptores)
- Robustez a variaciones de iluminación, escala y oclusión
- Capacidad de manejar múltiples clases simultáneamente
- Resultados estado del arte en todos los benchmarks
⚠️ Requisitos
- Necesita grandes datasets anotados (cada píxel etiquetado)
- Computacionalmente costoso (GPU necesaria)
- Anotación a nivel de píxel es muy cara
- Modelos grandes, más difíciles de interpretar
Dato: anotar una imagen para clasificación lleva ~1 segundo. Para detección (bounding box), ~10 segundos. Para segmentación a nivel de píxel, ~1-5 minutos por imagen. Esto hace que los datasets de segmentación sean mucho más costosos de crear.
🔄 FCN: de clasificación a segmentación
En 2015, Long et al. propusieron las Fully Convolutional Networks (FCN), la primera arquitectura deep learning exitosa para segmentación semántica. La idea revolucionaria fue simple pero potente: reemplazar las capas fully-connected de una red de clasificación por capas convolucionales, permitiendo que la red acepte imágenes de cualquier tamaño y produzca un mapa de probabilidades denso.
La idea clave
Una CNN de clasificación como VGG-16 tiene dos partes:
- Backbone convolucional: extrae features (ej. 7×7×512 en VGG-16)
- Cabeza FC (fully-connected): aplana las features y produce un vector de clases
El problema de las capas FC es que destruyen la información espacial: sabemos qué hay en la imagen, pero no dónde. La solución de FCN fue convertir las capas FC en convoluciones 1×1:
Insight clave: una convolución 1×1 con C filtros produce un tensor de H' \times W' \times C, donde cada posición espacial tiene un vector de C valores — es decir, una predicción por clase en cada posición.
Upsampling: recuperar la resolución
Después de pasar por el backbone, el mapa de features tiene una resolución mucho menor que la imagen original (ej. 7×7 para una entrada de 224×224, un factor ×32). Para obtener una predicción a la resolución original, FCN usa convoluciones transpuestas (transposed convolutions) para hacer upsampling:
Donde s es el stride, p el padding y k el tamaño del kernel. Con s=2, k=4, p=1 se consigue duplicar la resolución espacial.
🧪 Explorador de arquitectura FCN
Variantes: FCN-32s, FCN-16s, FCN-8s
El paper original propuso tres variantes que difieren en la granularidad del upsampling:
| Variante | Skip connections | Upsampling | Calidad |
|---|---|---|---|
| FCN-32s | Ninguna | ×32 directamente | Borrosa, bordes imprecisos |
| FCN-16s | pool4 (×16) | ×2 + suma + ×16 | Mejor resolución de bordes |
| FCN-8s | pool4 + pool3 (×8) | ×2 + suma + ×2 + suma + ×8 | Bordes más finos y detallados |
Skip connections en FCN
Las skip connections son la innovación que hace a FCN-8s muy superior a FCN-32s. La idea es combinar predicciones de distintas profundidades:
- Las capas profundas tienen features semánticamente ricas pero espacialmente burdas (saben qué hay pero no dónde exactamente).
- Las capas superficiales tienen features con alta resolución espacial pero menos semántica (saben dónde pero no qué).
- Al sumar ambas, se obtiene lo mejor de cada una.
Concepto fundamental: las skip connections de FCN son el precursor directo de la arquitectura U-Net. La idea de combinar features de alta y baja resolución se convertirá en un principio de diseño central en segmentación.
📊 Resultados de FCN en PASCAL VOC 2012
| Método | mIoU (%) | Mejora |
|---|---|---|
| SDS (mejor método clásico) | 51.6 | — |
| FCN-32s | 59.4 | +7.8 |
| FCN-16s | 62.4 | +10.8 |
| FCN-8s | 62.7 | +11.1 |
FCN mejoró el estado del arte en ~11 puntos de mIoU, demostrando la superioridad del deep learning para segmentación semántica.
Limitaciones de FCN
A pesar de su impacto, FCN tiene varias limitaciones que motivaron arquitecturas posteriores:
- Upsampling brusco: incluso FCN-8s produce bordes poco definidos porque el upsampling ×8 es grande.
- Sin contexto global: no captura relaciones a larga distancia en la imagen.
- No aprovecha features intermedias: solo usa 2-3 niveles de skip connections.
- Entrenamiento complejo: requiere inicializar desde un modelo pre-entrenado (VGG-16).
Estas limitaciones llevaron al desarrollo de la arquitectura encoder-decoder y, en particular, de U-Net, que veremos en la siguiente sección.
🏗️ Arquitectura Encoder-Decoder
La arquitectura encoder-decoder es el paradigma dominante en segmentación semántica. La idea es intuitiva:
- Encoder (codificador): comprime la imagen progresivamente, extrayendo features cada vez más abstractas y reduciendo la resolución espacial.
- Decoder (decodificador): expande las features comprimidas de vuelta a la resolución original, generando la máscara de segmentación.
H×W×3
↓ resolución
↑ semántica
h×w×D
↑ resolución
mantiene semántica
H×W×C
A diferencia de FCN, donde el upsampling es un paso final brusco (×8, ×16 o ×32), el decoder reconstruye gradualmente la resolución, permitiendo refinamientos progresivos.
🔗 Conexión con Autoencoders
La arquitectura encoder-decoder para segmentación está inspirada en los autoencoders, un tipo de red neuronal que aprende a comprimir y reconstruir datos:
z = f(x)
(comprimido)
\hat{x} = g(z)
La diferencia clave entre un autoencoder y una red de segmentación es:
| Aspecto | Autoencoder | Segmentación |
|---|---|---|
| Objetivo | Reconstruir la entrada | Predecir una máscara de clases |
| Salida | Imagen (misma forma que entrada) | Mapa de C canales (uno por clase) |
| Loss | MSE, L1 (reconstrucción) | Cross-entropy, Dice (clasificación por píxel) |
| Entrenamiento | No supervisado (auto-aprendizaje) | Supervisado (con máscaras ground truth) |
Los autoencoders son un tema amplio con aplicaciones en generación, compresión y representación. Puedes profundizar en ellos en nuestro módulo de IA Generativa → Autoencoders.
🏆 U-Net: la arquitectura estrella
U-Net (Ronneberger et al., 2015) es probablemente la arquitectura más influyente en segmentación. Diseñada originalmente para segmentación de imágenes biomédicas, su elegancia y efectividad la han convertido en el estándar de facto para muchas tareas de segmentación.
¿Por qué se llama «U-Net»?
El nombre viene de la forma de U que tiene la arquitectura cuando se dibuja: el encoder desciende por la izquierda, el bottleneck está en el fondo, y el decoder asciende por la derecha, con conexiones horizontales (skip connections) que cruzan de izquierda a derecha.
🧪 Explorador interactivo de U-Net
Camino contractivo (Encoder)
El encoder de U-Net sigue un patrón clásico de CNN, con 4 bloques de downsampling. Cada bloque contiene:
- Dos convoluciones 3×3 + ReLU (sin padding en el paper original)
- Max pooling 2×2 con stride 2 (reduce resolución a la mitad)
El número de canales se duplica en cada bloque: 64 → 128 → 256 → 512 → 1024 (bottleneck).
Bottleneck
En la base de la U, el bottleneck procesa las features más comprimidas (la resolución más baja y el mayor número de canales: 1024). Consta de dos convoluciones 3×3 + ReLU, sin max pooling.
Camino expansivo (Decoder)
El decoder reconstruye la resolución progresivamente. Cada bloque:
- Upsampling 2×2 (convolución transpuesta) que duplica la resolución y reduce canales a la mitad
- Concatenación con el mapa de features correspondiente del encoder (skip connection)
- Dos convoluciones 3×3 + ReLU
Skip connections: la clave del éxito
Las skip connections son lo que hace a U-Net especial. En lugar de sumar features como en FCN, U-Net las concatena:
🔗 Concatenar (U-Net)
- Preserva toda la información del encoder
- El decoder puede aprender qué información usar
- Más parámetros pero más capacidad expresiva
- Resultado: [\hat{x}_{up};\; x_{enc}] \in \mathbb{R}^{H \times W \times (C_1+C_2)}
➕ Sumar (FCN)
- Fuerza a encoder y decoder a tener los mismos canales
- Pierde información si las features son muy distintas
- Menos parámetros, más eficiente
- Resultado: \hat{x}_{up} + x_{enc} \in \mathbb{R}^{H \times W \times C}
¿Por qué las skip connections son tan importantes? El encoder pierde detalles espaciales finos (bordes, texturas) al reducir resolución. Las skip connections cortocircuitan esa información directamente al decoder, permitiéndole reconstruir bordes precisos. Sin ellas, el decoder tendría que «adivinar» dónde están los bordes a partir del bottleneck comprimido.
Capa final
La última capa de U-Net es una convolución 1×1 que mapea los 64 canales del último bloque del decoder a C canales (uno por clase). Se aplica softmax (o sigmoid para segmentación binaria) para obtener probabilidades:
📐 Dimensiones exactas de U-Net original (572×572)
| Etapa | Operación | Tamaño salida | Canales |
|---|---|---|---|
| Input | — | 572×572 | 1 |
| Enc 1 | 2×Conv 3×3 | 568×568 | 64 |
| Pool 1 | MaxPool 2×2 | 284×284 | 64 |
| Enc 2 | 2×Conv 3×3 | 280×280 | 128 |
| Pool 2 | MaxPool 2×2 | 140×140 | 128 |
| Enc 3 | 2×Conv 3×3 | 136×136 | 256 |
| Pool 3 | MaxPool 2×2 | 68×68 | 256 |
| Enc 4 | 2×Conv 3×3 | 64×64 | 512 |
| Pool 4 | MaxPool 2×2 | 32×32 | 512 |
| Bottleneck | 2×Conv 3×3 | 28×28 | 1024 |
| Up 4 | UpConv 2×2 + Crop&Concat | 56×56 | 512+512 |
| Dec 4 | 2×Conv 3×3 | 52×52 | 512 |
| Up 3 | UpConv 2×2 + Crop&Concat | 104×104 | 256+256 |
| Dec 3 | 2×Conv 3×3 | 100×100 | 256 |
| Up 2 | UpConv 2×2 + Crop&Concat | 200×200 | 128+128 |
| Dec 2 | 2×Conv 3×3 | 196×196 | 128 |
| Up 1 | UpConv 2×2 + Crop&Concat | 392×392 | 64+64 |
| Dec 1 | 2×Conv 3×3 | 388×388 | 64 |
| Output | Conv 1×1 | 388×388 | C |
Nota: el U-Net original usa convoluciones sin padding, por lo que la salida (388×388) es menor que la entrada (572×572). Las implementaciones modernas usan padding='same' para mantener la resolución.
¿Qué hizo especial a U-Net?
- Funciona con pocos datos: diseñada para imagen biomédica donde las muestras anotadas son escasas. Data augmentation agresiva compensa.
- Skip connections por concatenación: preservan toda la información espacial del encoder.
- Decoder simétrico: el decoder es un espejo del encoder, creando una arquitectura balanceada y elegante.
- Loss ponderada: los píxeles de borde entre objetos cercanos reciben mayor peso, mejorando la separación de instancias tocándose.
Donde d_1 y d_2 son las distancias al borde de las dos células más cercanas. Los píxeles entre células cercanas reciben mayor peso, forzando a la red a aprender a separarlas.
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
"""Bloque básico de U-Net: (Conv3x3 → BN → ReLU) × 2"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=2):
super().__init__()
# Encoder (camino contractivo)
self.enc1 = DoubleConv(in_channels, 64)
self.enc2 = DoubleConv(64, 128)
self.enc3 = DoubleConv(128, 256)
self.enc4 = DoubleConv(256, 512)
self.pool = nn.MaxPool2d(2)
# Bottleneck
self.bottleneck = DoubleConv(512, 1024)
# Decoder (camino expansivo)
self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.dec4 = DoubleConv(1024, 512) # 512+512 por concat
self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.dec3 = DoubleConv(512, 256) # 256+256 por concat
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec2 = DoubleConv(256, 128) # 128+128 por concat
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec1 = DoubleConv(128, 64) # 64+64 por concat
# Clasificador final (1×1 conv)
self.final = nn.Conv2d(64, num_classes, 1)
def forward(self, x):
# Encoder
e1 = self.enc1(x) # → 64 canales
e2 = self.enc2(self.pool(e1)) # → 128 canales
e3 = self.enc3(self.pool(e2)) # → 256 canales
e4 = self.enc4(self.pool(e3)) # → 512 canales
# Bottleneck
b = self.bottleneck(self.pool(e4)) # → 1024 canales
# Decoder con skip connections
d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
return self.final(d1) # → num_classes canales
# Ejemplo de uso
model = UNet(in_channels=3, num_classes=21) # 21 clases PASCAL VOC
x = torch.randn(1, 3, 256, 256)
out = model(x)
print(f"Input: {x.shape}") # [1, 3, 256, 256]
print(f"Output: {out.shape}") # [1, 21, 256, 256]
import tensorflow as tf
from tensorflow.keras import layers, Model
def double_conv(x, filters):
"""Bloque (Conv3x3 → BN → ReLU) × 2"""
x = layers.Conv2D(filters, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2D(filters, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
return x
def build_unet(input_shape=(256, 256, 3), num_classes=2):
inputs = layers.Input(shape=input_shape)
# Encoder
e1 = double_conv(inputs, 64)
p1 = layers.MaxPool2D(2)(e1)
e2 = double_conv(p1, 128)
p2 = layers.MaxPool2D(2)(e2)
e3 = double_conv(p2, 256)
p3 = layers.MaxPool2D(2)(e3)
e4 = double_conv(p3, 512)
p4 = layers.MaxPool2D(2)(e4)
# Bottleneck
b = double_conv(p4, 1024)
# Decoder con skip connections (concatenación)
u4 = layers.Conv2DTranspose(512, 2, strides=2, padding='same')(b)
u4 = layers.Concatenate()([u4, e4])
d4 = double_conv(u4, 512)
u3 = layers.Conv2DTranspose(256, 2, strides=2, padding='same')(d4)
u3 = layers.Concatenate()([u3, e3])
d3 = double_conv(u3, 256)
u2 = layers.Conv2DTranspose(128, 2, strides=2, padding='same')(d3)
u2 = layers.Concatenate()([u2, e2])
d2 = double_conv(u2, 128)
u1 = layers.Conv2DTranspose(64, 2, strides=2, padding='same')(d2)
u1 = layers.Concatenate()([u1, e1])
d1 = double_conv(u1, 64)
# Capa final
outputs = layers.Conv2D(num_classes, 1, activation='softmax')(d1)
return Model(inputs, outputs, name='UNet')
model = build_unet(num_classes=21)
model.summary() # ~31M parámetros
🔀 Variantes de U-Net
El éxito de U-Net inspiró una familia de arquitecturas que mantienen la estructura encoder-decoder con skip connections pero introducen mejoras específicas. Veamos las más relevantes.
🧪 Comparador de variantes
Attention U-Net
Attention U-Net (Oktay et al., 2018) introduce attention gates en las skip connections para que el modelo aprenda a enfocarse en las regiones relevantes y suprimir features irrelevantes.
En U-Net estándar, las skip connections pasan todas las features del encoder al decoder. Pero no toda esa información es útil — en imagen médica, por ejemplo, gran parte del fondo es irrelevante. Los attention gates aprenden a «filtrar» qué features pasar:
Donde x_i son las features del encoder, g_i la señal del decoder (gating signal), y \alpha_i \in [0,1] es el coeficiente de atención que pondera cada feature espacial. La salida es \hat{x}_i = \alpha_i \cdot x_i.
Intuición: el decoder «le dice» al encoder qué regiones son importantes en esta etapa, y las features del encoder se ponderan en consecuencia. Es como un mecanismo de «atención» que silencia el ruido de fondo.
U-Net++ (Nested U-Net)
U-Net++ (Zhou et al., 2018) rediseña las skip connections añadiendo bloques convolucionales intermedios entre el encoder y el decoder. En lugar de conexiones directas, crea una red densa de sub-redes anidadas.
La intuición es que el gap semántico entre encoder y decoder puede ser grande (features en escala ×1 vs. features en escala ×16). Los nodos intermedios de U-Net++ reducen progresivamente ese gap:
Donde [\cdot] denota concatenación y \mathcal{H} es un bloque convolucional. Cada nodo recibe features de todos los nodos anteriores en su nivel, más el upsampling del nivel inferior.
🎓 Deep supervision en U-Net++
U-Net++ permite deep supervision: se pueden generar predicciones desde múltiples niveles de la red, no solo desde el final. Esto tiene dos ventajas:
- Entrenamiento más estable: los gradientes llegan mejor a todas las capas.
- Poda en inferencia: si las predicciones de un nivel intermedio son suficientemente buenas, se puede podar el resto de la red para mayor velocidad.
V-Net: segmentación 3D
V-Net (Milletari et al., 2016) adapta la idea de U-Net a datos volumétricos 3D (como CT scans o MRI). Las principales diferencias:
- Convoluciones 3D en lugar de 2D (kernels 5×5×5)
- Skip connections residuales (suma) en lugar de concatenación
- Dice loss como función de pérdida (veremos esto en detalle más adelante)
- Convoluciones strided en lugar de max pooling para downsampling
¿Por qué 3D? En imagen médica, los datos a menudo son volúmenes (ej. una secuencia de cortes de CT con resolución 512×512×200). Procesar cada corte independientemente (2D) pierde la información inter-slice. V-Net procesa el volumen completo, capturando el contexto 3D.
📐 SegNet
SegNet (Badrinarayanan et al., 2017) es otra arquitectura encoder-decoder, pero con un enfoque distinto para el upsampling: usa los índices del max pooling del encoder para guiar el unpooling en el decoder.
(VGG-16)
pooling
indices
(max unpooling)
La ventaja de usar los pooling indices es que es muy eficiente en memoria: no se necesitan almacenar los feature maps completos del encoder (como en U-Net con concatenación), sino solo los índices de las posiciones de los máximos.
| Aspecto | U-Net | SegNet |
|---|---|---|
| Skip info | Feature maps completos (concatenación) | Solo pooling indices (posiciones) |
| Memoria | Alta (almacena features del encoder) | Baja (solo índices enteros) |
| Calidad | Superior (más información disponible) | Buena, pero menos detallada |
| Encoder | Propio (entrenado desde cero) | VGG-16 pre-entrenado |
| Aplicación típica | Imagen médica | Conducción autónoma, scenes |
📊 Resumen comparativo de todas las variantes
| Arquitectura | Año | Skip connection | Innovación principal | Uso típico |
|---|---|---|---|---|
| U-Net | 2015 | Concatenación | Encoder-decoder simétrico | Biomédica |
| SegNet | 2017 | Pooling indices | Eficiencia de memoria | Escenas urbanas |
| Attention U-Net | 2018 | Atención + concat | Attention gates | Biomédica |
| U-Net++ | 2018 | Dense nested | Skip intermedios densos | Biomédica |
| V-Net | 2016 | Residual (suma) | Conv 3D, Dice loss | Volúmenes 3D |
🔍 Convoluciones dilatadas (Atrous)
Antes de ver DeepLab, necesitamos entender la convolución dilatada (o atrous convolution), que es su pieza fundamental.
Una convolución estándar 3×3 tiene un campo receptivo de 3×3. Para capturar contexto más amplio, normalmente apilamos muchas capas o usamos pooling. La convolución dilatada ofrece una alternativa: inserta «huecos» (dilation) entre los pesos del kernel.
Donde r es la tasa de dilatación (dilation rate). Con r=1 tenemos una convolución estándar. Con r=2, el kernel 3×3 tiene un campo receptivo efectivo de 5×5. Con r=4, el campo receptivo es de 9×9.
🧪 Visualizador de convolución dilatada
Ventaja clave: la convolución dilatada aumenta el campo receptivo sin perder resolución (no hay pooling) y sin aumentar parámetros (el kernel sigue teniendo 3×3 = 9 pesos). Es «ver más lejos con la misma lupa».
🏛️ DeepLab: segmentación de alta resolución
La familia DeepLab (Chen et al., Google) es una de las líneas de investigación más influyentes en segmentación semántica. Veamos su evolución:
DeepLab v1 y v2
DeepLab v1 (2015) introdujo el uso de convoluciones dilatadas en un backbone pre-entrenado (VGG-16), manteniendo una resolución de salida de ×8 en lugar de ×32. También incorporó un CRF (Conditional Random Field) como post-procesamiento para refinar bordes.
DeepLab v2 (2017) mejoró con dos innovaciones:
- Backbone más potente: cambió VGG por ResNet-101
- ASPP (Atrous Spatial Pyramid Pooling): aplica convoluciones dilatadas con múltiples tasas en paralelo para capturar contexto a distintas escalas
ASPP: capturando contexto multi-escala
ASPP es la innovación clave de DeepLab. Aplica varias convoluciones dilatadas en paralelo con diferentes tasas de dilatación (r = 6, 12, 18), capturando contexto a múltiples escalas simultáneamente:
DeepLab v3+: el estado del arte
DeepLab v3+ (2018) combinó ASPP con un decoder ligero, creando una arquitectura encoder-decoder donde:
- Encoder: backbone (ResNet/Xception) con convoluciones dilatadas + módulo ASPP
- Decoder: un módulo simple que combina features de baja resolución (del ASPP) con features de alta resolución (del backbone temprano) mediante concatenación y convoluciones
El resultado es un modelo que captura tanto contexto global (ASPP) como detalles finos (skip connection del encoder), logrando resultados excelentes en PASCAL VOC y Cityscapes.
📊 Evolución de DeepLab en PASCAL VOC 2012
| Versión | Backbone | Innovación | mIoU (%) |
|---|---|---|---|
| DeepLab v1 | VGG-16 | Atrous conv + CRF | 71.6 |
| DeepLab v2 | ResNet-101 | ASPP + CRF | 79.7 |
| DeepLab v3 | ResNet-101 | Improved ASPP, no CRF | 85.7 |
| DeepLab v3+ | Xception-65 | Encoder-decoder + ASPP | 87.8 |
🔮 PSPNet: Pyramid Scene Parsing
PSPNet (Zhao et al., 2017) aborda el problema del contexto global con un módulo de Pyramid Pooling. La motivación: cuando un modelo ve solo un trozo de un objeto grande (ej. un barco visto de cerca), puede confundirlo con algo pequeño (ej. un coche) si le falta contexto de la escena completa.
El Pyramid Pooling Module (PPM) aplica average pooling a 4 escalas distintas (1×1, 2×2, 3×3, 6×6), generando representaciones a diferentes niveles de granularidad que se concatenan con las features originales:
(backbone)
+ Concat
Diferencia con ASPP: ASPP usa convoluciones dilatadas (contexto local variable), mientras que PPM usa pooling global a distintas escalas (contexto genuinamente global). Ambos capturan multi-escala, pero de forma complementaria.
🎭 Mask R-CNN: segmentación por instancias
Todas las arquitecturas anteriores realizan segmentación semántica (no distinguen instancias). Mask R-CNN (He et al., 2017) extiende Faster R-CNN para realizar segmentación por instancias: detecta cada objeto y genera una máscara de segmentación para cada uno.
Arquitectura
Mask R-CNN añade una rama de máscara paralela a las ramas existentes de clasificación y regresión de bounding box de Faster R-CNN:
(ResNet+FPN)
(propuestas)
RoIAlign: alineamiento preciso
Una innovación clave de Mask R-CNN es RoIAlign, que reemplaza el RoI Pooling de Faster R-CNN. El RoI Pooling original cuantiza las coordenadas a enteros, causando desalineamientos de 1-2 píxeles — insignificante para bounding boxes pero catastrófico para máscaras a nivel de píxel.
RoIAlign usa interpolación bilineal para muestrear las features en coordenadas continuas, eliminando la cuantización.
Insight: el desacoplamiento de las tres tareas (clasificación, bbox, máscara) es fundamental. La rama de máscara predice C máscaras binarias (una por clase) sin competencia entre clases — la clasificación de clase la decide otra rama. Esto elimina interferencia y mejora la calidad de las máscaras.
🌐 Segmentación panóptica
La segmentación panóptica (Kirillov et al., 2019) unifica segmentación semántica e instancias en un solo marco. Para cada píxel, se predice:
- Clase semántica (para todo: stuff y things)
- ID de instancia (solo para things, es decir, objetos contables)
| Categoría | Ejemplos | Tratamiento |
|---|---|---|
| Things (objetos contables) | Personas, coches, perros, sillas | Clase + ID de instancia |
| Stuff (regiones amorfas) | Cielo, carretera, hierba, pared | Solo clase (sin instancias) |
La métrica principal es PQ (Panoptic Quality), que combina reconocimiento y calidad de segmentación:
Tendencia actual: los modelos más recientes como Mask2Former (2022) y SAM (Segment Anything, 2023) unifican las tres tareas de segmentación con arquitecturas basadas en Transformers, superando a todas las CNNs puras.
📉 Funciones de pérdida para segmentación
La elección de la función de pérdida es crítica en segmentación, especialmente cuando hay desbalance de clases (mucho más fondo que objeto). Veamos las más utilizadas.
Cross-Entropy píxel a píxel
La extensión natural de la cross-entropy de clasificación. Se calcula independientemente para cada píxel y se promedia:
Donde y_{c,h,w} es el ground truth one-hot y \hat{p}_{c,h,w} es la probabilidad predicha para la clase c en la posición (h,w).
Problema: cuando una clase ocupa el 95% de la imagen (ej. fondo), la red puede predecir «todo es fondo» y conseguir una loss baja. La cross-entropy no penaliza esto suficientemente.
Dice Loss
La Dice loss está basada en el coeficiente de Sørensen-Dice, una medida de solapamiento entre dos conjuntos. Mide directamente cuánto se «parecen» la predicción y el ground truth:
\epsilon (ej. 10^{-6}) evita divisiones por cero. La Dice loss es intrínsecamente robusta al desbalance porque normaliza por el tamaño del objeto.
Focal Loss
La Focal loss (Lin et al., 2017) modifica la cross-entropy para reducir la contribución de los píxeles fáciles (clasificados con alta confianza) y enfocarse en los difíciles:
Donde \gamma \geq 0 es el parámetro de enfoque. Con \gamma = 0 es cross-entropy estándar. Con \gamma = 2 (valor típico), un píxel clasificado con \hat{p}_t = 0.9 recibe un peso 100× menor que uno con \hat{p}_t = 0.1.
Tversky Loss
La Tversky loss generaliza la Dice loss con pesos independientes para falsos positivos (FP) y falsos negativos (FN):
Con \alpha = \beta = 0.5 se reduce a Dice loss. Aumentar \beta > \alpha penaliza más los falsos negativos, útil cuando «no detectar» un tumor es peor que una falsa alarma.
🧪 Comparador de funciones de pérdida
💡 Combinaciones comunes de loss
En la práctica, es común combinar varias funciones de pérdida:
La cross-entropy proporciona gradientes estables para el entrenamiento, mientras que la Dice loss optimiza directamente la métrica de solapamiento. Un ratio típico es \lambda_1 = 1, \lambda_2 = 1.
📏 Métricas de evaluación
La accuracy (proporción de píxeles correctos) no es útil en segmentación debido al desbalance de clases. Las métricas estándar son:
IoU (Intersection over Union)
También llamada Jaccard Index. Mide el solapamiento entre la predicción y el ground truth para cada clase:
mIoU (Mean IoU)
La métrica principal en segmentación semántica. Es la media de IoU sobre todas las clases:
Coeficiente Dice / F1
Relación con IoU: \text{Dice} = \frac{2 \cdot \text{IoU}}{1 + \text{IoU}}. El Dice siempre es ≥ IoU para el mismo par (predicción, ground truth).
🧪 Calculadora de IoU y Dice
| Métrica | Rango | Uso principal | Sensibilidad |
|---|---|---|---|
| Pixel Accuracy | [0, 1] | Visión general rápida | Engañada por desbalance |
| IoU / Jaccard | [0, 1] | Benchmark estándar (mIoU) | Más estricta que Dice |
| Dice / F1 | [0, 1] | Imagen médica | Más tolerante que IoU |
| PQ | [0, 1] | Segmentación panóptica | Combina reconocimiento + calidad |
🔧 Data augmentation para segmentación
El data augmentation es especialmente importante en segmentación porque los datasets anotados a nivel de píxel son pequeños. Pero hay una diferencia clave: las transformaciones geométricas deben aplicarse también a la máscara.
| Transformación | ¿Aplica a la máscara? | Notas |
|---|---|---|
| Flip horizontal/vertical | ✅ Sí | Misma transformación |
| Rotación | ✅ Sí | Usar interpolación nearest para máscara |
| Escala / Crop | ✅ Sí | Misma región de recorte |
| Elastic deformation | ✅ Sí | Mismo campo de deformación |
| Cambio de brillo/contraste | ❌ No | Solo afecta a la imagen |
| Gaussian blur | ❌ No | Solo afecta a la imagen |
| Color jitter | ❌ No | Solo afecta a la imagen |
Importante: al aplicar interpolación a la máscara, se debe usar interpolación nearest-neighbor, nunca bilineal. La bilineal crearía valores intermedios (ej. 1.5 entre clase 1 y clase 2) que no tienen sentido como etiquetas.
💻 Ejemplo práctico: entrenamiento de U-Net
Veamos un ejemplo completo de cómo entrenar un U-Net para segmentación binaria (ej. tumor vs. fondo) usando combinación de Cross-Entropy + Dice loss.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
# ── Loss combinada ──────────────────────────────────────────────
class DiceBCELoss(nn.Module):
"""Combina Binary Cross-Entropy y Dice Loss."""
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
self.bce = nn.BCEWithLogitsLoss()
def forward(self, logits, targets):
# BCE loss
bce_loss = self.bce(logits, targets)
# Dice loss (sobre probabilidades)
probs = torch.sigmoid(logits)
intersection = (probs * targets).sum(dim=(2, 3))
union = probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))
dice = (2. * intersection + self.smooth) / (union + self.smooth)
dice_loss = 1 - dice.mean()
return bce_loss + dice_loss
# ── Dataset de segmentación ─────────────────────────────────────
class SegmentationDataset(Dataset):
def __init__(self, images, masks, transform=None):
self.images = images # lista de arrays HxWx3
self.masks = masks # lista de arrays HxW (0 o 1)
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx] # HxWx3, float32
mask = self.masks[idx] # HxW, float32
if self.transform:
# Aplicar MISMA transformación a imagen y máscara
seed = np.random.randint(2147483647)
torch.manual_seed(seed)
image = self.transform(image)
torch.manual_seed(seed)
mask = self.transform(mask.unsqueeze(0)) # 1xHxW
return image, mask
# ── Entrenamiento ───────────────────────────────────────────────
def train_unet(model, train_loader, val_loader, epochs=50, lr=1e-4):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = DiceBCELoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=5, factor=0.5
)
best_dice = 0.0
for epoch in range(epochs):
# ── Train ──
model.train()
train_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images) # [B, 1, H, W]
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
train_loss += loss.item()
# ── Validation ──
model.eval()
val_dice = 0.0
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
outputs = torch.sigmoid(model(images))
preds = (outputs > 0.5).float()
# Calcular Dice
intersection = (preds * masks).sum(dim=(2, 3))
union = preds.sum(dim=(2, 3)) + masks.sum(dim=(2, 3))
dice = (2. * intersection + 1e-6) / (union + 1e-6)
val_dice += dice.mean().item()
avg_train_loss = train_loss / len(train_loader)
avg_val_dice = val_dice / len(val_loader)
scheduler.step(avg_train_loss)
print(f"Epoch {epoch+1}/{epochs} | "
f"Loss: {avg_train_loss:.4f} | "
f"Val Dice: {avg_val_dice:.4f}")
if avg_val_dice > best_dice:
best_dice = avg_val_dice
torch.save(model.state_dict(), 'best_unet.pth')
print(f" → Nuevo mejor modelo (Dice: {best_dice:.4f})")
print(f"\nMejor Dice en validación: {best_dice:.4f}")
return model
import tensorflow as tf
from tensorflow.keras import backend as K
# ── Dice loss ───────────────────────────────────────────────────
def dice_loss(y_true, y_pred, smooth=1e-6):
y_pred = tf.sigmoid(y_pred)
intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + \
tf.reduce_sum(y_pred, axis=[1, 2, 3])
dice = (2. * intersection + smooth) / (union + smooth)
return 1 - tf.reduce_mean(dice)
def bce_dice_loss(y_true, y_pred):
bce = tf.keras.losses.binary_crossentropy(
y_true, y_pred, from_logits=True
)
return tf.reduce_mean(bce) + dice_loss(y_true, y_pred)
# ── Dice metric ─────────────────────────────────────────────────
def dice_coefficient(y_true, y_pred, smooth=1e-6):
y_pred = tf.cast(tf.sigmoid(y_pred) > 0.5, tf.float32)
intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
union = tf.reduce_sum(y_true + y_pred, axis=[1, 2, 3])
return tf.reduce_mean((2. * intersection + smooth) / (union + smooth))
# ── Compilar y entrenar ─────────────────────────────────────────
model = build_unet(input_shape=(256, 256, 3), num_classes=1)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
loss=bce_dice_loss,
metrics=[dice_coefficient]
)
callbacks = [
tf.keras.callbacks.ReduceLROnPlateau(patience=5, factor=0.5),
tf.keras.callbacks.EarlyStopping(patience=15, restore_best_weights=True),
tf.keras.callbacks.ModelCheckpoint('best_unet.keras', save_best_only=True)
]
history = model.fit(
train_dataset,
validation_data=val_dataset,
epochs=50,
callbacks=callbacks
)
🏁 Comparativa de arquitecturas
Para cerrar, una comparativa completa de todas las arquitecturas de segmentación que hemos visto:
🧪 Comparador de arquitecturas de segmentación
| Arquitectura | Año | Tipo | Parámetros | mIoU (VOC) | Innovación clave |
|---|---|---|---|---|---|
| FCN-8s | 2015 | Semántica | ~134M | 62.7% | Primera red fully-conv |
| U-Net | 2015 | Semántica | ~31M | — | Skip concat + simetría |
| SegNet | 2017 | Semántica | ~29M | — | Pooling indices |
| PSPNet | 2017 | Semántica | ~65M | 85.4% | Pyramid Pooling |
| Mask R-CNN | 2017 | Instancias | ~44M | — | RoIAlign + mask branch |
| DeepLab v3+ | 2018 | Semántica | ~41M | 87.8% | ASPP + decoder |
Resumen: la segmentación ha evolucionado desde FCN (2015) hasta modelos basados en Transformers como Mask2Former y SAM. Pero los principios fundamentales — encoder-decoder, skip connections, multi-escala y losses especializadas — siguen siendo la base sobre la que se construye todo.