Skip to main content

tensorlogic_scirs_backend/
gradient_ops.rs

1//! Advanced gradient operations for non-differentiable logical operations.
2//!
3//! This module provides differentiable approximations for operations that are
4//! typically non-differentiable, enabling end-to-end training of logical models.
5//!
6//! ## Gradient Estimators
7//!
8//! - **Straight-Through Estimator (STE)**: Passes gradients through non-differentiable
9//!   operations (thresholding, argmax, etc.) by treating them as identity in the
10//!   backward pass.
11//!
12//! - **Gumbel-Softmax**: Continuous relaxation of categorical distributions,
13//!   allowing differentiable sampling with temperature annealing.
14//!
15//! - **Soft Quantifiers**: Differentiable approximations of ∃ (exists) and ∀ (forall)
16//!   using smooth max/min or probabilistic interpretations.
17
18use crate::error::TlBackendResult;
19use crate::Scirs2Tensor;
20use scirs2_core::ndarray::{Array, ArrayD, Axis};
21use scirs2_core::random::arrays::OptimizedArrayRandom;
22use scirs2_core::random::prelude::*;
23
24/// Straight-Through Estimator (STE) gradient configuration.
25///
26/// STE allows gradients to flow through non-differentiable operations
27/// by using a different forward and backward pass:
28/// - Forward: Apply the discrete/thresholded operation
29/// - Backward: Pass gradients through as if it were identity
30#[derive(Debug, Clone, Copy)]
31pub struct SteConfig {
32    /// Threshold for binarization (default: 0.5)
33    pub threshold: f64,
34    /// Whether to clip gradients to [-1, 1] range
35    pub clip_gradients: bool,
36}
37
38impl Default for SteConfig {
39    fn default() -> Self {
40        Self {
41            threshold: 0.5,
42            clip_gradients: false,
43        }
44    }
45}
46
47/// Gumbel-Softmax configuration for differentiable categorical sampling.
48///
49/// Provides a continuous relaxation of categorical distributions using
50/// Gumbel-Max trick combined with softmax temperature.
51#[derive(Debug, Clone, Copy)]
52pub struct GumbelSoftmaxConfig {
53    /// Temperature parameter (τ): lower → harder, higher → softer
54    /// Typical range: [0.1, 10.0], training starts high and anneals down
55    pub temperature: f64,
56    /// Whether to use hard (one-hot) samples in forward pass
57    /// but soft samples for gradient computation (straight-through)
58    pub hard: bool,
59    /// Random seed for reproducibility (None = non-deterministic)
60    pub seed: Option<u64>,
61}
62
63impl Default for GumbelSoftmaxConfig {
64    fn default() -> Self {
65        Self {
66            temperature: 1.0,
67            hard: false,
68            seed: None,
69        }
70    }
71}
72
73/// Configuration for soft differentiable quantifiers.
74///
75/// Provides smooth approximations of logical quantifiers:
76/// - ∃ (exists): At least one element is true
77/// - ∀ (forall): All elements are true
78#[derive(Debug, Clone, Copy)]
79pub enum QuantifierMode {
80    /// Hard quantifiers using max/min (non-differentiable)
81    Hard,
82    /// Soft using smooth approximations (log-sum-exp for max)
83    Smooth { temperature: f64 },
84    /// Probabilistic interpretation (1 - ∏(1-x) for ∃)
85    Probabilistic,
86}
87
88impl Default for QuantifierMode {
89    fn default() -> Self {
90        Self::Smooth { temperature: 1.0 }
91    }
92}
93
94/// Applies Straight-Through Estimator for binary thresholding.
95///
96/// Forward: y = (x >= threshold) ? 1.0 : 0.0
97/// Backward: ∂L/∂x = ∂L/∂y (identity gradient)
98///
99/// # Arguments
100/// * `input` - Input tensor with values in [0, 1]
101/// * `config` - STE configuration
102///
103/// # Returns
104/// Binarized tensor (forward pass result)
105pub fn ste_threshold(input: &Scirs2Tensor, config: SteConfig) -> TlBackendResult<Scirs2Tensor> {
106    // Forward: Apply threshold
107    let output = input.mapv(|x| if x >= config.threshold { 1.0 } else { 0.0 });
108    Ok(output)
109}
110
111/// Computes the backward pass gradient for STE threshold.
112///
113/// # Arguments
114/// * `grad_output` - Gradient from downstream
115/// * `input` - Original input tensor (unused in basic STE)
116/// * `config` - STE configuration
117///
118/// # Returns
119/// Gradient with respect to input (passed through)
120pub fn ste_threshold_backward(
121    grad_output: &Scirs2Tensor,
122    _input: &Scirs2Tensor,
123    config: SteConfig,
124) -> TlBackendResult<Scirs2Tensor> {
125    if config.clip_gradients {
126        // Clip gradients to [-1, 1] range
127        Ok(grad_output.mapv(|g| g.clamp(-1.0, 1.0)))
128    } else {
129        // Pass through gradient as-is
130        Ok(grad_output.clone())
131    }
132}
133
134/// Applies Gumbel-Softmax for differentiable categorical sampling.
135///
136/// Adds Gumbel noise to logits and applies temperature-scaled softmax:
137/// y_i = exp((log(p_i) + g_i) / τ) / Σ_j exp((log(p_j) + g_j) / τ)
138///
139/// where g_i ~ Gumbel(0, 1) = -log(-log(U(0,1)))
140///
141/// # Arguments
142/// * `logits` - Input logits (unnormalized log-probabilities)
143/// * `config` - Gumbel-Softmax configuration
144///
145/// # Returns
146/// Soft samples (or hard one-hot if config.hard = true)
147pub fn gumbel_softmax(
148    logits: &Scirs2Tensor,
149    config: GumbelSoftmaxConfig,
150) -> TlBackendResult<Scirs2Tensor> {
151    // Sample Gumbel noise: g = -log(-log(u)) where u ~ Uniform(0, 1)
152    let gumbel_noise = sample_gumbel(logits.shape(), config.seed)?;
153
154    // Add noise to logits: logits + gumbel_noise
155    let noisy_logits = logits + &gumbel_noise;
156
157    // Apply temperature-scaled softmax
158    let soft_samples = softmax_temperature(&noisy_logits, config.temperature)?;
159
160    if config.hard {
161        // Hard one-hot samples in forward, soft in backward (STE)
162        let hard_samples = argmax_to_onehot(&soft_samples)?;
163        Ok(hard_samples)
164    } else {
165        Ok(soft_samples)
166    }
167}
168
169/// Computes gradient for Gumbel-Softmax backward pass.
170///
171/// For soft mode: Standard softmax gradient
172/// For hard mode: Straight-through (gradient flows through soft samples)
173///
174/// # Arguments
175/// * `grad_output` - Gradient from downstream
176/// * `soft_samples` - Soft samples from forward pass
177/// * `config` - Gumbel-Softmax configuration
178///
179/// # Returns
180/// Gradient with respect to logits
181pub fn gumbel_softmax_backward(
182    grad_output: &Scirs2Tensor,
183    soft_samples: &Scirs2Tensor,
184    config: GumbelSoftmaxConfig,
185) -> TlBackendResult<Scirs2Tensor> {
186    // Softmax gradient: ∂L/∂logits = soft_samples * (∂L/∂y - (soft_samples · ∂L/∂y))
187    // For hard mode, we use STE (pass through as if soft)
188
189    // Compute dot product: sum(soft_samples * grad_output) along last axis
190    let last_axis = soft_samples.ndim() - 1;
191    let dot_product = (soft_samples * grad_output)
192        .sum_axis(Axis(last_axis))
193        .insert_axis(Axis(last_axis));
194
195    // Compute gradient: soft_samples * (grad_output - dot_product)
196    let grad_logits = soft_samples * &(grad_output - &dot_product);
197
198    // Scale by temperature
199    Ok(grad_logits.mapv(|g| g / config.temperature))
200}
201
202/// Applies soft exists quantifier: ∃x. P(x).
203///
204/// Differentiable approximation of "at least one element is true".
205///
206/// # Arguments
207/// * `input` - Input tensor with values in [0, 1]
208/// * `axis` - Axis along which to apply quantifier (None = all axes)
209/// * `mode` - Quantifier mode (Hard/Smooth/Probabilistic)
210///
211/// # Returns
212/// Result of exists quantification
213pub fn soft_exists(
214    input: &Scirs2Tensor,
215    axis: Option<usize>,
216    mode: QuantifierMode,
217) -> TlBackendResult<Scirs2Tensor> {
218    match mode {
219        QuantifierMode::Hard => {
220            // Hard max (non-differentiable)
221            if let Some(ax) = axis {
222                Ok(input.map_axis(Axis(ax), |slice| {
223                    slice.iter().fold(0.0_f64, |a, &b| a.max(b))
224                }))
225            } else {
226                let max_val = input.iter().fold(0.0_f64, |a, &b| a.max(b));
227                Ok(Array::from_elem(vec![], max_val))
228            }
229        }
230        QuantifierMode::Smooth { temperature } => {
231            // Smooth max using log-sum-exp: max(x) ≈ τ * log(Σ exp(x/τ))
232            smooth_max(input, axis, temperature)
233        }
234        QuantifierMode::Probabilistic => {
235            // Probabilistic: 1 - ∏(1 - x_i)
236            // This is equivalent to OR in probability theory
237            probabilistic_exists(input, axis)
238        }
239    }
240}
241
242/// Computes gradient for soft exists quantifier.
243///
244/// # Arguments
245/// * `grad_output` - Gradient from downstream
246/// * `input` - Original input tensor
247/// * `output` - Output from forward pass
248/// * `axis` - Axis along which quantifier was applied
249/// * `mode` - Quantifier mode
250///
251/// # Returns
252/// Gradient with respect to input
253pub fn soft_exists_backward(
254    grad_output: &Scirs2Tensor,
255    input: &Scirs2Tensor,
256    _output: &Scirs2Tensor,
257    axis: Option<usize>,
258    mode: QuantifierMode,
259) -> TlBackendResult<Scirs2Tensor> {
260    match mode {
261        QuantifierMode::Hard => {
262            // For hard max, gradient goes only to the maximum element
263            // This is similar to argmax gradient (sparse)
264            argmax_gradient(grad_output, input, axis)
265        }
266        QuantifierMode::Smooth { temperature } => {
267            // Smooth max gradient: softmax weights
268            smooth_max_gradient(grad_output, input, temperature, axis)
269        }
270        QuantifierMode::Probabilistic => {
271            // Probabilistic gradient: ∂(1 - ∏(1-x_i))/∂x_j = ∏_{i≠j}(1-x_i)
272            probabilistic_exists_gradient(grad_output, input, axis)
273        }
274    }
275}
276
277/// Applies soft forall quantifier: ∀x. P(x).
278///
279/// Differentiable approximation of "all elements are true".
280///
281/// # Arguments
282/// * `input` - Input tensor with values in [0, 1]
283/// * `axis` - Axis along which to apply quantifier (None = all axes)
284/// * `mode` - Quantifier mode (Hard/Smooth/Probabilistic)
285///
286/// # Returns
287/// Result of forall quantification
288pub fn soft_forall(
289    input: &Scirs2Tensor,
290    axis: Option<usize>,
291    mode: QuantifierMode,
292) -> TlBackendResult<Scirs2Tensor> {
293    // ∀x. P(x) is equivalent to ¬∃x. ¬P(x)
294    // Or directly: min(x) in hard mode, product in probabilistic
295    match mode {
296        QuantifierMode::Hard => {
297            // Hard min (non-differentiable)
298            if let Some(ax) = axis {
299                Ok(input.map_axis(Axis(ax), |slice| {
300                    slice.iter().fold(1.0_f64, |a, &b| a.min(b))
301                }))
302            } else {
303                let min_val = input.iter().fold(1.0_f64, |a, &b| a.min(b));
304                Ok(Array::from_elem(vec![], min_val))
305            }
306        }
307        QuantifierMode::Smooth { temperature } => {
308            // Smooth min using -log-sum-exp(-x): min(x) ≈ -τ * log(Σ exp(-x/τ))
309            smooth_min(input, axis, temperature)
310        }
311        QuantifierMode::Probabilistic => {
312            // Probabilistic: ∏ x_i (product of probabilities)
313            probabilistic_forall(input, axis)
314        }
315    }
316}
317
318/// Computes gradient for soft forall quantifier.
319///
320/// # Arguments
321/// * `grad_output` - Gradient from downstream
322/// * `input` - Original input tensor
323/// * `output` - Output from forward pass
324/// * `axis` - Axis along which quantifier was applied
325/// * `mode` - Quantifier mode
326///
327/// # Returns
328/// Gradient with respect to input
329pub fn soft_forall_backward(
330    grad_output: &Scirs2Tensor,
331    input: &Scirs2Tensor,
332    output: &Scirs2Tensor,
333    axis: Option<usize>,
334    mode: QuantifierMode,
335) -> TlBackendResult<Scirs2Tensor> {
336    match mode {
337        QuantifierMode::Hard => {
338            // For hard min, gradient goes only to the minimum element
339            argmin_gradient(grad_output, input, axis)
340        }
341        QuantifierMode::Smooth { temperature } => {
342            // Smooth min gradient: similar to smooth max but with negated values
343            smooth_min_gradient(grad_output, input, temperature, axis)
344        }
345        QuantifierMode::Probabilistic => {
346            // Product gradient: ∂(∏ x_i)/∂x_j = (∏ x_i) / x_j
347            probabilistic_forall_gradient(grad_output, input, output, axis)
348        }
349    }
350}
351
352// ============================================================================
353// Helper Functions
354// ============================================================================
355
356/// Samples Gumbel noise: g = -log(-log(u)) where u ~ Uniform(0, 1).
357fn sample_gumbel(shape: &[usize], seed: Option<u64>) -> TlBackendResult<Scirs2Tensor> {
358    use scirs2_core::ndarray::IxDyn;
359
360    let uniform_dist = Uniform::new(1e-10, 1.0 - 1e-10).unwrap(); // Avoid log(0)
361    let dyn_shape = IxDyn(shape);
362
363    let gumbel = if let Some(s) = seed {
364        let mut rng = seeded_rng(s);
365        ArrayD::random_bulk(dyn_shape, uniform_dist, &mut rng)
366    } else {
367        let mut rng = thread_rng();
368        ArrayD::random_bulk(dyn_shape, uniform_dist, &mut rng)
369    };
370
371    // Apply Gumbel transformation: -log(-log(u))
372    let gumbel = gumbel.mapv(|u: f64| -(-u.ln()).ln());
373    Ok(gumbel)
374}
375
376/// Applies temperature-scaled softmax along the last axis.
377fn softmax_temperature(logits: &Scirs2Tensor, temperature: f64) -> TlBackendResult<Scirs2Tensor> {
378    // Scale by temperature
379    let scaled = logits.mapv(|x| x / temperature);
380
381    // Compute softmax along last axis
382    let last_axis = scaled.ndim() - 1;
383
384    // Subtract max for numerical stability
385    let max_vals = scaled
386        .map_axis(Axis(last_axis), |slice| {
387            slice.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
388        })
389        .insert_axis(Axis(last_axis));
390
391    let exp_vals = (&scaled - &max_vals).mapv(|x| x.exp());
392    let sum_exp = exp_vals
393        .sum_axis(Axis(last_axis))
394        .insert_axis(Axis(last_axis));
395
396    Ok(&exp_vals / &sum_exp)
397}
398
399/// Converts soft samples to hard one-hot encoding (argmax).
400fn argmax_to_onehot(soft_samples: &Scirs2Tensor) -> TlBackendResult<Scirs2Tensor> {
401    let last_axis = soft_samples.ndim() - 1;
402    let mut onehot = ArrayD::zeros(soft_samples.raw_dim());
403
404    // Iterate over all elements except the last axis
405    let n_classes = soft_samples.len_of(Axis(last_axis));
406
407    // Get views along the last axis
408    for i in 0..soft_samples.len() / n_classes {
409        // Calculate multi-dimensional index
410        let mut flat_idx = i;
411        let mut indices = vec![0; soft_samples.ndim()];
412
413        for dim in (0..last_axis).rev() {
414            let size = soft_samples.len_of(Axis(dim));
415            indices[dim] = flat_idx % size;
416            flat_idx /= size;
417        }
418
419        // Find argmax along last dimension for this slice
420        let mut max_val = f64::NEG_INFINITY;
421        let mut max_idx = 0;
422
423        for j in 0..n_classes {
424            indices[last_axis] = j;
425            let val = soft_samples[&indices[..]];
426            if val > max_val {
427                max_val = val;
428                max_idx = j;
429            }
430        }
431
432        // Set one-hot
433        indices[last_axis] = max_idx;
434        onehot[&indices[..]] = 1.0;
435    }
436
437    Ok(onehot)
438}
439
440/// Smooth max using log-sum-exp: max(x) ≈ τ * log(Σ exp(x/τ)).
441fn smooth_max(
442    input: &Scirs2Tensor,
443    axis: Option<usize>,
444    temperature: f64,
445) -> TlBackendResult<Scirs2Tensor> {
446    let scaled = input.mapv(|x| x / temperature);
447
448    if let Some(ax) = axis {
449        // Max for numerical stability
450        let max_vals = scaled.map_axis(Axis(ax), |slice| {
451            slice.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
452        });
453
454        // For broadcasting, temporarily insert axis
455        let max_vals_broadcast = max_vals.clone().insert_axis(Axis(ax));
456        let exp_vals = (&scaled - &max_vals_broadcast).mapv(|x| x.exp());
457        let sum_exp = exp_vals.sum_axis(Axis(ax));
458        let log_sum_exp = &max_vals + &sum_exp.mapv(|x| x.ln());
459
460        Ok(log_sum_exp.mapv(|x| x * temperature))
461    } else {
462        let max_val = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
463        let exp_vals = scaled.mapv(|x| (x - max_val).exp());
464        let sum_exp: f64 = exp_vals.iter().sum();
465        let result = temperature * (max_val + sum_exp.ln());
466        Ok(Array::from_elem(vec![], result))
467    }
468}
469
470/// Gradient for smooth max.
471fn smooth_max_gradient(
472    grad_output: &Scirs2Tensor,
473    input: &Scirs2Tensor,
474    temperature: f64,
475    axis: Option<usize>,
476) -> TlBackendResult<Scirs2Tensor> {
477    // Gradient: softmax weights
478    let scaled = input.mapv(|x| x / temperature);
479
480    if let Some(ax) = axis {
481        let max_vals = scaled
482            .map_axis(Axis(ax), |slice| {
483                slice.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
484            })
485            .insert_axis(Axis(ax));
486
487        let exp_vals = (&scaled - &max_vals).mapv(|x| x.exp());
488        let sum_exp = exp_vals.sum_axis(Axis(ax)).insert_axis(Axis(ax));
489        let weights = &exp_vals / &sum_exp;
490
491        // Broadcast grad_output and multiply
492        Ok(&weights * grad_output)
493    } else {
494        let max_val = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
495        let exp_vals = scaled.mapv(|x| (x - max_val).exp());
496        let sum_exp: f64 = exp_vals.iter().sum();
497        let weights = exp_vals.mapv(|x| x / sum_exp);
498
499        let grad_scalar = grad_output.iter().next().unwrap_or(&0.0);
500        Ok(weights.mapv(|w| w * grad_scalar))
501    }
502}
503
504/// Smooth min using -log-sum-exp(-x).
505fn smooth_min(
506    input: &Scirs2Tensor,
507    axis: Option<usize>,
508    temperature: f64,
509) -> TlBackendResult<Scirs2Tensor> {
510    // min(x) = -max(-x)
511    let negated = input.mapv(|x| -x);
512    let result = smooth_max(&negated, axis, temperature)?;
513    Ok(result.mapv(|x| -x))
514}
515
516/// Gradient for smooth min.
517fn smooth_min_gradient(
518    grad_output: &Scirs2Tensor,
519    input: &Scirs2Tensor,
520    temperature: f64,
521    axis: Option<usize>,
522) -> TlBackendResult<Scirs2Tensor> {
523    // Same as smooth max gradient but with negated input
524    let negated = input.mapv(|x| -x);
525    let grad = smooth_max_gradient(grad_output, &negated, temperature, axis)?;
526    Ok(grad.mapv(|g| -g))
527}
528
529/// Probabilistic exists: 1 - ∏(1 - x_i).
530fn probabilistic_exists(
531    input: &Scirs2Tensor,
532    axis: Option<usize>,
533) -> TlBackendResult<Scirs2Tensor> {
534    let one_minus_input = input.mapv(|x| 1.0 - x);
535
536    if let Some(ax) = axis {
537        let product = one_minus_input.map_axis(Axis(ax), |slice| slice.iter().product::<f64>());
538        Ok(product.mapv(|p| 1.0 - p))
539    } else {
540        let product: f64 = one_minus_input.iter().product();
541        Ok(Array::from_elem(vec![], 1.0 - product))
542    }
543}
544
545/// Gradient for probabilistic exists.
546fn probabilistic_exists_gradient(
547    grad_output: &Scirs2Tensor,
548    input: &Scirs2Tensor,
549    axis: Option<usize>,
550) -> TlBackendResult<Scirs2Tensor> {
551    // ∂(1 - ∏(1-x_i))/∂x_j = ∏_{i≠j}(1-x_i)
552    let one_minus_input = input.mapv(|x| 1.0 - x);
553
554    if let Some(ax) = axis {
555        // For each element, compute product of all others
556        let mut grad = ArrayD::zeros(input.raw_dim());
557
558        for i in 0..input.len_of(Axis(ax)) {
559            let mut slice = input.index_axis(Axis(ax), i).to_owned();
560            // Product of all elements except i
561            let product: f64 = one_minus_input.iter().product();
562            let elem_val = 1.0 - input.index_axis(Axis(ax), i).iter().next().unwrap_or(&0.0);
563            let grad_elem = if elem_val.abs() > 1e-10 {
564                product / elem_val
565            } else {
566                0.0
567            };
568
569            slice.fill(grad_elem);
570            grad.index_axis_mut(Axis(ax), i).assign(&slice);
571        }
572
573        Ok(&grad * grad_output)
574    } else {
575        let product: f64 = one_minus_input.iter().product();
576        let grad = input.mapv(|x| {
577            let denom = 1.0 - x;
578            if denom.abs() > 1e-10 {
579                product / denom
580            } else {
581                0.0
582            }
583        });
584
585        let grad_scalar = grad_output.iter().next().unwrap_or(&0.0);
586        Ok(grad.mapv(|g| g * grad_scalar))
587    }
588}
589
590/// Probabilistic forall: ∏ x_i.
591fn probabilistic_forall(
592    input: &Scirs2Tensor,
593    axis: Option<usize>,
594) -> TlBackendResult<Scirs2Tensor> {
595    if let Some(ax) = axis {
596        Ok(input.map_axis(Axis(ax), |slice| slice.iter().product::<f64>()))
597    } else {
598        let product: f64 = input.iter().product();
599        Ok(Array::from_elem(vec![], product))
600    }
601}
602
603/// Gradient for probabilistic forall.
604fn probabilistic_forall_gradient(
605    grad_output: &Scirs2Tensor,
606    input: &Scirs2Tensor,
607    output: &Scirs2Tensor,
608    axis: Option<usize>,
609) -> TlBackendResult<Scirs2Tensor> {
610    // ∂(∏ x_i)/∂x_j = (∏ x_i) / x_j = output / x_j
611
612    if let Some(_ax) = axis {
613        // Broadcast output and divide by input
614        let grad = output / input;
615        Ok(&grad * grad_output)
616    } else {
617        let output_val = output.iter().next().unwrap_or(&0.0);
618        let grad = input.mapv(|x| if x.abs() > 1e-10 { output_val / x } else { 0.0 });
619
620        let grad_scalar = grad_output.iter().next().unwrap_or(&0.0);
621        Ok(grad.mapv(|g| g * grad_scalar))
622    }
623}
624
625/// Gradient for hard argmax (sparse gradient to maximum element).
626fn argmax_gradient(
627    grad_output: &Scirs2Tensor,
628    input: &Scirs2Tensor,
629    axis: Option<usize>,
630) -> TlBackendResult<Scirs2Tensor> {
631    let mut grad = ArrayD::zeros(input.raw_dim());
632
633    if let Some(ax) = axis {
634        for i in 0..input.len_of(Axis(ax)) {
635            let slice = input.index_axis(Axis(ax), i);
636            let max_idx = slice
637                .iter()
638                .enumerate()
639                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
640                .map(|(idx, _)| idx)
641                .unwrap_or(0);
642
643            grad.index_axis_mut(Axis(ax), i)[max_idx] = *grad_output
644                .index_axis(Axis(ax), i)
645                .iter()
646                .next()
647                .unwrap_or(&0.0);
648        }
649    } else {
650        let max_idx = input
651            .iter()
652            .enumerate()
653            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
654            .map(|(idx, _)| idx)
655            .unwrap_or(0);
656
657        grad.as_slice_mut().unwrap()[max_idx] = *grad_output.iter().next().unwrap_or(&0.0);
658    }
659
660    Ok(grad)
661}
662
663/// Gradient for hard argmin (sparse gradient to minimum element).
664fn argmin_gradient(
665    grad_output: &Scirs2Tensor,
666    input: &Scirs2Tensor,
667    axis: Option<usize>,
668) -> TlBackendResult<Scirs2Tensor> {
669    let mut grad = ArrayD::zeros(input.raw_dim());
670
671    if let Some(ax) = axis {
672        for i in 0..input.len_of(Axis(ax)) {
673            let slice = input.index_axis(Axis(ax), i);
674            let min_idx = slice
675                .iter()
676                .enumerate()
677                .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
678                .map(|(idx, _)| idx)
679                .unwrap_or(0);
680
681            grad.index_axis_mut(Axis(ax), i)[min_idx] = *grad_output
682                .index_axis(Axis(ax), i)
683                .iter()
684                .next()
685                .unwrap_or(&0.0);
686        }
687    } else {
688        let min_idx = input
689            .iter()
690            .enumerate()
691            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
692            .map(|(idx, _)| idx)
693            .unwrap_or(0);
694
695        grad.as_slice_mut().unwrap()[min_idx] = *grad_output.iter().next().unwrap_or(&0.0);
696    }
697
698    Ok(grad)
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704    use scirs2_core::ndarray::array;
705
706    #[test]
707    fn test_ste_threshold_forward() {
708        let input = array![[0.2, 0.6], [0.4, 0.8]].into_dyn();
709        let config = SteConfig::default();
710
711        let output = ste_threshold(&input, config).unwrap();
712        let expected = array![[0.0, 1.0], [0.0, 1.0]].into_dyn();
713
714        assert_eq!(output, expected);
715    }
716
717    #[test]
718    fn test_ste_threshold_backward() {
719        let grad_output = array![[1.0, 2.0], [3.0, 4.0]].into_dyn();
720        let input = array![[0.2, 0.6], [0.4, 0.8]].into_dyn();
721        let config = SteConfig::default();
722
723        let grad_input = ste_threshold_backward(&grad_output, &input, config).unwrap();
724
725        // Gradient passes through unchanged
726        assert_eq!(grad_input, grad_output);
727    }
728
729    #[test]
730    fn test_ste_gradient_clipping() {
731        let grad_output = array![[5.0, -3.0], [0.5, -10.0]].into_dyn();
732        let input = array![[0.2, 0.6], [0.4, 0.8]].into_dyn();
733        let config = SteConfig {
734            threshold: 0.5,
735            clip_gradients: true,
736        };
737
738        let grad_input = ste_threshold_backward(&grad_output, &input, config).unwrap();
739        let expected = array![[1.0, -1.0], [0.5, -1.0]].into_dyn();
740
741        assert_eq!(grad_input, expected);
742    }
743
744    #[test]
745    fn test_gumbel_softmax_deterministic() {
746        let logits = array![[1.0, 2.0, 3.0]].into_dyn();
747        let config = GumbelSoftmaxConfig {
748            temperature: 1.0,
749            hard: false,
750            seed: Some(42),
751        };
752
753        let samples = gumbel_softmax(&logits, config).unwrap();
754
755        // Check output is valid probability distribution
756        assert_eq!(samples.shape(), &[1, 3]);
757        let sum: f64 = samples.iter().sum();
758        assert!((sum - 1.0).abs() < 1e-6);
759
760        // All values should be in [0, 1]
761        for &val in samples.iter() {
762            assert!((0.0..=1.0).contains(&val));
763        }
764    }
765
766    #[test]
767    fn test_gumbel_softmax_hard_mode() {
768        let logits = array![[1.0, 5.0, 2.0]].into_dyn();
769        let config = GumbelSoftmaxConfig {
770            temperature: 0.1,
771            hard: true,
772            seed: Some(123),
773        };
774
775        let samples = gumbel_softmax(&logits, config).unwrap();
776
777        // In hard mode with low temperature and high logit[1], should be close to one-hot
778        let sum: f64 = samples.iter().sum();
779        assert!((sum - 1.0).abs() < 1e-6);
780
781        // At least one value should be close to 1.0
782        let max_val = samples.iter().fold(0.0_f64, |a, &b| a.max(b));
783        assert!(max_val >= 0.9);
784    }
785
786    #[test]
787    fn test_soft_exists_smooth() {
788        let input = array![[0.1, 0.3], [0.2, 0.9]].into_dyn();
789        let mode = QuantifierMode::Smooth { temperature: 1.0 };
790
791        let output = soft_exists(&input, Some(1), mode).unwrap();
792
793        // Should be approximately max along axis 1 (but higher due to log-sum-exp)
794        // smooth_max([0.1, 0.3], τ=1) ≈ 0.898
795        // smooth_max([0.2, 0.9], τ=1) ≈ 1.303
796        assert_eq!(output.shape(), &[2]);
797        assert!(
798            output[0] >= 0.85 && output[0] <= 0.95,
799            "output[0] = {} not in [0.85, 0.95]",
800            output[0]
801        );
802        assert!(
803            output[1] >= 1.25 && output[1] <= 1.35,
804            "output[1] = {} not in [1.25, 1.35]",
805            output[1]
806        );
807    }
808
809    #[test]
810    fn test_soft_exists_probabilistic() {
811        let input = array![[0.5, 0.5]].into_dyn();
812        let mode = QuantifierMode::Probabilistic;
813
814        let output = soft_exists(&input, Some(1), mode).unwrap();
815
816        // 1 - (1-0.5)*(1-0.5) = 1 - 0.25 = 0.75
817        assert!((output[0] - 0.75).abs() < 1e-6);
818    }
819
820    #[test]
821    fn test_soft_forall_probabilistic() {
822        let input = array![[0.5, 0.5]].into_dyn();
823        let mode = QuantifierMode::Probabilistic;
824
825        let output = soft_forall(&input, Some(1), mode).unwrap();
826
827        // 0.5 * 0.5 = 0.25
828        assert!((output[0] - 0.25).abs() < 1e-6);
829    }
830
831    #[test]
832    fn test_probabilistic_forall_gradient() {
833        let input = array![[0.5, 0.8]].into_dyn();
834        let output = array![0.4].into_dyn(); // 0.5 * 0.8
835        let grad_output = array![1.0].into_dyn();
836
837        let grad_input =
838            probabilistic_forall_gradient(&grad_output, &input, &output, Some(1)).unwrap();
839
840        // ∂(0.4)/∂x[0] = 0.4 / 0.5 = 0.8
841        // ∂(0.4)/∂x[1] = 0.4 / 0.8 = 0.5
842        assert!((grad_input[[0, 0]] - 0.8).abs() < 1e-6);
843        assert!((grad_input[[0, 1]] - 0.5).abs() < 1e-6);
844    }
845
846    #[test]
847    fn test_smooth_max_vs_hard_max() {
848        let input = array![[1.0, 2.0, 3.0]].into_dyn();
849
850        // Hard max
851        let hard = soft_exists(&input, Some(1), QuantifierMode::Hard).unwrap();
852        assert!((hard[0] - 3.0).abs() < 1e-6);
853
854        // Smooth max with low temperature
855        let smooth = soft_exists(
856            &input,
857            Some(1),
858            QuantifierMode::Smooth { temperature: 0.01 },
859        )
860        .unwrap();
861        assert!((smooth[0] - 3.0).abs() < 0.1); // Should be close to 3.0
862    }
863
864    #[test]
865    fn test_gumbel_noise_properties() {
866        // Test that Gumbel samples have correct properties
867        let shape = &[1000];
868        let noise = sample_gumbel(shape, Some(42)).unwrap();
869
870        // Mean of Gumbel(0,1) is Euler-Mascheroni constant ≈ 0.5772
871        let mean: f64 = noise.iter().sum::<f64>() / noise.len() as f64;
872        assert!((mean - 0.5772).abs() < 0.1); // Rough check
873
874        // Check no NaN or Inf
875        for &val in noise.iter() {
876            assert!(val.is_finite());
877        }
878    }
879}