GRU PyTorch — Clasificación de noticias (AG News)
Clasificación multiclase de noticias con GRU bidireccional en PyTorch sobre el dataset AG News (4 categorías).
GRU en PyTorch: clasificación de noticias con AG News
En este notebook construiremos un pipeline completo y didáctico para entrenar una red GRU (Gated Recurrent Unit) en PyTorch sobre una tarea real de NLP: clasificación de titulares de noticias en 4 categorías.
Categorías de AG News: World, Sports, Business, Sci/Tech.
Objetivo del notebook
Aprender, paso a paso, a:
- Preparar texto para una red recurrente (tokenización, vocabulario, padding/truncado).
- Diseñar una arquitectura basada en Embedding + GRU bidireccional + clasificador.
- Entrenar el modelo con validación y buenas prácticas (grad clipping, early stopping).
- Evaluar con métricas y visualizaciones (curvas de entrenamiento, matriz de confusión, reporte por clase).
- Interpretar resultados y proponer mejoras.
Fundamento matemático/computacional (conexión con la teoría del submódulo GRU)
En una GRU no hay cell state separado como en LSTM; hay un único estado oculto $h_t$, actualizado por dos puertas:
$$ r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r) \quad\text{(reset gate)} $$ $$ z_t = \sigma(W_z x_t + U_z h_{t-1} + b_z) \quad\text{(update gate)} $$ $$ \tilde{h}t = \tanh\left(W_h x_t + U_h (r_t \odot h{t-1}) + b_h\right) $$ $$ h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t $$
Interpretación práctica:
- Reset gate $r_t$: decide cuánto pasado usar para construir el estado candidato.
- Update gate $z_t$: decide cuánto conservar del estado anterior y cuánto sobrescribir con información nueva.
- La última ecuación es una interpolación entre memoria previa e información candidata, lo que ayuda al flujo de gradiente y simplifica frente a LSTM.
En este notebook usaremos una tarea many-to-one: una secuencia de tokens (titular) y una sola salida (clase).
Modelo y dataset usados
- Dataset: AG News (descargado como CSV).
- Modelo:
Embedding -> BiGRU (2 capas) -> Dropout -> Capa lineal. - Framework: PyTorch.
Nota: para mantener tiempos razonables, entrenaremos con un subconjunto configurable. Si quieres más rendimiento, aumenta épocas/tamaño de datos o usa GPU.
# ============================================================
# 0) Imports y configuración global
# ============================================================
import csv
import io
import os
import random
import re
import time
from collections import Counter, OrderedDict
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
f1_score,
)
sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (10, 4)
# ---------- Reemplazos ligeros de torchtext (deprecado) ----------
def basic_english_tokenizer(text: str) -> list[str]:
"""Tokenizador básico: minúsculas, separa puntuación, divide por espacios."""
text = text.lower()
text = re.sub(r"([.!?,;:\"'()\[\]{}\-/])", r" \1 ", text)
return text.split()
tokenizer = basic_english_tokenizer
class SimpleVocab:
"""Vocabulario ligero compatible con la interfaz usada en el notebook."""
def __init__(self, counter: Counter, min_freq: int = 1, specials: list[str] | None = None):
self.itos = list(specials or [])
self.stoi = {tok: i for i, tok in enumerate(self.itos)}
for tok, freq in counter.most_common():
if freq >= min_freq and tok not in self.stoi:
self.stoi[tok] = len(self.itos)
self.itos.append(tok)
self._default_index = 0
def set_default_index(self, idx: int):
self._default_index = idx
def __getitem__(self, token: str) -> int:
return self.stoi.get(token, self._default_index)
def __len__(self) -> int:
return len(self.itos)
def download_ag_news(root: str = "./data/ag_news"):
"""Descarga AG News CSV si no existe y devuelve rutas (train, test)."""
root = Path(root)
root.mkdir(parents=True, exist_ok=True)
base_url = "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv"
paths = {}
for split in ("train", "test"):
fpath = root / f"{split}.csv"
if not fpath.exists():
import urllib.request
url = f"{base_url}/{split}.csv"
print(f"Descargando {url} ...")
urllib.request.urlretrieve(url, fpath)
print(f" -> {fpath}")
paths[split] = fpath
return paths["train"], paths["test"]
def load_ag_news_csv(csv_path: str) -> list[tuple[int, str]]:
"""Lee CSV de AG News y devuelve lista de (label_0indexed, text)."""
data = []
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
for row in reader:
label = int(row[0]) - 1 # 1-4 → 0-3
text = row[1] + " " + row[2] # title + description
data.append((label, text))
return data
# ---------- Reproducibilidad ----------
def set_seed(seed: int = 42):
"""Fija semillas para reproducibilidad."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Dispositivo:", device)
Dispositivo: cuda
1) Carga de datos
Vamos a cargar AG News (descargando los CSV si es necesario) y transformarlo a una lista en memoria para facilitar EDA y preprocesado.
La etiqueta original de AG News viene en ${1,2,3,4}$; aquí la convertimos a ${0,1,2,3}$ para usarla directamente con CrossEntropyLoss.
# ============================================================
# 1) Carga de AG_NEWS
# ============================================================
LABELS = ["World", "Sports", "Business", "Sci/Tech"]
train_csv, test_csv = download_ag_news()
train_raw = load_ag_news_csv(train_csv)
test_raw = load_ag_news_csv(test_csv)
print(f"Train size: {len(train_raw):,}")
print(f"Test size: {len(test_raw):,}")
print("Ejemplo:")
print(train_raw[0][0], "->", train_raw[0][1][:140], "...")
Train size: 120,000 Test size: 7,600 Ejemplo: 2 -> Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green ag ...
2) EDA (análisis exploratorio básico)
Antes de entrenar una red recurrente conviene inspeccionar:
- balance de clases,
- longitud de textos,
- palabras frecuentes.
Esto guía decisiones como max_len, tamaño de vocabulario y regularización.
# ============================================================
# 2.1 Distribución de clases
# ============================================================
train_labels = [y for y, _ in train_raw]
class_counts = Counter(train_labels)
plt.figure(figsize=(7, 4))
plt.bar([LABELS[i] for i in range(4)], [class_counts[i] for i in range(4)], color=["#4c78a8", "#f58518", "#54a24b", "#e45756"])
plt.title("Distribución de clases en train")
plt.ylabel("Número de titulares")
plt.xticks(rotation=15)
plt.show()
print("Conteos por clase:")
for i, c in class_counts.items():
print(f" {LABELS[i]}: {c:,}")
Conteos por clase: Business: 30,000 Sci/Tech: 30,000 Sports: 30,000 World: 30,000
# ============================================================
# 2.2 Longitud de secuencias (número de tokens por titular)
# ============================================================
train_lengths = [len(tokenizer(text)) for _, text in train_raw]
plt.figure(figsize=(8, 4))
plt.hist(train_lengths, bins=50, color="#4c78a8", alpha=0.8)
plt.title("Distribución de longitudes de titulares (tokens)")
plt.xlabel("Longitud")
plt.ylabel("Frecuencia")
plt.show()
for q in [50, 75, 90, 95, 99]:
print(f"Percentil {q:>2}: {np.percentile(train_lengths, q):.1f} tokens")
Percentil 50: 44.0 tokens Percentil 75: 52.0 tokens Percentil 90: 60.0 tokens Percentil 95: 71.0 tokens Percentil 99: 99.0 tokens
# ============================================================
# 2.3 Palabras más frecuentes (aprox.)
# ============================================================
word_counter = Counter()
for _, text in train_raw[:20000]: # muestreo para acelerar el conteo
word_counter.update(tokenizer(text))
print("Top 20 tokens más frecuentes:")
print(word_counter.most_common(20))
Top 20 tokens más frecuentes:
[('.', 40471), ('the', 34785), (',', 28681), ('-', 25116), ('to', 19728), ('a', 18733), ('of', 16978), ('in', 16956), (';', 12808), ('and', 11751), ('s', 11134), ('on', 9431), ('for', 8358), ('(', 7695), (')', 7659), ('#39', 6658), ("'", 6554), ('that', 4535), ('with', 4390), ('at', 4346)]
3) Preprocesado para GRU
Decisiones de diseño (coherentes con la teoría del submódulo):
- usamos índices enteros para tokens (
Embeddingaprende representaciones densas), - aplicamos
paddingpara lotes uniformes, - fijamos una longitud máxima
MAX_LEN(truncado/padding).
Esto permite pasar lotes de forma eficiente a nn.GRU.
# ============================================================
# 3.1 Construcción de vocabulario
# ============================================================
# Para demo rápida puedes limitar la cantidad de ejemplos de entrenamiento
MAX_TRAIN_SAMPLES = 50000 # sube a 120000 para usar todo AG News
train_used = train_raw[:MAX_TRAIN_SAMPLES]
specials = ["<pad>", "<unk>"]
counter = Counter()
for _, text in train_used:
counter.update(tokenizer(text))
# min_freq evita vocabulario enorme y reduce ruido
tok_vocab = SimpleVocab(counter, min_freq=2, specials=specials)
tok_vocab.set_default_index(tok_vocab["<unk>"])
PAD_IDX = tok_vocab["<pad>"]
UNK_IDX = tok_vocab["<unk>"]
VOCAB_SIZE = len(tok_vocab)
print("Vocab size:", VOCAB_SIZE)
print("PAD_IDX:", PAD_IDX, "UNK_IDX:", UNK_IDX)
Vocab size: 31186 PAD_IDX: 0 UNK_IDX: 1
# ============================================================
# 3.2 Dataset PyTorch: tokenización + numerización + padding
# ============================================================
MAX_LEN = 35 # cercano al p95 observado normalmente en titulares AG News
def encode_text(text: str, max_len: int = MAX_LEN):
"""Tokeniza, convierte a ids y aplica truncado/padding."""
ids = [tok_vocab[token] for token in tokenizer(text)]
ids = ids[:max_len]
if len(ids) < max_len:
ids += [PAD_IDX] * (max_len - len(ids))
return ids
class AGNewsDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
label, text = self.data[idx]
x = torch.tensor(encode_text(text), dtype=torch.long)
y = torch.tensor(label, dtype=torch.long)
return x, y
# Split train/val manual (reproducible)
val_ratio = 0.15
n_total = len(train_used)
indices = np.arange(n_total)
np.random.shuffle(indices)
n_val = int(n_total * val_ratio)
val_idx = indices[:n_val]
train_idx = indices[n_val:]
train_split = [train_used[i] for i in train_idx]
val_split = [train_used[i] for i in val_idx]
train_ds = AGNewsDataset(train_split)
val_ds = AGNewsDataset(val_split)
# Test completo (se podría submuestrear en entornos lentos)
test_ds = AGNewsDataset(test_raw)
BATCH_SIZE = 128
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")
Train batches: 333 Val batches: 59 Test batches: 60
4) Arquitectura GRU
Usaremos una arquitectura estándar para clasificación de texto:
Embedding: convierte ids en vectores densos.GRUbidireccional y apilada (num_layers=2).- Tomamos el estado oculto final de ambas direcciones y lo concatenamos.
Dropout+Linearpara clasificar 4 clases.
Esto conecta directamente con lo visto en teoría: nn.GRU devuelve (output, h_n) y no hay cell state separado.
# ============================================================
# 4) Definición del modelo GRU
# ============================================================
class GRUClassifier(nn.Module):
def __init__(
self,
vocab_size: int,
embed_dim: int,
hidden_dim: int,
num_classes: int,
num_layers: int = 2,
dropout: float = 0.3,
bidirectional: bool = True,
pad_idx: int = 0,
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
self.gru = nn.GRU(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0.0,
bidirectional=bidirectional,
)
self.dropout = nn.Dropout(dropout)
self.bidirectional = bidirectional
direction_factor = 2 if bidirectional else 1
self.classifier = nn.Linear(hidden_dim * direction_factor, num_classes)
def forward(self, x):
# x: (batch, seq_len)
emb = self.dropout(self.embedding(x)) # (batch, seq_len, embed_dim)
# GRU devuelve output y h_n
_, h_n = self.gru(emb)
# h_n shape: (num_layers * num_directions, batch, hidden_dim)
if self.bidirectional:
# Última capa forward y backward
h_last = torch.cat([h_n[-2], h_n[-1]], dim=1)
else:
h_last = h_n[-1]
logits = self.classifier(self.dropout(h_last))
return logits
model = GRUClassifier(
vocab_size=VOCAB_SIZE,
embed_dim=128,
hidden_dim=128,
num_classes=4,
num_layers=2,
dropout=0.3,
bidirectional=True,
pad_idx=PAD_IDX,
).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(model)
print(f"Parámetros entrenables: {n_params:,}")
GRUClassifier( (embedding): Embedding(31186, 128, padding_idx=0) (gru): GRU(128, 128, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True) (dropout): Dropout(p=0.3, inplace=False) (classifier): Linear(in_features=256, out_features=4, bias=True) ) Parámetros entrenables: 4,487,428
5) Entrenamiento
Entrenaremos con:
CrossEntropyLoss(clasificación multiclase),Adam,- gradient clipping para estabilizar RNNs,
- early stopping sobre pérdida de validación.
Además guardaremos historial para visualizar curvas loss/accuracy de train/val.
# ============================================================
# 5) Bucle de entrenamiento y validación
# ============================================================
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def run_epoch(model, loader, optimizer=None, clip_grad=1.0):
"""Ejecuta una época de train o validación.
Si optimizer es None, se ejecuta en modo evaluación.
"""
is_train = optimizer is not None
model.train() if is_train else model.eval()
losses, preds_all, y_all = [], [], []
for x_batch, y_batch in loader:
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
with torch.set_grad_enabled(is_train):
logits = model(x_batch)
loss = criterion(logits, y_batch)
if is_train:
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
optimizer.step()
losses.append(loss.item())
preds_all.extend(torch.argmax(logits, dim=1).detach().cpu().numpy())
y_all.extend(y_batch.detach().cpu().numpy())
epoch_loss = float(np.mean(losses))
epoch_acc = accuracy_score(y_all, preds_all)
return epoch_loss, epoch_acc
EPOCHS = 12
PATIENCE = 3
history = {
"train_loss": [],
"val_loss": [],
"train_acc": [],
"val_acc": [],
}
best_val_loss = float("inf")
best_state = None
patience_counter = 0
start_time = time.time()
for epoch in range(1, EPOCHS + 1):
train_loss, train_acc = run_epoch(model, train_loader, optimizer=optimizer)
val_loss, val_acc = run_epoch(model, val_loader, optimizer=None)
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["train_acc"].append(train_acc)
history["val_acc"].append(val_acc)
print(
f"Epoch {epoch:02d}/{EPOCHS} | "
f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
f"val_loss={val_loss:.4f} val_acc={val_acc:.4f}"
)
# Early stopping por val_loss
if val_loss < best_val_loss:
best_val_loss = val_loss
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= PATIENCE:
print("Early stopping activado.")
break
elapsed = time.time() - start_time
print(f"Tiempo total de entrenamiento: {elapsed/60:.1f} min")
# Restauramos mejor estado validación
if best_state is not None:
model.load_state_dict(best_state)
print("Modelo restaurado al mejor checkpoint de validación.")
Epoch 01/12 | train_loss=0.8272 train_acc=0.6634 | val_loss=0.5516 val_acc=0.8092 Epoch 02/12 | train_loss=0.4614 train_acc=0.8343 | val_loss=0.4082 val_acc=0.8616 Epoch 03/12 | train_loss=0.3566 train_acc=0.8722 | val_loss=0.4090 val_acc=0.8691 Epoch 04/12 | train_loss=0.3025 train_acc=0.8933 | val_loss=0.3447 val_acc=0.8844 Epoch 05/12 | train_loss=0.2538 train_acc=0.9107 | val_loss=0.3360 val_acc=0.8907 Epoch 06/12 | train_loss=0.2195 train_acc=0.9218 | val_loss=0.3343 val_acc=0.8949 Epoch 07/12 | train_loss=0.1941 train_acc=0.9294 | val_loss=0.3517 val_acc=0.8951 Epoch 08/12 | train_loss=0.1661 train_acc=0.9399 | val_loss=0.3408 val_acc=0.8971 Epoch 09/12 | train_loss=0.1503 train_acc=0.9458 | val_loss=0.3652 val_acc=0.8972 Early stopping activado. Tiempo total de entrenamiento: 0.2 min Modelo restaurado al mejor checkpoint de validación.
# ============================================================
# 5.1 Curvas de entrenamiento: loss y accuracy
# ============================================================
epoch_axis = np.arange(1, len(history["train_loss"]) + 1)
fig, ax = plt.subplots(1, 2, figsize=(13, 4))
ax[0].plot(epoch_axis, history["train_loss"], marker="o", label="train")
ax[0].plot(epoch_axis, history["val_loss"], marker="o", label="val")
ax[0].set_title("Loss vs Epoch")
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Cross-Entropy Loss")
ax[0].legend()
ax[1].plot(epoch_axis, history["train_acc"], marker="o", label="train")
ax[1].plot(epoch_axis, history["val_acc"], marker="o", label="val")
ax[1].set_title("Accuracy vs Epoch")
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("Accuracy")
ax[1].legend()
plt.tight_layout()
plt.show()
6) Evaluación en test
Ahora medimos generalización sobre el conjunto de test con:
- accuracy,
- macro-F1,
- reporte de clasificación por clase,
- matriz de confusión.
Esto permite detectar clases más difíciles y posibles sesgos.
# ============================================================
# 6) Evaluación final en test
# ============================================================
model.eval()
all_preds, all_targets = [], []
with torch.no_grad():
for x_batch, y_batch in test_loader:
x_batch = x_batch.to(device)
logits = model(x_batch)
preds = torch.argmax(logits, dim=1).cpu().numpy()
all_preds.extend(preds)
all_targets.extend(y_batch.numpy())
test_acc = accuracy_score(all_targets, all_preds)
test_f1_macro = f1_score(all_targets, all_preds, average="macro")
print(f"Test Accuracy : {test_acc:.4f}")
print(f"Test Macro-F1 : {test_f1_macro:.4f}")
print("Classification report:")
print(classification_report(all_targets, all_preds, target_names=LABELS, digits=4))
Test Accuracy : 0.8875
Test Macro-F1 : 0.8876
Classification report:
precision recall f1-score support
World 0.9139 0.8774 0.8953 1900
Sports 0.9362 0.9505 0.9433 1900
Business 0.8361 0.8647 0.8502 1900
Sci/Tech 0.8656 0.8574 0.8614 1900
accuracy 0.8875 7600
macro avg 0.8880 0.8875 0.8876 7600
weighted avg 0.8880 0.8875 0.8876 7600
# ============================================================
# 6.1 Matriz de confusión
# ============================================================
cm = confusion_matrix(all_targets, all_preds)
plt.figure(figsize=(6, 5))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=LABELS,
yticklabels=LABELS,
)
plt.title("Matriz de confusión (test)")
plt.xlabel("Predicción")
plt.ylabel("Etiqueta real")
plt.tight_layout()
plt.show()
7) Ejemplos cualitativos (aciertos y errores)
Además de métricas agregadas, conviene revisar ejemplos concretos para entender:
- cuándo la GRU funciona bien,
- qué patrones confunden al modelo,
- cómo mejorar el preprocesado o la arquitectura.
# ============================================================
# 7) Mostrar ejemplos acertados/erróneos
# ============================================================
# Construimos predicciones sobre un subconjunto del test con texto original
sample_size = 300
sample_data = test_raw[:sample_size]
x_sample = torch.tensor([encode_text(text) for _, text in sample_data], dtype=torch.long).to(device)
y_sample = np.array([label for label, _ in sample_data])
texts_sample = [text for _, text in sample_data]
with torch.no_grad():
logits = model(x_sample)
pred_sample = torch.argmax(logits, dim=1).cpu().numpy()
correct_idx = np.where(pred_sample == y_sample)[0]
wrong_idx = np.where(pred_sample != y_sample)[0]
print(f"Aciertos en muestra: {len(correct_idx)}/{sample_size}")
print(f"Errores en muestra: {len(wrong_idx)}/{sample_size}")
print("--- Ejemplos de aciertos ---")
for i in correct_idx[:3]:
print(f"Real={LABELS[y_sample[i]]} | Pred={LABELS[pred_sample[i]]}")
print(texts_sample[i][:180], "\n")
print("--- Ejemplos de errores ---")
for i in wrong_idx[:3]:
print(f"Real={LABELS[y_sample[i]]} | Pred={LABELS[pred_sample[i]]}")
print(texts_sample[i][:180], "\n")
Aciertos en muestra: 271/300 Errores en muestra: 29/300 --- Ejemplos de aciertos --- Real=Business | Pred=Business Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul. Real=Sci/Tech | Pred=Sci/Tech The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the #36;10 million Real=Sci/Tech | Pred=Sci/Tech Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better p --- Ejemplos de errores --- Real=Sci/Tech | Pred=Sports Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he Real=Sci/Tech | Pred=Business Card fraud unit nets 36,000 cards In its first two years, the UK's dedicated card fraud unit, has recovered 36,000 stolen cards and 171 arrests - and estimates it saved 65m. Real=Sci/Tech | Pred=World Super ant colony hits Australia A giant 100km colony of ants which has been discovered in Melbourne, Australia, could threaten local insect species.
8) Mini experimento guiado (ablation): GRU bidireccional vs unidireccional
Para reforzar el aprendizaje, proponemos comparar una variante unidireccional manteniendo el resto igual.
Puedes ejecutar esta celda para observar si la bidireccionalidad aporta mejora en accuracy/F1.
# ============================================================
# 8) Experimento opcional (rápido): cambiar bidirectional=False
# ============================================================
# Esta celda está pensada como ejercicio.
# Para ahorrar tiempo por defecto no se ejecuta entrenamiento completo aquí.
# Puedes descomentar y lanzar una versión corta (3-4 épocas).
# quick_model = GRUClassifier(
# vocab_size=VOCAB_SIZE,
# embed_dim=128,
# hidden_dim=128,
# num_classes=4,
# num_layers=2,
# dropout=0.3,
# bidirectional=False,
# pad_idx=PAD_IDX,
# ).to(device)
#
# quick_opt = torch.optim.Adam(quick_model.parameters(), lr=1e-3)
# for e in range(4):
# tr_l, tr_a = run_epoch(quick_model, train_loader, optimizer=quick_opt)
# va_l, va_a = run_epoch(quick_model, val_loader, optimizer=None)
# print(f"[UNI-GRU] epoch={e+1} tr_acc={tr_a:.4f} val_acc={va_a:.4f}")
Conclusiones y siguientes pasos
Qué hemos aprendido
- Una GRU permite modelar dependencias temporales con menos complejidad que LSTM.
- En texto, la combinación
Embedding + (Bi)GRUes una base sólida para clasificación secuencial. - El flujo completo (EDA → preprocesado → entrenamiento → evaluación → análisis cualitativo) es clave para construir modelos fiables.
Sugerencias para profundizar
- Embeddings preentrenados (GloVe/FastText) en vez de embeddings aleatorios.
- Probar PackedSequence para manejar longitudes variables sin padding fijo.
- Ajustar hiperparámetros:
hidden_dim,num_layers,dropout,lr. - Comparar GRU vs LSTM en igualdad de condiciones (parámetros/tiempo).
- Añadir attention sobre salidas de GRU.
- Pasar de titulares a textos más largos (p.ej. IMDB completo).
Si haces estas extensiones, tendrás una visión muy completa de cuándo y por qué usar GRU en problemas reales de NLP.