Skip to main content

yscv_model/
loss.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6fn validate_loss_inputs(
7    graph: &Graph,
8    prediction: NodeId,
9    target: NodeId,
10) -> Result<usize, ModelError> {
11    let prediction_shape = graph.value(prediction)?.shape().to_vec();
12    let target_shape = graph.value(target)?.shape().to_vec();
13    if prediction_shape != target_shape {
14        return Err(ModelError::PredictionTargetShapeMismatch {
15            prediction: prediction_shape,
16            target: target_shape,
17        });
18    }
19
20    let element_count = graph.value(prediction)?.len();
21    if element_count == 0 {
22        return Err(ModelError::EmptyLossTensor);
23    }
24    Ok(element_count)
25}
26
27fn abs_node(graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
28    let zero = graph.constant(Tensor::scalar(0.0));
29    let neg_input = graph.sub(zero, input)?;
30    let positive = graph.relu(input)?;
31    let negative = graph.relu(neg_input)?;
32    graph.add(positive, negative).map_err(Into::into)
33}
34
35/// Mean squared error loss: `mean((prediction - target)^2)`.
36pub fn mse_loss(
37    graph: &mut Graph,
38    prediction: NodeId,
39    target: NodeId,
40) -> Result<NodeId, ModelError> {
41    let element_count = validate_loss_inputs(graph, prediction, target)?;
42
43    let diff = graph.sub(prediction, target)?;
44    let sq = graph.mul(diff, diff)?;
45    let sum = graph.sum(sq)?;
46    let inv_count = graph.constant(Tensor::scalar(1.0 / element_count as f32));
47    graph.mul(sum, inv_count).map_err(Into::into)
48}
49
50/// Mean absolute error loss: `mean(abs(prediction - target))`.
51pub fn mae_loss(
52    graph: &mut Graph,
53    prediction: NodeId,
54    target: NodeId,
55) -> Result<NodeId, ModelError> {
56    let element_count = validate_loss_inputs(graph, prediction, target)?;
57
58    let diff = graph.sub(prediction, target)?;
59    let abs = abs_node(graph, diff)?;
60    let sum = graph.sum(abs)?;
61    let inv_count = graph.constant(Tensor::scalar(1.0 / element_count as f32));
62    graph.mul(sum, inv_count).map_err(Into::into)
63}
64
65/// Mean Huber loss:
66/// `mean(0.5 * min(|e|, delta)^2 + delta * max(|e| - delta, 0))`, where `e = prediction - target`.
67pub fn huber_loss(
68    graph: &mut Graph,
69    prediction: NodeId,
70    target: NodeId,
71    delta: f32,
72) -> Result<NodeId, ModelError> {
73    if !delta.is_finite() || delta <= 0.0 {
74        return Err(ModelError::InvalidHuberDelta { delta });
75    }
76    let element_count = validate_loss_inputs(graph, prediction, target)?;
77
78    let diff = graph.sub(prediction, target)?;
79    let abs = abs_node(graph, diff)?;
80    let delta_node = graph.constant(Tensor::scalar(delta));
81    let abs_minus_delta = graph.sub(abs, delta_node)?;
82    let excess = graph.relu(abs_minus_delta)?;
83    let clipped = graph.sub(abs, excess)?;
84
85    let clipped_sq = graph.mul(clipped, clipped)?;
86    let half = graph.constant(Tensor::scalar(0.5));
87    let quadratic = graph.mul(clipped_sq, half)?;
88    let linear = graph.mul(excess, delta_node)?;
89    let per_element = graph.add(quadratic, linear)?;
90    let sum = graph.sum(per_element)?;
91    let inv_count = graph.constant(Tensor::scalar(1.0 / element_count as f32));
92    graph.mul(sum, inv_count).map_err(Into::into)
93}
94
95/// Mean hinge loss:
96/// `mean(max(0, margin - prediction * target))`.
97pub fn hinge_loss(
98    graph: &mut Graph,
99    prediction: NodeId,
100    target: NodeId,
101    margin: f32,
102) -> Result<NodeId, ModelError> {
103    if !margin.is_finite() || margin <= 0.0 {
104        return Err(ModelError::InvalidHingeMargin { margin });
105    }
106    let element_count = validate_loss_inputs(graph, prediction, target)?;
107
108    let product = graph.mul(prediction, target)?;
109    let margin_node = graph.constant(Tensor::scalar(margin));
110    let raw = graph.sub(margin_node, product)?;
111    let positive = graph.relu(raw)?;
112    let sum = graph.sum(positive)?;
113    let inv_count = graph.constant(Tensor::scalar(1.0 / element_count as f32));
114    graph.mul(sum, inv_count).map_err(Into::into)
115}
116
117/// Binary cross-entropy loss for predictions already passed through sigmoid.
118/// `bce = -mean(target * log(pred) + (1 - target) * log(1 - pred))`.
119///
120/// `prediction` values are clamped to `[eps, 1-eps]` for numerical stability.
121pub fn bce_loss(
122    graph: &mut Graph,
123    prediction: NodeId,
124    target: NodeId,
125) -> Result<NodeId, ModelError> {
126    let element_count = validate_loss_inputs(graph, prediction, target)?;
127
128    let eps = 1e-7_f32;
129    let eps_node = graph.constant(Tensor::scalar(eps));
130    let one_node = graph.constant(Tensor::scalar(1.0));
131
132    // pred_safe = clamp(pred, eps, 1-eps) via relu chain
133    let shifted_low = graph.sub(prediction, eps_node)?;
134    let positive_part = graph.relu(shifted_low)?;
135    let pred_above_eps = graph.add(positive_part, eps_node)?;
136
137    let one_minus_eps_node = graph.constant(Tensor::scalar(1.0 - eps));
138    let over = graph.sub(pred_above_eps, one_minus_eps_node)?;
139    let excess = graph.relu(over)?;
140    let pred_safe = graph.sub(pred_above_eps, excess)?;
141
142    // log(pred_safe)
143    let log_pred = graph.log(pred_safe)?;
144
145    // log(1 - pred_safe + eps) for stability
146    let one_minus_pred = graph.sub(one_node, pred_safe)?;
147    let one_minus_pred_safe = graph.add(one_minus_pred, eps_node)?;
148    let log_one_minus_pred = graph.log(one_minus_pred_safe)?;
149
150    // -mean(t*log(p) + (1-t)*log(1-p))
151    let term1 = graph.mul(target, log_pred)?;
152    let one_minus_t = graph.sub(one_node, target)?;
153    let term2 = graph.mul(one_minus_t, log_one_minus_pred)?;
154    let combined = graph.add(term1, term2)?;
155    let sum = graph.sum(combined)?;
156    let neg_sum = graph.neg(sum)?;
157    let inv_count = graph.constant(Tensor::scalar(1.0 / element_count as f32));
158    graph.mul(neg_sum, inv_count).map_err(Into::into)
159}
160
161/// Negative log-likelihood loss from log-probabilities.
162/// Expects `log_probs` shape `[batch, classes]` and `targets` shape `[batch, 1]`
163/// where targets contain class indices as f32.
164///
165/// `nll = -mean(log_probs[i, target[i]])` across the batch.
166pub fn nll_loss(
167    graph: &mut Graph,
168    log_probs: NodeId,
169    targets: NodeId,
170) -> Result<NodeId, ModelError> {
171    let lp_shape = graph.value(log_probs)?.shape().to_vec();
172    let t_shape = graph.value(targets)?.shape().to_vec();
173
174    if lp_shape.len() != 2 {
175        return Err(ModelError::InvalidInputShape {
176            expected_features: 0,
177            got: lp_shape,
178        });
179    }
180    if t_shape.len() != 2 || t_shape[1] != 1 {
181        return Err(ModelError::PredictionTargetShapeMismatch {
182            prediction: lp_shape.clone(),
183            target: t_shape,
184        });
185    }
186    let batch_size = lp_shape[0];
187    let num_classes = lp_shape[1];
188    if batch_size == 0 {
189        return Err(ModelError::EmptyLossTensor);
190    }
191
192    let lp_data = graph.value(log_probs)?.data().to_vec();
193    let t_data = graph.value(targets)?.data().to_vec();
194
195    let mut selected = vec![0.0f32; batch_size];
196    for i in 0..batch_size {
197        let class_idx = t_data[i] as usize;
198        if class_idx >= num_classes {
199            return Err(ModelError::InvalidDatasetRecordValue {
200                line: i,
201                field: "nll_target",
202                index: 0,
203                reason: "class index out of range",
204            });
205        }
206        selected[i] = lp_data[i * num_classes + class_idx];
207    }
208
209    let selected_node = graph.constant(Tensor::from_vec(vec![batch_size], selected)?);
210    let sum = graph.sum(selected_node)?;
211    let neg_sum = graph.neg(sum)?;
212    let inv_batch = graph.constant(Tensor::scalar(1.0 / batch_size as f32));
213    graph.mul(neg_sum, inv_batch).map_err(Into::into)
214}
215
216/// Cross-entropy loss from raw logits.
217/// Computes `nll_loss(log_softmax(logits), targets)`.
218///
219/// Expects `logits` shape `[batch, classes]` and `targets` shape `[batch, 1]`
220/// with class indices as f32.
221pub fn cross_entropy_loss(
222    graph: &mut Graph,
223    logits: NodeId,
224    targets: NodeId,
225) -> Result<NodeId, ModelError> {
226    let shape = graph.value(logits)?.shape().to_vec();
227    if shape.len() != 2 {
228        return Err(ModelError::InvalidInputShape {
229            expected_features: 0,
230            got: shape,
231        });
232    }
233    let batch_size = shape[0];
234    let num_classes = shape[1];
235
236    if batch_size == 0 {
237        return Err(ModelError::EmptyLossTensor);
238    }
239
240    let logits_data = graph.value(logits)?.data().to_vec();
241    let t_data = graph.value(targets)?.data().to_vec();
242    let t_shape = graph.value(targets)?.shape().to_vec();
243
244    if t_shape.len() != 2 || t_shape[1] != 1 || t_shape[0] != batch_size {
245        return Err(ModelError::PredictionTargetShapeMismatch {
246            prediction: shape.clone(),
247            target: t_shape,
248        });
249    }
250
251    // Compute log-softmax manually for numerical stability:
252    // log_softmax(x_i) = x_i - log(sum(exp(x_j - max_x)))  - max_x
253    let mut log_probs = vec![0.0f32; batch_size * num_classes];
254    for b in 0..batch_size {
255        let row = &logits_data[b * num_classes..(b + 1) * num_classes];
256        let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
257        let sum_exp: f32 = row.iter().map(|&v| (v - max_val).exp()).sum();
258        let log_sum_exp = max_val + sum_exp.ln();
259        for c in 0..num_classes {
260            log_probs[b * num_classes + c] = row[c] - log_sum_exp;
261        }
262    }
263
264    // Gather: pick log_probs[i, target[i]]
265    let mut neg_sum = 0.0f32;
266    for i in 0..batch_size {
267        let class_idx = t_data[i] as usize;
268        if class_idx >= num_classes {
269            return Err(ModelError::InvalidDatasetRecordValue {
270                line: i,
271                field: "cross_entropy_target",
272                index: 0,
273                reason: "class index out of range",
274            });
275        }
276        neg_sum -= log_probs[i * num_classes + class_idx];
277    }
278
279    let loss_val = neg_sum / batch_size as f32;
280    let loss_node = graph.constant(Tensor::scalar(loss_val));
281    Ok(loss_node)
282}
283
284/// Focal loss for imbalanced classification.
285///
286/// `FL = -alpha * (1 - p_t)^gamma * log(p_t)` averaged over elements.
287/// `prediction` should be sigmoid probabilities, `target` binary labels.
288pub fn focal_loss(
289    graph: &mut Graph,
290    prediction: NodeId,
291    target: NodeId,
292    alpha: f32,
293    gamma: f32,
294) -> Result<NodeId, ModelError> {
295    let element_count = validate_loss_inputs(graph, prediction, target)?;
296    let eps = 1e-7_f32;
297
298    let pred_data = graph.value(prediction)?.data().to_vec();
299    let target_data = graph.value(target)?.data().to_vec();
300
301    let mut loss_sum = 0.0f32;
302    for i in 0..element_count {
303        let p = pred_data[i].clamp(eps, 1.0 - eps);
304        let t = target_data[i];
305        let pt = if t > 0.5 { p } else { 1.0 - p };
306        loss_sum += -alpha * (1.0 - pt).powf(gamma) * pt.ln();
307    }
308
309    let loss_val = loss_sum / element_count as f32;
310    Ok(graph.constant(Tensor::scalar(loss_val)))
311}
312
313/// Dice loss for segmentation.
314///
315/// `DiceLoss = 1 - (2 * |P ∩ G| + smooth) / (|P| + |G| + smooth)`
316/// where `prediction` and `target` are probability maps.
317pub fn dice_loss(
318    graph: &mut Graph,
319    prediction: NodeId,
320    target: NodeId,
321    smooth: f32,
322) -> Result<NodeId, ModelError> {
323    let element_count = validate_loss_inputs(graph, prediction, target)?;
324
325    let pred_data = graph.value(prediction)?.data().to_vec();
326    let target_data = graph.value(target)?.data().to_vec();
327
328    let mut intersection = 0.0f32;
329    let mut pred_sum = 0.0f32;
330    let mut target_sum = 0.0f32;
331    for i in 0..element_count {
332        intersection += pred_data[i] * target_data[i];
333        pred_sum += pred_data[i];
334        target_sum += target_data[i];
335    }
336
337    let dice = (2.0 * intersection + smooth) / (pred_sum + target_sum + smooth);
338    Ok(graph.constant(Tensor::scalar(1.0 - dice)))
339}
340
341/// Triplet loss for metric learning.
342///
343/// `L = mean(max(0, d(a,p) - d(a,n) + margin))` where d is L2 distance.
344/// All inputs must have the same shape `[batch, embedding_dim]`.
345pub fn triplet_loss(
346    graph: &mut Graph,
347    anchor: NodeId,
348    positive: NodeId,
349    negative: NodeId,
350    margin: f32,
351) -> Result<NodeId, ModelError> {
352    let a_shape = graph.value(anchor)?.shape().to_vec();
353    let p_shape = graph.value(positive)?.shape().to_vec();
354    let n_shape = graph.value(negative)?.shape().to_vec();
355    if a_shape != p_shape || a_shape != n_shape {
356        return Err(ModelError::PredictionTargetShapeMismatch {
357            prediction: a_shape,
358            target: p_shape,
359        });
360    }
361    if a_shape.len() != 2 || a_shape[0] == 0 {
362        return Err(ModelError::EmptyLossTensor);
363    }
364    let batch = a_shape[0];
365    let dim = a_shape[1];
366
367    let a_data = graph.value(anchor)?.data().to_vec();
368    let p_data = graph.value(positive)?.data().to_vec();
369    let n_data = graph.value(negative)?.data().to_vec();
370
371    let mut loss_sum = 0.0f32;
372    for b in 0..batch {
373        let mut dp = 0.0f32;
374        let mut dn = 0.0f32;
375        for d in 0..dim {
376            let idx = b * dim + d;
377            dp += (a_data[idx] - p_data[idx]).powi(2);
378            dn += (a_data[idx] - n_data[idx]).powi(2);
379        }
380        loss_sum += (dp.sqrt() - dn.sqrt() + margin).max(0.0);
381    }
382
383    Ok(graph.constant(Tensor::scalar(loss_sum / batch as f32)))
384}
385
386/// Contrastive loss for siamese networks.
387///
388/// `L = mean( y * d^2 + (1-y) * max(0, margin - d)^2 )` where d = L2 distance.
389/// `label`: 1.0 for same pair, 0.0 for different.
390pub fn contrastive_loss(
391    graph: &mut Graph,
392    x1: NodeId,
393    x2: NodeId,
394    label: NodeId,
395    margin: f32,
396) -> Result<NodeId, ModelError> {
397    let s1 = graph.value(x1)?.shape().to_vec();
398    let s2 = graph.value(x2)?.shape().to_vec();
399    if s1 != s2 || s1.len() != 2 || s1[0] == 0 {
400        return Err(ModelError::PredictionTargetShapeMismatch {
401            prediction: s1,
402            target: s2,
403        });
404    }
405    let batch = s1[0];
406    let dim = s1[1];
407
408    let x1d = graph.value(x1)?.data().to_vec();
409    let x2d = graph.value(x2)?.data().to_vec();
410    let ld = graph.value(label)?.data().to_vec();
411
412    let mut loss_sum = 0.0f32;
413    for b in 0..batch {
414        let mut dist_sq = 0.0f32;
415        for d in 0..dim {
416            let idx = b * dim + d;
417            dist_sq += (x1d[idx] - x2d[idx]).powi(2);
418        }
419        let dist = dist_sq.sqrt();
420        let y = ld[b];
421        loss_sum += y * dist_sq + (1.0 - y) * (margin - dist).max(0.0).powi(2);
422    }
423
424    Ok(graph.constant(Tensor::scalar(loss_sum / batch as f32)))
425}
426
427/// Cosine embedding loss.
428///
429/// `L = mean( y==1: 1-cos(x1,x2), y==-1: max(0, cos(x1,x2)-margin) )`
430pub fn cosine_embedding_loss(
431    graph: &mut Graph,
432    x1: NodeId,
433    x2: NodeId,
434    label: NodeId,
435    margin: f32,
436) -> Result<NodeId, ModelError> {
437    let s1 = graph.value(x1)?.shape().to_vec();
438    let s2 = graph.value(x2)?.shape().to_vec();
439    if s1 != s2 || s1.len() != 2 || s1[0] == 0 {
440        return Err(ModelError::PredictionTargetShapeMismatch {
441            prediction: s1,
442            target: s2,
443        });
444    }
445    let batch = s1[0];
446    let dim = s1[1];
447
448    let x1d = graph.value(x1)?.data().to_vec();
449    let x2d = graph.value(x2)?.data().to_vec();
450    let ld = graph.value(label)?.data().to_vec();
451
452    let mut loss_sum = 0.0f32;
453    for b in 0..batch {
454        let mut dot = 0.0f32;
455        let mut n1 = 0.0f32;
456        let mut n2 = 0.0f32;
457        for d in 0..dim {
458            let idx = b * dim + d;
459            dot += x1d[idx] * x2d[idx];
460            n1 += x1d[idx] * x1d[idx];
461            n2 += x2d[idx] * x2d[idx];
462        }
463        let cos = dot / (n1.sqrt() * n2.sqrt()).max(1e-8);
464        let y = ld[b];
465        if y > 0.0 {
466            loss_sum += 1.0 - cos;
467        } else {
468            loss_sum += (cos - margin).max(0.0);
469        }
470    }
471
472    Ok(graph.constant(Tensor::scalar(loss_sum / batch as f32)))
473}
474
475/// Cross-entropy with label smoothing.
476///
477/// Smoothed target: `(1 - smoothing) * one_hot + smoothing / num_classes`.
478/// Expects `logits` shape `[batch, classes]` and `targets` shape `[batch, 1]`.
479pub fn label_smoothing_cross_entropy(
480    graph: &mut Graph,
481    logits: NodeId,
482    targets: NodeId,
483    smoothing: f32,
484) -> Result<NodeId, ModelError> {
485    let shape = graph.value(logits)?.shape().to_vec();
486    if shape.len() != 2 || shape[0] == 0 {
487        return Err(ModelError::EmptyLossTensor);
488    }
489    let batch_size = shape[0];
490    let num_classes = shape[1];
491
492    let logits_data = graph.value(logits)?.data().to_vec();
493    let t_data = graph.value(targets)?.data().to_vec();
494
495    let smooth_val = smoothing / num_classes as f32;
496    let confidence = 1.0 - smoothing;
497
498    let mut total_loss = 0.0f32;
499    for b in 0..batch_size {
500        let row = &logits_data[b * num_classes..(b + 1) * num_classes];
501        let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
502        let sum_exp: f32 = row.iter().map(|&v| (v - max_val).exp()).sum();
503        let log_sum_exp = max_val + sum_exp.ln();
504
505        let class_idx = t_data[b] as usize;
506        for c in 0..num_classes {
507            let log_prob = row[c] - log_sum_exp;
508            let target_prob = if c == class_idx {
509                confidence + smooth_val
510            } else {
511                smooth_val
512            };
513            total_loss -= target_prob * log_prob;
514        }
515    }
516
517    Ok(graph.constant(Tensor::scalar(total_loss / batch_size as f32)))
518}
519
520/// CTC (Connectionist Temporal Classification) loss.
521///
522/// Simplified implementation for sequence-to-sequence tasks (OCR/ASR).
523/// `log_probs`: `[T, batch, classes]`, `targets`: `[batch, S]`, lengths as 1-D tensors.
524pub fn ctc_loss(
525    graph: &mut Graph,
526    log_probs: NodeId,
527    targets: NodeId,
528    input_lengths: NodeId,
529    target_lengths: NodeId,
530    blank: usize,
531) -> Result<NodeId, ModelError> {
532    let lp_shape = graph.value(log_probs)?.shape().to_vec();
533    if lp_shape.len() != 3 {
534        return Err(ModelError::InvalidInputShape {
535            expected_features: 0,
536            got: lp_shape,
537        });
538    }
539    let _t_max = lp_shape[0];
540    let batch = lp_shape[1];
541    let num_classes = lp_shape[2];
542
543    let lp_data = graph.value(log_probs)?.data().to_vec();
544    let tgt_data = graph.value(targets)?.data().to_vec();
545    let il_data = graph.value(input_lengths)?.data().to_vec();
546    let tl_data = graph.value(target_lengths)?.data().to_vec();
547
548    let tgt_shape = graph.value(targets)?.shape().to_vec();
549    let s_max = if tgt_shape.len() >= 2 {
550        tgt_shape[1]
551    } else {
552        tgt_shape[0] / batch
553    };
554
555    let mut total_loss = 0.0f32;
556
557    for b in 0..batch {
558        let input_len = il_data[b] as usize;
559        let target_len = tl_data[b] as usize;
560
561        // Build label sequence with blanks: [blank, l1, blank, l2, blank, ...]
562        let label_len = 2 * target_len + 1;
563        let mut labels = vec![blank; label_len];
564        for s in 0..target_len {
565            labels[2 * s + 1] = tgt_data[b * s_max + s] as usize;
566        }
567
568        // Forward pass (alpha)
569        let mut alpha = vec![f32::NEG_INFINITY; label_len * input_len];
570        // Init t=0
571        alpha[0] = lp_data[b * num_classes + labels[0]];
572        if label_len > 1 {
573            alpha[1] = lp_data[b * num_classes + labels[1]];
574        }
575
576        for t in 1..input_len {
577            for s in 0..label_len {
578                let lp_idx = t * batch * num_classes + b * num_classes + labels[s];
579                let log_p = lp_data[lp_idx];
580                let mut sum = alpha[(t - 1) * label_len + s];
581                if s > 0 {
582                    sum = log_sum_exp_pair(sum, alpha[(t - 1) * label_len + s - 1]);
583                }
584                if s > 1 && labels[s] != blank && labels[s] != labels[s - 2] {
585                    sum = log_sum_exp_pair(sum, alpha[(t - 1) * label_len + s - 2]);
586                }
587                alpha[t * label_len + s] = sum + log_p;
588            }
589        }
590
591        let last_t = input_len - 1;
592        let log_likelihood = log_sum_exp_pair(
593            alpha[last_t * label_len + label_len - 1],
594            if label_len >= 2 {
595                alpha[last_t * label_len + label_len - 2]
596            } else {
597                f32::NEG_INFINITY
598            },
599        );
600        total_loss -= log_likelihood;
601    }
602
603    Ok(graph.constant(Tensor::scalar(total_loss / batch as f32)))
604}
605
606/// Smooth L1 loss (detection-style parameterization of Huber loss):
607///
608/// ```text
609/// smooth_l1(x, beta) = 0.5 * x^2 / beta   if |x| < beta
610///                     = |x| - 0.5 * beta    otherwise
611/// ```
612///
613/// Equivalent to `huber_loss(pred, target, delta=beta) / beta`.
614/// `beta` must be positive and finite.
615pub fn smooth_l1_loss(
616    graph: &mut Graph,
617    prediction: NodeId,
618    target: NodeId,
619    beta: f32,
620) -> Result<NodeId, ModelError> {
621    if !beta.is_finite() || beta <= 0.0 {
622        return Err(ModelError::InvalidHuberDelta { delta: beta });
623    }
624    let element_count = validate_loss_inputs(graph, prediction, target)?;
625
626    let pred_data = graph.value(prediction)?.data().to_vec();
627    let target_data = graph.value(target)?.data().to_vec();
628
629    let mut loss_sum = 0.0f32;
630    for i in 0..element_count {
631        let x = (pred_data[i] - target_data[i]).abs();
632        if x < beta {
633            loss_sum += 0.5 * x * x / beta;
634        } else {
635            loss_sum += x - 0.5 * beta;
636        }
637    }
638
639    Ok(graph.constant(Tensor::scalar(loss_sum / element_count as f32)))
640}
641
642/// KL divergence loss:
643///
644/// ```text
645/// kl_div(log_pred, target) = sum(target * (log(target) - log_pred)) / n
646/// ```
647///
648/// `log_pred` is the log of the predicted distribution (already log-transformed).
649/// `target` is the true probability distribution. Target values <= 0 are skipped
650/// (their contribution is treated as zero, matching the convention that `0 * log(0) = 0`).
651pub fn kl_div_loss(
652    graph: &mut Graph,
653    log_prediction: NodeId,
654    target: NodeId,
655) -> Result<NodeId, ModelError> {
656    let element_count = validate_loss_inputs(graph, log_prediction, target)?;
657
658    let log_pred_data = graph.value(log_prediction)?.data().to_vec();
659    let target_data = graph.value(target)?.data().to_vec();
660
661    let mut loss_sum = 0.0f32;
662    for i in 0..element_count {
663        let t = target_data[i];
664        if t > 0.0 {
665            loss_sum += t * (t.ln() - log_pred_data[i]);
666        }
667    }
668
669    Ok(graph.constant(Tensor::scalar(loss_sum / element_count as f32)))
670}
671
672fn log_sum_exp_pair(a: f32, b: f32) -> f32 {
673    if a == f32::NEG_INFINITY {
674        return b;
675    }
676    if b == f32::NEG_INFINITY {
677        return a;
678    }
679    let max = a.max(b);
680    max + ((a - max).exp() + (b - max).exp()).ln()
681}
682
683// ---------------------------------------------------------------------------
684// Knowledge Distillation
685// ---------------------------------------------------------------------------
686
687/// Knowledge distillation loss (Hinton et al., 2015).
688///
689/// Combines soft target KL divergence with hard target cross-entropy:
690///
691/// ```text
692/// L = alpha * T² * KL(softmax(student/T) || softmax(teacher/T))
693///   + (1 - alpha) * CrossEntropy(student, labels)
694/// ```
695///
696/// * `student`:  student logits `[batch, num_classes]`
697/// * `teacher`:  teacher logits `[batch, num_classes]` (detached / no grad)
698/// * `labels`:   hard labels node `[batch]` (class indices as f32)
699/// * `temperature`: softening temperature (typically 3-20)
700/// * `alpha`: weight for soft loss (0.0 = pure hard, 1.0 = pure soft)
701pub fn distillation_loss(
702    graph: &mut Graph,
703    student: NodeId,
704    teacher: NodeId,
705    labels: NodeId,
706    temperature: f32,
707    alpha: f32,
708) -> Result<NodeId, ModelError> {
709    // Soft targets: KL(softmax(s/T) || softmax(t/T)) * T²
710    let t_scalar = graph.constant(Tensor::scalar(temperature));
711    let t2_scalar = graph.constant(Tensor::scalar(temperature * temperature));
712
713    let s_scaled = graph.div(student, t_scalar)?;
714    let t_scaled = graph.div(teacher, t_scalar)?;
715
716    let s_log_softmax = graph.log_softmax(s_scaled)?;
717    let t_softmax = graph.softmax(t_scaled)?;
718
719    // KL divergence = sum(t_soft * (log(t_soft) - log_s_soft))
720    let t_log = graph.log(t_softmax)?;
721    let kl_pointwise = graph.sub(t_log, s_log_softmax)?;
722    let kl_weighted = graph.mul(t_softmax, kl_pointwise)?;
723    let kl_sum = graph.mean(kl_weighted)?;
724    let soft_loss = graph.mul(kl_sum, t2_scalar)?;
725
726    // Hard targets: cross-entropy(student, labels)
727    let hard_loss = cross_entropy_loss(graph, student, labels)?;
728
729    // Combined: alpha * soft + (1 - alpha) * hard
730    let alpha_node = graph.constant(Tensor::scalar(alpha));
731    let one_minus_alpha = graph.constant(Tensor::scalar(1.0 - alpha));
732
733    let weighted_soft = graph.mul(soft_loss, alpha_node)?;
734    let weighted_hard = graph.mul(hard_loss, one_minus_alpha)?;
735
736    graph.add(weighted_soft, weighted_hard).map_err(Into::into)
737}