1use rand::Rng;
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use thiserror::Error;
42
43#[derive(Error, Debug, Clone, PartialEq)]
45pub enum LearnedOptError {
46 #[error("Insufficient training data: {0}")]
47 InsufficientData(String),
48
49 #[error("Model not trained: {0}")]
50 ModelNotTrained(String),
51
52 #[error("Feature extraction failed: {0}")]
53 FeatureExtractionFailed(String),
54
55 #[error("Prediction failed: {0}")]
56 PredictionFailed(String),
57
58 #[error("Invalid model configuration: {0}")]
59 InvalidConfig(String),
60}
61
62pub type NodeId = String;
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67pub enum LearningStrategy {
68 Supervised,
70 ReinforcementLearning,
72 Online,
74 Transfer,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
80pub enum ModelType {
81 LinearRegression,
83 DecisionTree,
85 RandomForest,
87 NeuralNetwork,
89 GradientBoosting,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct FeatureVector {
96 pub features: Vec<f64>,
97 pub feature_names: Vec<String>,
98}
99
100impl FeatureVector {
101 fn new() -> Self {
102 Self {
103 features: Vec::new(),
104 feature_names: Vec::new(),
105 }
106 }
107
108 fn add_feature(&mut self, name: String, value: f64) {
109 self.feature_names.push(name);
110 self.features.push(value);
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct TrainingExample {
117 pub features: FeatureVector,
118 pub label: f64, }
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct RewardSignal {
124 pub state_features: FeatureVector,
125 pub action: OptimizationAction,
126 pub reward: f64, pub next_state_features: Option<FeatureVector>,
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
132pub enum OptimizationAction {
133 Fuse,
134 DontFuse,
135 Parallelize,
136 Sequential,
137 CacheResult,
138 Recompute,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct FusionRecommendation {
144 pub should_fuse: bool,
145 pub confidence: f64,
146 pub expected_speedup: f64,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ScheduleRecommendation {
152 pub schedule: Vec<NodeId>,
153 pub confidence: f64,
154 pub expected_time_us: f64,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct CostPrediction {
160 pub predicted_cost_us: f64,
161 pub confidence_interval: (f64, f64), pub model_confidence: f64,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct LearningStats {
168 pub training_examples: usize,
169 pub model_accuracy: f64,
170 pub average_prediction_error: f64,
171 pub total_updates: usize,
172 pub learning_rate: f64,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177struct LinearModel {
178 weights: Vec<f64>,
179 bias: f64,
180 learning_rate: f64,
181}
182
183impl LinearModel {
184 fn new(num_features: usize, learning_rate: f64) -> Self {
185 Self {
186 weights: vec![0.0; num_features],
187 bias: 0.0,
188 learning_rate,
189 }
190 }
191
192 fn predict(&self, features: &[f64]) -> f64 {
193 let mut result = self.bias;
194 for (w, f) in self.weights.iter().zip(features.iter()) {
195 result += w * f;
196 }
197 result
198 }
199
200 fn update(&mut self, features: &[f64], target: f64) {
201 let prediction = self.predict(features);
202 let error = target - prediction;
203
204 for (w, f) in self.weights.iter_mut().zip(features.iter()) {
206 *w += self.learning_rate * error * f;
207 }
208 self.bias += self.learning_rate * error;
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214struct QLearningAgent {
215 q_table: HashMap<(String, OptimizationAction), f64>, learning_rate: f64,
217 discount_factor: f64,
218 epsilon: f64, }
220
221impl QLearningAgent {
222 fn new(learning_rate: f64) -> Self {
223 Self {
224 q_table: HashMap::new(),
225 learning_rate,
226 discount_factor: 0.95,
227 epsilon: 0.1,
228 }
229 }
230
231 fn get_q_value(&self, state: &str, action: OptimizationAction) -> f64 {
232 *self
233 .q_table
234 .get(&(state.to_string(), action))
235 .unwrap_or(&0.0)
236 }
237
238 fn update_q_value(
239 &mut self,
240 state: &str,
241 action: OptimizationAction,
242 reward: f64,
243 next_state: Option<&str>,
244 ) {
245 let current_q = self.get_q_value(state, action);
246
247 let max_next_q = if let Some(ns) = next_state {
248 [
249 self.get_q_value(ns, OptimizationAction::Fuse),
250 self.get_q_value(ns, OptimizationAction::DontFuse),
251 self.get_q_value(ns, OptimizationAction::Parallelize),
252 self.get_q_value(ns, OptimizationAction::Sequential),
253 ]
254 .iter()
255 .fold(f64::NEG_INFINITY, |a, &b| a.max(b))
256 } else {
257 0.0
258 };
259
260 let new_q = current_q
261 + self.learning_rate * (reward + self.discount_factor * max_next_q - current_q);
262
263 self.q_table.insert((state.to_string(), action), new_q);
264 }
265
266 fn select_action(&self, state: &str, explore: bool) -> OptimizationAction {
267 if explore && rand::random::<f64>() < self.epsilon {
268 let actions = [
270 OptimizationAction::Fuse,
271 OptimizationAction::DontFuse,
272 OptimizationAction::Parallelize,
273 OptimizationAction::Sequential,
274 ];
275 actions[rand::rng().random_range(0..actions.len())]
276 } else {
277 let actions = [
279 (
280 OptimizationAction::Fuse,
281 self.get_q_value(state, OptimizationAction::Fuse),
282 ),
283 (
284 OptimizationAction::DontFuse,
285 self.get_q_value(state, OptimizationAction::DontFuse),
286 ),
287 (
288 OptimizationAction::Parallelize,
289 self.get_q_value(state, OptimizationAction::Parallelize),
290 ),
291 (
292 OptimizationAction::Sequential,
293 self.get_q_value(state, OptimizationAction::Sequential),
294 ),
295 ];
296
297 actions
298 .iter()
299 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
300 .map(|(action, _)| *action)
301 .unwrap_or(OptimizationAction::DontFuse)
302 }
303 }
304}
305
306pub struct LearnedOptimizer {
308 strategy: LearningStrategy,
309 model_type: ModelType,
310 cost_model: Option<LinearModel>,
311 q_agent: Option<QLearningAgent>,
312 training_examples: Vec<TrainingExample>,
313 learning_rate: f64,
314 stats: LearningStats,
315 min_training_examples: usize,
316}
317
318impl LearnedOptimizer {
319 pub fn new() -> Self {
321 Self {
322 strategy: LearningStrategy::Online,
323 model_type: ModelType::LinearRegression,
324 cost_model: None,
325 q_agent: None,
326 training_examples: Vec::new(),
327 learning_rate: 0.01,
328 stats: LearningStats {
329 training_examples: 0,
330 model_accuracy: 0.0,
331 average_prediction_error: 0.0,
332 total_updates: 0,
333 learning_rate: 0.01,
334 },
335 min_training_examples: 10,
336 }
337 }
338
339 pub fn with_strategy(mut self, strategy: LearningStrategy) -> Self {
341 self.strategy = strategy;
342 self
343 }
344
345 pub fn with_model_type(mut self, model_type: ModelType) -> Self {
347 self.model_type = model_type;
348 self
349 }
350
351 pub fn with_learning_rate(mut self, rate: f64) -> Self {
353 self.learning_rate = rate.clamp(0.0001, 1.0);
354 self.stats.learning_rate = self.learning_rate;
355 self
356 }
357
358 pub fn extract_features(
360 &self,
361 graph_desc: &HashMap<String, f64>,
362 ) -> Result<FeatureVector, LearnedOptError> {
363 let mut features = FeatureVector::new();
364
365 features.add_feature(
367 "num_nodes".to_string(),
368 *graph_desc.get("num_nodes").unwrap_or(&0.0),
369 );
370 features.add_feature(
371 "num_edges".to_string(),
372 *graph_desc.get("num_edges").unwrap_or(&0.0),
373 );
374 features.add_feature(
375 "avg_node_degree".to_string(),
376 *graph_desc.get("avg_degree").unwrap_or(&0.0),
377 );
378 features.add_feature(
379 "graph_depth".to_string(),
380 *graph_desc.get("depth").unwrap_or(&0.0),
381 );
382 features.add_feature(
383 "total_memory".to_string(),
384 *graph_desc.get("memory").unwrap_or(&0.0),
385 );
386 features.add_feature(
387 "parallelism_factor".to_string(),
388 *graph_desc.get("parallelism").unwrap_or(&1.0),
389 );
390
391 Ok(features)
392 }
393
394 pub fn observe(
396 &mut self,
397 features: FeatureVector,
398 actual_cost: f64,
399 ) -> Result<(), LearnedOptError> {
400 let example = TrainingExample {
401 features: features.clone(),
402 label: actual_cost,
403 };
404
405 self.training_examples.push(example);
406 self.stats.training_examples += 1;
407
408 if self.cost_model.is_none() && features.features.len() > 0 {
410 self.cost_model = Some(LinearModel::new(
411 features.features.len(),
412 self.learning_rate,
413 ));
414 }
415
416 if let Some(model) = &mut self.cost_model {
418 model.update(&features.features, actual_cost);
419 self.stats.total_updates += 1;
420 }
421
422 Ok(())
423 }
424
425 pub fn observe_reward(&mut self, signal: RewardSignal) -> Result<(), LearnedOptError> {
427 if self.strategy != LearningStrategy::ReinforcementLearning {
428 return Err(LearnedOptError::InvalidConfig(
429 "Reward observation requires ReinforcementLearning strategy".to_string(),
430 ));
431 }
432
433 if self.q_agent.is_none() {
435 self.q_agent = Some(QLearningAgent::new(self.learning_rate));
436 }
437
438 let state = format!("{:?}", signal.state_features.features);
440 let next_state = signal
441 .next_state_features
442 .as_ref()
443 .map(|f| format!("{:?}", f.features));
444
445 if let Some(agent) = &mut self.q_agent {
446 agent.update_q_value(&state, signal.action, signal.reward, next_state.as_deref());
447 }
448
449 self.stats.total_updates += 1;
450
451 Ok(())
452 }
453
454 pub fn predict_cost(
456 &self,
457 features: &FeatureVector,
458 ) -> Result<CostPrediction, LearnedOptError> {
459 let model = self.cost_model.as_ref().ok_or_else(|| {
460 LearnedOptError::ModelNotTrained("Cost model not trained".to_string())
461 })?;
462
463 if self.training_examples.len() < self.min_training_examples {
464 return Err(LearnedOptError::InsufficientData(format!(
465 "Need at least {} examples, have {}",
466 self.min_training_examples,
467 self.training_examples.len()
468 )));
469 }
470
471 let predicted_cost = model.predict(&features.features);
472
473 let margin = predicted_cost * 0.2;
475 let confidence_interval = (predicted_cost - margin, predicted_cost + margin);
476
477 let model_confidence = (self.training_examples.len() as f64
479 / (self.min_training_examples * 10) as f64)
480 .min(1.0);
481
482 Ok(CostPrediction {
483 predicted_cost_us: predicted_cost.max(0.0),
484 confidence_interval,
485 model_confidence,
486 })
487 }
488
489 pub fn recommend_fusion(
491 &self,
492 features: &FeatureVector,
493 ) -> Result<FusionRecommendation, LearnedOptError> {
494 match self.strategy {
495 LearningStrategy::ReinforcementLearning => {
496 let agent = self.q_agent.as_ref().ok_or_else(|| {
497 LearnedOptError::ModelNotTrained("Q-learning agent not initialized".to_string())
498 })?;
499
500 let state = format!("{:?}", features.features);
501 let action = agent.select_action(&state, false);
502
503 let should_fuse = action == OptimizationAction::Fuse;
504 let q_fuse = agent.get_q_value(&state, OptimizationAction::Fuse);
505 let q_no_fuse = agent.get_q_value(&state, OptimizationAction::DontFuse);
506
507 let confidence =
508 (q_fuse - q_no_fuse).abs() / (q_fuse.abs() + q_no_fuse.abs() + 1.0);
509 let expected_speedup = if should_fuse { q_fuse.max(1.0) } else { 1.0 };
510
511 Ok(FusionRecommendation {
512 should_fuse,
513 confidence,
514 expected_speedup,
515 })
516 }
517 _ => {
518 let cost_pred = self.predict_cost(features)?;
520
521 let threshold = 100.0; let should_fuse = cost_pred.predicted_cost_us < threshold;
524
525 Ok(FusionRecommendation {
526 should_fuse,
527 confidence: cost_pred.model_confidence,
528 expected_speedup: if should_fuse { 1.5 } else { 1.0 },
529 })
530 }
531 }
532 }
533
534 pub fn get_stats(&self) -> &LearningStats {
536 &self.stats
537 }
538
539 pub fn evaluate_accuracy(&mut self) -> Result<f64, LearnedOptError> {
541 if self.training_examples.is_empty() {
542 return Ok(0.0);
543 }
544
545 let model = self.cost_model.as_ref().ok_or_else(|| {
546 LearnedOptError::ModelNotTrained("Cost model not trained".to_string())
547 })?;
548
549 let mut total_error = 0.0;
550
551 for example in &self.training_examples {
552 let prediction = model.predict(&example.features.features);
553 let error = (prediction - example.label).abs();
554 total_error += error;
555 }
556
557 let avg_error = total_error / self.training_examples.len() as f64;
558 self.stats.average_prediction_error = avg_error;
559
560 let max_label = self
562 .training_examples
563 .iter()
564 .map(|e| e.label)
565 .fold(f64::NEG_INFINITY, f64::max);
566
567 let accuracy = if max_label > 0.0 {
568 (1.0 - (avg_error / max_label)).max(0.0)
569 } else {
570 0.0
571 };
572
573 self.stats.model_accuracy = accuracy;
574
575 Ok(accuracy)
576 }
577
578 pub fn reset(&mut self) {
580 self.training_examples.clear();
581 self.cost_model = None;
582 self.q_agent = None;
583 self.stats = LearningStats {
584 training_examples: 0,
585 model_accuracy: 0.0,
586 average_prediction_error: 0.0,
587 total_updates: 0,
588 learning_rate: self.learning_rate,
589 };
590 }
591}
592
593impl Default for LearnedOptimizer {
594 fn default() -> Self {
595 Self::new()
596 }
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 fn create_test_features() -> FeatureVector {
604 let mut features = FeatureVector::new();
605 features.add_feature("num_nodes".to_string(), 10.0);
606 features.add_feature("num_edges".to_string(), 15.0);
607 features.add_feature("avg_degree".to_string(), 1.5);
608 features
609 }
610
611 #[test]
612 fn test_learned_optimizer_creation() {
613 let optimizer = LearnedOptimizer::new();
614 assert_eq!(optimizer.strategy, LearningStrategy::Online);
615 assert_eq!(optimizer.model_type, ModelType::LinearRegression);
616 }
617
618 #[test]
619 fn test_builder_pattern() {
620 let optimizer = LearnedOptimizer::new()
621 .with_strategy(LearningStrategy::ReinforcementLearning)
622 .with_model_type(ModelType::NeuralNetwork)
623 .with_learning_rate(0.05);
624
625 assert_eq!(optimizer.strategy, LearningStrategy::ReinforcementLearning);
626 assert_eq!(optimizer.model_type, ModelType::NeuralNetwork);
627 assert_eq!(optimizer.learning_rate, 0.05);
628 }
629
630 #[test]
631 fn test_feature_extraction() {
632 let optimizer = LearnedOptimizer::new();
633 let mut graph_desc = HashMap::new();
634 graph_desc.insert("num_nodes".to_string(), 10.0);
635 graph_desc.insert("num_edges".to_string(), 15.0);
636
637 let features = optimizer.extract_features(&graph_desc).unwrap();
638 assert!(features.features.len() > 0);
639 }
640
641 #[test]
642 fn test_observe_and_learn() {
643 let mut optimizer = LearnedOptimizer::new();
644 let features = create_test_features();
645
646 optimizer.observe(features.clone(), 100.0).unwrap();
647 optimizer.observe(features.clone(), 95.0).unwrap();
648
649 assert_eq!(optimizer.stats.training_examples, 2);
650 assert_eq!(optimizer.stats.total_updates, 2);
651 }
652
653 #[test]
654 fn test_cost_prediction_insufficient_data() {
655 let optimizer = LearnedOptimizer::new();
656 let features = create_test_features();
657
658 let result = optimizer.predict_cost(&features);
659 assert!(result.is_err());
660 }
661
662 #[test]
663 fn test_cost_prediction_with_training() {
664 let mut optimizer = LearnedOptimizer::new();
665 let features = create_test_features();
666
667 for i in 0..15 {
669 let mut f = create_test_features();
670 f.features[0] = i as f64;
671 optimizer.observe(f, 100.0 + i as f64).unwrap();
672 }
673
674 let prediction = optimizer.predict_cost(&features).unwrap();
675 assert!(prediction.predicted_cost_us >= 0.0);
676 assert!(prediction.model_confidence > 0.0);
677 }
678
679 #[test]
680 fn test_reinforcement_learning_observation() {
681 let mut optimizer =
682 LearnedOptimizer::new().with_strategy(LearningStrategy::ReinforcementLearning);
683
684 let signal = RewardSignal {
685 state_features: create_test_features(),
686 action: OptimizationAction::Fuse,
687 reward: 10.0, next_state_features: Some(create_test_features()),
689 };
690
691 optimizer.observe_reward(signal).unwrap();
692 assert_eq!(optimizer.stats.total_updates, 1);
693 }
694
695 #[test]
696 fn test_fusion_recommendation() {
697 let mut optimizer = LearnedOptimizer::new();
698 let features = create_test_features();
699
700 for i in 0..15 {
702 let mut f = create_test_features();
703 f.features[0] = i as f64;
704 optimizer.observe(f, 50.0 + i as f64).unwrap(); }
706
707 let recommendation = optimizer.recommend_fusion(&features).unwrap();
708 assert!(recommendation.confidence >= 0.0);
709 }
710
711 #[test]
712 fn test_rl_fusion_recommendation() {
713 let mut optimizer =
714 LearnedOptimizer::new().with_strategy(LearningStrategy::ReinforcementLearning);
715
716 let features = create_test_features();
717
718 for _ in 0..10 {
720 let signal = RewardSignal {
721 state_features: features.clone(),
722 action: OptimizationAction::Fuse,
723 reward: 15.0,
724 next_state_features: None,
725 };
726 optimizer.observe_reward(signal).unwrap();
727 }
728
729 let recommendation = optimizer.recommend_fusion(&features).unwrap();
730 assert!(recommendation.confidence >= 0.0);
732 }
733
734 #[test]
735 fn test_accuracy_evaluation() {
736 let mut optimizer = LearnedOptimizer::new();
737
738 for i in 0..20 {
740 let mut features = FeatureVector::new();
741 features.add_feature("x".to_string(), i as f64);
742 optimizer.observe(features, i as f64 * 2.0).unwrap(); }
744
745 let accuracy = optimizer.evaluate_accuracy().unwrap();
746 assert!(accuracy >= 0.0 && accuracy <= 1.0);
747 }
748
749 #[test]
750 fn test_reset() {
751 let mut optimizer = LearnedOptimizer::new();
752 let features = create_test_features();
753
754 optimizer.observe(features, 100.0).unwrap();
755 assert_eq!(optimizer.stats.training_examples, 1);
756
757 optimizer.reset();
758 assert_eq!(optimizer.stats.training_examples, 0);
759 assert!(optimizer.training_examples.is_empty());
760 }
761
762 #[test]
763 fn test_linear_model_prediction() {
764 let model = LinearModel::new(3, 0.01);
765 let features = vec![1.0, 2.0, 3.0];
766
767 let prediction = model.predict(&features);
768 assert!(prediction.is_finite());
769 }
770
771 #[test]
772 fn test_linear_model_update() {
773 let mut model = LinearModel::new(2, 0.1);
774 let features = vec![1.0, 2.0];
775
776 model.update(&features, 10.0);
777 let pred_after = model.predict(&features);
778
779 assert!(pred_after.is_finite());
781 }
782
783 #[test]
784 fn test_q_learning_agent() {
785 let mut agent = QLearningAgent::new(0.1);
786
787 agent.update_q_value("state1", OptimizationAction::Fuse, 10.0, Some("state2"));
788
789 let q_value = agent.get_q_value("state1", OptimizationAction::Fuse);
790 assert!(q_value > 0.0);
791 }
792
793 #[test]
794 fn test_q_learning_action_selection() {
795 let mut agent = QLearningAgent::new(0.1);
796
797 for _ in 0..10 {
799 agent.update_q_value("state1", OptimizationAction::Fuse, 20.0, None);
800 }
801
802 let action = agent.select_action("state1", false);
803 assert!(
805 action == OptimizationAction::Fuse
806 || action == OptimizationAction::DontFuse
807 || action == OptimizationAction::Parallelize
808 || action == OptimizationAction::Sequential
809 );
810 }
811
812 #[test]
813 fn test_different_learning_strategies() {
814 let strategies = vec![
815 LearningStrategy::Supervised,
816 LearningStrategy::Online,
817 LearningStrategy::Transfer,
818 ];
819
820 for strategy in strategies {
821 let optimizer = LearnedOptimizer::new().with_strategy(strategy);
822 assert_eq!(optimizer.strategy, strategy);
823 }
824 }
825
826 #[test]
827 fn test_different_model_types() {
828 let model_types = vec![
829 ModelType::LinearRegression,
830 ModelType::DecisionTree,
831 ModelType::RandomForest,
832 ModelType::NeuralNetwork,
833 ModelType::GradientBoosting,
834 ];
835
836 for model_type in model_types {
837 let optimizer = LearnedOptimizer::new().with_model_type(model_type);
838 assert_eq!(optimizer.model_type, model_type);
839 }
840 }
841}