1use crate::graph::{ComputationGraph, NodeId};
46use crate::JitResult;
48use serde::{Deserialize, Serialize};
49use std::collections::HashMap;
50use std::sync::{Arc, Mutex};
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct CompilationParams {
59 pub pass_weights: HashMap<String, f32>,
61
62 pub fusion_temp: f32,
64
65 pub unroll_factor: f32,
67
68 pub vector_width: f32,
70
71 pub layout_preference: f32,
73
74 #[serde(skip)]
76 pub gradients: HashMap<String, f32>,
77}
78
79impl CompilationParams {
80 pub fn new() -> Self {
82 let mut pass_weights = HashMap::new();
83
84 for pass_name in [
86 "constant_folding",
87 "dead_code_elimination",
88 "common_subexpression_elimination",
89 "loop_invariant_motion",
90 "strength_reduction",
91 "fusion",
92 "vectorization",
93 "parallelization",
94 ] {
95 pass_weights.insert(pass_name.to_string(), 0.5); }
97
98 Self {
99 pass_weights,
100 fusion_temp: 1.0,
101 unroll_factor: 4.0,
102 vector_width: 4.0,
103 layout_preference: 0.5,
104 gradients: HashMap::new(),
105 }
106 }
107
108 pub fn update(&mut self, learning_rate: f32) {
110 for (name, weight) in &mut self.pass_weights {
112 if let Some(&grad) = self.gradients.get(name) {
113 *weight -= learning_rate * grad;
114 *weight = weight.clamp(0.0, 1.0); }
116 }
117
118 if let Some(&grad) = self.gradients.get("fusion_temp") {
120 self.fusion_temp -= learning_rate * grad;
121 self.fusion_temp = self.fusion_temp.max(0.1); }
123
124 if let Some(&grad) = self.gradients.get("unroll_factor") {
125 self.unroll_factor -= learning_rate * grad;
126 self.unroll_factor = self.unroll_factor.clamp(1.0, 32.0);
127 }
128
129 if let Some(&grad) = self.gradients.get("vector_width") {
130 self.vector_width -= learning_rate * grad;
131 self.vector_width = self.vector_width.clamp(1.0, 16.0);
132 }
133
134 if let Some(&grad) = self.gradients.get("layout_preference") {
135 self.layout_preference -= learning_rate * grad;
136 self.layout_preference = self.layout_preference.clamp(0.0, 1.0);
137 }
138
139 self.gradients.clear();
141 }
142
143 pub fn zero_grad(&mut self) {
145 self.gradients.clear();
146 }
147
148 pub fn accumulate_grad(&mut self, name: &str, grad: f32) {
150 *self.gradients.entry(name.to_string()).or_insert(0.0) += grad;
151 }
152}
153
154impl Default for CompilationParams {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160#[derive(Debug, Clone)]
166pub struct SoftDecision {
167 pub probability: f32,
169
170 pub gradient: f32,
172}
173
174impl SoftDecision {
175 pub fn new(probability: f32) -> Self {
177 Self {
178 probability: probability.clamp(0.0, 1.0),
179 gradient: 0.0,
180 }
181 }
182
183 pub fn apply<T: Clone>(&self, if_true: T, if_false: T, blend_fn: fn(&T, &T, f32) -> T) -> T {
185 blend_fn(&if_true, &if_false, self.probability)
186 }
187
188 pub fn backward(&mut self, upstream_grad: f32) {
190 self.gradient += upstream_grad;
191 }
192
193 pub fn sigmoid(x: f32) -> f32 {
195 1.0 / (1.0 + (-x).exp())
196 }
197
198 pub fn softmax(logits: &[f32]) -> Vec<f32> {
200 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
201 let exp_values: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
202 let sum: f32 = exp_values.iter().sum();
203 exp_values.iter().map(|&x| x / sum).collect()
204 }
205}
206
207#[derive(Debug, Clone)]
213pub struct DiffCompilationResult {
214 pub graph: ComputationGraph,
216
217 pub decisions: Vec<CompilationDecision>,
219
220 pub estimated_performance: PerformanceMetrics,
222
223 pub tape: Arc<Mutex<ComputationTape>>,
225}
226
227#[derive(Debug, Clone)]
229pub struct CompilationDecision {
230 pub name: String,
232
233 pub decision_type: DecisionType,
235
236 pub decision: SoftDecision,
238
239 pub impact: f32,
241}
242
243#[derive(Debug, Clone, PartialEq)]
245pub enum DecisionType {
246 ApplyOptimization(String),
248
249 FuseOperations(NodeId, NodeId),
251
252 UnrollLoop(usize),
254
255 Vectorize(usize),
257
258 MemoryLayout(String),
260}
261
262#[derive(Debug, Clone, Default)]
264pub struct PerformanceMetrics {
265 pub exec_time_us: f32,
267
268 pub memory_bytes: f32,
270
271 pub flops: f32,
273
274 pub cache_efficiency: f32,
276}
277
278#[derive(Debug, Clone, Default)]
280pub struct ComputationTape {
281 pub operations: Vec<TapeOperation>,
283
284 pub gradients: HashMap<String, f32>,
286}
287
288#[derive(Debug, Clone)]
290pub struct TapeOperation {
291 pub name: String,
293
294 pub inputs: Vec<String>,
296
297 pub output: String,
299
300 pub forward_val: f32,
302
303 pub grad_fn: GradientFunction,
305}
306
307#[derive(Debug, Clone)]
309pub enum GradientFunction {
310 Linear(f32),
312
313 Product(f32),
315
316 Sigmoid,
318
319 ReLU,
321
322 Custom(fn(f32, f32) -> f32),
324}
325
326pub struct DifferentiableCompiler {
332 config: DiffCompilerConfig,
334
335 stats: CompilerStatistics,
337}
338
339#[derive(Debug, Clone)]
341pub struct DiffCompilerConfig {
342 pub gradient_checkpointing: bool,
344
345 pub straight_through: bool,
347
348 pub gumbel_temperature: f32,
350
351 pub grad_clip: f32,
353}
354
355impl Default for DiffCompilerConfig {
356 fn default() -> Self {
357 Self {
358 gradient_checkpointing: true,
359 straight_through: true,
360 gumbel_temperature: 1.0,
361 grad_clip: 10.0,
362 }
363 }
364}
365
366#[derive(Debug, Clone, Default)]
368pub struct CompilerStatistics {
369 pub compilations: usize,
371
372 pub gradient_updates: usize,
374
375 pub avg_loss: f32,
377
378 pub best_performance: f32,
380}
381
382impl DifferentiableCompiler {
383 pub fn new() -> Self {
385 Self::with_config(DiffCompilerConfig::default())
386 }
387
388 pub fn with_config(config: DiffCompilerConfig) -> Self {
390 Self {
391 config,
392 stats: CompilerStatistics::default(),
393 }
394 }
395
396 pub fn compile_differentiable(
398 &mut self,
399 graph: &ComputationGraph,
400 params: &CompilationParams,
401 ) -> JitResult<DiffCompilationResult> {
402 let mut tape = ComputationTape::default();
403 let mut decisions = Vec::new();
404
405 let mut compiled_graph = graph.clone();
407
408 for (pass_name, &weight) in ¶ms.pass_weights {
410 let decision = SoftDecision::new(weight);
411
412 decisions.push(CompilationDecision {
414 name: pass_name.clone(),
415 decision_type: DecisionType::ApplyOptimization(pass_name.clone()),
416 decision: decision.clone(),
417 impact: self.estimate_pass_impact(pass_name, graph),
418 });
419
420 if weight > 0.5 {
422 compiled_graph =
425 self.apply_soft_optimization(&compiled_graph, pass_name, weight)?;
426 }
427
428 tape.operations.push(TapeOperation {
430 name: format!("apply_{}", pass_name),
431 inputs: vec!["graph".to_string()],
432 output: "graph".to_string(),
433 forward_val: weight,
434 grad_fn: GradientFunction::Linear(1.0),
435 });
436 }
437
438 let estimated_performance = self.estimate_performance(&compiled_graph, params);
440
441 self.stats.compilations += 1;
442
443 Ok(DiffCompilationResult {
444 graph: compiled_graph,
445 decisions,
446 estimated_performance,
447 tape: Arc::new(Mutex::new(tape)),
448 })
449 }
450
451 pub fn backward(
453 &mut self,
454 result: &DiffCompilationResult,
455 loss: f32,
456 ) -> JitResult<CompilationParams> {
457 let mut params_grad = CompilationParams::new();
458 params_grad.zero_grad();
459
460 for decision in &result.decisions {
462 match &decision.decision_type {
463 DecisionType::ApplyOptimization(pass_name) => {
464 let grad = if decision.impact > 0.0 {
467 -loss * decision.impact
468 } else {
469 loss * decision.impact.abs()
470 };
471
472 params_grad.accumulate_grad(pass_name, grad);
473 }
474 _ => {
475 }
477 }
478 }
479
480 for (_name, grad) in &mut params_grad.gradients {
482 *grad = grad.clamp(-self.config.grad_clip, self.config.grad_clip);
483 }
484
485 self.stats.gradient_updates += 1;
486 self.stats.avg_loss = (self.stats.avg_loss * (self.stats.gradient_updates - 1) as f32
487 + loss)
488 / self.stats.gradient_updates as f32;
489
490 Ok(params_grad)
491 }
492
493 fn apply_soft_optimization(
495 &self,
496 graph: &ComputationGraph,
497 pass_name: &str,
498 _weight: f32,
499 ) -> JitResult<ComputationGraph> {
500 log::debug!("Applying soft optimization: {} with weight", pass_name);
503 Ok(graph.clone())
504 }
505
506 fn estimate_pass_impact(&self, pass_name: &str, _graph: &ComputationGraph) -> f32 {
508 match pass_name {
510 "constant_folding" => 0.1,
511 "dead_code_elimination" => 0.15,
512 "common_subexpression_elimination" => 0.2,
513 "fusion" => 0.3,
514 "vectorization" => 0.4,
515 "parallelization" => 0.5,
516 _ => 0.05,
517 }
518 }
519
520 fn estimate_performance(
522 &self,
523 graph: &ComputationGraph,
524 params: &CompilationParams,
525 ) -> PerformanceMetrics {
526 let node_count = graph.node_count() as f32;
527
528 let base_time = node_count * 10.0; let mut speedup = 1.0;
533 for (pass_name, &weight) in ¶ms.pass_weights {
534 let impact = self.estimate_pass_impact(pass_name, graph);
535 speedup += weight * impact;
536 }
537
538 let exec_time_us = base_time / speedup;
539 let memory_bytes = node_count * 1024.0; PerformanceMetrics {
542 exec_time_us,
543 memory_bytes,
544 flops: node_count * 100.0,
545 cache_efficiency: 0.7 + params.layout_preference * 0.3,
546 }
547 }
548
549 pub fn statistics(&self) -> &CompilerStatistics {
551 &self.stats
552 }
553
554 pub fn reset_stats(&mut self) {
556 self.stats = CompilerStatistics::default();
557 }
558}
559
560impl Default for DifferentiableCompiler {
561 fn default() -> Self {
562 Self::new()
563 }
564}
565
566pub struct GumbelSoftmax {
572 temperature: f32,
574}
575
576impl GumbelSoftmax {
577 pub fn new(temperature: f32) -> Self {
579 Self { temperature }
580 }
581
582 fn sample_gumbel(&self) -> f32 {
584 use std::time::{SystemTime, UNIX_EPOCH};
587 let nanos = SystemTime::now()
588 .duration_since(UNIX_EPOCH)
589 .expect("system time should be after UNIX_EPOCH")
590 .subsec_nanos();
591 let u = ((nanos % 1000) as f32 / 1000.0).max(1e-10);
592 -(-u).ln().ln()
593 }
594
595 pub fn apply(&self, logits: &[f32]) -> Vec<f32> {
597 let gumbel_logits: Vec<f32> = logits
598 .iter()
599 .map(|&logit| (logit + self.sample_gumbel()) / self.temperature)
600 .collect();
601
602 SoftDecision::softmax(&gumbel_logits)
603 }
604
605 pub fn straight_through(&self, logits: &[f32]) -> (usize, Vec<f32>) {
607 let probs = SoftDecision::softmax(logits);
608
609 let choice = probs
611 .iter()
612 .enumerate()
613 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
614 .map(|(i, _)| i)
615 .unwrap_or(0);
616
617 (choice, probs)
619 }
620}
621
622pub struct CompilationTrainer {
628 compiler: DifferentiableCompiler,
630
631 params: CompilationParams,
633
634 learning_rate: f32,
636
637 history: Vec<TrainingEpoch>,
639}
640
641#[derive(Debug, Clone)]
643pub struct TrainingEpoch {
644 pub epoch: usize,
646
647 pub loss: f32,
649
650 pub performance: PerformanceMetrics,
652
653 pub params: CompilationParams,
655}
656
657impl CompilationTrainer {
658 pub fn new(learning_rate: f32) -> Self {
660 Self {
661 compiler: DifferentiableCompiler::new(),
662 params: CompilationParams::new(),
663 learning_rate,
664 history: Vec::new(),
665 }
666 }
667
668 pub fn train_step(
670 &mut self,
671 graph: &ComputationGraph,
672 target_performance: f32,
673 ) -> JitResult<f32> {
674 let result = self.compiler.compile_differentiable(graph, &self.params)?;
676
677 let loss = (result.estimated_performance.exec_time_us - target_performance).powi(2);
679
680 let grads = self.compiler.backward(&result, loss)?;
682
683 self.params.gradients = grads.gradients;
685 self.params.update(self.learning_rate);
686
687 Ok(loss)
688 }
689
690 pub fn train(
692 &mut self,
693 graphs: &[ComputationGraph],
694 targets: &[f32],
695 epochs: usize,
696 ) -> JitResult<Vec<TrainingEpoch>> {
697 for epoch in 0..epochs {
698 let mut total_loss = 0.0;
699
700 for (graph, &target) in graphs.iter().zip(targets.iter()) {
701 let loss = self.train_step(graph, target)?;
702 total_loss += loss;
703 }
704
705 let avg_loss = total_loss / graphs.len() as f32;
706
707 let result = self
709 .compiler
710 .compile_differentiable(&graphs[0], &self.params)?;
711 self.history.push(TrainingEpoch {
712 epoch,
713 loss: avg_loss,
714 performance: result.estimated_performance.clone(),
715 params: self.params.clone(),
716 });
717
718 log::info!("Epoch {}: loss = {:.4}", epoch, avg_loss);
719 }
720
721 Ok(self.history.clone())
722 }
723
724 pub fn best_params(&self) -> &CompilationParams {
726 self.history
727 .iter()
728 .min_by(|a, b| {
729 a.loss
730 .partial_cmp(&b.loss)
731 .unwrap_or(std::cmp::Ordering::Equal)
732 })
733 .map(|e| &e.params)
734 .unwrap_or(&self.params)
735 }
736}
737
738#[cfg(test)]
743mod tests {
744 use super::*;
745 use crate::graph::GraphBuilder;
746 use torsh_core::{DType, Shape};
747
748 #[test]
749 fn test_compilation_params() {
750 let mut params = CompilationParams::new();
751 assert!(params.pass_weights.len() > 0);
752
753 params.accumulate_grad("fusion", 0.1);
754 params.update(0.01);
755
756 assert!(params.pass_weights.contains_key("fusion"));
757 }
758
759 #[test]
760 fn test_soft_decision() {
761 let decision = SoftDecision::new(0.7);
762 assert!((decision.probability - 0.7).abs() < 1e-6);
763
764 let probs = SoftDecision::softmax(&[1.0, 2.0, 3.0]);
765 let sum: f32 = probs.iter().sum();
766 assert!((sum - 1.0).abs() < 1e-5);
767 }
768
769 #[test]
770 fn test_differentiable_compilation() {
771 let mut compiler = DifferentiableCompiler::new();
772 let params = CompilationParams::new();
773
774 let mut builder = GraphBuilder::new();
775 let x = builder.add_input("x".to_string(), Shape::new(vec![10, 10]), DType::F32);
776 builder.mark_output(x).unwrap();
777
778 let graph = builder.build().unwrap();
779 let result = compiler.compile_differentiable(&graph, ¶ms).unwrap();
780
781 assert!(result.decisions.len() > 0);
782 assert!(result.estimated_performance.exec_time_us > 0.0);
783 }
784
785 #[test]
786 fn test_backward_pass() {
787 let mut compiler = DifferentiableCompiler::new();
788 let params = CompilationParams::new();
789
790 let mut builder = GraphBuilder::new();
791 let x = builder.add_input("x".to_string(), Shape::new(vec![5, 5]), DType::F32);
792 builder.mark_output(x).unwrap();
793
794 let graph = builder.build().unwrap();
795 let result = compiler.compile_differentiable(&graph, ¶ms).unwrap();
796
797 let loss = 100.0; let grads = compiler.backward(&result, loss).unwrap();
799
800 assert!(grads.gradients.len() > 0);
801 }
802
803 #[test]
804 fn test_compilation_trainer() {
805 let mut trainer = CompilationTrainer::new(0.01);
806
807 let mut builder = GraphBuilder::new();
808 let x = builder.add_input("x".to_string(), Shape::new(vec![3, 3]), DType::F32);
809 builder.mark_output(x).unwrap();
810
811 let graph = builder.build().unwrap();
812 let loss = trainer.train_step(&graph, 50.0).unwrap();
813
814 assert!(loss >= 0.0);
815 }
816
817 #[test]
818 fn test_gumbel_softmax() {
819 let gumbel = GumbelSoftmax::new(1.0);
820 let logits = vec![1.0, 2.0, 3.0];
821
822 let (choice, probs) = gumbel.straight_through(&logits);
823 assert!(choice < 3);
824
825 let sum: f32 = probs.iter().sum();
826 assert!((sum - 1.0).abs() < 1e-5);
827 }
828}