BERT — Análisis de sentimientos (IMDb) y mapas de atención
Fine-tuning de BERT para clasificación binaria de sentimiento en reseñas IMDb, con evaluación multi-métrica e inspección visual de mapas de self-attention.
🧠 BERT para análisis de sentimientos + mapas de atención (IMDb)
Objetivo del notebook
Este notebook tiene un objetivo doble:
- Aplicar un modelo tipo BERT a una tarea real de sentiment analysis.
- Entender cómo “mira” el texto el modelo a través de sus mapas de atención (self-attention maps).
La idea está alineada con la teoría del submódulo de LLMs: diferencia entre arquitecturas encoder-only (como BERT) y decoder-only (como GPT), uso de pre-entrenamiento y fine-tuning, y análisis interpretativo de decisiones del modelo.
Modelos y datasets que usaremos
Modelo base
bert-base-uncased(Hugging Face):- Arquitectura encoder-only.
- 12 capas Transformer (BERT base).
- 12 cabezas de atención por capa.
- Embedding oculto de tamaño 768.
Dataset
- IMDb (
datasetsde Hugging Face):- 50.000 reseñas de películas en inglés.
- Etiquetas binarias:
0(negativo),1(positivo). - División estándar: entrenamiento y test.
Fundamentos matemáticos y computacionales (visión didáctica)
1) Representación de entrada en BERT
Para cada token de entrada, BERT suma tres embeddings:
[ \mathbf{x}_i = \mathbf{e}^{\text{token}}_i + \mathbf{e}^{\text{segmento}}_i + \mathbf{e}^{\text{posición}}_i ]
- Token embedding: significado léxico (WordPiece).
- Segment embedding: distingue frase A/B (útil en tareas pareadas).
- Positional embedding: incorpora orden en la secuencia.
2) Atención escalada producto punto
En cada cabeza se construyen matrices Q, K y V:
[ \text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V ]
Interpretación intuitiva:
- Cada token “pregunta” (Q) a todos los demás tokens (K).
- El softmax convierte similitudes en pesos (importancias relativas).
- Con esos pesos se mezcla la información de V.
3) Multi-head attention
En lugar de una sola atención, hay varias cabezas en paralelo:
[ \text{MultiHead}(Q,K,V)=\text{Concat}(head_1,\dots,head_h)W^O ]
Cada cabeza puede especializarse en patrones distintos: negaciones (not), intensificadores (very), estructura sintáctica, etc.
4) Pre-training y fine-tuning
BERT se pre-entrena con objetivos como:
- MLM (Masked Language Modeling): predecir tokens ocultos.
- NSP (Next Sentence Prediction) (en BERT original).
Después se adapta por fine-tuning a una tarea concreta (aquí clasificación binaria).
5) Capa de clasificación
Para clasificación de secuencia se usa normalmente el estado final del token [CLS]:
[ \hat{y}=\text{softmax}(W\mathbf{h}_{[CLS]}+b) ]
La pérdida es entropía cruzada:
[ \mathcal{L}=-\sum_{c \in {0,1}} y_c\log(\hat{y}_c) ]
Qué aprenderás al final
- Cómo preparar texto para BERT (tokenización y truncado/padding).
- Cómo entrenar (o re-entrenar parcialmente) un clasificador de sentimientos.
- Cómo evaluar rigurosamente con múltiples métricas (accuracy, F1, matriz de confusión, ROC-AUC).
- Cómo inspeccionar mapas de atención y conectar patrones de atención con errores/aciertos del modelo.
Nota: entrenar BERT completo puede tardar. Este notebook usa un subconjunto configurable para mantenerlo práctico y reproducible en entorno educativo.
1) Instalación/importación y configuración
# Si hace falta instalar dependencias (descomenta en Colab/local)
# !pip install -q transformers datasets evaluate accelerate scikit-learn seaborn matplotlib
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
DataCollatorWithPadding,
TrainingArguments,
Trainer,
)
from sklearn.metrics import (
precision_recall_fscore_support,
confusion_matrix,
classification_report,
roc_auc_score,
roc_curve,
)
import evaluate
# Estilo visual
sns.set(style="whitegrid")
plt.rcParams["figure.figsize"] = (9, 5)
# Reproducibilidad
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Dispositivo detectado: {device}")
Dispositivo detectado: cuda
2) Carga del dataset IMDb
Usaremos la partición estándar (train y test). Luego tomaremos subconjuntos para entrenar más rápido en modo didáctico.
# Cargamos IMDb desde Hugging Face Datasets
imdb = load_dataset("imdb")
print(imdb)
# Vemos un ejemplo
print("\nEjemplo de reseña:")
print(imdb["train"][0]["text"][:500], "...")
print("Etiqueta:", imdb["train"][0]["label"], "(0=neg, 1=pos)")
DatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 25000
})
test: Dataset({
features: ['text', 'label'],
num_rows: 25000
})
unsupervised: Dataset({
features: ['text', 'label'],
num_rows: 50000
})
})
Ejemplo de reseña:
I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attent ...
Etiqueta: 0 (0=neg, 1=pos)
# Distribución de etiquetas (sanity check)
train_labels = np.array(imdb["train"]["label"])
test_labels = np.array(imdb["test"]["label"])
print("Distribución train:", np.bincount(train_labels))
print("Distribución test:", np.bincount(test_labels))
plt.figure(figsize=(6,4))
plt.bar(["negativo", "positivo"], np.bincount(train_labels), color=["#ff6b6b", "#51cf66"])
plt.title("Distribución de clases en IMDb (train)")
plt.ylabel("Número de reseñas")
plt.show()
Distribución train: [12500 12500] Distribución test: [12500 12500]
3) Tokenización (WordPiece) y preparación de subconjuntos
MODEL_NAME = "bert-base-uncased"
MAX_LENGTH = 256 # Reducimos longitud para acelerar entrenamiento en entornos modestos
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def tokenize_batch(batch):
# Tokenizamos lotes completos; truncamos a MAX_LENGTH
return tokenizer(batch["text"], truncation=True, max_length=MAX_LENGTH)
# Tamaños didácticos (ajusta si tienes GPU potente)
N_TRAIN = 4000
N_EVAL = 1000
N_TEST = 2000
train_small = imdb["train"].shuffle(seed=SEED).select(range(N_TRAIN))
eval_small = imdb["train"].shuffle(seed=SEED + 1).select(range(N_EVAL))
test_small = imdb["test"].shuffle(seed=SEED).select(range(N_TEST))
train_tok = train_small.map(tokenize_batch, batched=True)
eval_tok = eval_small.map(tokenize_batch, batched=True)
test_tok = test_small.map(tokenize_batch, batched=True)
# El collator se encarga del padding dinámico por batch
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
print(train_tok)
Dataset({
features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 4000
})
4) Definición del modelo BERT para clasificación binaria
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=2,
)
model.to(device)
print(model.config)
print(f"Número de parámetros (aprox): {model.num_parameters():,}")
Loading weights: 100%|██████████| 199/199 [00:00<00:00, 13105.35it/s] [1mBertForSequenceClassification LOAD REPORT[0m from: bert-base-uncased Key | Status | -------------------------------------------+------------+- cls.predictions.transform.dense.weight | UNEXPECTED | cls.seq_relationship.weight | UNEXPECTED | cls.predictions.transform.LayerNorm.bias | UNEXPECTED | cls.predictions.transform.LayerNorm.weight | UNEXPECTED | cls.seq_relationship.bias | UNEXPECTED | cls.predictions.bias | UNEXPECTED | cls.predictions.transform.dense.bias | UNEXPECTED | classifier.bias | MISSING | classifier.weight | MISSING | [3mNotes: - UNEXPECTED[3m :can be ignored when loading from different task/architecture; not ok if you expect identical arch. - MISSING[3m :those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m
BertConfig {
"add_cross_attention": false,
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"bos_token_id": null,
"classifier_dropout": null,
"dtype": "float32",
"eos_token_id": null,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": false,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"tie_word_embeddings": true,
"transformers_version": "5.3.0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}
Número de parámetros (aprox): 109,483,778
5) Métricas de evaluación
Calculamos accuracy, precision, recall y F1 macro para tener una visión más robusta que accuracy sola.
metric_acc = evaluate.load("accuracy")
metric_f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
acc = metric_acc.compute(predictions=preds, references=labels)["accuracy"]
f1_macro = metric_f1.compute(predictions=preds, references=labels, average="macro")["f1"]
precision, recall, f1_weighted, _ = precision_recall_fscore_support(
labels, preds, average="weighted", zero_division=0
)
return {
"accuracy": acc,
"f1_macro": f1_macro,
"precision_weighted": precision,
"recall_weighted": recall,
"f1_weighted": f1_weighted,
}
6) Fine-tuning con Trainer (entrenamiento didáctico)
training_args = TrainingArguments(
output_dir="./bert-imdb-results",
eval_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=50,
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="f1_macro",
report_to="none", # Evita requerir wandb/mlflow
seed=SEED,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_tok,
eval_dataset=eval_tok,
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
train_output = trainer.train()
print(train_output)
| Epoch | Training Loss | Validation Loss | Accuracy | F1 Macro | Precision Weighted | Recall Weighted | F1 Weighted |
|---|---|---|---|---|---|---|---|
| 1 | 0.394222 | 0.331851 | 0.878000 | 0.877960 | 0.895363 | 0.878000 | 0.877780 |
| 2 | 0.226751 | 0.345660 | 0.910000 | 0.909449 | 0.910086 | 0.910000 | 0.910028 |
Writing model shards: 100%|██████████| 1/1 [00:01<00:00, 1.23s/it] Writing model shards: 100%|██████████| 1/1 [00:01<00:00, 1.21s/it] There were missing keys in the checkpoint model loaded: ['bert.embeddings.LayerNorm.weight', 'bert.embeddings.LayerNorm.bias', 'bert.encoder.layer.0.attention.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.output.LayerNorm.bias', 'bert.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.layer.0.output.LayerNorm.bias', 'bert.encoder.layer.1.attention.output.LayerNorm.weight', 'bert.encoder.layer.1.attention.output.LayerNorm.bias', 'bert.encoder.layer.1.output.LayerNorm.weight', 'bert.encoder.layer.1.output.LayerNorm.bias', 'bert.encoder.layer.2.attention.output.LayerNorm.weight', 'bert.encoder.layer.2.attention.output.LayerNorm.bias', 'bert.encoder.layer.2.output.LayerNorm.weight', 'bert.encoder.layer.2.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.layer.3.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.output.LayerNorm.weight', 'bert.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.5.output.LayerNorm.weight', 'bert.encoder.layer.5.output.LayerNorm.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.11.output.LayerNorm.bias']. There were unexpected keys in the checkpoint model loaded: ['bert.embeddings.LayerNorm.beta', 'bert.embeddings.LayerNorm.gamma', 'bert.encoder.layer.0.attention.output.LayerNorm.beta', 'bert.encoder.layer.0.attention.output.LayerNorm.gamma', 'bert.encoder.layer.0.output.LayerNorm.beta', 'bert.encoder.layer.0.output.LayerNorm.gamma', 'bert.encoder.layer.1.attention.output.LayerNorm.beta', 'bert.encoder.layer.1.attention.output.LayerNorm.gamma', 'bert.encoder.layer.1.output.LayerNorm.beta', 'bert.encoder.layer.1.output.LayerNorm.gamma', 'bert.encoder.layer.2.attention.output.LayerNorm.beta', 'bert.encoder.layer.2.attention.output.LayerNorm.gamma', 'bert.encoder.layer.2.output.LayerNorm.beta', 'bert.encoder.layer.2.output.LayerNorm.gamma', 'bert.encoder.layer.3.attention.output.LayerNorm.beta', 'bert.encoder.layer.3.attention.output.LayerNorm.gamma', 'bert.encoder.layer.3.output.LayerNorm.beta', 'bert.encoder.layer.3.output.LayerNorm.gamma', 'bert.encoder.layer.4.attention.output.LayerNorm.beta', 'bert.encoder.layer.4.attention.output.LayerNorm.gamma', 'bert.encoder.layer.4.output.LayerNorm.beta', 'bert.encoder.layer.4.output.LayerNorm.gamma', 'bert.encoder.layer.5.attention.output.LayerNorm.beta', 'bert.encoder.layer.5.attention.output.LayerNorm.gamma', 'bert.encoder.layer.5.output.LayerNorm.beta', 'bert.encoder.layer.5.output.LayerNorm.gamma', 'bert.encoder.layer.6.attention.output.LayerNorm.beta', 'bert.encoder.layer.6.attention.output.LayerNorm.gamma', 'bert.encoder.layer.6.output.LayerNorm.beta', 'bert.encoder.layer.6.output.LayerNorm.gamma', 'bert.encoder.layer.7.attention.output.LayerNorm.beta', 'bert.encoder.layer.7.attention.output.LayerNorm.gamma', 'bert.encoder.layer.7.output.LayerNorm.beta', 'bert.encoder.layer.7.output.LayerNorm.gamma', 'bert.encoder.layer.8.attention.output.LayerNorm.beta', 'bert.encoder.layer.8.attention.output.LayerNorm.gamma', 'bert.encoder.layer.8.output.LayerNorm.beta', 'bert.encoder.layer.8.output.LayerNorm.gamma', 'bert.encoder.layer.9.attention.output.LayerNorm.beta', 'bert.encoder.layer.9.attention.output.LayerNorm.gamma', 'bert.encoder.layer.9.output.LayerNorm.beta', 'bert.encoder.layer.9.output.LayerNorm.gamma', 'bert.encoder.layer.10.attention.output.LayerNorm.beta', 'bert.encoder.layer.10.attention.output.LayerNorm.gamma', 'bert.encoder.layer.10.output.LayerNorm.beta', 'bert.encoder.layer.10.output.LayerNorm.gamma', 'bert.encoder.layer.11.attention.output.LayerNorm.beta', 'bert.encoder.layer.11.attention.output.LayerNorm.gamma', 'bert.encoder.layer.11.output.LayerNorm.beta', 'bert.encoder.layer.11.output.LayerNorm.gamma'].
TrainOutput(global_step=1000, training_loss=0.2959137349128723, metrics={'train_runtime': 55.0103, 'train_samples_per_second': 145.427, 'train_steps_per_second': 18.178, 'total_flos': 1051663110494400.0, 'train_loss': 0.2959137349128723, 'epoch': 2.0})
7) Evaluación cuantitativa en test (IMDb)
# Eliminar NotebookProgressCallback que pierde estado entre celdas
from transformers.utils.notebook import NotebookProgressCallback
trainer.remove_callback(NotebookProgressCallback)
test_metrics = trainer.evaluate(test_tok)
print("Métricas en test:")
for k, v in test_metrics.items():
if isinstance(v, (int, float)):
print(f"{k:25s}: {v:.4f}")
Métricas en test: eval_loss : 0.3709 eval_accuracy : 0.9055 eval_f1_macro : 0.9055 eval_precision_weighted : 0.9055 eval_recall_weighted : 0.9055 eval_f1_weighted : 0.9055 eval_runtime : 3.0385 eval_samples_per_second : 658.2270 eval_steps_per_second : 41.1390 epoch : 2.0000
# Predicciones detalladas para análisis adicional
pred_output = trainer.predict(test_tok)
logits = pred_output.predictions
y_true = np.array(test_tok["label"])
y_pred = np.argmax(logits, axis=1)
# Probabilidad de clase positiva
probs_pos = torch.softmax(torch.tensor(logits), dim=1)[:, 1].numpy()
print(classification_report(y_true, y_pred, target_names=["negativo", "positivo"]))
precision recall f1-score support
negativo 0.90 0.91 0.91 1000
positivo 0.91 0.90 0.91 1000
accuracy 0.91 2000
macro avg 0.91 0.91 0.91 2000
weighted avg 0.91 0.91 0.91 2000
# Matriz de confusión
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
xticklabels=["neg", "pos"], yticklabels=["neg", "pos"])
plt.title("Matriz de confusión (test IMDb)")
plt.xlabel("Predicción")
plt.ylabel("Etiqueta real")
plt.show()
# Curva ROC y AUC
auc = roc_auc_score(y_true, probs_pos)
fpr, tpr, _ = roc_curve(y_true, probs_pos)
plt.figure(figsize=(7,5))
plt.plot(fpr, tpr, label=f"BERT (AUC = {auc:.4f})", linewidth=2)
plt.plot([0,1], [0,1], linestyle="--", color="gray", label="Azar")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Curva ROC en IMDb")
plt.legend()
plt.show()
8) Inspección cualitativa de errores
Revisar ejemplos mal clasificados ayuda a detectar patrones:
- sarcasmo,
- opiniones mixtas,
- frases largas con negaciones complejas,
- ruido de formato.
# Mostramos algunos errores de clasificación
errors_idx = np.where(y_true != y_pred)[0]
print(f"Número de errores en test_small: {len(errors_idx)}")
for idx in errors_idx[:5]:
sample = test_small[int(idx)]
pred_label = int(y_pred[idx])
true_label = int(y_true[idx])
conf = float(max(1 - probs_pos[idx], probs_pos[idx]))
print("\n" + "="*90)
print(f"REAL: {true_label} | PRED: {pred_label} | Confianza aprox: {conf:.3f}")
print("Texto (recortado):")
print(sample["text"][:600].replace("\n", " "), "...")
Número de errores en test_small: 189 ========================================================================================== REAL: 0 | PRED: 1 | Confianza aprox: 0.587 Texto (recortado): Coming from Kiarostami, this art-house visual and sound exposition is a surprise. For a director known for his narratives and keen observation of humans, especially children, this excursion into minimalist cinematography begs for questions: Why did he do it? Was it to keep him busy during a vacation at the shore? <br /><br />"Five, 5 Long Takes" consists of, you guessed it, five long takes. They are (the title names are my own and the times approximate): <br /><br />"Driftwood and waves". The camera stands nearly still looking at a small piece of driftwood as it gets moved around by small wave ... ========================================================================================== REAL: 0 | PRED: 1 | Confianza aprox: 0.944 Texto (recortado): "An astronaut (Michael Emmet) dies while returning from a mission and his body is recovered by the military. The base where the dead astronaut is taken to becomes the scene of a bizarre invasion plan from outer space. Alien embryos inside the dead astronaut resurrect the corpse and begin a terrifying assault on the military staff in the hopes of conquering the world," according to the DVD sleeve's synopsis.<br /><br />A Roger Corman "American International" production. The man who fell to Earth impregnated, Mr. Emmet (as John Corcoran), does all right. Angela Greene is his pretty conflicted fi ... ========================================================================================== REAL: 0 | PRED: 1 | Confianza aprox: 0.995 Texto (recortado): I used to always love the bill because of its great script and characters, but lately i feel as though it has turned into an emotional type of soap. If you look at promotional pictures/posters of the bill now you will see either two of the officers hugging/kissing or something to do with friendships whereas promotional pictures of the bill a long time ago would have shown something to do with crime. This proves that it has changed a lot from being an absolutely amazing Police drama to an average type of television soap. When i watch it i feel like I'm watching a police version of Coronation St ... ========================================================================================== REAL: 0 | PRED: 1 | Confianza aprox: 0.997 Texto (recortado): A truly masterful piece of filmmaking. It managed to put me to sleep and to boggle my mind. So boring that it induces sleep and yet so ludicrous that it made me wonder how stuff like this gets made. Avoid at all costs. That is, unless you like taking invisible cranial punishment, in which case I highly recommend it. ... ========================================================================================== REAL: 1 | PRED: 0 | Confianza aprox: 0.866 Texto (recortado): First, the positives: an excellent job at depicting urban landscapes to suit the mood of the film. Some of the shots could be paintings by De Chirico. Sophie Marceau, beautiful.<br /><br />The negatives: the stories are hard to believe. Unreal, uni-dimensional characters preen and posture 100% of the time, as if they were in some kind of catwalk. This is neither the Antonioni of his earlier, much better movies nor the Wenders we've all come to know and appreciate. Malkovich is excess baggage in this movie. ...
9) Mapas de atención: ¿en qué tokens se fija BERT?
Ahora entra la parte interpretativa.
Idea clave
En BERT, cada capa y cada cabeza generan una matriz de atención de tamaño:
[ (n_tokens \times n_tokens) ]
donde la fila i indica cuánto atiende el token i al resto.
En clasificación de secuencia, el token [CLS] suele actuar como “resumen” global, así que mirar la atención desde [CLS] es una aproximación útil para interpretar señales relevantes.
⚠️ Importante: atención no equivale siempre a explicación causal perfecta. Es una pista valiosa, pero conviene combinarla con otras técnicas (ablaciones, Integrated Gradients, SHAP, etc.).
# Cargamos el mejor modelo entrenado con salida de atenciones activada
best_model = trainer.model
# Cambiar a implementación "eager" para poder extraer mapas de atención
best_model.config._attn_implementation = "eager"
best_model.config.output_attentions = True
best_model.eval()
def get_attention_for_text(text, model, tokenizer, max_length=128):
"""
Devuelve tokens y atenciones para un texto.
attentions es una tupla de longitud n_layers; cada elemento tiene forma:
(batch_size, n_heads, seq_len, seq_len)
"""
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=max_length,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
attentions = outputs.attentions
logits = outputs.logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
return tokens, attentions, probs
# Frases de ejemplo (puedes añadir más)
example_texts = [
"This movie is absolutely wonderful, emotionally rich and beautifully directed.",
"I expected a lot, but the plot was boring and the acting was terrible.",
"The film is not bad at all, actually quite enjoyable in many scenes.",
]
for t in example_texts:
tokens, attentions, probs = get_attention_for_text(t, best_model, tokenizer)
pred = int(np.argmax(probs))
print("\nTexto:", t)
print(f"Probabilidades -> negativo: {probs[0]:.3f}, positivo: {probs[1]:.3f} | pred={pred}")
print("Tokens:", tokens[:25], "...")
Texto: This movie is absolutely wonderful, emotionally rich and beautifully directed. Probabilidades -> negativo: 0.003, positivo: 0.997 | pred=1 Tokens: ['[CLS]', 'this', 'movie', 'is', 'absolutely', 'wonderful', ',', 'emotionally', 'rich', 'and', 'beautifully', 'directed', '.', '[SEP]'] ... Texto: I expected a lot, but the plot was boring and the acting was terrible. Probabilidades -> negativo: 0.998, positivo: 0.002 | pred=0 Tokens: ['[CLS]', 'i', 'expected', 'a', 'lot', ',', 'but', 'the', 'plot', 'was', 'boring', 'and', 'the', 'acting', 'was', 'terrible', '.', '[SEP]'] ... Texto: The film is not bad at all, actually quite enjoyable in many scenes. Probabilidades -> negativo: 0.053, positivo: 0.947 | pred=1 Tokens: ['[CLS]', 'the', 'film', 'is', 'not', 'bad', 'at', 'all', ',', 'actually', 'quite', 'enjoyable', 'in', 'many', 'scenes', '.', '[SEP]'] ...
def plot_attention_head(tokens, attentions, layer=0, head=0, max_tokens=25):
"""
Visualiza una matriz de atención (layer, head) como heatmap.
Recorta a max_tokens para que sea legible.
"""
# attentions[layer]: (1, n_heads, seq_len, seq_len)
attn = attentions[layer][0, head].detach().cpu().numpy()
n = min(max_tokens, len(tokens))
attn = attn[:n, :n]
tok = tokens[:n]
plt.figure(figsize=(8,6))
sns.heatmap(attn, cmap="magma", xticklabels=tok, yticklabels=tok)
plt.title(f"Mapa de atención — capa {layer}, cabeza {head}")
plt.xlabel("Token al que se atiende")
plt.ylabel("Token que atiende")
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
# Ejemplo: visualizamos una cabeza concreta
sample_text = "The movie was not great, but the soundtrack was surprisingly good."
tokens, attentions, probs = get_attention_for_text(sample_text, best_model, tokenizer)
print(f"Probabilidad positivo: {probs[1]:.3f}")
plot_attention_head(tokens, attentions, layer=10, head=3, max_tokens=30)
Probabilidad positivo: 0.990
def top_tokens_attended_by_cls(tokens, attentions, layer=11, head=0, top_k=10):
"""
Muestra qué tokens reciben más atención desde [CLS]
para una capa/cabeza concretas.
"""
attn = attentions[layer][0, head].detach().cpu().numpy() # (seq_len, seq_len)
# Fila 0 suele corresponder al token [CLS]
cls_to_all = attn[0]
idx_sorted = np.argsort(cls_to_all)[::-1][:top_k]
results = [(tokens[i], float(cls_to_all[i])) for i in idx_sorted]
return results
for layer in [0, 5, 11]:
top = top_tokens_attended_by_cls(tokens, attentions, layer=layer, head=0, top_k=8)
print(f"\nTop tokens atendidos por [CLS] (layer={layer}, head=0):")
for tok, score in top:
print(f" {tok:15s} -> {score:.4f}")
Top tokens atendidos por [CLS] (layer=0, head=0): [SEP] -> 0.2338 the -> 0.1033 the -> 0.0947 . -> 0.0839 but -> 0.0773 great -> 0.0579 not -> 0.0549 [CLS] -> 0.0474 Top tokens atendidos por [CLS] (layer=5, head=0): [SEP] -> 0.7696 [CLS] -> 0.0559 was -> 0.0412 the -> 0.0300 movie -> 0.0291 soundtrack -> 0.0153 great -> 0.0137 was -> 0.0090 Top tokens atendidos por [CLS] (layer=11, head=0): [CLS] -> 0.2266 [SEP] -> 0.1631 . -> 0.1184 , -> 0.0983 soundtrack -> 0.0550 was -> 0.0541 good -> 0.0500 the -> 0.0388
10) Mini-experimentos (tests didácticos)
Estos bloques son útiles para consolidar intuiciones.
# TEST 1: sensibilidad a negación
pairs = [
("The movie is good.", "The movie is not good."),
("The plot is interesting.", "The plot is not interesting."),
]
for a, b in pairs:
_, _, pa = get_attention_for_text(a, best_model, tokenizer)
_, _, pb = get_attention_for_text(b, best_model, tokenizer)
print("\n---")
print(a, "-> P(pos)=", round(float(pa[1]), 4))
print(b, "-> P(pos)=", round(float(pb[1]), 4))
--- The movie is good. -> P(pos)= 0.9954 The movie is not good. -> P(pos)= 0.0039 --- The plot is interesting. -> P(pos)= 0.9931 The plot is not interesting. -> P(pos)= 0.0071
# TEST 2: robustez con intensificadores
variants = [
"The acting is good.",
"The acting is very good.",
"The acting is extremely good.",
"The acting is unbelievably good.",
]
for v in variants:
_, _, p = get_attention_for_text(v, best_model, tokenizer)
print(f"{v:45s} -> P(pos)={p[1]:.4f}")
The acting is good. -> P(pos)=0.9918 The acting is very good. -> P(pos)=0.9909 The acting is extremely good. -> P(pos)=0.9915 The acting is unbelievably good. -> P(pos)=0.9942
11) Conclusiones
- BERT (encoder-only) funciona muy bien en clasificación de sentimiento tras fine-tuning.
- El análisis de atención ayuda a construir intuición sobre qué partes del texto influyen en la decisión.
- Vimos que métricas múltiples (F1, ROC-AUC, matriz de confusión) aportan una evaluación más completa.
- Revisar errores cualitativamente es clave para detectar límites del modelo.
12) Qué podrías probar a continuación
- Comparar BERT vs DistilBERT vs RoBERTa en tiempo/calidad.
- Entrenar con más datos o más épocas y estudiar sobreajuste.
- Probar longitudes máximas distintas (
max_length=128/256/512). - Usar técnicas de interpretabilidad adicionales (Integrated Gradients, LIME, SHAP).
- Analizar sesgos: ¿hay palabras o estilos que disparan falsos positivos/negativos?
- Repetir el flujo en otro dominio (reviews de productos, noticias, soporte técnico).
Si quieres convertir este notebook en práctica evaluable del submódulo LLMs, una buena extensión es pedir al alumno un informe comparativo entre dos modelos y una discusión crítica sobre la interpretabilidad por atención.