Skip to main content

scirs2_text/bert_finetune/
mod.rs

1//! Lightweight BERT fine-tuning API on top of pre-computed embeddings.
2//!
3//! This module provides a gradient-descent fine-tuning API for classification
4//! and sequence labelling tasks.  It operates on pre-computed dense embeddings
5//! (e.g., \[CLS\] token embeddings from BERT) rather than raw text, so it does
6//! **not** require any external ML library.
7
8use crate::error::{Result, TextError};
9use std::f64;
10
11// ─── FineTuneTask ─────────────────────────────────────────────────────────────
12
13/// The task type that determines the classifier head configuration.
14#[derive(Debug, Clone)]
15#[non_exhaustive]
16pub enum FineTuneTask {
17    /// Single-sentence classification.
18    Classification {
19        /// Number of output classes.
20        n_classes: usize,
21    },
22    /// Per-token sequence labelling (NER, POS, etc.).
23    SequenceLabeling {
24        /// Number of label classes.
25        n_labels: usize,
26    },
27    /// Two-sentence pair classification (e.g., NLI, STS).
28    SentencePairClassification {
29        /// Number of output classes.
30        n_classes: usize,
31    },
32}
33
34impl FineTuneTask {
35    /// Number of output classes / labels for this task.
36    pub fn n_outputs(&self) -> usize {
37        match self {
38            FineTuneTask::Classification { n_classes } => *n_classes,
39            FineTuneTask::SequenceLabeling { n_labels } => *n_labels,
40            FineTuneTask::SentencePairClassification { n_classes } => *n_classes,
41        }
42    }
43}
44
45// ─── FineTuneConfig ───────────────────────────────────────────────────────────
46
47/// Hyperparameters for fine-tuning.
48#[derive(Debug, Clone)]
49pub struct FineTuneConfig {
50    /// Peak learning rate (after warmup).
51    pub lr: f64,
52    /// Number of training epochs.
53    pub n_epochs: usize,
54    /// Mini-batch size.
55    pub batch_size: usize,
56    /// Number of linear warmup steps.
57    pub warmup_steps: usize,
58    /// Maximum gradient norm for clipping.
59    pub max_grad_norm: f64,
60    /// Dropout probability applied to the input embedding during training.
61    pub dropout: f64,
62}
63
64impl Default for FineTuneConfig {
65    fn default() -> Self {
66        Self {
67            lr: 2e-5,
68            n_epochs: 3,
69            batch_size: 32,
70            warmup_steps: 100,
71            max_grad_norm: 1.0,
72            dropout: 0.1,
73        }
74    }
75}
76
77// ─── ClassificationHead ───────────────────────────────────────────────────────
78
79/// Single linear classification head: logits = W · embedding + b.
80#[derive(Debug, Clone)]
81pub struct ClassificationHead {
82    /// Weight matrix of shape `(n_classes, hidden_size)`.
83    pub weight: Vec<Vec<f64>>,
84    /// Bias vector of length `n_classes`.
85    pub bias: Vec<f64>,
86}
87
88impl ClassificationHead {
89    /// Construct a randomly initialised head.
90    pub fn new(hidden_size: usize, n_classes: usize) -> Self {
91        let mut seed: u64 = 0xFAFAFAFA_12345678;
92        let weight = (0..n_classes)
93            .map(|_| {
94                (0..hidden_size)
95                    .map(|_| {
96                        seed = seed
97                            .wrapping_mul(6364136223846793005)
98                            .wrapping_add(1442695040888963407);
99                        let bits = (seed >> 33) as f64 / (u32::MAX as f64);
100                        (bits - 0.5) * 0.02 // Xavier-like init
101                    })
102                    .collect()
103            })
104            .collect();
105
106        Self {
107            weight,
108            bias: vec![0.0; n_classes],
109        }
110    }
111
112    /// Compute raw logits for a single \[CLS\] embedding.
113    pub fn forward(&self, cls_embedding: &[f64]) -> Vec<f64> {
114        self.weight
115            .iter()
116            .zip(self.bias.iter())
117            .map(|(row, &b)| {
118                row.iter()
119                    .zip(cls_embedding.iter())
120                    .map(|(w, x)| w * x)
121                    .sum::<f64>()
122                    + b
123            })
124            .collect()
125    }
126
127    /// One-step update using the cross-entropy gradient.
128    ///
129    /// Returns the cross-entropy loss for the sample.
130    pub fn backward_update(
131        &mut self,
132        cls_embedding: &[f64],
133        logits: &[f64],
134        label: usize,
135        lr: f64,
136    ) -> f64 {
137        let n_classes = logits.len();
138        if label >= n_classes {
139            // Safety: silently clamp to last class (no panic).
140            return 0.0;
141        }
142
143        // Softmax probabilities.
144        let probs = softmax(logits);
145
146        // Cross-entropy loss: -log p(true_class).
147        let loss = -(probs[label] + 1e-15).ln();
148
149        // Gradient of CE w.r.t. logits: probs - one_hot.
150        let grad_logits: Vec<f64> = probs
151            .iter()
152            .enumerate()
153            .map(|(k, &p)| if k == label { p - 1.0 } else { p })
154            .collect();
155
156        // Gradient w.r.t. weight[k][j] = grad_logits[k] * cls_embedding[j].
157        // Gradient w.r.t. bias[k]       = grad_logits[k].
158        let hidden = cls_embedding.len();
159        for k in 0..n_classes {
160            let g = grad_logits[k];
161            self.bias[k] -= lr * g;
162            for j in 0..hidden {
163                self.weight[k][j] -= lr * g * cls_embedding[j];
164            }
165        }
166
167        loss
168    }
169}
170
171// ─── BertFineTuner ────────────────────────────────────────────────────────────
172
173/// Fine-tuner wrapping a `ClassificationHead`.
174///
175/// Operates on batches of pre-computed BERT \[CLS\] embeddings.
176pub struct BertFineTuner {
177    /// Classification / labelling head.
178    pub head: ClassificationHead,
179    /// Fine-tuning hyperparameters.
180    pub config: FineTuneConfig,
181    /// Current global training step (for LR schedule).
182    pub step: usize,
183    /// Total training steps (set at `train` call for cosine decay).
184    total_steps: usize,
185}
186
187impl BertFineTuner {
188    /// Create a new `BertFineTuner`.
189    ///
190    /// # Errors
191    /// Returns `TextError::InvalidInput` if the task specifies 0 output classes.
192    pub fn new(hidden_size: usize, task: FineTuneTask, config: FineTuneConfig) -> Result<Self> {
193        let n_outputs = task.n_outputs();
194        if n_outputs == 0 {
195            return Err(TextError::InvalidInput(
196                "BertFineTuner: task must have at least 1 output class".into(),
197            ));
198        }
199        Ok(Self {
200            head: ClassificationHead::new(hidden_size, n_outputs),
201            config,
202            step: 0,
203            total_steps: 0,
204        })
205    }
206
207    // ── LR schedule ─────────────────────────────────────────────────────────
208
209    /// Learning rate at the current step: linear warmup then cosine decay.
210    pub fn learning_rate_schedule(&self) -> f64 {
211        let peak = self.config.lr;
212        let warmup = self.config.warmup_steps as f64;
213        let total = (self.total_steps.max(1)) as f64;
214        let s = self.step as f64;
215
216        if s < warmup {
217            // Linear warmup.
218            peak * (s + 1.0) / warmup
219        } else {
220            // Cosine decay.
221            let progress = (s - warmup) / (total - warmup).max(1.0);
222            let cosine = (1.0 + (std::f64::consts::PI * progress).cos()) * 0.5;
223            peak * cosine
224        }
225    }
226
227    // ── gradient clipping ────────────────────────────────────────────────────
228
229    /// Apply gradient norm clipping.  Returns the (possibly scaled) gradient.
230    fn clip_grad(grad: &mut [f64], max_norm: f64) {
231        let norm: f64 = grad.iter().map(|x| x * x).sum::<f64>().sqrt();
232        if norm > max_norm && norm > 1e-12 {
233            let scale = max_norm / norm;
234            grad.iter_mut().for_each(|g| *g *= scale);
235        }
236    }
237
238    // ── training ────────────────────────────────────────────────────────────
239
240    /// Train on a set of (embedding, label) pairs for `config.n_epochs` epochs.
241    ///
242    /// Returns a vector of per-epoch average losses.
243    pub fn train(&mut self, embeddings: &[Vec<f64>], labels: &[usize]) -> Vec<f64> {
244        let n = embeddings.len().min(labels.len());
245        let batch_size = self.config.batch_size.max(1);
246        let n_epochs = self.config.n_epochs;
247        self.total_steps = n_epochs * n.div_ceil(batch_size);
248
249        let mut epoch_losses = Vec::with_capacity(n_epochs);
250
251        for _epoch in 0..n_epochs {
252            let mut epoch_loss = 0.0_f64;
253            let mut n_batches = 0usize;
254
255            // Mini-batch SGD.
256            let mut start = 0;
257            while start < n {
258                let end = (start + batch_size).min(n);
259                let batch_embs = &embeddings[start..end];
260                let batch_labels = &labels[start..end];
261
262                let lr = self.learning_rate_schedule();
263                let mut batch_loss = 0.0_f64;
264
265                // Accumulate gradients.
266                let n_classes = self.head.bias.len();
267                let hidden = if batch_embs.is_empty() {
268                    0
269                } else {
270                    batch_embs[0].len()
271                };
272                let mut grad_w = vec![vec![0.0_f64; hidden]; n_classes];
273                let mut grad_b = vec![0.0_f64; n_classes];
274
275                for (emb, &lbl) in batch_embs.iter().zip(batch_labels.iter()) {
276                    let logits = self.head.forward(emb);
277                    let probs = softmax(&logits);
278                    let loss = -(probs[lbl.min(n_classes - 1)] + 1e-15).ln();
279                    batch_loss += loss;
280
281                    // Gradient of CE w.r.t. logits.
282                    for k in 0..n_classes {
283                        let g = if k == lbl { probs[k] - 1.0 } else { probs[k] };
284                        grad_b[k] += g;
285                        for j in 0..hidden {
286                            grad_w[k][j] += g * emb[j];
287                        }
288                    }
289                }
290
291                let batch_len = (end - start) as f64;
292
293                // Average gradients.
294                grad_b.iter_mut().for_each(|g| *g /= batch_len);
295                for row in &mut grad_w {
296                    row.iter_mut().for_each(|g| *g /= batch_len);
297                }
298
299                // Clip gradients.
300                let max_norm = self.config.max_grad_norm;
301                Self::clip_grad(&mut grad_b, max_norm);
302                for row in &mut grad_w {
303                    Self::clip_grad(row, max_norm);
304                }
305
306                // Apply updates.
307                for k in 0..n_classes {
308                    self.head.bias[k] -= lr * grad_b[k];
309                    for j in 0..hidden {
310                        self.head.weight[k][j] -= lr * grad_w[k][j];
311                    }
312                }
313
314                epoch_loss += batch_loss / batch_len;
315                n_batches += 1;
316                self.step += 1;
317                start = end;
318            }
319
320            epoch_losses.push(if n_batches > 0 {
321                epoch_loss / n_batches as f64
322            } else {
323                0.0
324            });
325        }
326
327        epoch_losses
328    }
329
330    // ── inference ───────────────────────────────────────────────────────────
331
332    /// Predict the class index for a single embedding (argmax of logits).
333    pub fn predict(&self, embedding: &[f64]) -> usize {
334        let logits = self.head.forward(embedding);
335        logits
336            .iter()
337            .enumerate()
338            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
339            .map(|(i, _)| i)
340            .unwrap_or(0)
341    }
342
343    /// Compute softmax probability distribution for a single embedding.
344    pub fn predict_proba(&self, embedding: &[f64]) -> Vec<f64> {
345        softmax(&self.head.forward(embedding))
346    }
347
348    /// Compute classification accuracy on a labelled dataset.
349    pub fn evaluate(&self, embeddings: &[Vec<f64>], labels: &[usize]) -> f64 {
350        let n = embeddings.len().min(labels.len());
351        if n == 0 {
352            return 0.0;
353        }
354        let correct: usize = embeddings[..n]
355            .iter()
356            .zip(labels[..n].iter())
357            .filter(|(emb, &lbl)| self.predict(emb) == lbl)
358            .count();
359        correct as f64 / n as f64
360    }
361}
362
363// ─── helpers ─────────────────────────────────────────────────────────────────
364
365fn softmax(logits: &[f64]) -> Vec<f64> {
366    if logits.is_empty() {
367        return Vec::new();
368    }
369    let max_v = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
370    let exps: Vec<f64> = logits.iter().map(|&x| (x - max_v).exp()).collect();
371    let sum: f64 = exps.iter().sum();
372    if sum < 1e-15 {
373        exps
374    } else {
375        exps.iter().map(|&e| e / sum).collect()
376    }
377}
378
379// ─── Tests ────────────────────────────────────────────────────────────────────
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn test_classification_head_shape() {
387        let head = ClassificationHead::new(16, 4);
388        assert_eq!(head.weight.len(), 4);
389        assert_eq!(head.weight[0].len(), 16);
390        assert_eq!(head.bias.len(), 4);
391
392        let emb: Vec<f64> = (0..16).map(|i| i as f64 * 0.1).collect();
393        let logits = head.forward(&emb);
394        assert_eq!(logits.len(), 4, "logits must have one entry per class");
395    }
396
397    #[test]
398    fn test_classification_head_backward_update_returns_loss() {
399        let mut head = ClassificationHead::new(8, 3);
400        let emb: Vec<f64> = vec![1.0; 8];
401        let logits = head.forward(&emb);
402        let loss = head.backward_update(&emb, &logits, 0, 1e-3);
403        assert!(loss.is_finite(), "loss should be finite, got {}", loss);
404        assert!(loss >= 0.0, "CE loss must be non-negative");
405    }
406
407    #[test]
408    fn test_bert_finetuner_new_invalid_task() {
409        // SequenceLabeling with 0 labels should fail.
410        let result = BertFineTuner::new(
411            16,
412            FineTuneTask::SequenceLabeling { n_labels: 0 },
413            FineTuneConfig::default(),
414        );
415        assert!(result.is_err());
416    }
417
418    #[test]
419    fn test_bert_finetuner_train_returns_epoch_losses() {
420        let config = FineTuneConfig {
421            lr: 0.1,
422            n_epochs: 3,
423            batch_size: 4,
424            warmup_steps: 2,
425            ..Default::default()
426        };
427        let mut tuner =
428            BertFineTuner::new(4, FineTuneTask::Classification { n_classes: 2 }, config)
429                .expect("should create tuner");
430
431        let embeddings: Vec<Vec<f64>> = (0..8)
432            .map(|i| vec![(i % 2) as f64, ((i + 1) % 2) as f64, 0.0, 0.0])
433            .collect();
434        let labels: Vec<usize> = (0..8).map(|i| i % 2).collect();
435
436        let losses = tuner.train(&embeddings, &labels);
437        assert_eq!(losses.len(), 3, "should return one loss per epoch");
438        for &loss in &losses {
439            assert!(loss.is_finite(), "loss must be finite");
440        }
441    }
442
443    #[test]
444    fn test_bert_finetuner_accuracy_improves_on_separable_data() {
445        // Linearly separable 2-class dataset.
446        // Class 0: embedding [1, 0, 0, 0], Class 1: embedding [0, 1, 0, 0]
447        let hidden = 4;
448        let config = FineTuneConfig {
449            lr: 1.0,
450            n_epochs: 20,
451            batch_size: 2,
452            warmup_steps: 5,
453            max_grad_norm: 10.0,
454            dropout: 0.0,
455        };
456        let mut tuner = BertFineTuner::new(
457            hidden,
458            FineTuneTask::Classification { n_classes: 2 },
459            config,
460        )
461        .expect("should create tuner");
462
463        let embeddings: Vec<Vec<f64>> = (0..20)
464            .map(|i| {
465                if i % 2 == 0 {
466                    vec![1.0, 0.0, 0.0, 0.0]
467                } else {
468                    vec![0.0, 1.0, 0.0, 0.0]
469                }
470            })
471            .collect();
472        let labels: Vec<usize> = (0..20).map(|i| i % 2).collect();
473
474        let initial_acc = tuner.evaluate(&embeddings, &labels);
475        tuner.train(&embeddings, &labels);
476        let final_acc = tuner.evaluate(&embeddings, &labels);
477
478        assert!(
479            final_acc >= initial_acc,
480            "accuracy should not decrease after training on separable data: {} -> {}",
481            initial_acc,
482            final_acc
483        );
484    }
485
486    #[test]
487    fn test_predict_proba_sums_to_one() {
488        let tuner = BertFineTuner::new(
489            4,
490            FineTuneTask::Classification { n_classes: 3 },
491            FineTuneConfig::default(),
492        )
493        .expect("should create tuner");
494
495        let emb = vec![0.1, 0.2, 0.3, 0.4];
496        let proba = tuner.predict_proba(&emb);
497        let sum: f64 = proba.iter().sum();
498        assert!(
499            (sum - 1.0).abs() < 1e-9,
500            "probabilities must sum to 1, got {}",
501            sum
502        );
503    }
504
505    #[test]
506    fn test_lr_schedule_warmup() {
507        let config = FineTuneConfig {
508            warmup_steps: 10,
509            lr: 1.0,
510            ..Default::default()
511        };
512        let mut tuner =
513            BertFineTuner::new(2, FineTuneTask::Classification { n_classes: 2 }, config)
514                .expect("tuner");
515        tuner.total_steps = 100;
516
517        // Step 0: lr should be 1/10 of peak.
518        tuner.step = 0;
519        let lr0 = tuner.learning_rate_schedule();
520        assert!(lr0 > 0.0 && lr0 <= 1.0, "warmup lr should be in (0, peak]");
521
522        // After warmup, at exactly warmup_steps, cosine starts at 1.0.
523        tuner.step = 10;
524        let lr_warm = tuner.learning_rate_schedule();
525        assert!(lr_warm > 0.0, "lr after warmup should be positive");
526    }
527}