I’ve asked chatgpt to generate some sample code distill an embedding model and this is what I got
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
# 1️⃣ Load an NLI dataset with pair-class subset
:contentReference[oaicite:1]{index=1}
# hf_ds now contains 'premise', 'hypothesis', 'label' fields :contentReference[oaicite:2]{index=2}
sentences :contentReference[oaicite:3]{index=3}
:contentReference[oaicite:4]{index=4}
:contentReference[oaicite:5]{index=5}
:contentReference[oaicite:6]{index=6}
:contentReference[oaicite:7]{index=7}
:contentReference[oaicite:8]{index=8}
:contentReference[oaicite:9]{index=9}
:contentReference[oaicite:10]{index=10}
:contentReference[oaicite:11]{index=11}
# 2️⃣ Load teacher and student models
:contentReference[oaicite:12]{index=12}
:contentReference[oaicite:13]{index=13}
:contentReference[oaicite:14]{index=14}
:contentReference[oaicite:15]{index=15}
:contentReference[oaicite:16]{index=16}
:contentReference[oaicite:17]{index=17}
:contentReference[oaicite:18]{index=18}
:contentReference[oaicite:19]{index=19}
# 3️⃣ Mean pooling function
:contentReference[oaicite:20]{index=20}
:contentReference[oaicite:21]{index=21}
:contentReference[oaicite:22]{index=22}
:contentReference[oaicite:23]{index=23}
:contentReference[oaicite:24]{index=24}
# 4️⃣ Set up training
:contentReference[oaicite:25]{index=25}
:contentReference[oaicite:26]{index=26}
:contentReference[oaicite:27]{index=27}
:contentReference[oaicite:28]{index=28}
:contentReference[oaicite:29]{index=29}
# 5️⃣ Training loop
model_S.train()
:contentReference[oaicite:30]{index=30}
:contentReference[oaicite:31]{index=31}
:contentReference[oaicite:32]{index=32}
# Teacher embeddings
:contentReference[oaicite:33]{index=33}
:contentReference[oaicite:34]{index=34}
:contentReference[oaicite:35]{index=35}
:contentReference[oaicite:36]{index=36}
# Student embeddings
inp_S = tok_S(batch, padding=True, truncation=True, return_tensors="pt").to(device)
out_S = model_S(**inp_S)
emb_S = nn.functional.normalize(mean_pool(out_S.last_hidden_state, inp_S.attention_mask), dim=1)
# Cosine loss
labels = torch.ones(emb_T.size(0), device=device)
loss = loss_fn(emb_S, emb_T, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(loader)
print(f"[Epoch {epoch+1}] Avg Cosine Loss: {avg_loss:.4f}")
The code was actually generated correctly at first but after the generation was done it was corrupted with these contentReference tags