Skip to main content

tensorlogic_infer/
learned_opt.rs

1//! Machine learning-based optimization decisions.
2//!
3//! This module implements learned optimization strategies:
4//! - **ML-based fusion decisions**: Learn which operations to fuse
5//! - **Learned cost models**: Predict operation costs using ML
6//! - **Reinforcement learning for scheduling**: Learn optimal execution schedules
7//! - **Feature extraction**: Extract relevant features from computation graphs
8//! - **Online learning**: Continuously improve from observed performance
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use tensorlogic_infer::{LearnedOptimizer, LearningStrategy, ModelType};
14//!
15//! // Create learned optimizer
16//! let mut optimizer = LearnedOptimizer::new()
17//!     .with_strategy(LearningStrategy::ReinforcementLearning)
18//!     .with_model_type(ModelType::NeuralNetwork)
19//!     .with_learning_rate(0.01);
20//!
21//! // Train from observed executions
22//! for (graph, performance) in training_data {
23//!     optimizer.observe(&graph, performance)?;
24//! }
25//!
26//! // Use learned model for optimization
27//! let decision = optimizer.recommend_fusion(&graph)?;
28//! ```
29//!
30//! ## SCIRS2 Policy Note
31//!
32//! This module uses `scirs2_core::random` for epsilon-greedy exploration in Q-learning,
33//! following the SCIRS2 policy. All random number generation is routed through the
34//! `scirs2_core::random` module rather than depending on `rand` directly.
35
36use scirs2_core::random::RngExt;
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use thiserror::Error;
40
41/// Learned optimization errors.
42#[derive(Error, Debug, Clone, PartialEq)]
43pub enum LearnedOptError {
44    #[error("Insufficient training data: {0}")]
45    InsufficientData(String),
46
47    #[error("Model not trained: {0}")]
48    ModelNotTrained(String),
49
50    #[error("Feature extraction failed: {0}")]
51    FeatureExtractionFailed(String),
52
53    #[error("Prediction failed: {0}")]
54    PredictionFailed(String),
55
56    #[error("Invalid model configuration: {0}")]
57    InvalidConfig(String),
58}
59
60/// Node ID in the computation graph.
61pub type NodeId = String;
62
63/// Learning strategy.
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
65pub enum LearningStrategy {
66    /// Supervised learning from labeled examples
67    Supervised,
68    /// Reinforcement learning from rewards
69    ReinforcementLearning,
70    /// Online learning with continuous updates
71    Online,
72    /// Transfer learning from pre-trained models
73    Transfer,
74}
75
76/// Model type for learned optimization.
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
78pub enum ModelType {
79    /// Linear regression model
80    LinearRegression,
81    /// Decision tree
82    DecisionTree,
83    /// Random forest
84    RandomForest,
85    /// Neural network (simplified)
86    NeuralNetwork,
87    /// Gradient boosting
88    GradientBoosting,
89}
90
91/// Feature vector for graph/node characteristics.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct FeatureVector {
94    pub features: Vec<f64>,
95    pub feature_names: Vec<String>,
96}
97
98impl FeatureVector {
99    fn new() -> Self {
100        Self {
101            features: Vec::new(),
102            feature_names: Vec::new(),
103        }
104    }
105
106    fn add_feature(&mut self, name: String, value: f64) {
107        self.feature_names.push(name);
108        self.features.push(value);
109    }
110}
111
112/// Training example for supervised learning.
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct TrainingExample {
115    pub features: FeatureVector,
116    pub label: f64, // For regression: execution time, for classification: 0/1
117}
118
119/// Reward signal for reinforcement learning.
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct RewardSignal {
122    pub state_features: FeatureVector,
123    pub action: OptimizationAction,
124    pub reward: f64, // Positive for speedup, negative for slowdown
125    pub next_state_features: Option<FeatureVector>,
126}
127
128/// Optimization action that can be learned.
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
130pub enum OptimizationAction {
131    Fuse,
132    DontFuse,
133    Parallelize,
134    Sequential,
135    CacheResult,
136    Recompute,
137}
138
139/// Fusion recommendation from learned model.
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct FusionRecommendation {
142    pub should_fuse: bool,
143    pub confidence: f64,
144    pub expected_speedup: f64,
145}
146
147/// Scheduling recommendation.
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct ScheduleRecommendation {
150    pub schedule: Vec<NodeId>,
151    pub confidence: f64,
152    pub expected_time_us: f64,
153}
154
155/// Cost prediction from learned model.
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct CostPrediction {
158    pub predicted_cost_us: f64,
159    pub confidence_interval: (f64, f64), // (lower, upper)
160    pub model_confidence: f64,
161}
162
163/// Learning statistics.
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct LearningStats {
166    pub training_examples: usize,
167    pub model_accuracy: f64,
168    pub average_prediction_error: f64,
169    pub total_updates: usize,
170    pub learning_rate: f64,
171}
172
173/// Simplified linear model for cost prediction.
174#[derive(Debug, Clone, Serialize, Deserialize)]
175struct LinearModel {
176    weights: Vec<f64>,
177    bias: f64,
178    learning_rate: f64,
179}
180
181impl LinearModel {
182    fn new(num_features: usize, learning_rate: f64) -> Self {
183        Self {
184            weights: vec![0.0; num_features],
185            bias: 0.0,
186            learning_rate,
187        }
188    }
189
190    fn predict(&self, features: &[f64]) -> f64 {
191        let mut result = self.bias;
192        for (w, f) in self.weights.iter().zip(features.iter()) {
193            result += w * f;
194        }
195        result
196    }
197
198    fn update(&mut self, features: &[f64], target: f64) {
199        let prediction = self.predict(features);
200        let error = target - prediction;
201
202        // Gradient descent update
203        for (w, f) in self.weights.iter_mut().zip(features.iter()) {
204            *w += self.learning_rate * error * f;
205        }
206        self.bias += self.learning_rate * error;
207    }
208}
209
210/// Q-learning agent for reinforcement learning.
211#[derive(Debug, Clone, Serialize, Deserialize)]
212struct QLearningAgent {
213    q_table: HashMap<(String, OptimizationAction), f64>, // (state, action) -> Q-value
214    learning_rate: f64,
215    discount_factor: f64,
216    epsilon: f64, // Exploration rate
217}
218
219impl QLearningAgent {
220    fn new(learning_rate: f64) -> Self {
221        Self {
222            q_table: HashMap::new(),
223            learning_rate,
224            discount_factor: 0.95,
225            epsilon: 0.1,
226        }
227    }
228
229    fn get_q_value(&self, state: &str, action: OptimizationAction) -> f64 {
230        *self
231            .q_table
232            .get(&(state.to_string(), action))
233            .unwrap_or(&0.0)
234    }
235
236    fn update_q_value(
237        &mut self,
238        state: &str,
239        action: OptimizationAction,
240        reward: f64,
241        next_state: Option<&str>,
242    ) {
243        let current_q = self.get_q_value(state, action);
244
245        let max_next_q = if let Some(ns) = next_state {
246            [
247                self.get_q_value(ns, OptimizationAction::Fuse),
248                self.get_q_value(ns, OptimizationAction::DontFuse),
249                self.get_q_value(ns, OptimizationAction::Parallelize),
250                self.get_q_value(ns, OptimizationAction::Sequential),
251            ]
252            .iter()
253            .fold(f64::NEG_INFINITY, |a, &b| a.max(b))
254        } else {
255            0.0
256        };
257
258        let new_q = current_q
259            + self.learning_rate * (reward + self.discount_factor * max_next_q - current_q);
260
261        self.q_table.insert((state.to_string(), action), new_q);
262    }
263
264    fn select_action(&self, state: &str, explore: bool) -> OptimizationAction {
265        if explore && scirs2_core::random::random::<f64>() < self.epsilon {
266            // Random exploration
267            let actions = [
268                OptimizationAction::Fuse,
269                OptimizationAction::DontFuse,
270                OptimizationAction::Parallelize,
271                OptimizationAction::Sequential,
272            ];
273            actions[scirs2_core::random::rng().random_range(0..actions.len())]
274        } else {
275            // Greedy exploitation
276            let actions = [
277                (
278                    OptimizationAction::Fuse,
279                    self.get_q_value(state, OptimizationAction::Fuse),
280                ),
281                (
282                    OptimizationAction::DontFuse,
283                    self.get_q_value(state, OptimizationAction::DontFuse),
284                ),
285                (
286                    OptimizationAction::Parallelize,
287                    self.get_q_value(state, OptimizationAction::Parallelize),
288                ),
289                (
290                    OptimizationAction::Sequential,
291                    self.get_q_value(state, OptimizationAction::Sequential),
292                ),
293            ];
294
295            actions
296                .iter()
297                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
298                .map(|(action, _)| *action)
299                .unwrap_or(OptimizationAction::DontFuse)
300        }
301    }
302}
303
304/// Learned optimizer.
305pub struct LearnedOptimizer {
306    strategy: LearningStrategy,
307    model_type: ModelType,
308    cost_model: Option<LinearModel>,
309    q_agent: Option<QLearningAgent>,
310    training_examples: Vec<TrainingExample>,
311    learning_rate: f64,
312    stats: LearningStats,
313    min_training_examples: usize,
314}
315
316impl LearnedOptimizer {
317    /// Create a new learned optimizer with default settings.
318    pub fn new() -> Self {
319        Self {
320            strategy: LearningStrategy::Online,
321            model_type: ModelType::LinearRegression,
322            cost_model: None,
323            q_agent: None,
324            training_examples: Vec::new(),
325            learning_rate: 0.01,
326            stats: LearningStats {
327                training_examples: 0,
328                model_accuracy: 0.0,
329                average_prediction_error: 0.0,
330                total_updates: 0,
331                learning_rate: 0.01,
332            },
333            min_training_examples: 10,
334        }
335    }
336
337    /// Set learning strategy.
338    pub fn with_strategy(mut self, strategy: LearningStrategy) -> Self {
339        self.strategy = strategy;
340        self
341    }
342
343    /// Set model type.
344    pub fn with_model_type(mut self, model_type: ModelType) -> Self {
345        self.model_type = model_type;
346        self
347    }
348
349    /// Set learning rate.
350    pub fn with_learning_rate(mut self, rate: f64) -> Self {
351        self.learning_rate = rate.clamp(0.0001, 1.0);
352        self.stats.learning_rate = self.learning_rate;
353        self
354    }
355
356    /// Extract features from graph description.
357    pub fn extract_features(
358        &self,
359        graph_desc: &HashMap<String, f64>,
360    ) -> Result<FeatureVector, LearnedOptError> {
361        let mut features = FeatureVector::new();
362
363        // Extract common graph features
364        features.add_feature(
365            "num_nodes".to_string(),
366            *graph_desc.get("num_nodes").unwrap_or(&0.0),
367        );
368        features.add_feature(
369            "num_edges".to_string(),
370            *graph_desc.get("num_edges").unwrap_or(&0.0),
371        );
372        features.add_feature(
373            "avg_node_degree".to_string(),
374            *graph_desc.get("avg_degree").unwrap_or(&0.0),
375        );
376        features.add_feature(
377            "graph_depth".to_string(),
378            *graph_desc.get("depth").unwrap_or(&0.0),
379        );
380        features.add_feature(
381            "total_memory".to_string(),
382            *graph_desc.get("memory").unwrap_or(&0.0),
383        );
384        features.add_feature(
385            "parallelism_factor".to_string(),
386            *graph_desc.get("parallelism").unwrap_or(&1.0),
387        );
388
389        Ok(features)
390    }
391
392    /// Observe execution and update model (online learning).
393    pub fn observe(
394        &mut self,
395        features: FeatureVector,
396        actual_cost: f64,
397    ) -> Result<(), LearnedOptError> {
398        let example = TrainingExample {
399            features: features.clone(),
400            label: actual_cost,
401        };
402
403        self.training_examples.push(example);
404        self.stats.training_examples += 1;
405
406        // Initialize model if needed
407        if self.cost_model.is_none() && features.features.len() > 0 {
408            self.cost_model = Some(LinearModel::new(
409                features.features.len(),
410                self.learning_rate,
411            ));
412        }
413
414        // Update model online
415        if let Some(model) = &mut self.cost_model {
416            model.update(&features.features, actual_cost);
417            self.stats.total_updates += 1;
418        }
419
420        Ok(())
421    }
422
423    /// Observe reward signal for reinforcement learning.
424    pub fn observe_reward(&mut self, signal: RewardSignal) -> Result<(), LearnedOptError> {
425        if self.strategy != LearningStrategy::ReinforcementLearning {
426            return Err(LearnedOptError::InvalidConfig(
427                "Reward observation requires ReinforcementLearning strategy".to_string(),
428            ));
429        }
430
431        // Initialize Q-learning agent if needed
432        if self.q_agent.is_none() {
433            self.q_agent = Some(QLearningAgent::new(self.learning_rate));
434        }
435
436        // Create state representation (simplified as feature hash)
437        let state = format!("{:?}", signal.state_features.features);
438        let next_state = signal
439            .next_state_features
440            .as_ref()
441            .map(|f| format!("{:?}", f.features));
442
443        if let Some(agent) = &mut self.q_agent {
444            agent.update_q_value(&state, signal.action, signal.reward, next_state.as_deref());
445        }
446
447        self.stats.total_updates += 1;
448
449        Ok(())
450    }
451
452    /// Predict cost for given features.
453    pub fn predict_cost(
454        &self,
455        features: &FeatureVector,
456    ) -> Result<CostPrediction, LearnedOptError> {
457        let model = self.cost_model.as_ref().ok_or_else(|| {
458            LearnedOptError::ModelNotTrained("Cost model not trained".to_string())
459        })?;
460
461        if self.training_examples.len() < self.min_training_examples {
462            return Err(LearnedOptError::InsufficientData(format!(
463                "Need at least {} examples, have {}",
464                self.min_training_examples,
465                self.training_examples.len()
466            )));
467        }
468
469        let predicted_cost = model.predict(&features.features);
470
471        // Simplified confidence interval (±20% of prediction)
472        let margin = predicted_cost * 0.2;
473        let confidence_interval = (predicted_cost - margin, predicted_cost + margin);
474
475        // Model confidence based on training data size
476        let model_confidence = (self.training_examples.len() as f64
477            / (self.min_training_examples * 10) as f64)
478            .min(1.0);
479
480        Ok(CostPrediction {
481            predicted_cost_us: predicted_cost.max(0.0),
482            confidence_interval,
483            model_confidence,
484        })
485    }
486
487    /// Recommend whether to fuse operations.
488    pub fn recommend_fusion(
489        &self,
490        features: &FeatureVector,
491    ) -> Result<FusionRecommendation, LearnedOptError> {
492        match self.strategy {
493            LearningStrategy::ReinforcementLearning => {
494                let agent = self.q_agent.as_ref().ok_or_else(|| {
495                    LearnedOptError::ModelNotTrained("Q-learning agent not initialized".to_string())
496                })?;
497
498                let state = format!("{:?}", features.features);
499                let action = agent.select_action(&state, false);
500
501                let should_fuse = action == OptimizationAction::Fuse;
502                let q_fuse = agent.get_q_value(&state, OptimizationAction::Fuse);
503                let q_no_fuse = agent.get_q_value(&state, OptimizationAction::DontFuse);
504
505                let confidence =
506                    (q_fuse - q_no_fuse).abs() / (q_fuse.abs() + q_no_fuse.abs() + 1.0);
507                let expected_speedup = if should_fuse { q_fuse.max(1.0) } else { 1.0 };
508
509                Ok(FusionRecommendation {
510                    should_fuse,
511                    confidence,
512                    expected_speedup,
513                })
514            }
515            _ => {
516                // Use cost model to estimate fusion benefit
517                let cost_pred = self.predict_cost(features)?;
518
519                // Heuristic: fuse if predicted cost is below threshold
520                let threshold = 100.0; // microseconds
521                let should_fuse = cost_pred.predicted_cost_us < threshold;
522
523                Ok(FusionRecommendation {
524                    should_fuse,
525                    confidence: cost_pred.model_confidence,
526                    expected_speedup: if should_fuse { 1.5 } else { 1.0 },
527                })
528            }
529        }
530    }
531
532    /// Get learning statistics.
533    pub fn get_stats(&self) -> &LearningStats {
534        &self.stats
535    }
536
537    /// Evaluate model accuracy on training data.
538    pub fn evaluate_accuracy(&mut self) -> Result<f64, LearnedOptError> {
539        if self.training_examples.is_empty() {
540            return Ok(0.0);
541        }
542
543        let model = self.cost_model.as_ref().ok_or_else(|| {
544            LearnedOptError::ModelNotTrained("Cost model not trained".to_string())
545        })?;
546
547        let mut total_error = 0.0;
548
549        for example in &self.training_examples {
550            let prediction = model.predict(&example.features.features);
551            let error = (prediction - example.label).abs();
552            total_error += error;
553        }
554
555        let avg_error = total_error / self.training_examples.len() as f64;
556        self.stats.average_prediction_error = avg_error;
557
558        // Accuracy = 1 - normalized error
559        let max_label = self
560            .training_examples
561            .iter()
562            .map(|e| e.label)
563            .fold(f64::NEG_INFINITY, f64::max);
564
565        let accuracy = if max_label > 0.0 {
566            (1.0 - (avg_error / max_label)).max(0.0)
567        } else {
568            0.0
569        };
570
571        self.stats.model_accuracy = accuracy;
572
573        Ok(accuracy)
574    }
575
576    /// Reset learning state.
577    pub fn reset(&mut self) {
578        self.training_examples.clear();
579        self.cost_model = None;
580        self.q_agent = None;
581        self.stats = LearningStats {
582            training_examples: 0,
583            model_accuracy: 0.0,
584            average_prediction_error: 0.0,
585            total_updates: 0,
586            learning_rate: self.learning_rate,
587        };
588    }
589}
590
591impl Default for LearnedOptimizer {
592    fn default() -> Self {
593        Self::new()
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    fn create_test_features() -> FeatureVector {
602        let mut features = FeatureVector::new();
603        features.add_feature("num_nodes".to_string(), 10.0);
604        features.add_feature("num_edges".to_string(), 15.0);
605        features.add_feature("avg_degree".to_string(), 1.5);
606        features
607    }
608
609    #[test]
610    fn test_learned_optimizer_creation() {
611        let optimizer = LearnedOptimizer::new();
612        assert_eq!(optimizer.strategy, LearningStrategy::Online);
613        assert_eq!(optimizer.model_type, ModelType::LinearRegression);
614    }
615
616    #[test]
617    fn test_builder_pattern() {
618        let optimizer = LearnedOptimizer::new()
619            .with_strategy(LearningStrategy::ReinforcementLearning)
620            .with_model_type(ModelType::NeuralNetwork)
621            .with_learning_rate(0.05);
622
623        assert_eq!(optimizer.strategy, LearningStrategy::ReinforcementLearning);
624        assert_eq!(optimizer.model_type, ModelType::NeuralNetwork);
625        assert_eq!(optimizer.learning_rate, 0.05);
626    }
627
628    #[test]
629    fn test_feature_extraction() {
630        let optimizer = LearnedOptimizer::new();
631        let mut graph_desc = HashMap::new();
632        graph_desc.insert("num_nodes".to_string(), 10.0);
633        graph_desc.insert("num_edges".to_string(), 15.0);
634
635        let features = optimizer.extract_features(&graph_desc).expect("unwrap");
636        assert!(features.features.len() > 0);
637    }
638
639    #[test]
640    fn test_observe_and_learn() {
641        let mut optimizer = LearnedOptimizer::new();
642        let features = create_test_features();
643
644        optimizer.observe(features.clone(), 100.0).expect("unwrap");
645        optimizer.observe(features.clone(), 95.0).expect("unwrap");
646
647        assert_eq!(optimizer.stats.training_examples, 2);
648        assert_eq!(optimizer.stats.total_updates, 2);
649    }
650
651    #[test]
652    fn test_cost_prediction_insufficient_data() {
653        let optimizer = LearnedOptimizer::new();
654        let features = create_test_features();
655
656        let result = optimizer.predict_cost(&features);
657        assert!(result.is_err());
658    }
659
660    #[test]
661    fn test_cost_prediction_with_training() {
662        let mut optimizer = LearnedOptimizer::new();
663        let features = create_test_features();
664
665        // Add sufficient training examples
666        for i in 0..15 {
667            let mut f = create_test_features();
668            f.features[0] = i as f64;
669            optimizer.observe(f, 100.0 + i as f64).expect("unwrap");
670        }
671
672        let prediction = optimizer.predict_cost(&features).expect("unwrap");
673        assert!(prediction.predicted_cost_us >= 0.0);
674        assert!(prediction.model_confidence > 0.0);
675    }
676
677    #[test]
678    fn test_reinforcement_learning_observation() {
679        let mut optimizer =
680            LearnedOptimizer::new().with_strategy(LearningStrategy::ReinforcementLearning);
681
682        let signal = RewardSignal {
683            state_features: create_test_features(),
684            action: OptimizationAction::Fuse,
685            reward: 10.0, // Positive reward for speedup
686            next_state_features: Some(create_test_features()),
687        };
688
689        optimizer.observe_reward(signal).expect("unwrap");
690        assert_eq!(optimizer.stats.total_updates, 1);
691    }
692
693    #[test]
694    fn test_fusion_recommendation() {
695        let mut optimizer = LearnedOptimizer::new();
696        let features = create_test_features();
697
698        // Train with examples
699        for i in 0..15 {
700            let mut f = create_test_features();
701            f.features[0] = i as f64;
702            optimizer.observe(f, 50.0 + i as f64).expect("unwrap"); // Low cost -> should recommend fusion
703        }
704
705        let recommendation = optimizer.recommend_fusion(&features).expect("unwrap");
706        assert!(recommendation.confidence >= 0.0);
707    }
708
709    #[test]
710    fn test_rl_fusion_recommendation() {
711        let mut optimizer =
712            LearnedOptimizer::new().with_strategy(LearningStrategy::ReinforcementLearning);
713
714        let features = create_test_features();
715
716        // Train with rewards
717        for _ in 0..10 {
718            let signal = RewardSignal {
719                state_features: features.clone(),
720                action: OptimizationAction::Fuse,
721                reward: 15.0,
722                next_state_features: None,
723            };
724            optimizer.observe_reward(signal).expect("unwrap");
725        }
726
727        let recommendation = optimizer.recommend_fusion(&features).expect("unwrap");
728        // Just check it returns a valid recommendation
729        assert!(recommendation.confidence >= 0.0);
730    }
731
732    #[test]
733    fn test_accuracy_evaluation() {
734        let mut optimizer = LearnedOptimizer::new();
735
736        // Add training examples with known relationship
737        for i in 0..20 {
738            let mut features = FeatureVector::new();
739            features.add_feature("x".to_string(), i as f64);
740            optimizer.observe(features, i as f64 * 2.0).expect("unwrap"); // y = 2x
741        }
742
743        let accuracy = optimizer.evaluate_accuracy().expect("unwrap");
744        assert!(accuracy >= 0.0 && accuracy <= 1.0);
745    }
746
747    #[test]
748    fn test_reset() {
749        let mut optimizer = LearnedOptimizer::new();
750        let features = create_test_features();
751
752        optimizer.observe(features, 100.0).expect("unwrap");
753        assert_eq!(optimizer.stats.training_examples, 1);
754
755        optimizer.reset();
756        assert_eq!(optimizer.stats.training_examples, 0);
757        assert!(optimizer.training_examples.is_empty());
758    }
759
760    #[test]
761    fn test_linear_model_prediction() {
762        let model = LinearModel::new(3, 0.01);
763        let features = vec![1.0, 2.0, 3.0];
764
765        let prediction = model.predict(&features);
766        assert!(prediction.is_finite());
767    }
768
769    #[test]
770    fn test_linear_model_update() {
771        let mut model = LinearModel::new(2, 0.1);
772        let features = vec![1.0, 2.0];
773
774        model.update(&features, 10.0);
775        let pred_after = model.predict(&features);
776
777        // After update, prediction should move towards target
778        assert!(pred_after.is_finite());
779    }
780
781    #[test]
782    fn test_q_learning_agent() {
783        let mut agent = QLearningAgent::new(0.1);
784
785        agent.update_q_value("state1", OptimizationAction::Fuse, 10.0, Some("state2"));
786
787        let q_value = agent.get_q_value("state1", OptimizationAction::Fuse);
788        assert!(q_value > 0.0);
789    }
790
791    #[test]
792    fn test_q_learning_action_selection() {
793        let mut agent = QLearningAgent::new(0.1);
794
795        // Train with high reward for Fuse action
796        for _ in 0..10 {
797            agent.update_q_value("state1", OptimizationAction::Fuse, 20.0, None);
798        }
799
800        let action = agent.select_action("state1", false);
801        // With no exploration, should select Fuse (but other actions possible)
802        assert!(
803            action == OptimizationAction::Fuse
804                || action == OptimizationAction::DontFuse
805                || action == OptimizationAction::Parallelize
806                || action == OptimizationAction::Sequential
807        );
808    }
809
810    #[test]
811    fn test_different_learning_strategies() {
812        let strategies = vec![
813            LearningStrategy::Supervised,
814            LearningStrategy::Online,
815            LearningStrategy::Transfer,
816        ];
817
818        for strategy in strategies {
819            let optimizer = LearnedOptimizer::new().with_strategy(strategy);
820            assert_eq!(optimizer.strategy, strategy);
821        }
822    }
823
824    #[test]
825    fn test_different_model_types() {
826        let model_types = vec![
827            ModelType::LinearRegression,
828            ModelType::DecisionTree,
829            ModelType::RandomForest,
830            ModelType::NeuralNetwork,
831            ModelType::GradientBoosting,
832        ];
833
834        for model_type in model_types {
835            let optimizer = LearnedOptimizer::new().with_model_type(model_type);
836            assert_eq!(optimizer.model_type, model_type);
837        }
838    }
839}