🏭 Caso de Uso

GRU PyTorch — Clasificación de noticias (AG News)

Clasificación multiclase de noticias con GRU bidireccional en PyTorch sobre el dataset AG News (4 categorías).

🐍 Python 📓 Jupyter Notebook

GRU en PyTorch: clasificación de noticias con AG News

En este notebook construiremos un pipeline completo y didáctico para entrenar una red GRU (Gated Recurrent Unit) en PyTorch sobre una tarea real de NLP: clasificación de titulares de noticias en 4 categorías.

Categorías de AG News: World, Sports, Business, Sci/Tech.


Objetivo del notebook

Aprender, paso a paso, a:

  1. Preparar texto para una red recurrente (tokenización, vocabulario, padding/truncado).
  2. Diseñar una arquitectura basada en Embedding + GRU bidireccional + clasificador.
  3. Entrenar el modelo con validación y buenas prácticas (grad clipping, early stopping).
  4. Evaluar con métricas y visualizaciones (curvas de entrenamiento, matriz de confusión, reporte por clase).
  5. Interpretar resultados y proponer mejoras.

Fundamento matemático/computacional (conexión con la teoría del submódulo GRU)

En una GRU no hay cell state separado como en LSTM; hay un único estado oculto $h_t$, actualizado por dos puertas:

$$ r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r) \quad\text{(reset gate)} $$ $$ z_t = \sigma(W_z x_t + U_z h_{t-1} + b_z) \quad\text{(update gate)} $$ $$ \tilde{h}t = \tanh\left(W_h x_t + U_h (r_t \odot h{t-1}) + b_h\right) $$ $$ h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t $$

Interpretación práctica:

  • Reset gate $r_t$: decide cuánto pasado usar para construir el estado candidato.
  • Update gate $z_t$: decide cuánto conservar del estado anterior y cuánto sobrescribir con información nueva.
  • La última ecuación es una interpolación entre memoria previa e información candidata, lo que ayuda al flujo de gradiente y simplifica frente a LSTM.

En este notebook usaremos una tarea many-to-one: una secuencia de tokens (titular) y una sola salida (clase).


Modelo y dataset usados

  • Dataset: AG News (descargado como CSV).
  • Modelo: Embedding -> BiGRU (2 capas) -> Dropout -> Capa lineal.
  • Framework: PyTorch.

Nota: para mantener tiempos razonables, entrenaremos con un subconjunto configurable. Si quieres más rendimiento, aumenta épocas/tamaño de datos o usa GPU.

[1]
# ============================================================
# 0) Imports y configuración global
# ============================================================

import csv
import io
import os
import random
import re
import time
from collections import Counter, OrderedDict
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score,
)

sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (10, 4)


# ---------- Reemplazos ligeros de torchtext (deprecado) ----------

def basic_english_tokenizer(text: str) -> list[str]:
    """Tokenizador básico: minúsculas, separa puntuación, divide por espacios."""
    text = text.lower()
    text = re.sub(r"([.!?,;:\"'()\[\]{}\-/])", r" \1 ", text)
    return text.split()

tokenizer = basic_english_tokenizer


class SimpleVocab:
    """Vocabulario ligero compatible con la interfaz usada en el notebook."""

    def __init__(self, counter: Counter, min_freq: int = 1, specials: list[str] | None = None):
        self.itos = list(specials or [])
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}
        for tok, freq in counter.most_common():
            if freq >= min_freq and tok not in self.stoi:
                self.stoi[tok] = len(self.itos)
                self.itos.append(tok)
        self._default_index = 0

    def set_default_index(self, idx: int):
        self._default_index = idx

    def __getitem__(self, token: str) -> int:
        return self.stoi.get(token, self._default_index)

    def __len__(self) -> int:
        return len(self.itos)


def download_ag_news(root: str = "./data/ag_news"):
    """Descarga AG News CSV si no existe y devuelve rutas (train, test)."""
    root = Path(root)
    root.mkdir(parents=True, exist_ok=True)
    base_url = "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv"
    paths = {}
    for split in ("train", "test"):
        fpath = root / f"{split}.csv"
        if not fpath.exists():
            import urllib.request
            url = f"{base_url}/{split}.csv"
            print(f"Descargando {url} ...")
            urllib.request.urlretrieve(url, fpath)
            print(f"  -> {fpath}")
        paths[split] = fpath
    return paths["train"], paths["test"]


def load_ag_news_csv(csv_path: str) -> list[tuple[int, str]]:
    """Lee CSV de AG News y devuelve lista de (label_0indexed, text)."""
    data = []
    with open(csv_path, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        for row in reader:
            label = int(row[0]) - 1  # 1-4 → 0-3
            text = row[1] + " " + row[2]  # title + description
            data.append((label, text))
    return data


# ---------- Reproducibilidad ----------

def set_seed(seed: int = 42):
    """Fija semillas para reproducibilidad."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Dispositivo:", device)
Dispositivo: cuda

1) Carga de datos

Vamos a cargar AG News (descargando los CSV si es necesario) y transformarlo a una lista en memoria para facilitar EDA y preprocesado.

La etiqueta original de AG News viene en ${1,2,3,4}$; aquí la convertimos a ${0,1,2,3}$ para usarla directamente con CrossEntropyLoss.

[2]
# ============================================================
# 1) Carga de AG_NEWS
# ============================================================

LABELS = ["World", "Sports", "Business", "Sci/Tech"]

train_csv, test_csv = download_ag_news()

train_raw = load_ag_news_csv(train_csv)
test_raw = load_ag_news_csv(test_csv)

print(f"Train size: {len(train_raw):,}")
print(f"Test  size: {len(test_raw):,}")
print("Ejemplo:")
print(train_raw[0][0], "->", train_raw[0][1][:140], "...")
Train size: 120,000
Test  size: 7,600
Ejemplo:
2 -> Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green ag ...

2) EDA (análisis exploratorio básico)

Antes de entrenar una red recurrente conviene inspeccionar:

  • balance de clases,
  • longitud de textos,
  • palabras frecuentes.

Esto guía decisiones como max_len, tamaño de vocabulario y regularización.

[3]
# ============================================================
# 2.1 Distribución de clases
# ============================================================

train_labels = [y for y, _ in train_raw]
class_counts = Counter(train_labels)

plt.figure(figsize=(7, 4))
plt.bar([LABELS[i] for i in range(4)], [class_counts[i] for i in range(4)], color=["#4c78a8", "#f58518", "#54a24b", "#e45756"])
plt.title("Distribución de clases en train")
plt.ylabel("Número de titulares")
plt.xticks(rotation=15)
plt.show()

print("Conteos por clase:")
for i, c in class_counts.items():
    print(f"  {LABELS[i]}: {c:,}")
Output
Conteos por clase:
  Business: 30,000
  Sci/Tech: 30,000
  Sports: 30,000
  World: 30,000
[7]
# ============================================================
# 2.2 Longitud de secuencias (número de tokens por titular)
# ============================================================

train_lengths = [len(tokenizer(text)) for _, text in train_raw]

plt.figure(figsize=(8, 4))
plt.hist(train_lengths, bins=50, color="#4c78a8", alpha=0.8)
plt.title("Distribución de longitudes de titulares (tokens)")
plt.xlabel("Longitud")
plt.ylabel("Frecuencia")
plt.show()

for q in [50, 75, 90, 95, 99]:
    print(f"Percentil {q:>2}: {np.percentile(train_lengths, q):.1f} tokens")
Output
Percentil 50: 44.0 tokens
Percentil 75: 52.0 tokens
Percentil 90: 60.0 tokens
Percentil 95: 71.0 tokens
Percentil 99: 99.0 tokens
[8]
# ============================================================
# 2.3 Palabras más frecuentes (aprox.)
# ============================================================

word_counter = Counter()
for _, text in train_raw[:20000]:  # muestreo para acelerar el conteo
    word_counter.update(tokenizer(text))

print("Top 20 tokens más frecuentes:")
print(word_counter.most_common(20))
Top 20 tokens más frecuentes:
[('.', 40471), ('the', 34785), (',', 28681), ('-', 25116), ('to', 19728), ('a', 18733), ('of', 16978), ('in', 16956), (';', 12808), ('and', 11751), ('s', 11134), ('on', 9431), ('for', 8358), ('(', 7695), (')', 7659), ('#39', 6658), ("'", 6554), ('that', 4535), ('with', 4390), ('at', 4346)]

3) Preprocesado para GRU

Decisiones de diseño (coherentes con la teoría del submódulo):

  • usamos índices enteros para tokens (Embedding aprende representaciones densas),
  • aplicamos padding para lotes uniformes,
  • fijamos una longitud máxima MAX_LEN (truncado/padding).

Esto permite pasar lotes de forma eficiente a nn.GRU.

[9]
# ============================================================
# 3.1 Construcción de vocabulario
# ============================================================

# Para demo rápida puedes limitar la cantidad de ejemplos de entrenamiento
MAX_TRAIN_SAMPLES = 50000  # sube a 120000 para usar todo AG News

train_used = train_raw[:MAX_TRAIN_SAMPLES]

specials = ["<pad>", "<unk>"]

counter = Counter()
for _, text in train_used:
    counter.update(tokenizer(text))

# min_freq evita vocabulario enorme y reduce ruido
tok_vocab = SimpleVocab(counter, min_freq=2, specials=specials)
tok_vocab.set_default_index(tok_vocab["<unk>"])

PAD_IDX = tok_vocab["<pad>"]
UNK_IDX = tok_vocab["<unk>"]
VOCAB_SIZE = len(tok_vocab)

print("Vocab size:", VOCAB_SIZE)
print("PAD_IDX:", PAD_IDX, "UNK_IDX:", UNK_IDX)
Vocab size: 31186
PAD_IDX: 0 UNK_IDX: 1
[10]
# ============================================================
# 3.2 Dataset PyTorch: tokenización + numerización + padding
# ============================================================

MAX_LEN = 35  # cercano al p95 observado normalmente en titulares AG News


def encode_text(text: str, max_len: int = MAX_LEN):
    """Tokeniza, convierte a ids y aplica truncado/padding."""
    ids = [tok_vocab[token] for token in tokenizer(text)]
    ids = ids[:max_len]
    if len(ids) < max_len:
        ids += [PAD_IDX] * (max_len - len(ids))
    return ids


class AGNewsDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label, text = self.data[idx]
        x = torch.tensor(encode_text(text), dtype=torch.long)
        y = torch.tensor(label, dtype=torch.long)
        return x, y


# Split train/val manual (reproducible)
val_ratio = 0.15
n_total = len(train_used)
indices = np.arange(n_total)
np.random.shuffle(indices)

n_val = int(n_total * val_ratio)
val_idx = indices[:n_val]
train_idx = indices[n_val:]

train_split = [train_used[i] for i in train_idx]
val_split = [train_used[i] for i in val_idx]

train_ds = AGNewsDataset(train_split)
val_ds = AGNewsDataset(val_split)

# Test completo (se podría submuestrear en entornos lentos)
test_ds = AGNewsDataset(test_raw)

BATCH_SIZE = 128

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")
print(f"Test batches:  {len(test_loader)}")
Train batches: 333
Val batches:   59
Test batches:  60

4) Arquitectura GRU

Usaremos una arquitectura estándar para clasificación de texto:

  • Embedding: convierte ids en vectores densos.
  • GRU bidireccional y apilada (num_layers=2).
  • Tomamos el estado oculto final de ambas direcciones y lo concatenamos.
  • Dropout + Linear para clasificar 4 clases.

Esto conecta directamente con lo visto en teoría: nn.GRU devuelve (output, h_n) y no hay cell state separado.

[11]
# ============================================================
# 4) Definición del modelo GRU
# ============================================================

class GRUClassifier(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_dim: int,
        num_classes: int,
        num_layers: int = 2,
        dropout: float = 0.3,
        bidirectional: bool = True,
        pad_idx: int = 0,
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.gru = nn.GRU(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
            bidirectional=bidirectional,
        )

        self.dropout = nn.Dropout(dropout)
        self.bidirectional = bidirectional
        direction_factor = 2 if bidirectional else 1
        self.classifier = nn.Linear(hidden_dim * direction_factor, num_classes)

    def forward(self, x):
        # x: (batch, seq_len)
        emb = self.dropout(self.embedding(x))  # (batch, seq_len, embed_dim)

        # GRU devuelve output y h_n
        _, h_n = self.gru(emb)

        # h_n shape: (num_layers * num_directions, batch, hidden_dim)
        if self.bidirectional:
            # Última capa forward y backward
            h_last = torch.cat([h_n[-2], h_n[-1]], dim=1)
        else:
            h_last = h_n[-1]

        logits = self.classifier(self.dropout(h_last))
        return logits


model = GRUClassifier(
    vocab_size=VOCAB_SIZE,
    embed_dim=128,
    hidden_dim=128,
    num_classes=4,
    num_layers=2,
    dropout=0.3,
    bidirectional=True,
    pad_idx=PAD_IDX,
).to(device)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(model)
print(f"Parámetros entrenables: {n_params:,}")
GRUClassifier(
  (embedding): Embedding(31186, 128, padding_idx=0)
  (gru): GRU(128, 128, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (dropout): Dropout(p=0.3, inplace=False)
  (classifier): Linear(in_features=256, out_features=4, bias=True)
)
Parámetros entrenables: 4,487,428

5) Entrenamiento

Entrenaremos con:

  • CrossEntropyLoss (clasificación multiclase),
  • Adam,
  • gradient clipping para estabilizar RNNs,
  • early stopping sobre pérdida de validación.

Además guardaremos historial para visualizar curvas loss/accuracy de train/val.

[14]
# ============================================================
# 5) Bucle de entrenamiento y validación
# ============================================================

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


def run_epoch(model, loader, optimizer=None, clip_grad=1.0):
    """Ejecuta una época de train o validación.
    Si optimizer es None, se ejecuta en modo evaluación.
    """
    is_train = optimizer is not None
    model.train() if is_train else model.eval()

    losses, preds_all, y_all = [], [], []

    for x_batch, y_batch in loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        with torch.set_grad_enabled(is_train):
            logits = model(x_batch)
            loss = criterion(logits, y_batch)

            if is_train:
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
                optimizer.step()

        losses.append(loss.item())
        preds_all.extend(torch.argmax(logits, dim=1).detach().cpu().numpy())
        y_all.extend(y_batch.detach().cpu().numpy())

    epoch_loss = float(np.mean(losses))
    epoch_acc = accuracy_score(y_all, preds_all)
    return epoch_loss, epoch_acc


EPOCHS = 12
PATIENCE = 3

history = {
    "train_loss": [],
    "val_loss": [],
    "train_acc": [],
    "val_acc": [],
}

best_val_loss = float("inf")
best_state = None
patience_counter = 0

start_time = time.time()

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = run_epoch(model, train_loader, optimizer=optimizer)
    val_loss, val_acc = run_epoch(model, val_loader, optimizer=None)

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)

    print(
        f"Epoch {epoch:02d}/{EPOCHS} | "
        f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
        f"val_loss={val_loss:.4f} val_acc={val_acc:.4f}"
    )

    # Early stopping por val_loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Early stopping activado.")
            break

elapsed = time.time() - start_time
print(f"Tiempo total de entrenamiento: {elapsed/60:.1f} min")

# Restauramos mejor estado validación
if best_state is not None:
    model.load_state_dict(best_state)
    print("Modelo restaurado al mejor checkpoint de validación.")
Epoch 01/12 | train_loss=0.8272 train_acc=0.6634 | val_loss=0.5516 val_acc=0.8092
Epoch 02/12 | train_loss=0.4614 train_acc=0.8343 | val_loss=0.4082 val_acc=0.8616
Epoch 03/12 | train_loss=0.3566 train_acc=0.8722 | val_loss=0.4090 val_acc=0.8691
Epoch 04/12 | train_loss=0.3025 train_acc=0.8933 | val_loss=0.3447 val_acc=0.8844
Epoch 05/12 | train_loss=0.2538 train_acc=0.9107 | val_loss=0.3360 val_acc=0.8907
Epoch 06/12 | train_loss=0.2195 train_acc=0.9218 | val_loss=0.3343 val_acc=0.8949
Epoch 07/12 | train_loss=0.1941 train_acc=0.9294 | val_loss=0.3517 val_acc=0.8951
Epoch 08/12 | train_loss=0.1661 train_acc=0.9399 | val_loss=0.3408 val_acc=0.8971
Epoch 09/12 | train_loss=0.1503 train_acc=0.9458 | val_loss=0.3652 val_acc=0.8972
Early stopping activado.
Tiempo total de entrenamiento: 0.2 min
Modelo restaurado al mejor checkpoint de validación.
[15]
# ============================================================
# 5.1 Curvas de entrenamiento: loss y accuracy
# ============================================================

epoch_axis = np.arange(1, len(history["train_loss"]) + 1)

fig, ax = plt.subplots(1, 2, figsize=(13, 4))

ax[0].plot(epoch_axis, history["train_loss"], marker="o", label="train")
ax[0].plot(epoch_axis, history["val_loss"], marker="o", label="val")
ax[0].set_title("Loss vs Epoch")
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Cross-Entropy Loss")
ax[0].legend()

ax[1].plot(epoch_axis, history["train_acc"], marker="o", label="train")
ax[1].plot(epoch_axis, history["val_acc"], marker="o", label="val")
ax[1].set_title("Accuracy vs Epoch")
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("Accuracy")
ax[1].legend()

plt.tight_layout()
plt.show()
Output

6) Evaluación en test

Ahora medimos generalización sobre el conjunto de test con:

  • accuracy,
  • macro-F1,
  • reporte de clasificación por clase,
  • matriz de confusión.

Esto permite detectar clases más difíciles y posibles sesgos.

[17]
# ============================================================
# 6) Evaluación final en test
# ============================================================

model.eval()
all_preds, all_targets = [], []

with torch.no_grad():
    for x_batch, y_batch in test_loader:
        x_batch = x_batch.to(device)
        logits = model(x_batch)
        preds = torch.argmax(logits, dim=1).cpu().numpy()

        all_preds.extend(preds)
        all_targets.extend(y_batch.numpy())

test_acc = accuracy_score(all_targets, all_preds)
test_f1_macro = f1_score(all_targets, all_preds, average="macro")

print(f"Test Accuracy : {test_acc:.4f}")
print(f"Test Macro-F1 : {test_f1_macro:.4f}")

print("Classification report:")
print(classification_report(all_targets, all_preds, target_names=LABELS, digits=4))
Test Accuracy : 0.8875
Test Macro-F1 : 0.8876
Classification report:
              precision    recall  f1-score   support

       World     0.9139    0.8774    0.8953      1900
      Sports     0.9362    0.9505    0.9433      1900
    Business     0.8361    0.8647    0.8502      1900
    Sci/Tech     0.8656    0.8574    0.8614      1900

    accuracy                         0.8875      7600
   macro avg     0.8880    0.8875    0.8876      7600
weighted avg     0.8880    0.8875    0.8876      7600

[18]
# ============================================================
# 6.1 Matriz de confusión
# ============================================================

cm = confusion_matrix(all_targets, all_preds)

plt.figure(figsize=(6, 5))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=LABELS,
    yticklabels=LABELS,
)
plt.title("Matriz de confusión (test)")
plt.xlabel("Predicción")
plt.ylabel("Etiqueta real")
plt.tight_layout()
plt.show()
Output

7) Ejemplos cualitativos (aciertos y errores)

Además de métricas agregadas, conviene revisar ejemplos concretos para entender:

  • cuándo la GRU funciona bien,
  • qué patrones confunden al modelo,
  • cómo mejorar el preprocesado o la arquitectura.
[20]
# ============================================================
# 7) Mostrar ejemplos acertados/erróneos
# ============================================================

# Construimos predicciones sobre un subconjunto del test con texto original
sample_size = 300
sample_data = test_raw[:sample_size]

x_sample = torch.tensor([encode_text(text) for _, text in sample_data], dtype=torch.long).to(device)
y_sample = np.array([label for label, _ in sample_data])
texts_sample = [text for _, text in sample_data]

with torch.no_grad():
    logits = model(x_sample)
    pred_sample = torch.argmax(logits, dim=1).cpu().numpy()

correct_idx = np.where(pred_sample == y_sample)[0]
wrong_idx = np.where(pred_sample != y_sample)[0]

print(f"Aciertos en muestra: {len(correct_idx)}/{sample_size}")
print(f"Errores  en muestra: {len(wrong_idx)}/{sample_size}")

print("--- Ejemplos de aciertos ---")
for i in correct_idx[:3]:
    print(f"Real={LABELS[y_sample[i]]} | Pred={LABELS[pred_sample[i]]}")
    print(texts_sample[i][:180], "\n")

print("--- Ejemplos de errores ---")
for i in wrong_idx[:3]:
    print(f"Real={LABELS[y_sample[i]]} | Pred={LABELS[pred_sample[i]]}")
    print(texts_sample[i][:180], "\n")
Aciertos en muestra: 271/300
Errores  en muestra: 29/300
--- Ejemplos de aciertos ---
Real=Business | Pred=Business
Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul. 

Real=Sci/Tech | Pred=Sci/Tech
The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the  #36;10 million  

Real=Sci/Tech | Pred=Sci/Tech
Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better p 

--- Ejemplos de errores ---
Real=Sci/Tech | Pred=Sports
Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he  

Real=Sci/Tech | Pred=Business
Card fraud unit nets 36,000 cards In its first two years, the UK's dedicated card fraud unit, has recovered 36,000 stolen cards and 171 arrests - and estimates it saved 65m. 

Real=Sci/Tech | Pred=World
Super ant colony hits Australia A giant 100km colony of ants  which has been discovered in Melbourne, Australia, could threaten local insect species. 

8) Mini experimento guiado (ablation): GRU bidireccional vs unidireccional

Para reforzar el aprendizaje, proponemos comparar una variante unidireccional manteniendo el resto igual.

Puedes ejecutar esta celda para observar si la bidireccionalidad aporta mejora en accuracy/F1.

[21]
# ============================================================
# 8) Experimento opcional (rápido): cambiar bidirectional=False
# ============================================================

# Esta celda está pensada como ejercicio.
# Para ahorrar tiempo por defecto no se ejecuta entrenamiento completo aquí.
# Puedes descomentar y lanzar una versión corta (3-4 épocas).

# quick_model = GRUClassifier(
#     vocab_size=VOCAB_SIZE,
#     embed_dim=128,
#     hidden_dim=128,
#     num_classes=4,
#     num_layers=2,
#     dropout=0.3,
#     bidirectional=False,
#     pad_idx=PAD_IDX,
# ).to(device)
#
# quick_opt = torch.optim.Adam(quick_model.parameters(), lr=1e-3)
# for e in range(4):
#     tr_l, tr_a = run_epoch(quick_model, train_loader, optimizer=quick_opt)
#     va_l, va_a = run_epoch(quick_model, val_loader, optimizer=None)
#     print(f"[UNI-GRU] epoch={e+1} tr_acc={tr_a:.4f} val_acc={va_a:.4f}")

Conclusiones y siguientes pasos

Qué hemos aprendido

  • Una GRU permite modelar dependencias temporales con menos complejidad que LSTM.
  • En texto, la combinación Embedding + (Bi)GRU es una base sólida para clasificación secuencial.
  • El flujo completo (EDA → preprocesado → entrenamiento → evaluación → análisis cualitativo) es clave para construir modelos fiables.

Sugerencias para profundizar

  1. Embeddings preentrenados (GloVe/FastText) en vez de embeddings aleatorios.
  2. Probar PackedSequence para manejar longitudes variables sin padding fijo.
  3. Ajustar hiperparámetros: hidden_dim, num_layers, dropout, lr.
  4. Comparar GRU vs LSTM en igualdad de condiciones (parámetros/tiempo).
  5. Añadir attention sobre salidas de GRU.
  6. Pasar de titulares a textos más largos (p.ej. IMDB completo).

Si haces estas extensiones, tendrás una visión muy completa de cuándo y por qué usar GRU en problemas reales de NLP.