Skip to main content

torsh_jit/
differentiable_compilation.rs

1// Copyright (c) 2025 ToRSh Contributors
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! # Differentiable Compilation
5//!
6//! This module implements differentiable compilation where the compilation process itself
7//! is differentiable, enabling gradient-based optimization of compilation decisions.
8//!
9//! ## Key Concepts
10//!
11//! - **Differentiable Optimization Passes**: Each optimization can compute gradients
12//! - **Compilation Graph**: Represent compilation as a differentiable computation graph
13//! - **Meta-Optimization**: Optimize compilation strategies using gradient descent
14//! - **Soft Decision Making**: Use continuous relaxations of discrete choices
15//! - **End-to-End Learning**: Jointly learn program behavior and compilation strategy
16//!
17//! ## Architecture
18//!
19//! ```text
20//! Input Program → Soft Compilation Decisions → Optimized Program → Loss
21//!                  ↑                                                  ↓
22//!                  └────────────── Gradients ───────────────────────┘
23//! ```
24//!
25//! ## Example
26//!
27//! ```rust,ignore
28//! use torsh_jit::differentiable_compilation::{DifferentiableCompiler, CompilationParams};
29//!
30//! let mut compiler = DifferentiableCompiler::new();
31//!
32//! // Define compilation parameters (learnable)
33//! let mut params = CompilationParams::new();
34//!
35//! // Compile with soft decisions
36//! let result = compiler.compile_differentiable(&graph, &params)?;
37//!
38//! // Compute gradients based on performance
39//! let grads = compiler.backward(&result, performance_loss)?;
40//!
41//! // Update compilation strategy
42//! params.update(&grads, learning_rate);
43//! ```
44
45use crate::graph::{ComputationGraph, NodeId};
46// use crate::optimizer::OptimizationPass; // Reserved for future integration
47use crate::JitResult;
48use serde::{Deserialize, Serialize};
49use std::collections::HashMap;
50use std::sync::{Arc, Mutex};
51
52// ============================================================================
53// Differentiable Parameters
54// ============================================================================
55
56/// Learnable compilation parameters
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct CompilationParams {
59    /// Weights for each optimization pass (0-1, continuous)
60    pub pass_weights: HashMap<String, f32>,
61
62    /// Fusion temperature (controls fusion aggressiveness)
63    pub fusion_temp: f32,
64
65    /// Loop unrolling factor (continuous relaxation)
66    pub unroll_factor: f32,
67
68    /// Vectorization width (soft selection)
69    pub vector_width: f32,
70
71    /// Memory layout preference (continuous)
72    pub layout_preference: f32,
73
74    /// Gradients for each parameter
75    #[serde(skip)]
76    pub gradients: HashMap<String, f32>,
77}
78
79impl CompilationParams {
80    /// Create new parameters with default initialization
81    pub fn new() -> Self {
82        let mut pass_weights = HashMap::new();
83
84        // Initialize pass weights (will be learned)
85        for pass_name in [
86            "constant_folding",
87            "dead_code_elimination",
88            "common_subexpression_elimination",
89            "loop_invariant_motion",
90            "strength_reduction",
91            "fusion",
92            "vectorization",
93            "parallelization",
94        ] {
95            pass_weights.insert(pass_name.to_string(), 0.5); // Neutral initialization
96        }
97
98        Self {
99            pass_weights,
100            fusion_temp: 1.0,
101            unroll_factor: 4.0,
102            vector_width: 4.0,
103            layout_preference: 0.5,
104            gradients: HashMap::new(),
105        }
106    }
107
108    /// Update parameters using gradients (gradient descent)
109    pub fn update(&mut self, learning_rate: f32) {
110        // Update pass weights
111        for (name, weight) in &mut self.pass_weights {
112            if let Some(&grad) = self.gradients.get(name) {
113                *weight -= learning_rate * grad;
114                *weight = weight.clamp(0.0, 1.0); // Keep in valid range
115            }
116        }
117
118        // Update other parameters
119        if let Some(&grad) = self.gradients.get("fusion_temp") {
120            self.fusion_temp -= learning_rate * grad;
121            self.fusion_temp = self.fusion_temp.max(0.1); // Avoid zero/negative
122        }
123
124        if let Some(&grad) = self.gradients.get("unroll_factor") {
125            self.unroll_factor -= learning_rate * grad;
126            self.unroll_factor = self.unroll_factor.clamp(1.0, 32.0);
127        }
128
129        if let Some(&grad) = self.gradients.get("vector_width") {
130            self.vector_width -= learning_rate * grad;
131            self.vector_width = self.vector_width.clamp(1.0, 16.0);
132        }
133
134        if let Some(&grad) = self.gradients.get("layout_preference") {
135            self.layout_preference -= learning_rate * grad;
136            self.layout_preference = self.layout_preference.clamp(0.0, 1.0);
137        }
138
139        // Clear gradients after update
140        self.gradients.clear();
141    }
142
143    /// Zero out all gradients
144    pub fn zero_grad(&mut self) {
145        self.gradients.clear();
146    }
147
148    /// Accumulate gradients
149    pub fn accumulate_grad(&mut self, name: &str, grad: f32) {
150        *self.gradients.entry(name.to_string()).or_insert(0.0) += grad;
151    }
152}
153
154impl Default for CompilationParams {
155    fn default() -> Self {
156        Self::new()
157    }
158}
159
160// ============================================================================
161// Soft Operations
162// ============================================================================
163
164/// Soft (differentiable) version of binary decisions
165#[derive(Debug, Clone)]
166pub struct SoftDecision {
167    /// Probability of taking the decision (0-1)
168    pub probability: f32,
169
170    /// Gradient with respect to probability
171    pub gradient: f32,
172}
173
174impl SoftDecision {
175    /// Create a new soft decision
176    pub fn new(probability: f32) -> Self {
177        Self {
178            probability: probability.clamp(0.0, 1.0),
179            gradient: 0.0,
180        }
181    }
182
183    /// Apply soft decision (weighted combination)
184    pub fn apply<T: Clone>(&self, if_true: T, if_false: T, blend_fn: fn(&T, &T, f32) -> T) -> T {
185        blend_fn(&if_true, &if_false, self.probability)
186    }
187
188    /// Backward pass for soft decision
189    pub fn backward(&mut self, upstream_grad: f32) {
190        self.gradient += upstream_grad;
191    }
192
193    /// Sigmoid activation for smooth decisions
194    pub fn sigmoid(x: f32) -> f32 {
195        1.0 / (1.0 + (-x).exp())
196    }
197
198    /// Softmax for multi-way decisions
199    pub fn softmax(logits: &[f32]) -> Vec<f32> {
200        let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
201        let exp_values: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
202        let sum: f32 = exp_values.iter().sum();
203        exp_values.iter().map(|&x| x / sum).collect()
204    }
205}
206
207// ============================================================================
208// Differentiable Compilation Result
209// ============================================================================
210
211/// Result of differentiable compilation
212#[derive(Debug, Clone)]
213pub struct DiffCompilationResult {
214    /// Compiled graph (soft-compiled)
215    pub graph: ComputationGraph,
216
217    /// Compilation decisions made (with probabilities)
218    pub decisions: Vec<CompilationDecision>,
219
220    /// Estimated performance metrics
221    pub estimated_performance: PerformanceMetrics,
222
223    /// Computation tape for backward pass
224    pub tape: Arc<Mutex<ComputationTape>>,
225}
226
227/// A single compilation decision
228#[derive(Debug, Clone)]
229pub struct CompilationDecision {
230    /// Name of the decision
231    pub name: String,
232
233    /// Decision type
234    pub decision_type: DecisionType,
235
236    /// Soft decision value
237    pub decision: SoftDecision,
238
239    /// Impact on performance (estimated)
240    pub impact: f32,
241}
242
243/// Types of compilation decisions
244#[derive(Debug, Clone, PartialEq)]
245pub enum DecisionType {
246    /// Whether to apply an optimization pass
247    ApplyOptimization(String),
248
249    /// Whether to fuse two operations
250    FuseOperations(NodeId, NodeId),
251
252    /// Loop unrolling decision
253    UnrollLoop(usize),
254
255    /// Vectorization decision
256    Vectorize(usize),
257
258    /// Memory layout choice
259    MemoryLayout(String),
260}
261
262/// Performance metrics
263#[derive(Debug, Clone, Default)]
264pub struct PerformanceMetrics {
265    /// Estimated execution time (microseconds)
266    pub exec_time_us: f32,
267
268    /// Estimated memory usage (bytes)
269    pub memory_bytes: f32,
270
271    /// Estimated FLOPs
272    pub flops: f32,
273
274    /// Cache efficiency score (0-1)
275    pub cache_efficiency: f32,
276}
277
278/// Computation tape for automatic differentiation
279#[derive(Debug, Clone, Default)]
280pub struct ComputationTape {
281    /// Recorded operations
282    pub operations: Vec<TapeOperation>,
283
284    /// Variable gradients
285    pub gradients: HashMap<String, f32>,
286}
287
288/// A single operation in the tape
289#[derive(Debug, Clone)]
290pub struct TapeOperation {
291    /// Operation name
292    pub name: String,
293
294    /// Input variables
295    pub inputs: Vec<String>,
296
297    /// Output variable
298    pub output: String,
299
300    /// Forward function
301    pub forward_val: f32,
302
303    /// Gradient function (chain rule)
304    pub grad_fn: GradientFunction,
305}
306
307/// Gradient computation function
308#[derive(Debug, Clone)]
309pub enum GradientFunction {
310    /// Linear: dy/dx = a
311    Linear(f32),
312
313    /// Product: dy/dx = other_input
314    Product(f32),
315
316    /// Sigmoid: dy/dx = sigmoid(x) * (1 - sigmoid(x))
317    Sigmoid,
318
319    /// ReLU: dy/dx = x > 0 ? 1 : 0
320    ReLU,
321
322    /// Custom gradient function
323    Custom(fn(f32, f32) -> f32),
324}
325
326// ============================================================================
327// Differentiable Compiler
328// ============================================================================
329
330/// Main differentiable compilation engine
331pub struct DifferentiableCompiler {
332    /// Configuration
333    config: DiffCompilerConfig,
334
335    /// Training statistics
336    stats: CompilerStatistics,
337}
338
339/// Configuration for differentiable compiler
340#[derive(Debug, Clone)]
341pub struct DiffCompilerConfig {
342    /// Enable gradient checkpointing to save memory
343    pub gradient_checkpointing: bool,
344
345    /// Use straight-through estimators for discrete operations
346    pub straight_through: bool,
347
348    /// Temperature for Gumbel-Softmax relaxation
349    pub gumbel_temperature: f32,
350
351    /// Gradient clipping threshold
352    pub grad_clip: f32,
353}
354
355impl Default for DiffCompilerConfig {
356    fn default() -> Self {
357        Self {
358            gradient_checkpointing: true,
359            straight_through: true,
360            gumbel_temperature: 1.0,
361            grad_clip: 10.0,
362        }
363    }
364}
365
366/// Compiler statistics
367#[derive(Debug, Clone, Default)]
368pub struct CompilerStatistics {
369    /// Number of compilations
370    pub compilations: usize,
371
372    /// Total gradient updates
373    pub gradient_updates: usize,
374
375    /// Average loss
376    pub avg_loss: f32,
377
378    /// Best performance achieved
379    pub best_performance: f32,
380}
381
382impl DifferentiableCompiler {
383    /// Create a new differentiable compiler
384    pub fn new() -> Self {
385        Self::with_config(DiffCompilerConfig::default())
386    }
387
388    /// Create with custom configuration
389    pub fn with_config(config: DiffCompilerConfig) -> Self {
390        Self {
391            config,
392            stats: CompilerStatistics::default(),
393        }
394    }
395
396    /// Compile a graph with differentiable decisions
397    pub fn compile_differentiable(
398        &mut self,
399        graph: &ComputationGraph,
400        params: &CompilationParams,
401    ) -> JitResult<DiffCompilationResult> {
402        let mut tape = ComputationTape::default();
403        let mut decisions = Vec::new();
404
405        // Create a copy of the graph for modification
406        let mut compiled_graph = graph.clone();
407
408        // Apply soft optimization passes
409        for (pass_name, &weight) in &params.pass_weights {
410            let decision = SoftDecision::new(weight);
411
412            // Record decision
413            decisions.push(CompilationDecision {
414                name: pass_name.clone(),
415                decision_type: DecisionType::ApplyOptimization(pass_name.clone()),
416                decision: decision.clone(),
417                impact: self.estimate_pass_impact(pass_name, graph),
418            });
419
420            // Apply optimization with probability weight
421            if weight > 0.5 {
422                // For now, deterministic application
423                // In full implementation, would use soft blending
424                compiled_graph =
425                    self.apply_soft_optimization(&compiled_graph, pass_name, weight)?;
426            }
427
428            // Record in tape
429            tape.operations.push(TapeOperation {
430                name: format!("apply_{}", pass_name),
431                inputs: vec!["graph".to_string()],
432                output: "graph".to_string(),
433                forward_val: weight,
434                grad_fn: GradientFunction::Linear(1.0),
435            });
436        }
437
438        // Estimate performance
439        let estimated_performance = self.estimate_performance(&compiled_graph, params);
440
441        self.stats.compilations += 1;
442
443        Ok(DiffCompilationResult {
444            graph: compiled_graph,
445            decisions,
446            estimated_performance,
447            tape: Arc::new(Mutex::new(tape)),
448        })
449    }
450
451    /// Backward pass to compute gradients
452    pub fn backward(
453        &mut self,
454        result: &DiffCompilationResult,
455        loss: f32,
456    ) -> JitResult<CompilationParams> {
457        let mut params_grad = CompilationParams::new();
458        params_grad.zero_grad();
459
460        // Compute gradients for each decision
461        for decision in &result.decisions {
462            match &decision.decision_type {
463                DecisionType::ApplyOptimization(pass_name) => {
464                    // Simple gradient: dL/dw = dL/dperf * dperf/dw
465                    // Approximate: if optimization helps, gradient is negative (minimize loss)
466                    let grad = if decision.impact > 0.0 {
467                        -loss * decision.impact
468                    } else {
469                        loss * decision.impact.abs()
470                    };
471
472                    params_grad.accumulate_grad(pass_name, grad);
473                }
474                _ => {
475                    // Handle other decision types
476                }
477            }
478        }
479
480        // Gradient clipping
481        for (_name, grad) in &mut params_grad.gradients {
482            *grad = grad.clamp(-self.config.grad_clip, self.config.grad_clip);
483        }
484
485        self.stats.gradient_updates += 1;
486        self.stats.avg_loss = (self.stats.avg_loss * (self.stats.gradient_updates - 1) as f32
487            + loss)
488            / self.stats.gradient_updates as f32;
489
490        Ok(params_grad)
491    }
492
493    /// Apply soft optimization (weighted)
494    fn apply_soft_optimization(
495        &self,
496        graph: &ComputationGraph,
497        pass_name: &str,
498        _weight: f32,
499    ) -> JitResult<ComputationGraph> {
500        // Simplified: just return graph (full implementation would apply soft transformations)
501        // In practice, would blend original and optimized versions based on weight
502        log::debug!("Applying soft optimization: {} with weight", pass_name);
503        Ok(graph.clone())
504    }
505
506    /// Estimate impact of an optimization pass
507    fn estimate_pass_impact(&self, pass_name: &str, _graph: &ComputationGraph) -> f32 {
508        // Heuristic estimates (in production, would be learned)
509        match pass_name {
510            "constant_folding" => 0.1,
511            "dead_code_elimination" => 0.15,
512            "common_subexpression_elimination" => 0.2,
513            "fusion" => 0.3,
514            "vectorization" => 0.4,
515            "parallelization" => 0.5,
516            _ => 0.05,
517        }
518    }
519
520    /// Estimate performance of compiled graph
521    fn estimate_performance(
522        &self,
523        graph: &ComputationGraph,
524        params: &CompilationParams,
525    ) -> PerformanceMetrics {
526        let node_count = graph.node_count() as f32;
527
528        // Simple performance model (in production, would use learned model)
529        let base_time = node_count * 10.0; // 10 us per node
530
531        // Apply optimization effects
532        let mut speedup = 1.0;
533        for (pass_name, &weight) in &params.pass_weights {
534            let impact = self.estimate_pass_impact(pass_name, graph);
535            speedup += weight * impact;
536        }
537
538        let exec_time_us = base_time / speedup;
539        let memory_bytes = node_count * 1024.0; // Simplified
540
541        PerformanceMetrics {
542            exec_time_us,
543            memory_bytes,
544            flops: node_count * 100.0,
545            cache_efficiency: 0.7 + params.layout_preference * 0.3,
546        }
547    }
548
549    /// Get compiler statistics
550    pub fn statistics(&self) -> &CompilerStatistics {
551        &self.stats
552    }
553
554    /// Reset statistics
555    pub fn reset_stats(&mut self) {
556        self.stats = CompilerStatistics::default();
557    }
558}
559
560impl Default for DifferentiableCompiler {
561    fn default() -> Self {
562        Self::new()
563    }
564}
565
566// ============================================================================
567// Gumbel-Softmax Trick
568// ============================================================================
569
570/// Gumbel-Softmax trick for differentiable discrete decisions
571pub struct GumbelSoftmax {
572    /// Temperature parameter
573    temperature: f32,
574}
575
576impl GumbelSoftmax {
577    /// Create new Gumbel-Softmax with temperature
578    pub fn new(temperature: f32) -> Self {
579        Self { temperature }
580    }
581
582    /// Sample from Gumbel distribution
583    fn sample_gumbel(&self) -> f32 {
584        // Note: In production, use proper random sampling from scirs2-core
585        // For now, using a simple hash-based pseudo-random approach
586        use std::time::{SystemTime, UNIX_EPOCH};
587        let nanos = SystemTime::now()
588            .duration_since(UNIX_EPOCH)
589            .expect("system time should be after UNIX_EPOCH")
590            .subsec_nanos();
591        let u = ((nanos % 1000) as f32 / 1000.0).max(1e-10);
592        -(-u).ln().ln()
593    }
594
595    /// Apply Gumbel-Softmax to logits
596    pub fn apply(&self, logits: &[f32]) -> Vec<f32> {
597        let gumbel_logits: Vec<f32> = logits
598            .iter()
599            .map(|&logit| (logit + self.sample_gumbel()) / self.temperature)
600            .collect();
601
602        SoftDecision::softmax(&gumbel_logits)
603    }
604
605    /// Straight-through estimator (forward: argmax, backward: softmax)
606    pub fn straight_through(&self, logits: &[f32]) -> (usize, Vec<f32>) {
607        let probs = SoftDecision::softmax(logits);
608
609        // Forward: argmax (discrete)
610        let choice = probs
611            .iter()
612            .enumerate()
613            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
614            .map(|(i, _)| i)
615            .unwrap_or(0);
616
617        // Backward: use softmax gradients
618        (choice, probs)
619    }
620}
621
622// ============================================================================
623// Training Loop Helper
624// ============================================================================
625
626/// Helper for training compilation parameters
627pub struct CompilationTrainer {
628    /// Compiler
629    compiler: DifferentiableCompiler,
630
631    /// Parameters to optimize
632    params: CompilationParams,
633
634    /// Learning rate
635    learning_rate: f32,
636
637    /// Training history
638    history: Vec<TrainingEpoch>,
639}
640
641/// Training epoch record
642#[derive(Debug, Clone)]
643pub struct TrainingEpoch {
644    /// Epoch number
645    pub epoch: usize,
646
647    /// Average loss
648    pub loss: f32,
649
650    /// Performance metrics
651    pub performance: PerformanceMetrics,
652
653    /// Parameter snapshot
654    pub params: CompilationParams,
655}
656
657impl CompilationTrainer {
658    /// Create new trainer
659    pub fn new(learning_rate: f32) -> Self {
660        Self {
661            compiler: DifferentiableCompiler::new(),
662            params: CompilationParams::new(),
663            learning_rate,
664            history: Vec::new(),
665        }
666    }
667
668    /// Train on a single graph
669    pub fn train_step(
670        &mut self,
671        graph: &ComputationGraph,
672        target_performance: f32,
673    ) -> JitResult<f32> {
674        // Forward pass
675        let result = self.compiler.compile_differentiable(graph, &self.params)?;
676
677        // Compute loss (MSE on execution time)
678        let loss = (result.estimated_performance.exec_time_us - target_performance).powi(2);
679
680        // Backward pass
681        let grads = self.compiler.backward(&result, loss)?;
682
683        // Update parameters
684        self.params.gradients = grads.gradients;
685        self.params.update(self.learning_rate);
686
687        Ok(loss)
688    }
689
690    /// Train for multiple epochs
691    pub fn train(
692        &mut self,
693        graphs: &[ComputationGraph],
694        targets: &[f32],
695        epochs: usize,
696    ) -> JitResult<Vec<TrainingEpoch>> {
697        for epoch in 0..epochs {
698            let mut total_loss = 0.0;
699
700            for (graph, &target) in graphs.iter().zip(targets.iter()) {
701                let loss = self.train_step(graph, target)?;
702                total_loss += loss;
703            }
704
705            let avg_loss = total_loss / graphs.len() as f32;
706
707            // Record epoch
708            let result = self
709                .compiler
710                .compile_differentiable(&graphs[0], &self.params)?;
711            self.history.push(TrainingEpoch {
712                epoch,
713                loss: avg_loss,
714                performance: result.estimated_performance.clone(),
715                params: self.params.clone(),
716            });
717
718            log::info!("Epoch {}: loss = {:.4}", epoch, avg_loss);
719        }
720
721        Ok(self.history.clone())
722    }
723
724    /// Get best parameters from training
725    pub fn best_params(&self) -> &CompilationParams {
726        self.history
727            .iter()
728            .min_by(|a, b| {
729                a.loss
730                    .partial_cmp(&b.loss)
731                    .unwrap_or(std::cmp::Ordering::Equal)
732            })
733            .map(|e| &e.params)
734            .unwrap_or(&self.params)
735    }
736}
737
738// ============================================================================
739// Tests
740// ============================================================================
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745    use crate::graph::GraphBuilder;
746    use torsh_core::{DType, Shape};
747
748    #[test]
749    fn test_compilation_params() {
750        let mut params = CompilationParams::new();
751        assert!(params.pass_weights.len() > 0);
752
753        params.accumulate_grad("fusion", 0.1);
754        params.update(0.01);
755
756        assert!(params.pass_weights.contains_key("fusion"));
757    }
758
759    #[test]
760    fn test_soft_decision() {
761        let decision = SoftDecision::new(0.7);
762        assert!((decision.probability - 0.7).abs() < 1e-6);
763
764        let probs = SoftDecision::softmax(&[1.0, 2.0, 3.0]);
765        let sum: f32 = probs.iter().sum();
766        assert!((sum - 1.0).abs() < 1e-5);
767    }
768
769    #[test]
770    fn test_differentiable_compilation() {
771        let mut compiler = DifferentiableCompiler::new();
772        let params = CompilationParams::new();
773
774        let mut builder = GraphBuilder::new();
775        let x = builder.add_input("x".to_string(), Shape::new(vec![10, 10]), DType::F32);
776        builder.mark_output(x).unwrap();
777
778        let graph = builder.build().unwrap();
779        let result = compiler.compile_differentiable(&graph, &params).unwrap();
780
781        assert!(result.decisions.len() > 0);
782        assert!(result.estimated_performance.exec_time_us > 0.0);
783    }
784
785    #[test]
786    fn test_backward_pass() {
787        let mut compiler = DifferentiableCompiler::new();
788        let params = CompilationParams::new();
789
790        let mut builder = GraphBuilder::new();
791        let x = builder.add_input("x".to_string(), Shape::new(vec![5, 5]), DType::F32);
792        builder.mark_output(x).unwrap();
793
794        let graph = builder.build().unwrap();
795        let result = compiler.compile_differentiable(&graph, &params).unwrap();
796
797        let loss = 100.0; // Simulated loss
798        let grads = compiler.backward(&result, loss).unwrap();
799
800        assert!(grads.gradients.len() > 0);
801    }
802
803    #[test]
804    fn test_compilation_trainer() {
805        let mut trainer = CompilationTrainer::new(0.01);
806
807        let mut builder = GraphBuilder::new();
808        let x = builder.add_input("x".to_string(), Shape::new(vec![3, 3]), DType::F32);
809        builder.mark_output(x).unwrap();
810
811        let graph = builder.build().unwrap();
812        let loss = trainer.train_step(&graph, 50.0).unwrap();
813
814        assert!(loss >= 0.0);
815    }
816
817    #[test]
818    fn test_gumbel_softmax() {
819        let gumbel = GumbelSoftmax::new(1.0);
820        let logits = vec![1.0, 2.0, 3.0];
821
822        let (choice, probs) = gumbel.straight_through(&logits);
823        assert!(choice < 3);
824
825        let sum: f32 = probs.iter().sum();
826        assert!((sum - 1.0).abs() < 1e-5);
827    }
828}