1use scirs2_core::random::RngExt;
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use thiserror::Error;
40
41#[derive(Error, Debug, Clone, PartialEq)]
43pub enum LearnedOptError {
44 #[error("Insufficient training data: {0}")]
45 InsufficientData(String),
46
47 #[error("Model not trained: {0}")]
48 ModelNotTrained(String),
49
50 #[error("Feature extraction failed: {0}")]
51 FeatureExtractionFailed(String),
52
53 #[error("Prediction failed: {0}")]
54 PredictionFailed(String),
55
56 #[error("Invalid model configuration: {0}")]
57 InvalidConfig(String),
58}
59
60pub type NodeId = String;
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
65pub enum LearningStrategy {
66 Supervised,
68 ReinforcementLearning,
70 Online,
72 Transfer,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
78pub enum ModelType {
79 LinearRegression,
81 DecisionTree,
83 RandomForest,
85 NeuralNetwork,
87 GradientBoosting,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct FeatureVector {
94 pub features: Vec<f64>,
95 pub feature_names: Vec<String>,
96}
97
98impl FeatureVector {
99 fn new() -> Self {
100 Self {
101 features: Vec::new(),
102 feature_names: Vec::new(),
103 }
104 }
105
106 fn add_feature(&mut self, name: String, value: f64) {
107 self.feature_names.push(name);
108 self.features.push(value);
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct TrainingExample {
115 pub features: FeatureVector,
116 pub label: f64, }
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct RewardSignal {
122 pub state_features: FeatureVector,
123 pub action: OptimizationAction,
124 pub reward: f64, pub next_state_features: Option<FeatureVector>,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
130pub enum OptimizationAction {
131 Fuse,
132 DontFuse,
133 Parallelize,
134 Sequential,
135 CacheResult,
136 Recompute,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct FusionRecommendation {
142 pub should_fuse: bool,
143 pub confidence: f64,
144 pub expected_speedup: f64,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct ScheduleRecommendation {
150 pub schedule: Vec<NodeId>,
151 pub confidence: f64,
152 pub expected_time_us: f64,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct CostPrediction {
158 pub predicted_cost_us: f64,
159 pub confidence_interval: (f64, f64), pub model_confidence: f64,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct LearningStats {
166 pub training_examples: usize,
167 pub model_accuracy: f64,
168 pub average_prediction_error: f64,
169 pub total_updates: usize,
170 pub learning_rate: f64,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175struct LinearModel {
176 weights: Vec<f64>,
177 bias: f64,
178 learning_rate: f64,
179}
180
181impl LinearModel {
182 fn new(num_features: usize, learning_rate: f64) -> Self {
183 Self {
184 weights: vec![0.0; num_features],
185 bias: 0.0,
186 learning_rate,
187 }
188 }
189
190 fn predict(&self, features: &[f64]) -> f64 {
191 let mut result = self.bias;
192 for (w, f) in self.weights.iter().zip(features.iter()) {
193 result += w * f;
194 }
195 result
196 }
197
198 fn update(&mut self, features: &[f64], target: f64) {
199 let prediction = self.predict(features);
200 let error = target - prediction;
201
202 for (w, f) in self.weights.iter_mut().zip(features.iter()) {
204 *w += self.learning_rate * error * f;
205 }
206 self.bias += self.learning_rate * error;
207 }
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212struct QLearningAgent {
213 q_table: HashMap<(String, OptimizationAction), f64>, learning_rate: f64,
215 discount_factor: f64,
216 epsilon: f64, }
218
219impl QLearningAgent {
220 fn new(learning_rate: f64) -> Self {
221 Self {
222 q_table: HashMap::new(),
223 learning_rate,
224 discount_factor: 0.95,
225 epsilon: 0.1,
226 }
227 }
228
229 fn get_q_value(&self, state: &str, action: OptimizationAction) -> f64 {
230 *self
231 .q_table
232 .get(&(state.to_string(), action))
233 .unwrap_or(&0.0)
234 }
235
236 fn update_q_value(
237 &mut self,
238 state: &str,
239 action: OptimizationAction,
240 reward: f64,
241 next_state: Option<&str>,
242 ) {
243 let current_q = self.get_q_value(state, action);
244
245 let max_next_q = if let Some(ns) = next_state {
246 [
247 self.get_q_value(ns, OptimizationAction::Fuse),
248 self.get_q_value(ns, OptimizationAction::DontFuse),
249 self.get_q_value(ns, OptimizationAction::Parallelize),
250 self.get_q_value(ns, OptimizationAction::Sequential),
251 ]
252 .iter()
253 .fold(f64::NEG_INFINITY, |a, &b| a.max(b))
254 } else {
255 0.0
256 };
257
258 let new_q = current_q
259 + self.learning_rate * (reward + self.discount_factor * max_next_q - current_q);
260
261 self.q_table.insert((state.to_string(), action), new_q);
262 }
263
264 fn select_action(&self, state: &str, explore: bool) -> OptimizationAction {
265 if explore && scirs2_core::random::random::<f64>() < self.epsilon {
266 let actions = [
268 OptimizationAction::Fuse,
269 OptimizationAction::DontFuse,
270 OptimizationAction::Parallelize,
271 OptimizationAction::Sequential,
272 ];
273 actions[scirs2_core::random::rng().random_range(0..actions.len())]
274 } else {
275 let actions = [
277 (
278 OptimizationAction::Fuse,
279 self.get_q_value(state, OptimizationAction::Fuse),
280 ),
281 (
282 OptimizationAction::DontFuse,
283 self.get_q_value(state, OptimizationAction::DontFuse),
284 ),
285 (
286 OptimizationAction::Parallelize,
287 self.get_q_value(state, OptimizationAction::Parallelize),
288 ),
289 (
290 OptimizationAction::Sequential,
291 self.get_q_value(state, OptimizationAction::Sequential),
292 ),
293 ];
294
295 actions
296 .iter()
297 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
298 .map(|(action, _)| *action)
299 .unwrap_or(OptimizationAction::DontFuse)
300 }
301 }
302}
303
304pub struct LearnedOptimizer {
306 strategy: LearningStrategy,
307 model_type: ModelType,
308 cost_model: Option<LinearModel>,
309 q_agent: Option<QLearningAgent>,
310 training_examples: Vec<TrainingExample>,
311 learning_rate: f64,
312 stats: LearningStats,
313 min_training_examples: usize,
314}
315
316impl LearnedOptimizer {
317 pub fn new() -> Self {
319 Self {
320 strategy: LearningStrategy::Online,
321 model_type: ModelType::LinearRegression,
322 cost_model: None,
323 q_agent: None,
324 training_examples: Vec::new(),
325 learning_rate: 0.01,
326 stats: LearningStats {
327 training_examples: 0,
328 model_accuracy: 0.0,
329 average_prediction_error: 0.0,
330 total_updates: 0,
331 learning_rate: 0.01,
332 },
333 min_training_examples: 10,
334 }
335 }
336
337 pub fn with_strategy(mut self, strategy: LearningStrategy) -> Self {
339 self.strategy = strategy;
340 self
341 }
342
343 pub fn with_model_type(mut self, model_type: ModelType) -> Self {
345 self.model_type = model_type;
346 self
347 }
348
349 pub fn with_learning_rate(mut self, rate: f64) -> Self {
351 self.learning_rate = rate.clamp(0.0001, 1.0);
352 self.stats.learning_rate = self.learning_rate;
353 self
354 }
355
356 pub fn extract_features(
358 &self,
359 graph_desc: &HashMap<String, f64>,
360 ) -> Result<FeatureVector, LearnedOptError> {
361 let mut features = FeatureVector::new();
362
363 features.add_feature(
365 "num_nodes".to_string(),
366 *graph_desc.get("num_nodes").unwrap_or(&0.0),
367 );
368 features.add_feature(
369 "num_edges".to_string(),
370 *graph_desc.get("num_edges").unwrap_or(&0.0),
371 );
372 features.add_feature(
373 "avg_node_degree".to_string(),
374 *graph_desc.get("avg_degree").unwrap_or(&0.0),
375 );
376 features.add_feature(
377 "graph_depth".to_string(),
378 *graph_desc.get("depth").unwrap_or(&0.0),
379 );
380 features.add_feature(
381 "total_memory".to_string(),
382 *graph_desc.get("memory").unwrap_or(&0.0),
383 );
384 features.add_feature(
385 "parallelism_factor".to_string(),
386 *graph_desc.get("parallelism").unwrap_or(&1.0),
387 );
388
389 Ok(features)
390 }
391
392 pub fn observe(
394 &mut self,
395 features: FeatureVector,
396 actual_cost: f64,
397 ) -> Result<(), LearnedOptError> {
398 let example = TrainingExample {
399 features: features.clone(),
400 label: actual_cost,
401 };
402
403 self.training_examples.push(example);
404 self.stats.training_examples += 1;
405
406 if self.cost_model.is_none() && features.features.len() > 0 {
408 self.cost_model = Some(LinearModel::new(
409 features.features.len(),
410 self.learning_rate,
411 ));
412 }
413
414 if let Some(model) = &mut self.cost_model {
416 model.update(&features.features, actual_cost);
417 self.stats.total_updates += 1;
418 }
419
420 Ok(())
421 }
422
423 pub fn observe_reward(&mut self, signal: RewardSignal) -> Result<(), LearnedOptError> {
425 if self.strategy != LearningStrategy::ReinforcementLearning {
426 return Err(LearnedOptError::InvalidConfig(
427 "Reward observation requires ReinforcementLearning strategy".to_string(),
428 ));
429 }
430
431 if self.q_agent.is_none() {
433 self.q_agent = Some(QLearningAgent::new(self.learning_rate));
434 }
435
436 let state = format!("{:?}", signal.state_features.features);
438 let next_state = signal
439 .next_state_features
440 .as_ref()
441 .map(|f| format!("{:?}", f.features));
442
443 if let Some(agent) = &mut self.q_agent {
444 agent.update_q_value(&state, signal.action, signal.reward, next_state.as_deref());
445 }
446
447 self.stats.total_updates += 1;
448
449 Ok(())
450 }
451
452 pub fn predict_cost(
454 &self,
455 features: &FeatureVector,
456 ) -> Result<CostPrediction, LearnedOptError> {
457 let model = self.cost_model.as_ref().ok_or_else(|| {
458 LearnedOptError::ModelNotTrained("Cost model not trained".to_string())
459 })?;
460
461 if self.training_examples.len() < self.min_training_examples {
462 return Err(LearnedOptError::InsufficientData(format!(
463 "Need at least {} examples, have {}",
464 self.min_training_examples,
465 self.training_examples.len()
466 )));
467 }
468
469 let predicted_cost = model.predict(&features.features);
470
471 let margin = predicted_cost * 0.2;
473 let confidence_interval = (predicted_cost - margin, predicted_cost + margin);
474
475 let model_confidence = (self.training_examples.len() as f64
477 / (self.min_training_examples * 10) as f64)
478 .min(1.0);
479
480 Ok(CostPrediction {
481 predicted_cost_us: predicted_cost.max(0.0),
482 confidence_interval,
483 model_confidence,
484 })
485 }
486
487 pub fn recommend_fusion(
489 &self,
490 features: &FeatureVector,
491 ) -> Result<FusionRecommendation, LearnedOptError> {
492 match self.strategy {
493 LearningStrategy::ReinforcementLearning => {
494 let agent = self.q_agent.as_ref().ok_or_else(|| {
495 LearnedOptError::ModelNotTrained("Q-learning agent not initialized".to_string())
496 })?;
497
498 let state = format!("{:?}", features.features);
499 let action = agent.select_action(&state, false);
500
501 let should_fuse = action == OptimizationAction::Fuse;
502 let q_fuse = agent.get_q_value(&state, OptimizationAction::Fuse);
503 let q_no_fuse = agent.get_q_value(&state, OptimizationAction::DontFuse);
504
505 let confidence =
506 (q_fuse - q_no_fuse).abs() / (q_fuse.abs() + q_no_fuse.abs() + 1.0);
507 let expected_speedup = if should_fuse { q_fuse.max(1.0) } else { 1.0 };
508
509 Ok(FusionRecommendation {
510 should_fuse,
511 confidence,
512 expected_speedup,
513 })
514 }
515 _ => {
516 let cost_pred = self.predict_cost(features)?;
518
519 let threshold = 100.0; let should_fuse = cost_pred.predicted_cost_us < threshold;
522
523 Ok(FusionRecommendation {
524 should_fuse,
525 confidence: cost_pred.model_confidence,
526 expected_speedup: if should_fuse { 1.5 } else { 1.0 },
527 })
528 }
529 }
530 }
531
532 pub fn get_stats(&self) -> &LearningStats {
534 &self.stats
535 }
536
537 pub fn evaluate_accuracy(&mut self) -> Result<f64, LearnedOptError> {
539 if self.training_examples.is_empty() {
540 return Ok(0.0);
541 }
542
543 let model = self.cost_model.as_ref().ok_or_else(|| {
544 LearnedOptError::ModelNotTrained("Cost model not trained".to_string())
545 })?;
546
547 let mut total_error = 0.0;
548
549 for example in &self.training_examples {
550 let prediction = model.predict(&example.features.features);
551 let error = (prediction - example.label).abs();
552 total_error += error;
553 }
554
555 let avg_error = total_error / self.training_examples.len() as f64;
556 self.stats.average_prediction_error = avg_error;
557
558 let max_label = self
560 .training_examples
561 .iter()
562 .map(|e| e.label)
563 .fold(f64::NEG_INFINITY, f64::max);
564
565 let accuracy = if max_label > 0.0 {
566 (1.0 - (avg_error / max_label)).max(0.0)
567 } else {
568 0.0
569 };
570
571 self.stats.model_accuracy = accuracy;
572
573 Ok(accuracy)
574 }
575
576 pub fn reset(&mut self) {
578 self.training_examples.clear();
579 self.cost_model = None;
580 self.q_agent = None;
581 self.stats = LearningStats {
582 training_examples: 0,
583 model_accuracy: 0.0,
584 average_prediction_error: 0.0,
585 total_updates: 0,
586 learning_rate: self.learning_rate,
587 };
588 }
589}
590
591impl Default for LearnedOptimizer {
592 fn default() -> Self {
593 Self::new()
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 fn create_test_features() -> FeatureVector {
602 let mut features = FeatureVector::new();
603 features.add_feature("num_nodes".to_string(), 10.0);
604 features.add_feature("num_edges".to_string(), 15.0);
605 features.add_feature("avg_degree".to_string(), 1.5);
606 features
607 }
608
609 #[test]
610 fn test_learned_optimizer_creation() {
611 let optimizer = LearnedOptimizer::new();
612 assert_eq!(optimizer.strategy, LearningStrategy::Online);
613 assert_eq!(optimizer.model_type, ModelType::LinearRegression);
614 }
615
616 #[test]
617 fn test_builder_pattern() {
618 let optimizer = LearnedOptimizer::new()
619 .with_strategy(LearningStrategy::ReinforcementLearning)
620 .with_model_type(ModelType::NeuralNetwork)
621 .with_learning_rate(0.05);
622
623 assert_eq!(optimizer.strategy, LearningStrategy::ReinforcementLearning);
624 assert_eq!(optimizer.model_type, ModelType::NeuralNetwork);
625 assert_eq!(optimizer.learning_rate, 0.05);
626 }
627
628 #[test]
629 fn test_feature_extraction() {
630 let optimizer = LearnedOptimizer::new();
631 let mut graph_desc = HashMap::new();
632 graph_desc.insert("num_nodes".to_string(), 10.0);
633 graph_desc.insert("num_edges".to_string(), 15.0);
634
635 let features = optimizer.extract_features(&graph_desc).expect("unwrap");
636 assert!(features.features.len() > 0);
637 }
638
639 #[test]
640 fn test_observe_and_learn() {
641 let mut optimizer = LearnedOptimizer::new();
642 let features = create_test_features();
643
644 optimizer.observe(features.clone(), 100.0).expect("unwrap");
645 optimizer.observe(features.clone(), 95.0).expect("unwrap");
646
647 assert_eq!(optimizer.stats.training_examples, 2);
648 assert_eq!(optimizer.stats.total_updates, 2);
649 }
650
651 #[test]
652 fn test_cost_prediction_insufficient_data() {
653 let optimizer = LearnedOptimizer::new();
654 let features = create_test_features();
655
656 let result = optimizer.predict_cost(&features);
657 assert!(result.is_err());
658 }
659
660 #[test]
661 fn test_cost_prediction_with_training() {
662 let mut optimizer = LearnedOptimizer::new();
663 let features = create_test_features();
664
665 for i in 0..15 {
667 let mut f = create_test_features();
668 f.features[0] = i as f64;
669 optimizer.observe(f, 100.0 + i as f64).expect("unwrap");
670 }
671
672 let prediction = optimizer.predict_cost(&features).expect("unwrap");
673 assert!(prediction.predicted_cost_us >= 0.0);
674 assert!(prediction.model_confidence > 0.0);
675 }
676
677 #[test]
678 fn test_reinforcement_learning_observation() {
679 let mut optimizer =
680 LearnedOptimizer::new().with_strategy(LearningStrategy::ReinforcementLearning);
681
682 let signal = RewardSignal {
683 state_features: create_test_features(),
684 action: OptimizationAction::Fuse,
685 reward: 10.0, next_state_features: Some(create_test_features()),
687 };
688
689 optimizer.observe_reward(signal).expect("unwrap");
690 assert_eq!(optimizer.stats.total_updates, 1);
691 }
692
693 #[test]
694 fn test_fusion_recommendation() {
695 let mut optimizer = LearnedOptimizer::new();
696 let features = create_test_features();
697
698 for i in 0..15 {
700 let mut f = create_test_features();
701 f.features[0] = i as f64;
702 optimizer.observe(f, 50.0 + i as f64).expect("unwrap"); }
704
705 let recommendation = optimizer.recommend_fusion(&features).expect("unwrap");
706 assert!(recommendation.confidence >= 0.0);
707 }
708
709 #[test]
710 fn test_rl_fusion_recommendation() {
711 let mut optimizer =
712 LearnedOptimizer::new().with_strategy(LearningStrategy::ReinforcementLearning);
713
714 let features = create_test_features();
715
716 for _ in 0..10 {
718 let signal = RewardSignal {
719 state_features: features.clone(),
720 action: OptimizationAction::Fuse,
721 reward: 15.0,
722 next_state_features: None,
723 };
724 optimizer.observe_reward(signal).expect("unwrap");
725 }
726
727 let recommendation = optimizer.recommend_fusion(&features).expect("unwrap");
728 assert!(recommendation.confidence >= 0.0);
730 }
731
732 #[test]
733 fn test_accuracy_evaluation() {
734 let mut optimizer = LearnedOptimizer::new();
735
736 for i in 0..20 {
738 let mut features = FeatureVector::new();
739 features.add_feature("x".to_string(), i as f64);
740 optimizer.observe(features, i as f64 * 2.0).expect("unwrap"); }
742
743 let accuracy = optimizer.evaluate_accuracy().expect("unwrap");
744 assert!(accuracy >= 0.0 && accuracy <= 1.0);
745 }
746
747 #[test]
748 fn test_reset() {
749 let mut optimizer = LearnedOptimizer::new();
750 let features = create_test_features();
751
752 optimizer.observe(features, 100.0).expect("unwrap");
753 assert_eq!(optimizer.stats.training_examples, 1);
754
755 optimizer.reset();
756 assert_eq!(optimizer.stats.training_examples, 0);
757 assert!(optimizer.training_examples.is_empty());
758 }
759
760 #[test]
761 fn test_linear_model_prediction() {
762 let model = LinearModel::new(3, 0.01);
763 let features = vec![1.0, 2.0, 3.0];
764
765 let prediction = model.predict(&features);
766 assert!(prediction.is_finite());
767 }
768
769 #[test]
770 fn test_linear_model_update() {
771 let mut model = LinearModel::new(2, 0.1);
772 let features = vec![1.0, 2.0];
773
774 model.update(&features, 10.0);
775 let pred_after = model.predict(&features);
776
777 assert!(pred_after.is_finite());
779 }
780
781 #[test]
782 fn test_q_learning_agent() {
783 let mut agent = QLearningAgent::new(0.1);
784
785 agent.update_q_value("state1", OptimizationAction::Fuse, 10.0, Some("state2"));
786
787 let q_value = agent.get_q_value("state1", OptimizationAction::Fuse);
788 assert!(q_value > 0.0);
789 }
790
791 #[test]
792 fn test_q_learning_action_selection() {
793 let mut agent = QLearningAgent::new(0.1);
794
795 for _ in 0..10 {
797 agent.update_q_value("state1", OptimizationAction::Fuse, 20.0, None);
798 }
799
800 let action = agent.select_action("state1", false);
801 assert!(
803 action == OptimizationAction::Fuse
804 || action == OptimizationAction::DontFuse
805 || action == OptimizationAction::Parallelize
806 || action == OptimizationAction::Sequential
807 );
808 }
809
810 #[test]
811 fn test_different_learning_strategies() {
812 let strategies = vec![
813 LearningStrategy::Supervised,
814 LearningStrategy::Online,
815 LearningStrategy::Transfer,
816 ];
817
818 for strategy in strategies {
819 let optimizer = LearnedOptimizer::new().with_strategy(strategy);
820 assert_eq!(optimizer.strategy, strategy);
821 }
822 }
823
824 #[test]
825 fn test_different_model_types() {
826 let model_types = vec![
827 ModelType::LinearRegression,
828 ModelType::DecisionTree,
829 ModelType::RandomForest,
830 ModelType::NeuralNetwork,
831 ModelType::GradientBoosting,
832 ];
833
834 for model_type in model_types {
835 let optimizer = LearnedOptimizer::new().with_model_type(model_type);
836 assert_eq!(optimizer.model_type, model_type);
837 }
838 }
839}