scirs2_text/bert_finetune/
mod.rs1use crate::error::{Result, TextError};
9use std::f64;
10
11#[derive(Debug, Clone)]
15#[non_exhaustive]
16pub enum FineTuneTask {
17 Classification {
19 n_classes: usize,
21 },
22 SequenceLabeling {
24 n_labels: usize,
26 },
27 SentencePairClassification {
29 n_classes: usize,
31 },
32}
33
34impl FineTuneTask {
35 pub fn n_outputs(&self) -> usize {
37 match self {
38 FineTuneTask::Classification { n_classes } => *n_classes,
39 FineTuneTask::SequenceLabeling { n_labels } => *n_labels,
40 FineTuneTask::SentencePairClassification { n_classes } => *n_classes,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
49pub struct FineTuneConfig {
50 pub lr: f64,
52 pub n_epochs: usize,
54 pub batch_size: usize,
56 pub warmup_steps: usize,
58 pub max_grad_norm: f64,
60 pub dropout: f64,
62}
63
64impl Default for FineTuneConfig {
65 fn default() -> Self {
66 Self {
67 lr: 2e-5,
68 n_epochs: 3,
69 batch_size: 32,
70 warmup_steps: 100,
71 max_grad_norm: 1.0,
72 dropout: 0.1,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
81pub struct ClassificationHead {
82 pub weight: Vec<Vec<f64>>,
84 pub bias: Vec<f64>,
86}
87
88impl ClassificationHead {
89 pub fn new(hidden_size: usize, n_classes: usize) -> Self {
91 let mut seed: u64 = 0xFAFAFAFA_12345678;
92 let weight = (0..n_classes)
93 .map(|_| {
94 (0..hidden_size)
95 .map(|_| {
96 seed = seed
97 .wrapping_mul(6364136223846793005)
98 .wrapping_add(1442695040888963407);
99 let bits = (seed >> 33) as f64 / (u32::MAX as f64);
100 (bits - 0.5) * 0.02 })
102 .collect()
103 })
104 .collect();
105
106 Self {
107 weight,
108 bias: vec![0.0; n_classes],
109 }
110 }
111
112 pub fn forward(&self, cls_embedding: &[f64]) -> Vec<f64> {
114 self.weight
115 .iter()
116 .zip(self.bias.iter())
117 .map(|(row, &b)| {
118 row.iter()
119 .zip(cls_embedding.iter())
120 .map(|(w, x)| w * x)
121 .sum::<f64>()
122 + b
123 })
124 .collect()
125 }
126
127 pub fn backward_update(
131 &mut self,
132 cls_embedding: &[f64],
133 logits: &[f64],
134 label: usize,
135 lr: f64,
136 ) -> f64 {
137 let n_classes = logits.len();
138 if label >= n_classes {
139 return 0.0;
141 }
142
143 let probs = softmax(logits);
145
146 let loss = -(probs[label] + 1e-15).ln();
148
149 let grad_logits: Vec<f64> = probs
151 .iter()
152 .enumerate()
153 .map(|(k, &p)| if k == label { p - 1.0 } else { p })
154 .collect();
155
156 let hidden = cls_embedding.len();
159 for k in 0..n_classes {
160 let g = grad_logits[k];
161 self.bias[k] -= lr * g;
162 for j in 0..hidden {
163 self.weight[k][j] -= lr * g * cls_embedding[j];
164 }
165 }
166
167 loss
168 }
169}
170
171pub struct BertFineTuner {
177 pub head: ClassificationHead,
179 pub config: FineTuneConfig,
181 pub step: usize,
183 total_steps: usize,
185}
186
187impl BertFineTuner {
188 pub fn new(hidden_size: usize, task: FineTuneTask, config: FineTuneConfig) -> Result<Self> {
193 let n_outputs = task.n_outputs();
194 if n_outputs == 0 {
195 return Err(TextError::InvalidInput(
196 "BertFineTuner: task must have at least 1 output class".into(),
197 ));
198 }
199 Ok(Self {
200 head: ClassificationHead::new(hidden_size, n_outputs),
201 config,
202 step: 0,
203 total_steps: 0,
204 })
205 }
206
207 pub fn learning_rate_schedule(&self) -> f64 {
211 let peak = self.config.lr;
212 let warmup = self.config.warmup_steps as f64;
213 let total = (self.total_steps.max(1)) as f64;
214 let s = self.step as f64;
215
216 if s < warmup {
217 peak * (s + 1.0) / warmup
219 } else {
220 let progress = (s - warmup) / (total - warmup).max(1.0);
222 let cosine = (1.0 + (std::f64::consts::PI * progress).cos()) * 0.5;
223 peak * cosine
224 }
225 }
226
227 fn clip_grad(grad: &mut [f64], max_norm: f64) {
231 let norm: f64 = grad.iter().map(|x| x * x).sum::<f64>().sqrt();
232 if norm > max_norm && norm > 1e-12 {
233 let scale = max_norm / norm;
234 grad.iter_mut().for_each(|g| *g *= scale);
235 }
236 }
237
238 pub fn train(&mut self, embeddings: &[Vec<f64>], labels: &[usize]) -> Vec<f64> {
244 let n = embeddings.len().min(labels.len());
245 let batch_size = self.config.batch_size.max(1);
246 let n_epochs = self.config.n_epochs;
247 self.total_steps = n_epochs * n.div_ceil(batch_size);
248
249 let mut epoch_losses = Vec::with_capacity(n_epochs);
250
251 for _epoch in 0..n_epochs {
252 let mut epoch_loss = 0.0_f64;
253 let mut n_batches = 0usize;
254
255 let mut start = 0;
257 while start < n {
258 let end = (start + batch_size).min(n);
259 let batch_embs = &embeddings[start..end];
260 let batch_labels = &labels[start..end];
261
262 let lr = self.learning_rate_schedule();
263 let mut batch_loss = 0.0_f64;
264
265 let n_classes = self.head.bias.len();
267 let hidden = if batch_embs.is_empty() {
268 0
269 } else {
270 batch_embs[0].len()
271 };
272 let mut grad_w = vec![vec![0.0_f64; hidden]; n_classes];
273 let mut grad_b = vec![0.0_f64; n_classes];
274
275 for (emb, &lbl) in batch_embs.iter().zip(batch_labels.iter()) {
276 let logits = self.head.forward(emb);
277 let probs = softmax(&logits);
278 let loss = -(probs[lbl.min(n_classes - 1)] + 1e-15).ln();
279 batch_loss += loss;
280
281 for k in 0..n_classes {
283 let g = if k == lbl { probs[k] - 1.0 } else { probs[k] };
284 grad_b[k] += g;
285 for j in 0..hidden {
286 grad_w[k][j] += g * emb[j];
287 }
288 }
289 }
290
291 let batch_len = (end - start) as f64;
292
293 grad_b.iter_mut().for_each(|g| *g /= batch_len);
295 for row in &mut grad_w {
296 row.iter_mut().for_each(|g| *g /= batch_len);
297 }
298
299 let max_norm = self.config.max_grad_norm;
301 Self::clip_grad(&mut grad_b, max_norm);
302 for row in &mut grad_w {
303 Self::clip_grad(row, max_norm);
304 }
305
306 for k in 0..n_classes {
308 self.head.bias[k] -= lr * grad_b[k];
309 for j in 0..hidden {
310 self.head.weight[k][j] -= lr * grad_w[k][j];
311 }
312 }
313
314 epoch_loss += batch_loss / batch_len;
315 n_batches += 1;
316 self.step += 1;
317 start = end;
318 }
319
320 epoch_losses.push(if n_batches > 0 {
321 epoch_loss / n_batches as f64
322 } else {
323 0.0
324 });
325 }
326
327 epoch_losses
328 }
329
330 pub fn predict(&self, embedding: &[f64]) -> usize {
334 let logits = self.head.forward(embedding);
335 logits
336 .iter()
337 .enumerate()
338 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
339 .map(|(i, _)| i)
340 .unwrap_or(0)
341 }
342
343 pub fn predict_proba(&self, embedding: &[f64]) -> Vec<f64> {
345 softmax(&self.head.forward(embedding))
346 }
347
348 pub fn evaluate(&self, embeddings: &[Vec<f64>], labels: &[usize]) -> f64 {
350 let n = embeddings.len().min(labels.len());
351 if n == 0 {
352 return 0.0;
353 }
354 let correct: usize = embeddings[..n]
355 .iter()
356 .zip(labels[..n].iter())
357 .filter(|(emb, &lbl)| self.predict(emb) == lbl)
358 .count();
359 correct as f64 / n as f64
360 }
361}
362
363fn softmax(logits: &[f64]) -> Vec<f64> {
366 if logits.is_empty() {
367 return Vec::new();
368 }
369 let max_v = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
370 let exps: Vec<f64> = logits.iter().map(|&x| (x - max_v).exp()).collect();
371 let sum: f64 = exps.iter().sum();
372 if sum < 1e-15 {
373 exps
374 } else {
375 exps.iter().map(|&e| e / sum).collect()
376 }
377}
378
379#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_classification_head_shape() {
387 let head = ClassificationHead::new(16, 4);
388 assert_eq!(head.weight.len(), 4);
389 assert_eq!(head.weight[0].len(), 16);
390 assert_eq!(head.bias.len(), 4);
391
392 let emb: Vec<f64> = (0..16).map(|i| i as f64 * 0.1).collect();
393 let logits = head.forward(&emb);
394 assert_eq!(logits.len(), 4, "logits must have one entry per class");
395 }
396
397 #[test]
398 fn test_classification_head_backward_update_returns_loss() {
399 let mut head = ClassificationHead::new(8, 3);
400 let emb: Vec<f64> = vec![1.0; 8];
401 let logits = head.forward(&emb);
402 let loss = head.backward_update(&emb, &logits, 0, 1e-3);
403 assert!(loss.is_finite(), "loss should be finite, got {}", loss);
404 assert!(loss >= 0.0, "CE loss must be non-negative");
405 }
406
407 #[test]
408 fn test_bert_finetuner_new_invalid_task() {
409 let result = BertFineTuner::new(
411 16,
412 FineTuneTask::SequenceLabeling { n_labels: 0 },
413 FineTuneConfig::default(),
414 );
415 assert!(result.is_err());
416 }
417
418 #[test]
419 fn test_bert_finetuner_train_returns_epoch_losses() {
420 let config = FineTuneConfig {
421 lr: 0.1,
422 n_epochs: 3,
423 batch_size: 4,
424 warmup_steps: 2,
425 ..Default::default()
426 };
427 let mut tuner =
428 BertFineTuner::new(4, FineTuneTask::Classification { n_classes: 2 }, config)
429 .expect("should create tuner");
430
431 let embeddings: Vec<Vec<f64>> = (0..8)
432 .map(|i| vec![(i % 2) as f64, ((i + 1) % 2) as f64, 0.0, 0.0])
433 .collect();
434 let labels: Vec<usize> = (0..8).map(|i| i % 2).collect();
435
436 let losses = tuner.train(&embeddings, &labels);
437 assert_eq!(losses.len(), 3, "should return one loss per epoch");
438 for &loss in &losses {
439 assert!(loss.is_finite(), "loss must be finite");
440 }
441 }
442
443 #[test]
444 fn test_bert_finetuner_accuracy_improves_on_separable_data() {
445 let hidden = 4;
448 let config = FineTuneConfig {
449 lr: 1.0,
450 n_epochs: 20,
451 batch_size: 2,
452 warmup_steps: 5,
453 max_grad_norm: 10.0,
454 dropout: 0.0,
455 };
456 let mut tuner = BertFineTuner::new(
457 hidden,
458 FineTuneTask::Classification { n_classes: 2 },
459 config,
460 )
461 .expect("should create tuner");
462
463 let embeddings: Vec<Vec<f64>> = (0..20)
464 .map(|i| {
465 if i % 2 == 0 {
466 vec![1.0, 0.0, 0.0, 0.0]
467 } else {
468 vec![0.0, 1.0, 0.0, 0.0]
469 }
470 })
471 .collect();
472 let labels: Vec<usize> = (0..20).map(|i| i % 2).collect();
473
474 let initial_acc = tuner.evaluate(&embeddings, &labels);
475 tuner.train(&embeddings, &labels);
476 let final_acc = tuner.evaluate(&embeddings, &labels);
477
478 assert!(
479 final_acc >= initial_acc,
480 "accuracy should not decrease after training on separable data: {} -> {}",
481 initial_acc,
482 final_acc
483 );
484 }
485
486 #[test]
487 fn test_predict_proba_sums_to_one() {
488 let tuner = BertFineTuner::new(
489 4,
490 FineTuneTask::Classification { n_classes: 3 },
491 FineTuneConfig::default(),
492 )
493 .expect("should create tuner");
494
495 let emb = vec![0.1, 0.2, 0.3, 0.4];
496 let proba = tuner.predict_proba(&emb);
497 let sum: f64 = proba.iter().sum();
498 assert!(
499 (sum - 1.0).abs() < 1e-9,
500 "probabilities must sum to 1, got {}",
501 sum
502 );
503 }
504
505 #[test]
506 fn test_lr_schedule_warmup() {
507 let config = FineTuneConfig {
508 warmup_steps: 10,
509 lr: 1.0,
510 ..Default::default()
511 };
512 let mut tuner =
513 BertFineTuner::new(2, FineTuneTask::Classification { n_classes: 2 }, config)
514 .expect("tuner");
515 tuner.total_steps = 100;
516
517 tuner.step = 0;
519 let lr0 = tuner.learning_rate_schedule();
520 assert!(lr0 > 0.0 && lr0 <= 1.0, "warmup lr should be in (0, peak]");
521
522 tuner.step = 10;
524 let lr_warm = tuner.learning_rate_schedule();
525 assert!(lr_warm > 0.0, "lr after warmup should be positive");
526 }
527}