Skip to main content

tensorlogic_infer/
learned_opt.rs

1//! Machine learning-based optimization decisions.
2//!
3//! This module implements learned optimization strategies:
4//! - **ML-based fusion decisions**: Learn which operations to fuse
5//! - **Learned cost models**: Predict operation costs using ML
6//! - **Reinforcement learning for scheduling**: Learn optimal execution schedules
7//! - **Feature extraction**: Extract relevant features from computation graphs
8//! - **Online learning**: Continuously improve from observed performance
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use tensorlogic_infer::{LearnedOptimizer, LearningStrategy, ModelType};
14//!
15//! // Create learned optimizer
16//! let mut optimizer = LearnedOptimizer::new()
17//!     .with_strategy(LearningStrategy::ReinforcementLearning)
18//!     .with_model_type(ModelType::NeuralNetwork)
19//!     .with_learning_rate(0.01);
20//!
21//! // Train from observed executions
22//! for (graph, performance) in training_data {
23//!     optimizer.observe(&graph, performance)?;
24//! }
25//!
26//! // Use learned model for optimization
27//! let decision = optimizer.recommend_fusion(&graph)?;
28//! ```
29//!
30//! ## SCIRS2 Policy Note
31//!
32//! This module uses `rand` for epsilon-greedy exploration in Q-learning, which is
33//! acceptable under the SCIRS2 policy exception for planning layer crates
34//! (tensorlogic-infer). The planning layer may avoid heavy SciRS2 dependencies to
35//! remain lightweight and engine-agnostic. The rand usage here is minimal (only
36//! for exploration strategy) and does not involve heavy tensor operations.
37
38use rand::Rng;
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use thiserror::Error;
42
43/// Learned optimization errors.
44#[derive(Error, Debug, Clone, PartialEq)]
45pub enum LearnedOptError {
46    #[error("Insufficient training data: {0}")]
47    InsufficientData(String),
48
49    #[error("Model not trained: {0}")]
50    ModelNotTrained(String),
51
52    #[error("Feature extraction failed: {0}")]
53    FeatureExtractionFailed(String),
54
55    #[error("Prediction failed: {0}")]
56    PredictionFailed(String),
57
58    #[error("Invalid model configuration: {0}")]
59    InvalidConfig(String),
60}
61
62/// Node ID in the computation graph.
63pub type NodeId = String;
64
65/// Learning strategy.
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67pub enum LearningStrategy {
68    /// Supervised learning from labeled examples
69    Supervised,
70    /// Reinforcement learning from rewards
71    ReinforcementLearning,
72    /// Online learning with continuous updates
73    Online,
74    /// Transfer learning from pre-trained models
75    Transfer,
76}
77
78/// Model type for learned optimization.
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
80pub enum ModelType {
81    /// Linear regression model
82    LinearRegression,
83    /// Decision tree
84    DecisionTree,
85    /// Random forest
86    RandomForest,
87    /// Neural network (simplified)
88    NeuralNetwork,
89    /// Gradient boosting
90    GradientBoosting,
91}
92
93/// Feature vector for graph/node characteristics.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct FeatureVector {
96    pub features: Vec<f64>,
97    pub feature_names: Vec<String>,
98}
99
100impl FeatureVector {
101    fn new() -> Self {
102        Self {
103            features: Vec::new(),
104            feature_names: Vec::new(),
105        }
106    }
107
108    fn add_feature(&mut self, name: String, value: f64) {
109        self.feature_names.push(name);
110        self.features.push(value);
111    }
112}
113
114/// Training example for supervised learning.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct TrainingExample {
117    pub features: FeatureVector,
118    pub label: f64, // For regression: execution time, for classification: 0/1
119}
120
121/// Reward signal for reinforcement learning.
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct RewardSignal {
124    pub state_features: FeatureVector,
125    pub action: OptimizationAction,
126    pub reward: f64, // Positive for speedup, negative for slowdown
127    pub next_state_features: Option<FeatureVector>,
128}
129
130/// Optimization action that can be learned.
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
132pub enum OptimizationAction {
133    Fuse,
134    DontFuse,
135    Parallelize,
136    Sequential,
137    CacheResult,
138    Recompute,
139}
140
141/// Fusion recommendation from learned model.
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct FusionRecommendation {
144    pub should_fuse: bool,
145    pub confidence: f64,
146    pub expected_speedup: f64,
147}
148
149/// Scheduling recommendation.
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ScheduleRecommendation {
152    pub schedule: Vec<NodeId>,
153    pub confidence: f64,
154    pub expected_time_us: f64,
155}
156
157/// Cost prediction from learned model.
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct CostPrediction {
160    pub predicted_cost_us: f64,
161    pub confidence_interval: (f64, f64), // (lower, upper)
162    pub model_confidence: f64,
163}
164
165/// Learning statistics.
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct LearningStats {
168    pub training_examples: usize,
169    pub model_accuracy: f64,
170    pub average_prediction_error: f64,
171    pub total_updates: usize,
172    pub learning_rate: f64,
173}
174
175/// Simplified linear model for cost prediction.
176#[derive(Debug, Clone, Serialize, Deserialize)]
177struct LinearModel {
178    weights: Vec<f64>,
179    bias: f64,
180    learning_rate: f64,
181}
182
183impl LinearModel {
184    fn new(num_features: usize, learning_rate: f64) -> Self {
185        Self {
186            weights: vec![0.0; num_features],
187            bias: 0.0,
188            learning_rate,
189        }
190    }
191
192    fn predict(&self, features: &[f64]) -> f64 {
193        let mut result = self.bias;
194        for (w, f) in self.weights.iter().zip(features.iter()) {
195            result += w * f;
196        }
197        result
198    }
199
200    fn update(&mut self, features: &[f64], target: f64) {
201        let prediction = self.predict(features);
202        let error = target - prediction;
203
204        // Gradient descent update
205        for (w, f) in self.weights.iter_mut().zip(features.iter()) {
206            *w += self.learning_rate * error * f;
207        }
208        self.bias += self.learning_rate * error;
209    }
210}
211
212/// Q-learning agent for reinforcement learning.
213#[derive(Debug, Clone, Serialize, Deserialize)]
214struct QLearningAgent {
215    q_table: HashMap<(String, OptimizationAction), f64>, // (state, action) -> Q-value
216    learning_rate: f64,
217    discount_factor: f64,
218    epsilon: f64, // Exploration rate
219}
220
221impl QLearningAgent {
222    fn new(learning_rate: f64) -> Self {
223        Self {
224            q_table: HashMap::new(),
225            learning_rate,
226            discount_factor: 0.95,
227            epsilon: 0.1,
228        }
229    }
230
231    fn get_q_value(&self, state: &str, action: OptimizationAction) -> f64 {
232        *self
233            .q_table
234            .get(&(state.to_string(), action))
235            .unwrap_or(&0.0)
236    }
237
238    fn update_q_value(
239        &mut self,
240        state: &str,
241        action: OptimizationAction,
242        reward: f64,
243        next_state: Option<&str>,
244    ) {
245        let current_q = self.get_q_value(state, action);
246
247        let max_next_q = if let Some(ns) = next_state {
248            [
249                self.get_q_value(ns, OptimizationAction::Fuse),
250                self.get_q_value(ns, OptimizationAction::DontFuse),
251                self.get_q_value(ns, OptimizationAction::Parallelize),
252                self.get_q_value(ns, OptimizationAction::Sequential),
253            ]
254            .iter()
255            .fold(f64::NEG_INFINITY, |a, &b| a.max(b))
256        } else {
257            0.0
258        };
259
260        let new_q = current_q
261            + self.learning_rate * (reward + self.discount_factor * max_next_q - current_q);
262
263        self.q_table.insert((state.to_string(), action), new_q);
264    }
265
266    fn select_action(&self, state: &str, explore: bool) -> OptimizationAction {
267        if explore && rand::random::<f64>() < self.epsilon {
268            // Random exploration
269            let actions = [
270                OptimizationAction::Fuse,
271                OptimizationAction::DontFuse,
272                OptimizationAction::Parallelize,
273                OptimizationAction::Sequential,
274            ];
275            actions[rand::rng().random_range(0..actions.len())]
276        } else {
277            // Greedy exploitation
278            let actions = [
279                (
280                    OptimizationAction::Fuse,
281                    self.get_q_value(state, OptimizationAction::Fuse),
282                ),
283                (
284                    OptimizationAction::DontFuse,
285                    self.get_q_value(state, OptimizationAction::DontFuse),
286                ),
287                (
288                    OptimizationAction::Parallelize,
289                    self.get_q_value(state, OptimizationAction::Parallelize),
290                ),
291                (
292                    OptimizationAction::Sequential,
293                    self.get_q_value(state, OptimizationAction::Sequential),
294                ),
295            ];
296
297            actions
298                .iter()
299                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
300                .map(|(action, _)| *action)
301                .unwrap_or(OptimizationAction::DontFuse)
302        }
303    }
304}
305
306/// Learned optimizer.
307pub struct LearnedOptimizer {
308    strategy: LearningStrategy,
309    model_type: ModelType,
310    cost_model: Option<LinearModel>,
311    q_agent: Option<QLearningAgent>,
312    training_examples: Vec<TrainingExample>,
313    learning_rate: f64,
314    stats: LearningStats,
315    min_training_examples: usize,
316}
317
318impl LearnedOptimizer {
319    /// Create a new learned optimizer with default settings.
320    pub fn new() -> Self {
321        Self {
322            strategy: LearningStrategy::Online,
323            model_type: ModelType::LinearRegression,
324            cost_model: None,
325            q_agent: None,
326            training_examples: Vec::new(),
327            learning_rate: 0.01,
328            stats: LearningStats {
329                training_examples: 0,
330                model_accuracy: 0.0,
331                average_prediction_error: 0.0,
332                total_updates: 0,
333                learning_rate: 0.01,
334            },
335            min_training_examples: 10,
336        }
337    }
338
339    /// Set learning strategy.
340    pub fn with_strategy(mut self, strategy: LearningStrategy) -> Self {
341        self.strategy = strategy;
342        self
343    }
344
345    /// Set model type.
346    pub fn with_model_type(mut self, model_type: ModelType) -> Self {
347        self.model_type = model_type;
348        self
349    }
350
351    /// Set learning rate.
352    pub fn with_learning_rate(mut self, rate: f64) -> Self {
353        self.learning_rate = rate.clamp(0.0001, 1.0);
354        self.stats.learning_rate = self.learning_rate;
355        self
356    }
357
358    /// Extract features from graph description.
359    pub fn extract_features(
360        &self,
361        graph_desc: &HashMap<String, f64>,
362    ) -> Result<FeatureVector, LearnedOptError> {
363        let mut features = FeatureVector::new();
364
365        // Extract common graph features
366        features.add_feature(
367            "num_nodes".to_string(),
368            *graph_desc.get("num_nodes").unwrap_or(&0.0),
369        );
370        features.add_feature(
371            "num_edges".to_string(),
372            *graph_desc.get("num_edges").unwrap_or(&0.0),
373        );
374        features.add_feature(
375            "avg_node_degree".to_string(),
376            *graph_desc.get("avg_degree").unwrap_or(&0.0),
377        );
378        features.add_feature(
379            "graph_depth".to_string(),
380            *graph_desc.get("depth").unwrap_or(&0.0),
381        );
382        features.add_feature(
383            "total_memory".to_string(),
384            *graph_desc.get("memory").unwrap_or(&0.0),
385        );
386        features.add_feature(
387            "parallelism_factor".to_string(),
388            *graph_desc.get("parallelism").unwrap_or(&1.0),
389        );
390
391        Ok(features)
392    }
393
394    /// Observe execution and update model (online learning).
395    pub fn observe(
396        &mut self,
397        features: FeatureVector,
398        actual_cost: f64,
399    ) -> Result<(), LearnedOptError> {
400        let example = TrainingExample {
401            features: features.clone(),
402            label: actual_cost,
403        };
404
405        self.training_examples.push(example);
406        self.stats.training_examples += 1;
407
408        // Initialize model if needed
409        if self.cost_model.is_none() && features.features.len() > 0 {
410            self.cost_model = Some(LinearModel::new(
411                features.features.len(),
412                self.learning_rate,
413            ));
414        }
415
416        // Update model online
417        if let Some(model) = &mut self.cost_model {
418            model.update(&features.features, actual_cost);
419            self.stats.total_updates += 1;
420        }
421
422        Ok(())
423    }
424
425    /// Observe reward signal for reinforcement learning.
426    pub fn observe_reward(&mut self, signal: RewardSignal) -> Result<(), LearnedOptError> {
427        if self.strategy != LearningStrategy::ReinforcementLearning {
428            return Err(LearnedOptError::InvalidConfig(
429                "Reward observation requires ReinforcementLearning strategy".to_string(),
430            ));
431        }
432
433        // Initialize Q-learning agent if needed
434        if self.q_agent.is_none() {
435            self.q_agent = Some(QLearningAgent::new(self.learning_rate));
436        }
437
438        // Create state representation (simplified as feature hash)
439        let state = format!("{:?}", signal.state_features.features);
440        let next_state = signal
441            .next_state_features
442            .as_ref()
443            .map(|f| format!("{:?}", f.features));
444
445        if let Some(agent) = &mut self.q_agent {
446            agent.update_q_value(&state, signal.action, signal.reward, next_state.as_deref());
447        }
448
449        self.stats.total_updates += 1;
450
451        Ok(())
452    }
453
454    /// Predict cost for given features.
455    pub fn predict_cost(
456        &self,
457        features: &FeatureVector,
458    ) -> Result<CostPrediction, LearnedOptError> {
459        let model = self.cost_model.as_ref().ok_or_else(|| {
460            LearnedOptError::ModelNotTrained("Cost model not trained".to_string())
461        })?;
462
463        if self.training_examples.len() < self.min_training_examples {
464            return Err(LearnedOptError::InsufficientData(format!(
465                "Need at least {} examples, have {}",
466                self.min_training_examples,
467                self.training_examples.len()
468            )));
469        }
470
471        let predicted_cost = model.predict(&features.features);
472
473        // Simplified confidence interval (±20% of prediction)
474        let margin = predicted_cost * 0.2;
475        let confidence_interval = (predicted_cost - margin, predicted_cost + margin);
476
477        // Model confidence based on training data size
478        let model_confidence = (self.training_examples.len() as f64
479            / (self.min_training_examples * 10) as f64)
480            .min(1.0);
481
482        Ok(CostPrediction {
483            predicted_cost_us: predicted_cost.max(0.0),
484            confidence_interval,
485            model_confidence,
486        })
487    }
488
489    /// Recommend whether to fuse operations.
490    pub fn recommend_fusion(
491        &self,
492        features: &FeatureVector,
493    ) -> Result<FusionRecommendation, LearnedOptError> {
494        match self.strategy {
495            LearningStrategy::ReinforcementLearning => {
496                let agent = self.q_agent.as_ref().ok_or_else(|| {
497                    LearnedOptError::ModelNotTrained("Q-learning agent not initialized".to_string())
498                })?;
499
500                let state = format!("{:?}", features.features);
501                let action = agent.select_action(&state, false);
502
503                let should_fuse = action == OptimizationAction::Fuse;
504                let q_fuse = agent.get_q_value(&state, OptimizationAction::Fuse);
505                let q_no_fuse = agent.get_q_value(&state, OptimizationAction::DontFuse);
506
507                let confidence =
508                    (q_fuse - q_no_fuse).abs() / (q_fuse.abs() + q_no_fuse.abs() + 1.0);
509                let expected_speedup = if should_fuse { q_fuse.max(1.0) } else { 1.0 };
510
511                Ok(FusionRecommendation {
512                    should_fuse,
513                    confidence,
514                    expected_speedup,
515                })
516            }
517            _ => {
518                // Use cost model to estimate fusion benefit
519                let cost_pred = self.predict_cost(features)?;
520
521                // Heuristic: fuse if predicted cost is below threshold
522                let threshold = 100.0; // microseconds
523                let should_fuse = cost_pred.predicted_cost_us < threshold;
524
525                Ok(FusionRecommendation {
526                    should_fuse,
527                    confidence: cost_pred.model_confidence,
528                    expected_speedup: if should_fuse { 1.5 } else { 1.0 },
529                })
530            }
531        }
532    }
533
534    /// Get learning statistics.
535    pub fn get_stats(&self) -> &LearningStats {
536        &self.stats
537    }
538
539    /// Evaluate model accuracy on training data.
540    pub fn evaluate_accuracy(&mut self) -> Result<f64, LearnedOptError> {
541        if self.training_examples.is_empty() {
542            return Ok(0.0);
543        }
544
545        let model = self.cost_model.as_ref().ok_or_else(|| {
546            LearnedOptError::ModelNotTrained("Cost model not trained".to_string())
547        })?;
548
549        let mut total_error = 0.0;
550
551        for example in &self.training_examples {
552            let prediction = model.predict(&example.features.features);
553            let error = (prediction - example.label).abs();
554            total_error += error;
555        }
556
557        let avg_error = total_error / self.training_examples.len() as f64;
558        self.stats.average_prediction_error = avg_error;
559
560        // Accuracy = 1 - normalized error
561        let max_label = self
562            .training_examples
563            .iter()
564            .map(|e| e.label)
565            .fold(f64::NEG_INFINITY, f64::max);
566
567        let accuracy = if max_label > 0.0 {
568            (1.0 - (avg_error / max_label)).max(0.0)
569        } else {
570            0.0
571        };
572
573        self.stats.model_accuracy = accuracy;
574
575        Ok(accuracy)
576    }
577
578    /// Reset learning state.
579    pub fn reset(&mut self) {
580        self.training_examples.clear();
581        self.cost_model = None;
582        self.q_agent = None;
583        self.stats = LearningStats {
584            training_examples: 0,
585            model_accuracy: 0.0,
586            average_prediction_error: 0.0,
587            total_updates: 0,
588            learning_rate: self.learning_rate,
589        };
590    }
591}
592
593impl Default for LearnedOptimizer {
594    fn default() -> Self {
595        Self::new()
596    }
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602
603    fn create_test_features() -> FeatureVector {
604        let mut features = FeatureVector::new();
605        features.add_feature("num_nodes".to_string(), 10.0);
606        features.add_feature("num_edges".to_string(), 15.0);
607        features.add_feature("avg_degree".to_string(), 1.5);
608        features
609    }
610
611    #[test]
612    fn test_learned_optimizer_creation() {
613        let optimizer = LearnedOptimizer::new();
614        assert_eq!(optimizer.strategy, LearningStrategy::Online);
615        assert_eq!(optimizer.model_type, ModelType::LinearRegression);
616    }
617
618    #[test]
619    fn test_builder_pattern() {
620        let optimizer = LearnedOptimizer::new()
621            .with_strategy(LearningStrategy::ReinforcementLearning)
622            .with_model_type(ModelType::NeuralNetwork)
623            .with_learning_rate(0.05);
624
625        assert_eq!(optimizer.strategy, LearningStrategy::ReinforcementLearning);
626        assert_eq!(optimizer.model_type, ModelType::NeuralNetwork);
627        assert_eq!(optimizer.learning_rate, 0.05);
628    }
629
630    #[test]
631    fn test_feature_extraction() {
632        let optimizer = LearnedOptimizer::new();
633        let mut graph_desc = HashMap::new();
634        graph_desc.insert("num_nodes".to_string(), 10.0);
635        graph_desc.insert("num_edges".to_string(), 15.0);
636
637        let features = optimizer.extract_features(&graph_desc).unwrap();
638        assert!(features.features.len() > 0);
639    }
640
641    #[test]
642    fn test_observe_and_learn() {
643        let mut optimizer = LearnedOptimizer::new();
644        let features = create_test_features();
645
646        optimizer.observe(features.clone(), 100.0).unwrap();
647        optimizer.observe(features.clone(), 95.0).unwrap();
648
649        assert_eq!(optimizer.stats.training_examples, 2);
650        assert_eq!(optimizer.stats.total_updates, 2);
651    }
652
653    #[test]
654    fn test_cost_prediction_insufficient_data() {
655        let optimizer = LearnedOptimizer::new();
656        let features = create_test_features();
657
658        let result = optimizer.predict_cost(&features);
659        assert!(result.is_err());
660    }
661
662    #[test]
663    fn test_cost_prediction_with_training() {
664        let mut optimizer = LearnedOptimizer::new();
665        let features = create_test_features();
666
667        // Add sufficient training examples
668        for i in 0..15 {
669            let mut f = create_test_features();
670            f.features[0] = i as f64;
671            optimizer.observe(f, 100.0 + i as f64).unwrap();
672        }
673
674        let prediction = optimizer.predict_cost(&features).unwrap();
675        assert!(prediction.predicted_cost_us >= 0.0);
676        assert!(prediction.model_confidence > 0.0);
677    }
678
679    #[test]
680    fn test_reinforcement_learning_observation() {
681        let mut optimizer =
682            LearnedOptimizer::new().with_strategy(LearningStrategy::ReinforcementLearning);
683
684        let signal = RewardSignal {
685            state_features: create_test_features(),
686            action: OptimizationAction::Fuse,
687            reward: 10.0, // Positive reward for speedup
688            next_state_features: Some(create_test_features()),
689        };
690
691        optimizer.observe_reward(signal).unwrap();
692        assert_eq!(optimizer.stats.total_updates, 1);
693    }
694
695    #[test]
696    fn test_fusion_recommendation() {
697        let mut optimizer = LearnedOptimizer::new();
698        let features = create_test_features();
699
700        // Train with examples
701        for i in 0..15 {
702            let mut f = create_test_features();
703            f.features[0] = i as f64;
704            optimizer.observe(f, 50.0 + i as f64).unwrap(); // Low cost -> should recommend fusion
705        }
706
707        let recommendation = optimizer.recommend_fusion(&features).unwrap();
708        assert!(recommendation.confidence >= 0.0);
709    }
710
711    #[test]
712    fn test_rl_fusion_recommendation() {
713        let mut optimizer =
714            LearnedOptimizer::new().with_strategy(LearningStrategy::ReinforcementLearning);
715
716        let features = create_test_features();
717
718        // Train with rewards
719        for _ in 0..10 {
720            let signal = RewardSignal {
721                state_features: features.clone(),
722                action: OptimizationAction::Fuse,
723                reward: 15.0,
724                next_state_features: None,
725            };
726            optimizer.observe_reward(signal).unwrap();
727        }
728
729        let recommendation = optimizer.recommend_fusion(&features).unwrap();
730        // Just check it returns a valid recommendation
731        assert!(recommendation.confidence >= 0.0);
732    }
733
734    #[test]
735    fn test_accuracy_evaluation() {
736        let mut optimizer = LearnedOptimizer::new();
737
738        // Add training examples with known relationship
739        for i in 0..20 {
740            let mut features = FeatureVector::new();
741            features.add_feature("x".to_string(), i as f64);
742            optimizer.observe(features, i as f64 * 2.0).unwrap(); // y = 2x
743        }
744
745        let accuracy = optimizer.evaluate_accuracy().unwrap();
746        assert!(accuracy >= 0.0 && accuracy <= 1.0);
747    }
748
749    #[test]
750    fn test_reset() {
751        let mut optimizer = LearnedOptimizer::new();
752        let features = create_test_features();
753
754        optimizer.observe(features, 100.0).unwrap();
755        assert_eq!(optimizer.stats.training_examples, 1);
756
757        optimizer.reset();
758        assert_eq!(optimizer.stats.training_examples, 0);
759        assert!(optimizer.training_examples.is_empty());
760    }
761
762    #[test]
763    fn test_linear_model_prediction() {
764        let model = LinearModel::new(3, 0.01);
765        let features = vec![1.0, 2.0, 3.0];
766
767        let prediction = model.predict(&features);
768        assert!(prediction.is_finite());
769    }
770
771    #[test]
772    fn test_linear_model_update() {
773        let mut model = LinearModel::new(2, 0.1);
774        let features = vec![1.0, 2.0];
775
776        model.update(&features, 10.0);
777        let pred_after = model.predict(&features);
778
779        // After update, prediction should move towards target
780        assert!(pred_after.is_finite());
781    }
782
783    #[test]
784    fn test_q_learning_agent() {
785        let mut agent = QLearningAgent::new(0.1);
786
787        agent.update_q_value("state1", OptimizationAction::Fuse, 10.0, Some("state2"));
788
789        let q_value = agent.get_q_value("state1", OptimizationAction::Fuse);
790        assert!(q_value > 0.0);
791    }
792
793    #[test]
794    fn test_q_learning_action_selection() {
795        let mut agent = QLearningAgent::new(0.1);
796
797        // Train with high reward for Fuse action
798        for _ in 0..10 {
799            agent.update_q_value("state1", OptimizationAction::Fuse, 20.0, None);
800        }
801
802        let action = agent.select_action("state1", false);
803        // With no exploration, should select Fuse (but other actions possible)
804        assert!(
805            action == OptimizationAction::Fuse
806                || action == OptimizationAction::DontFuse
807                || action == OptimizationAction::Parallelize
808                || action == OptimizationAction::Sequential
809        );
810    }
811
812    #[test]
813    fn test_different_learning_strategies() {
814        let strategies = vec![
815            LearningStrategy::Supervised,
816            LearningStrategy::Online,
817            LearningStrategy::Transfer,
818        ];
819
820        for strategy in strategies {
821            let optimizer = LearnedOptimizer::new().with_strategy(strategy);
822            assert_eq!(optimizer.strategy, strategy);
823        }
824    }
825
826    #[test]
827    fn test_different_model_types() {
828        let model_types = vec![
829            ModelType::LinearRegression,
830            ModelType::DecisionTree,
831            ModelType::RandomForest,
832            ModelType::NeuralNetwork,
833            ModelType::GradientBoosting,
834        ];
835
836        for model_type in model_types {
837            let optimizer = LearnedOptimizer::new().with_model_type(model_type);
838            assert_eq!(optimizer.model_type, model_type);
839        }
840    }
841}