🏭 Caso de Uso

VAE para Clasificación No Supervisada en MNIST

Autoencoder Variacional (VAE) con TensorFlow/Keras: aprendizaje de representaciones latentes, clustering con K-Means, generación de imágenes e interpolación en espacio latente.

🐍 Python 📓 Jupyter Notebook

Autoencoder Variacional (VAE) con TensorFlow/Keras en MNIST

Objetivo del notebook

En este notebook construiremos un Autoencoder Variacional (VAE) de principio a fin para trabajar con el dataset de MNIST y cubrir tres metas principales:

  1. Aprender una representación latente útil de imágenes de dígitos manuscritos sin usar etiquetas durante el entrenamiento del VAE.
  2. Usar ese espacio latente para hacer clasificación no supervisada mediante clustering (K-Means).
  3. Demostrar capacidades generativas del modelo:
    • Generar imágenes nuevas desde muestras aleatorias del espacio latente.
    • Interpolar imágenes entre dos dígitos reales para observar transiciones suaves.

Fundamentos matemáticos y computacionales

Un autoencoder clásico aprende una función de compresión y descompresión:

$$ x \to z \to \hat{x} $$

minimizando el error de reconstrucción $|x-\hat{x}|$. Esto reconstruye bien, pero no garantiza que el espacio latente sea suave ni muestreable: puede haber "huecos" donde el decoder produce resultados sin sentido.

El VAE añade una formulación probabilística que resuelve este problema:

  • Prior latente: $p(z)=\mathcal{N}(0, I)$ — asumimos que los códigos latentes siguen una normal estándar.
  • Encoder aproximado: $q_\phi(z|x)=\mathcal{N}(\mu_\phi(x), \sigma_\phi^2(x))$ — el encoder produce los parámetros de una distribución, no un punto fijo.
  • Decoder: $p_\theta(x|z)$ — reconstruye la imagen a partir de una muestra del espacio latente.

y optimiza la ELBO (Evidence Lower Bound):

$$ \log p(x) \ge \mathbb{E}{q\phi(z|x)}[\log p_\theta(x|z)] - D_{KL}(q_\phi(z|x)|p(z)) $$

En la práctica minimizamos la pérdida total del VAE, que se descompone en dos términos con roles opuestos y complementarios:

$$ \mathcal{L}{VAE}=\underbrace{\mathcal{L}{rec}}{\text{fidelidad de reconstrucción}}+\beta\underbrace{D{KL}}_{\text{regularización latente}} $$

"Castigos y recompensas" (intuición didáctica)

  • Recompensa (implícita): reconstruir bien la imagen (baja $\mathcal{L}_{rec}$).
  • Castigo 1 (reconstrucción): si el decoder no reconstruye detalles, sube $\mathcal{L}_{rec}$ — el modelo pierde información en la compresión.
  • Castigo 2 (variacional): si la distribución latente se aleja de $\mathcal{N}(0,I)$, sube el término KL — el modelo "memoriza" en vez de generalizar.

En otras palabras, el VAE negocia entre:

  • ser fiel a cada imagen (reconstrucción), y
  • mantener un espacio latente ordenado y continuo (regularización KL).

Este equilibrio es clave para poder generar nuevas imágenes coherentes e interpolar con transiciones suaves entre dígitos. Si solo optimizáramos reconstrucción, el espacio latente podría tener zonas vacías donde el decoder genera ruido; si solo optimizáramos KL, todas las imágenes colapsarían a la misma representación.


Modelos y dataset usados

  • Dataset: MNIST (70.000 imágenes 28×28 en escala de grises, 10 clases de dígitos).
  • Modelo principal: VAE denso (fully-connected) con espacio latente de 2 dimensiones (ideal para visualizar y explorar).
  • Modelo auxiliar para no supervisado: K-Means sobre las coordenadas $z_{mean}$ del espacio latente.
  • Métricas de clustering: Accuracy (vía mapeo por mayoría), NMI (Normalized Mutual Information) y ARI (Adjusted Rand Index).
[1]
# Instalación opcional (descomenta si tu entorno no tiene dependencias)
# !pip install -q tensorflow scikit-learn matplotlib seaborn

import os
# Forzar ejecución en CPU (evita errores del autotuner XLA/Triton en GPU)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
[2]

# Imports principales
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from sklearn.metrics import confusion_matrix

# Configuración visual y semilla para reproducibilidad
sns.set(style="whitegrid")
np.random.seed(42)
tf.random.set_seed(42)

print("TensorFlow:", tf.__version__)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1773755238.700838 3536069 port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
I0000 00:00:1773755238.729284 3536069 cpu_feature_guard.cc:227] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1773755239.429059 3536069 port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
TensorFlow: 2.21.0

1) Carga de datos y pequeño EDA (Exploratory Data Analysis)

[3]

# Cargamos MNIST desde Keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

print("Train:", x_train.shape, y_train.shape)
print("Test :", x_test.shape, y_test.shape)
Train: (60000, 28, 28) (60000,)
Test : (10000, 28, 28) (10000,)
[4]

# Estadísticos básicos de los píxeles (antes de normalizar)
print("Rango de pixel train:", x_train.min(), "->", x_train.max())
print("Media train:", x_train.mean(), "Desviación:", x_train.std())

# Distribución de clases (solo EDA; no se usan etiquetas para entrenar el VAE)
unique, counts = np.unique(y_train, return_counts=True)
for d, c in zip(unique, counts):
    print(f"Dígito {d}: {c} muestras")
Rango de pixel train: 0 -> 255
Media train: 33.318421449829934 Desviación: 78.56748998339798
Dígito 0: 5923 muestras
Dígito 1: 6742 muestras
Dígito 2: 5958 muestras
Dígito 3: 6131 muestras
Dígito 4: 5842 muestras
Dígito 5: 5421 muestras
Dígito 6: 5918 muestras
Dígito 7: 6265 muestras
Dígito 8: 5851 muestras
Dígito 9: 5949 muestras
[5]

# Visualizamos ejemplos por clase para entender variabilidad intra-clase
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
axes = axes.ravel()

for digit in range(10):
    idx = np.where(y_train == digit)[0][0]
    axes[digit].imshow(x_train[idx], cmap='gray')
    axes[digit].set_title(f"Clase {digit}")
    axes[digit].axis('off')

plt.suptitle("Primer ejemplo encontrado por clase")
plt.tight_layout()
plt.show()
Output
[6]

# Histograma de intensidades para entender contraste global del dataset
plt.figure(figsize=(8,4))
plt.hist(x_train.ravel(), bins=50, color='slateblue', alpha=0.85)
plt.title("Distribución de intensidades de píxel (train)")
plt.xlabel("Intensidad")
plt.ylabel("Frecuencia")
plt.show()
Output
[7]

# Preprocesado para la red:
# 1) normalizar a [0,1]
# 2) aplanar a vector de 784
x_train_f = x_train.astype("float32") / 255.0
x_test_f  = x_test.astype("float32") / 255.0

x_train_f = x_train_f.reshape((-1, 784))
x_test_f = x_test_f.reshape((-1, 784))

# Separamos validación desde train
x_val_f = x_train_f[-10000:]
y_val = y_train[-10000:]
x_train_f = x_train_f[:-10000]
y_train_small = y_train[:-10000]

print("Train final:", x_train_f.shape)
print("Val final  :", x_val_f.shape)
print("Test final :", x_test_f.shape)
Train final: (50000, 784)
Val final  : (10000, 784)
Test final : (10000, 784)

2) Definición del VAE (encoder, muestreo y decoder)

El VAE consta de tres componentes:

  • Encoder: recibe la imagen aplanada (784 dims) y produce dos vectores: $\mu$ (z_mean) y $\log \sigma^2$ (z_log_var), que parametrizan la distribución latente $q_\phi(z|x)$.
  • Capa de muestreo: aplica el reparameterization trick ($z = \mu + \sigma \odot \epsilon$, con $\epsilon \sim \mathcal{N}(0,I)$) para permitir la retropropagación a través del muestreo estocástico.
  • Decoder: recibe la muestra $z$ y reconstruye la imagen con activación sigmoide (salida en $[0,1]$).
[8]

# Capa de muestreo (reparameterization trick)
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        epsilon = tf.random.normal(shape=tf.shape(z_mean))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

latent_dim = 2
input_dim = 784
hidden_dim = 256

# Encoder
encoder_inputs = keras.Input(shape=(input_dim,))
x = layers.Dense(hidden_dim, activation="relu")(encoder_inputs)
x = layers.Dense(hidden_dim // 2, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])

encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
E0000 00:00:1773755240.481584 3536069 cuda_platform.cc:52] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
I0000 00:00:1773755240.481604 3536069 cuda_diagnostics.cc:160] env: CUDA_VISIBLE_DEVICES="-1"
I0000 00:00:1773755240.481611 3536069 cuda_diagnostics.cc:163] CUDA_VISIBLE_DEVICES is set to -1 - this hides all GPUs from CUDA
I0000 00:00:1773755240.481618 3536069 cuda_diagnostics.cc:171] verbose logging is disabled. Rerun with verbose logging (usually --v=1 or --vmodule=cuda_diagnostics=1) to get more diagnostic output from this module
I0000 00:00:1773755240.481618 3536069 cuda_diagnostics.cc:176] retrieving CUDA diagnostic information for host: tnp01-4090
I0000 00:00:1773755240.481621 3536069 cuda_diagnostics.cc:183] hostname: tnp01-4090
I0000 00:00:1773755240.481747 3536069 cuda_diagnostics.cc:190] libcuda reported version is: 580.126.9
I0000 00:00:1773755240.481755 3536069 cuda_diagnostics.cc:194] kernel reported version is: 580.126.9
I0000 00:00:1773755240.481756 3536069 cuda_diagnostics.cc:284] kernel version seems to match DSO: 580.126.9
Model: "encoder"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input_layer         │ (None, 784)       │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense (Dense)       │ (None, 256)       │    200,960 │ input_layer[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_1 (Dense)     │ (None, 128)       │     32,896 │ dense[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ z_mean (Dense)      │ (None, 2)         │        258 │ dense_1[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ z_log_var (Dense)   │ (None, 2)         │        258 │ dense_1[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ sampling (Sampling) │ (None, 2)         │          0 │ z_mean[0][0],     │
│                     │                   │            │ z_log_var[0][0]   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 234,372 (915.52 KB)
 Trainable params: 234,372 (915.52 KB)
 Non-trainable params: 0 (0.00 B)
[9]

# Decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(hidden_dim // 2, activation="relu")(latent_inputs)
x = layers.Dense(hidden_dim, activation="relu")(x)
decoder_outputs = layers.Dense(input_dim, activation="sigmoid")(x)

decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()
Model: "decoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer_1 (InputLayer)      │ (None, 2)              │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (None, 128)            │           384 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_3 (Dense)                 │ (None, 256)            │        33,024 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_4 (Dense)                 │ (None, 784)            │       201,488 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 234,896 (917.56 KB)
 Trainable params: 234,896 (917.56 KB)
 Non-trainable params: 0 (0.00 B)
[10]

# Modelo VAE custom para controlar losses y métricas de train/val
class VAE(keras.Model):
    def __init__(self, encoder, decoder, beta=1.0, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.beta = beta

        # Métricas rastreadas por Keras
        self.total_loss_tracker = keras.metrics.Mean(name="loss")
        self.rec_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        self.bin_acc = keras.metrics.BinaryAccuracy(name="binary_accuracy")

    @property
    def metrics(self):
        return [self.total_loss_tracker, self.rec_loss_tracker, self.kl_loss_tracker, self.bin_acc]

    def _compute_losses(self, data):
        """Calcula las tres pérdidas del VAE sobre un batch."""
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)

        # BCE por píxel: binary_crossentropy ya promedia sobre la última dimensión,
        # así que multiplicamos por el nº de features para obtener la suma sobre píxeles.
        bce_per_sample = keras.losses.binary_crossentropy(data, reconstruction)  # (batch,)
        n_features = tf.cast(tf.shape(data)[-1], tf.float32)
        rec_loss = tf.reduce_mean(bce_per_sample) * n_features

        kl_loss = -0.5 * tf.reduce_mean(
            tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1)
        )
        total_loss = rec_loss + self.beta * kl_loss
        return total_loss, rec_loss, kl_loss, reconstruction

    def train_step(self, data):
        # Keras puede pasar (x, None) cuando se usa validation_data=(x, None)
        if isinstance(data, tuple):
            data = data[0]

        with tf.GradientTape() as tape:
            total_loss, rec_loss, kl_loss, reconstruction = self._compute_losses(data)

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        self.total_loss_tracker.update_state(total_loss)
        self.rec_loss_tracker.update_state(rec_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.bin_acc.update_state(data, reconstruction)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.rec_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "binary_accuracy": self.bin_acc.result(),
        }

    def test_step(self, data):
        if isinstance(data, tuple):
            data = data[0]

        total_loss, rec_loss, kl_loss, reconstruction = self._compute_losses(data)

        self.total_loss_tracker.update_state(total_loss)
        self.rec_loss_tracker.update_state(rec_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        self.bin_acc.update_state(data, reconstruction)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.rec_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
            "binary_accuracy": self.bin_acc.result(),
        }

vae = VAE(encoder, decoder, beta=1.0)
vae.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3))

3) Entrenamiento y análisis de métricas (train/val)

Entrenamos el VAE durante 20 épocas con batches de 128 muestras. Usamos Adam como optimizador con learning rate $10^{-3}$. La pérdida total es la suma de la BCE de reconstrucción y el término KL de regularización, con $\beta=1.0$.

Las curvas posteriores desglosan la evolución de cada componente de la pérdida, lo que permite diagnosticar el equilibrio entre reconstrucción y regularización.

[11]

# Entrenamos el VAE
history = vae.fit(
    x_train_f,
    epochs=20,
    batch_size=128,
    validation_data=(x_val_f, None),
    verbose=1
)
Epoch 1/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - binary_accuracy: 0.7912 - kl_loss: 5.6189 - loss: 198.4038 - reconstruction_loss: 192.7850 - val_binary_accuracy: 0.7938 - val_kl_loss: 4.4163 - val_loss: 173.1055 - val_reconstruction_loss: 168.6893
Epoch 2/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - binary_accuracy: 0.7947 - kl_loss: 4.7273 - loss: 169.1500 - reconstruction_loss: 164.4226 - val_binary_accuracy: 0.7956 - val_kl_loss: 4.8210 - val_loss: 163.9979 - val_reconstruction_loss: 159.1769
Epoch 3/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7950 - kl_loss: 5.0213 - loss: 163.7342 - reconstruction_loss: 158.7129 - val_binary_accuracy: 0.7965 - val_kl_loss: 5.1862 - val_loss: 160.4791 - val_reconstruction_loss: 155.2929
Epoch 4/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - binary_accuracy: 0.7954 - kl_loss: 5.3023 - loss: 160.3815 - reconstruction_loss: 155.0792 - val_binary_accuracy: 0.7965 - val_kl_loss: 5.6590 - val_loss: 157.4631 - val_reconstruction_loss: 151.8041
Epoch 5/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - binary_accuracy: 0.7956 - kl_loss: 5.5337 - loss: 157.3770 - reconstruction_loss: 151.8434 - val_binary_accuracy: 0.7962 - val_kl_loss: 5.7158 - val_loss: 155.2885 - val_reconstruction_loss: 149.5727
Epoch 6/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - binary_accuracy: 0.7958 - kl_loss: 5.6697 - loss: 155.2385 - reconstruction_loss: 149.5689 - val_binary_accuracy: 0.7960 - val_kl_loss: 5.8044 - val_loss: 153.6642 - val_reconstruction_loss: 147.8598
Epoch 7/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - binary_accuracy: 0.7961 - kl_loss: 5.8049 - loss: 153.5451 - reconstruction_loss: 147.7402 - val_binary_accuracy: 0.7965 - val_kl_loss: 5.7747 - val_loss: 152.0217 - val_reconstruction_loss: 146.2470
Epoch 8/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7964 - kl_loss: 5.8814 - loss: 152.1057 - reconstruction_loss: 146.2243 - val_binary_accuracy: 0.7968 - val_kl_loss: 6.1041 - val_loss: 150.8130 - val_reconstruction_loss: 144.7089
Epoch 9/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7967 - kl_loss: 5.9718 - loss: 150.7322 - reconstruction_loss: 144.7604 - val_binary_accuracy: 0.7975 - val_kl_loss: 6.0538 - val_loss: 149.5502 - val_reconstruction_loss: 143.4964
Epoch 10/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7969 - kl_loss: 6.0359 - loss: 149.6004 - reconstruction_loss: 143.5645 - val_binary_accuracy: 0.7968 - val_kl_loss: 6.2248 - val_loss: 148.6927 - val_reconstruction_loss: 142.4679
Epoch 11/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7972 - kl_loss: 6.0973 - loss: 148.7120 - reconstruction_loss: 142.6147 - val_binary_accuracy: 0.7977 - val_kl_loss: 6.2033 - val_loss: 147.8698 - val_reconstruction_loss: 141.6665
Epoch 12/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - binary_accuracy: 0.7974 - kl_loss: 6.1383 - loss: 147.8659 - reconstruction_loss: 141.7275 - val_binary_accuracy: 0.7977 - val_kl_loss: 6.1280 - val_loss: 147.1283 - val_reconstruction_loss: 141.0003
Epoch 13/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7975 - kl_loss: 6.1957 - loss: 147.1447 - reconstruction_loss: 140.9490 - val_binary_accuracy: 0.7977 - val_kl_loss: 6.3721 - val_loss: 146.6411 - val_reconstruction_loss: 140.2690
Epoch 14/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7977 - kl_loss: 6.2241 - loss: 146.5570 - reconstruction_loss: 140.3329 - val_binary_accuracy: 0.7974 - val_kl_loss: 6.2339 - val_loss: 146.1645 - val_reconstruction_loss: 139.9305
Epoch 15/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7977 - kl_loss: 6.2521 - loss: 146.2076 - reconstruction_loss: 139.9555 - val_binary_accuracy: 0.7974 - val_kl_loss: 6.4387 - val_loss: 145.7203 - val_reconstruction_loss: 139.2816
Epoch 16/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7979 - kl_loss: 6.2832 - loss: 145.5795 - reconstruction_loss: 139.2962 - val_binary_accuracy: 0.7979 - val_kl_loss: 6.2645 - val_loss: 145.3710 - val_reconstruction_loss: 139.1065
Epoch 17/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7981 - kl_loss: 6.3140 - loss: 145.0197 - reconstruction_loss: 138.7058 - val_binary_accuracy: 0.7984 - val_kl_loss: 6.3254 - val_loss: 144.9554 - val_reconstruction_loss: 138.6300
Epoch 18/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7981 - kl_loss: 6.3426 - loss: 144.6932 - reconstruction_loss: 138.3506 - val_binary_accuracy: 0.7985 - val_kl_loss: 6.2431 - val_loss: 144.8329 - val_reconstruction_loss: 138.5898
Epoch 19/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7982 - kl_loss: 6.3620 - loss: 144.4323 - reconstruction_loss: 138.0702 - val_binary_accuracy: 0.7985 - val_kl_loss: 6.4447 - val_loss: 144.4792 - val_reconstruction_loss: 138.0346
Epoch 20/20
391/391 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - binary_accuracy: 0.7982 - kl_loss: 6.3714 - loss: 144.1684 - reconstruction_loss: 137.7970 - val_binary_accuracy: 0.7988 - val_kl_loss: 6.4698 - val_loss: 143.9036 - val_reconstruction_loss: 137.4338
[12]

# Curvas de entrenamiento: loss y "accuracy" (binary_accuracy)
h = history.history
epochs = range(1, len(h['loss']) + 1)

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Loss total
axes[0].plot(epochs, h['loss'], label='train_loss')
axes[0].plot(epochs, h['val_loss'], label='val_loss')
axes[0].set_title('Loss total vs epoch')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()

# Accuracy de reconstrucción
axes[1].plot(epochs, h['binary_accuracy'], label='train_binary_accuracy')
axes[1].plot(epochs, h['val_binary_accuracy'], label='val_binary_accuracy')
axes[1].set_title('Binary accuracy vs epoch')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()

plt.tight_layout()
plt.show()
Output
[13]

# Curvas separadas de reconstrucción y KL para entender "castigos y recompensas"
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

axes[0].plot(epochs, h['reconstruction_loss'], label='train_rec_loss')
axes[0].plot(epochs, h['val_reconstruction_loss'], label='val_rec_loss')
axes[0].set_title('Término de reconstrucción (recompensa por fidelidad)')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Recon loss')
axes[0].legend()

axes[1].plot(epochs, h['kl_loss'], label='train_kl_loss')
axes[1].plot(epochs, h['val_kl_loss'], label='val_kl_loss')
axes[1].set_title('Término KL (castigo por alejarse de N(0,I))')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('KL loss')
axes[1].legend()

plt.tight_layout()
plt.show()
Output

4) Clasificación no supervisada con clustering en el espacio latente

Una vez entrenado el VAE, extraemos las coordenadas $z_{mean}$ (la media de la distribución latente) para cada imagen y aplicamos K-Means con 10 clusters. Al no usar etiquetas durante el clustering, evaluamos la calidad del agrupamiento con métricas externas: Accuracy por mapeo de mayoría, NMI y ARI. Estas métricas nos dicen cuánto se alinean los clusters encontrados con las clases reales de dígitos.

[14]

# Obtenemos representaciones latentes medias (z_mean) para train/val/test
z_train_mean, _, _ = encoder.predict(x_train_f, batch_size=256, verbose=0)
z_val_mean, _, _ = encoder.predict(x_val_f, batch_size=256, verbose=0)
z_test_mean, _, _ = encoder.predict(x_test_f, batch_size=256, verbose=0)

print(z_train_mean.shape, z_val_mean.shape, z_test_mean.shape)
(50000, 2) (10000, 2) (10000, 2)
[15]

# Entrenamos KMeans sin etiquetas sobre el espacio latente
kmeans = KMeans(n_clusters=10, random_state=42, n_init=20)
train_clusters = kmeans.fit_predict(z_train_mean)
val_clusters = kmeans.predict(z_val_mean)
test_clusters = kmeans.predict(z_test_mean)

# Mapeo cluster->etiqueta por mayoría en train
cluster_to_label = {}
for c in range(10):
    idx = np.where(train_clusters == c)[0]
    if len(idx) == 0:
        cluster_to_label[c] = 0
    else:
        vals, cnts = np.unique(y_train_small[idx], return_counts=True)
        cluster_to_label[c] = vals[np.argmax(cnts)]

# Predicción final en val/test usando el mapeo aprendido
val_pred = np.array([cluster_to_label[c] for c in val_clusters])
test_pred = np.array([cluster_to_label[c] for c in test_clusters])

val_acc = np.mean(val_pred == y_val)
test_acc = np.mean(test_pred == y_test)

val_nmi = normalized_mutual_info_score(y_val, val_clusters)
test_nmi = normalized_mutual_info_score(y_test, test_clusters)

val_ari = adjusted_rand_score(y_val, val_clusters)
test_ari = adjusted_rand_score(y_test, test_clusters)

print(f"Val  ACC (mapeo mayoría): {val_acc:.4f}")
print(f"Test ACC (mapeo mayoría): {test_acc:.4f}")
print(f"Val  NMI: {val_nmi:.4f} | Test NMI: {test_nmi:.4f}")
print(f"Val  ARI: {val_ari:.4f} | Test ARI: {test_ari:.4f}")
Val  ACC (mapeo mayoría): 0.5647
Test ACC (mapeo mayoría): 0.5653
Val  NMI: 0.5557 | Test NMI: 0.5553
Val  ARI: 0.3686 | Test ARI: 0.3687
[16]

# Visualización del espacio latente 2D con color por clase real (solo para análisis)
plt.figure(figsize=(8,6))
scatter = plt.scatter(z_test_mean[:,0], z_test_mean[:,1], c=y_test, s=8, cmap='tab10', alpha=0.7)
plt.colorbar(scatter, ticks=range(10), label='Etiqueta real')
plt.title('Espacio latente (z_mean) del test set')
plt.xlabel('z[0]')
plt.ylabel('z[1]')
plt.show()
Output
[17]

# Matriz de confusión para inspeccionar qué dígitos se confunden más
cm = confusion_matrix(y_test, test_pred)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Matriz de confusión (clasificación no supervisada + mapeo por mayoría)')
plt.xlabel('Predicción')
plt.ylabel('Real')
plt.show()
Output

5) Capacidad generativa: crear nuevas imágenes desde el prior

Una ventaja fundamental del VAE frente a un autoencoder clásico es que podemos generar imágenes nuevas muestreando directamente del prior $z \sim \mathcal{N}(0, I)$ y pasando las muestras por el decoder. Gracias a la regularización KL, el espacio latente está "lleno" de representaciones válidas y cada punto genera una imagen plausible.

[18]

# Generación de nuevas imágenes muestreando z ~ N(0, I)
n = 12
z_samples = np.random.normal(size=(n, latent_dim)).astype('float32')
generated = decoder.predict(z_samples, verbose=0).reshape(-1, 28, 28)

fig, axes = plt.subplots(2, 6, figsize=(10,4))
for i, ax in enumerate(axes.ravel()):
    ax.imshow(generated[i], cmap='gray')
    ax.axis('off')
plt.suptitle('Nuevos dígitos generados por el VAE')
plt.tight_layout()
plt.show()
Output

6) Interpolación en espacio latente entre imágenes reales

Otra capacidad del VAE es la interpolación suave: dados dos dígitos reales, codificamos cada uno a su $z_{mean}$ y recorremos el segmento que los une en el espacio latente. El decoder traduce cada punto intermedio en una imagen, produciendo una transición gradual y coherente entre ambos dígitos. Esto demuestra que el espacio latente es continuo y semánticamente ordenado.

[19]

# Elegimos dos imágenes reales de test y obtenemos sus z_mean
idx_a, idx_b = 0, 1
x_a = x_test_f[idx_a:idx_a+1]
x_b = x_test_f[idx_b:idx_b+1]

za, _, _ = encoder.predict(x_a, verbose=0)
zb, _, _ = encoder.predict(x_b, verbose=0)

# Interpolación lineal en latent space
steps = 12
alphas = np.linspace(0, 1, steps)
z_interp = np.array([(1-a)*za[0] + a*zb[0] for a in alphas], dtype='float32')

x_interp = decoder.predict(z_interp, verbose=0).reshape(-1, 28, 28)

fig, axes = plt.subplots(1, steps, figsize=(16,2.5))
for i, ax in enumerate(axes):
    ax.imshow(x_interp[i], cmap='gray')
    ax.set_title(f"{alphas[i]:.1f}", fontsize=8)
    ax.axis('off')

plt.suptitle(f"Interpolación latente entre test[{idx_a}] (label={y_test[idx_a]}) y test[{idx_b}] (label={y_test[idx_b]})")
plt.tight_layout()
plt.show()
Output

7) Demostración del "agente" entrenado

Aquí tratamos al VAE + K-Means como un pequeño agente no supervisado que ejecuta un pipeline completo de inferencia:

  1. Observa una imagen de entrada (dígito manuscrito).
  2. La proyecta al espacio latente mediante el encoder, obteniendo $z_{mean}$.
  3. Decide un clúster asignado por K-Means sobre $z_{mean}$.
  4. Asigna una etiqueta interpretable mediante el mapeo de mayoría aprendido en train.
  5. Reconstruye la imagen a través del decoder para verificar la calidad de la compresión.

Este flujo muestra cómo un sistema generativo (VAE) puede servir de base para tareas discriminativas sin supervisión directa.

[20]

# Demostración con muestras aleatorias de test
rng = np.random.default_rng(123)
indices = rng.choice(len(x_test_f), size=8, replace=False)

fig, axes = plt.subplots(8, 3, figsize=(8, 18))

for row, idx in enumerate(indices):
    x = x_test_f[idx:idx+1]
    y_true = y_test[idx]

    z_m, _, z_s = encoder.predict(x, verbose=0)
    recon = decoder.predict(z_s, verbose=0).reshape(28, 28)

    cluster = kmeans.predict(z_m)[0]
    y_hat = cluster_to_label[cluster]

    axes[row, 0].imshow(x.reshape(28,28), cmap='gray')
    axes[row, 0].set_title(f"Original\nreal={y_true}")
    axes[row, 0].axis('off')

    axes[row, 1].imshow(recon, cmap='gray')
    axes[row, 1].set_title("Reconstrucción")
    axes[row, 1].axis('off')

    axes[row, 2].text(0.05, 0.8, f"z_mean = {np.round(z_m[0], 2)}", fontsize=10)
    axes[row, 2].text(0.05, 0.55, f"cluster = {cluster}", fontsize=10)
    axes[row, 2].text(0.05, 0.3, f"pred = {y_hat}", fontsize=12, weight='bold')
    axes[row, 2].axis('off')

plt.suptitle("Demostración del agente no supervisado entrenado", y=1.002)
plt.tight_layout()
plt.show()
Output

Conclusiones

  • El VAE logra aprender un espacio latente continuo de solo 2 dimensiones que captura la estructura esencial de los 10 dígitos de MNIST, sin necesidad de etiquetas durante el entrenamiento.
  • La combinación de los términos de reconstrucción (BCE) y divergencia KL implementa un equilibrio efectivo entre fidelidad a cada imagen y regularidad del espacio latente. Este sistema de "castigos complementarios" es lo que permite que el VAE no solo comprima, sino que genere e interpole con sentido.
  • El clustering con K-Means sobre las coordenadas $z_{mean}$ demuestra que la representación aprendida tiene estructura semántica: dígitos similares se agrupan en regiones cercanas del espacio latente, permitiendo una clasificación no supervisada razonable.
  • Las capacidades generativas del VAE quedan demostradas con la generación de nuevos dígitos muestreando del prior $\mathcal{N}(0,I)$ y con la interpolación suave entre pares de dígitos reales, donde se observan transiciones graduales y coherentes.
  • El uso de un espacio latente de 2 dimensiones facilita la visualización directa, pero limita la capacidad del modelo. Aumentar la dimensionalidad latente mejoraría las métricas de clustering y la calidad de reconstrucción, a costa de perder la interpretabilidad visual directa.

Sugerencias para seguir experimentando

  1. Probar un $\beta$-VAE variando $\beta$ (p.ej. 0.5, 2.0, 4.0) y comparar calidad de reconstrucción vs disentanglement del espacio latente.
  2. Usar un VAE convolucional (Conv2D/Conv2DTranspose) para mejorar la nitidez de las muestras generadas.
  3. Hacer annealing del KL (aumentar $\beta$ gradualmente durante el entrenamiento) para estabilizar el aprendizaje en las primeras épocas.
  4. Sustituir K-Means por GMM o HDBSCAN en el espacio latente para evaluar si un clustering más flexible mejora la asignación de clases.
  5. Repetir la práctica en Fashion-MNIST para comprobar la robustez del enfoque en un dataset con mayor variabilidad visual intra-clase.