scirs2_sparse/neural_adaptive_sparse/
reinforcement_learning.rs

1//! Reinforcement learning components for adaptive sparse matrix optimization
2//!
3//! This module implements RL agents that learn optimal strategies for sparse
4//! matrix operations based on performance feedback and matrix characteristics.
5
6use super::neural_network::NeuralNetwork;
7use super::pattern_memory::OptimizationStrategy;
8use crate::error::SparseResult;
9use scirs2_core::random::Rng;
10use std::collections::VecDeque;
11
12/// Reinforcement learning algorithms
13#[derive(Debug, Clone, Copy)]
14pub enum RLAlgorithm {
15    /// Q-Learning with experience replay
16    DQN,
17    /// Policy gradient methods
18    PolicyGradient,
19    /// Actor-Critic methods
20    ActorCritic,
21    /// Proximal Policy Optimization
22    PPO,
23    /// Soft Actor-Critic
24    SAC,
25}
26
27/// Reinforcement learning agent
28#[derive(Debug)]
29#[allow(dead_code)]
30pub(crate) struct RLAgent {
31    pub q_network: NeuralNetwork,
32    pub target_network: Option<NeuralNetwork>,
33    pub policy_network: Option<NeuralNetwork>,
34    pub value_network: Option<NeuralNetwork>,
35    pub algorithm: RLAlgorithm,
36    pub epsilon: f64,
37    pub learningrate: f64,
38}
39
40/// Experience for reinforcement learning
41#[derive(Debug, Clone)]
42#[allow(dead_code)]
43pub(crate) struct Experience {
44    pub state: Vec<f64>,
45    pub action: OptimizationStrategy,
46    pub reward: f64,
47    pub next_state: Vec<f64>,
48    pub done: bool,
49    pub timestamp: u64,
50}
51
52/// Experience replay buffer
53#[derive(Debug)]
54pub(crate) struct ExperienceBuffer {
55    pub buffer: VecDeque<Experience>,
56    pub capacity: usize,
57    pub priority_weights: Vec<f64>,
58}
59
60/// Performance metrics for reinforcement learning
61#[derive(Debug, Clone)]
62pub struct PerformanceMetrics {
63    #[allow(dead_code)]
64    pub executiontime: f64,
65    #[allow(dead_code)]
66    pub cache_efficiency: f64,
67    #[allow(dead_code)]
68    pub simd_utilization: f64,
69    #[allow(dead_code)]
70    pub parallel_efficiency: f64,
71    #[allow(dead_code)]
72    pub memory_bandwidth: f64,
73    pub strategy_used: OptimizationStrategy,
74}
75
76impl RLAgent {
77    /// Create a new RL agent
78    pub fn new(
79        state_size: usize,
80        action_size: usize,
81        algorithm: RLAlgorithm,
82        learning_rate: f64,
83        epsilon: f64,
84    ) -> Self {
85        let q_network = NeuralNetwork::new(state_size, 3, 64, action_size, 4);
86
87        let target_network = match algorithm {
88            RLAlgorithm::DQN => Some(q_network.clone()),
89            _ => None,
90        };
91
92        let (policy_network, value_network) = match algorithm {
93            RLAlgorithm::ActorCritic | RLAlgorithm::PPO | RLAlgorithm::SAC => {
94                let policy = NeuralNetwork::new(state_size, 2, 32, action_size, 4);
95                let value = NeuralNetwork::new(state_size, 2, 32, 1, 4);
96                (Some(policy), Some(value))
97            }
98            _ => (None, None),
99        };
100
101        Self {
102            q_network,
103            target_network,
104            policy_network,
105            value_network,
106            algorithm,
107            epsilon,
108            learningrate: learning_rate,
109        }
110    }
111
112    /// Select action using current policy
113    pub fn select_action(&self, state: &[f64]) -> OptimizationStrategy {
114        let mut rng = scirs2_core::random::thread_rng();
115
116        // Epsilon-greedy action selection for DQN
117        if matches!(self.algorithm, RLAlgorithm::DQN) && rng.gen::<f64>() < self.epsilon {
118            // Random action
119            self.random_action()
120        } else {
121            // Greedy action
122            self.greedy_action(state)
123        }
124    }
125
126    /// Select greedy action based on current Q-values or policy
127    fn greedy_action(&self, state: &[f64]) -> OptimizationStrategy {
128        match self.algorithm {
129            RLAlgorithm::DQN => {
130                let q_values = self.q_network.forward(state);
131                let best_action_idx = q_values
132                    .iter()
133                    .enumerate()
134                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
135                    .map(|(idx, _)| idx)
136                    .unwrap_or(0);
137                self.idx_to_strategy(best_action_idx)
138            }
139            RLAlgorithm::PolicyGradient
140            | RLAlgorithm::ActorCritic
141            | RLAlgorithm::PPO
142            | RLAlgorithm::SAC => {
143                if let Some(ref policy_network) = self.policy_network {
144                    let action_probs = policy_network.forward(state);
145                    let best_action_idx = action_probs
146                        .iter()
147                        .enumerate()
148                        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
149                        .map(|(idx, _)| idx)
150                        .unwrap_or(0);
151                    self.idx_to_strategy(best_action_idx)
152                } else {
153                    self.random_action()
154                }
155            }
156        }
157    }
158
159    /// Select random action
160    fn random_action(&self) -> OptimizationStrategy {
161        let mut rng = scirs2_core::random::thread_rng();
162        let strategies = [
163            OptimizationStrategy::RowWiseCache,
164            OptimizationStrategy::ColumnWiseLocality,
165            OptimizationStrategy::BlockStructured,
166            OptimizationStrategy::DiagonalOptimized,
167            OptimizationStrategy::Hierarchical,
168            OptimizationStrategy::StreamingCompute,
169            OptimizationStrategy::SIMDVectorized,
170            OptimizationStrategy::ParallelWorkStealing,
171            OptimizationStrategy::AdaptiveHybrid,
172        ];
173
174        strategies[rng.gen_range(0..strategies.len())]
175    }
176
177    /// Convert action index to optimization strategy
178    fn idx_to_strategy(&self, idx: usize) -> OptimizationStrategy {
179        match idx % 9 {
180            0 => OptimizationStrategy::RowWiseCache,
181            1 => OptimizationStrategy::ColumnWiseLocality,
182            2 => OptimizationStrategy::BlockStructured,
183            3 => OptimizationStrategy::DiagonalOptimized,
184            4 => OptimizationStrategy::Hierarchical,
185            5 => OptimizationStrategy::StreamingCompute,
186            6 => OptimizationStrategy::SIMDVectorized,
187            7 => OptimizationStrategy::ParallelWorkStealing,
188            _ => OptimizationStrategy::AdaptiveHybrid,
189        }
190    }
191
192    /// Convert optimization strategy to action index
193    fn strategy_to_idx(&self, strategy: OptimizationStrategy) -> usize {
194        Self::strategy_to_idx_static(strategy)
195    }
196
197    /// Static version of strategy_to_idx to avoid borrowing issues
198    fn strategy_to_idx_static(strategy: OptimizationStrategy) -> usize {
199        match strategy {
200            OptimizationStrategy::RowWiseCache => 0,
201            OptimizationStrategy::ColumnWiseLocality => 1,
202            OptimizationStrategy::BlockStructured => 2,
203            OptimizationStrategy::DiagonalOptimized => 3,
204            OptimizationStrategy::Hierarchical => 4,
205            OptimizationStrategy::StreamingCompute => 5,
206            OptimizationStrategy::SIMDVectorized => 6,
207            OptimizationStrategy::ParallelWorkStealing => 7,
208            OptimizationStrategy::AdaptiveHybrid => 8,
209        }
210    }
211
212    /// Train the agent on a batch of experiences
213    pub fn train(&mut self, experiences: &[Experience]) -> SparseResult<()> {
214        if experiences.is_empty() {
215            return Ok(());
216        }
217
218        match self.algorithm {
219            RLAlgorithm::DQN => self.train_dqn(experiences),
220            RLAlgorithm::PolicyGradient => self.train_policy_gradient(experiences),
221            RLAlgorithm::ActorCritic => self.train_actor_critic(experiences),
222            RLAlgorithm::PPO => self.train_ppo(experiences),
223            RLAlgorithm::SAC => self.train_sac(experiences),
224        }
225    }
226
227    /// Train DQN algorithm
228    fn train_dqn(&mut self, experiences: &[Experience]) -> SparseResult<()> {
229        for experience in experiences {
230            let current_q_values = self.q_network.forward(&experience.state);
231            let action_idx = self.strategy_to_idx(experience.action);
232
233            let target = if experience.done {
234                experience.reward
235            } else if let Some(ref target_network) = self.target_network {
236                let next_q_values = target_network.forward(&experience.next_state);
237                let max_next_q = next_q_values
238                    .iter()
239                    .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
240                experience.reward + 0.99 * max_next_q // gamma = 0.99
241            } else {
242                experience.reward
243            };
244
245            // Simplified Q-learning update
246            // In practice, you'd compute proper gradients and update weights
247            let mut target_q_values = current_q_values;
248            if action_idx < target_q_values.len() {
249                target_q_values[action_idx] = target;
250            }
251
252            // Update Q-network (simplified)
253            let (_, cache) = self.q_network.forward_with_cache(&experience.state);
254            let gradients =
255                self.q_network
256                    .compute_gradients(&experience.state, &target_q_values, &cache);
257            self.q_network.update_weights(&gradients, self.learningrate);
258        }
259
260        Ok(())
261    }
262
263    /// Train policy gradient algorithm
264    fn train_policy_gradient(&mut self, experiences: &[Experience]) -> SparseResult<()> {
265        // Simplified policy gradient implementation
266        let learning_rate = self.learningrate;
267        if let Some(ref mut policy_network) = self.policy_network {
268            for experience in experiences {
269                let action_probs = policy_network.forward(&experience.state);
270                let action_idx = Self::strategy_to_idx_static(experience.action);
271
272                // Compute policy gradient (simplified)
273                let mut target_probs = action_probs;
274                if action_idx < target_probs.len() {
275                    target_probs[action_idx] += learning_rate * experience.reward;
276                }
277
278                // Update policy network
279                let (_, cache) = policy_network.forward_with_cache(&experience.state);
280                let gradients =
281                    policy_network.compute_gradients(&experience.state, &target_probs, &cache);
282                policy_network.update_weights(&gradients, learning_rate);
283            }
284        }
285
286        Ok(())
287    }
288
289    /// Train Actor-Critic algorithm
290    fn train_actor_critic(&mut self, experiences: &[Experience]) -> SparseResult<()> {
291        // Simplified Actor-Critic implementation
292        let learning_rate = self.learningrate;
293        for experience in experiences {
294            // Update critic (value network)
295            if let Some(ref mut value_network) = self.value_network {
296                let current_value = value_network.forward(&experience.state)[0];
297                let target_value = if experience.done {
298                    experience.reward
299                } else {
300                    let next_value = value_network.forward(&experience.next_state)[0];
301                    experience.reward + 0.99 * next_value
302                };
303
304                let (_, cache) = value_network.forward_with_cache(&experience.state);
305                let gradients =
306                    value_network.compute_gradients(&experience.state, &[target_value], &cache);
307                value_network.update_weights(&gradients, learning_rate);
308
309                // Update actor (policy network)
310                if let Some(ref mut policy_network) = self.policy_network {
311                    let advantage = target_value - current_value;
312                    let action_probs = policy_network.forward(&experience.state);
313                    let action_idx = Self::strategy_to_idx_static(experience.action);
314
315                    let mut target_probs = action_probs;
316                    if action_idx < target_probs.len() {
317                        target_probs[action_idx] += learning_rate * advantage;
318                    }
319
320                    let (_, cache) = policy_network.forward_with_cache(&experience.state);
321                    let gradients =
322                        policy_network.compute_gradients(&experience.state, &target_probs, &cache);
323                    policy_network.update_weights(&gradients, learning_rate);
324                }
325            }
326        }
327
328        Ok(())
329    }
330
331    /// Train PPO algorithm (simplified)
332    fn train_ppo(&mut self, experiences: &[Experience]) -> SparseResult<()> {
333        // Simplified PPO implementation
334        self.train_actor_critic(experiences) // Using Actor-Critic as base
335    }
336
337    /// Train SAC algorithm (simplified)
338    fn train_sac(&mut self, experiences: &[Experience]) -> SparseResult<()> {
339        // Simplified SAC implementation
340        self.train_actor_critic(experiences) // Using Actor-Critic as base
341    }
342
343    /// Update target network (for DQN)
344    pub fn update_target_network(&mut self) {
345        if let Some(ref mut target_network) = self.target_network {
346            let params = self.q_network.get_parameters();
347            target_network.set_parameters(&params);
348        }
349    }
350
351    /// Decay exploration rate
352    pub fn decay_epsilon(&mut self, decay_rate: f64) {
353        self.epsilon *= decay_rate;
354        self.epsilon = self.epsilon.max(0.01); // Minimum epsilon
355    }
356
357    /// Compute value estimate for a state
358    pub fn estimate_value(&self, state: &[f64]) -> f64 {
359        match self.algorithm {
360            RLAlgorithm::DQN => {
361                let q_values = self.q_network.forward(state);
362                q_values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
363            }
364            _ => {
365                if let Some(ref value_network) = self.value_network {
366                    value_network.forward(state)[0]
367                } else {
368                    0.0
369                }
370            }
371        }
372    }
373}
374
375impl ExperienceBuffer {
376    /// Create a new experience buffer
377    pub fn new(capacity: usize) -> Self {
378        Self {
379            buffer: VecDeque::new(),
380            capacity,
381            priority_weights: Vec::new(),
382        }
383    }
384
385    /// Add experience to buffer
386    pub fn add(&mut self, experience: Experience) {
387        if self.buffer.len() >= self.capacity {
388            self.buffer.pop_front();
389            if !self.priority_weights.is_empty() {
390                self.priority_weights.remove(0);
391            }
392        }
393
394        self.buffer.push_back(experience);
395        self.priority_weights.push(1.0); // Default priority
396    }
397
398    /// Sample a batch of experiences
399    pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
400        let mut rng = scirs2_core::random::thread_rng();
401        let mut batch = Vec::new();
402
403        for _ in 0..batch_size.min(self.buffer.len()) {
404            let idx = rng.gen_range(0..self.buffer.len());
405            if let Some(experience) = self.buffer.get(idx) {
406                batch.push(experience.clone());
407            }
408        }
409
410        batch
411    }
412
413    /// Sample with priority weights
414    pub fn sample_prioritized(&self, batch_size: usize) -> Vec<Experience> {
415        if self.priority_weights.is_empty() {
416            return self.sample(batch_size);
417        }
418
419        let mut rng = scirs2_core::random::thread_rng();
420        let mut batch = Vec::new();
421        let total_weight: f64 = self.priority_weights.iter().sum();
422
423        for _ in 0..batch_size.min(self.buffer.len()) {
424            let mut weight_sum = 0.0;
425            let target = rng.gen::<f64>() * total_weight;
426
427            for (idx, &weight) in self.priority_weights.iter().enumerate() {
428                weight_sum += weight;
429                if weight_sum >= target {
430                    if let Some(experience) = self.buffer.get(idx) {
431                        batch.push(experience.clone());
432                        break;
433                    }
434                }
435            }
436        }
437
438        batch
439    }
440
441    /// Update priority for experience
442    pub fn update_priority(&mut self, idx: usize, priority: f64) {
443        if idx < self.priority_weights.len() {
444            self.priority_weights[idx] = priority.max(0.01); // Minimum priority
445        }
446    }
447
448    /// Get buffer size
449    pub fn len(&self) -> usize {
450        self.buffer.len()
451    }
452
453    /// Check if buffer is empty
454    pub fn is_empty(&self) -> bool {
455        self.buffer.is_empty()
456    }
457
458    /// Clear the buffer
459    pub fn clear(&mut self) {
460        self.buffer.clear();
461        self.priority_weights.clear();
462    }
463}
464
465impl PerformanceMetrics {
466    /// Create new performance metrics
467    pub fn new(
468        execution_time: f64,
469        cache_efficiency: f64,
470        simd_utilization: f64,
471        parallel_efficiency: f64,
472        memory_bandwidth: f64,
473        strategy_used: OptimizationStrategy,
474    ) -> Self {
475        Self {
476            executiontime: execution_time,
477            cache_efficiency,
478            simd_utilization,
479            parallel_efficiency,
480            memory_bandwidth,
481            strategy_used,
482        }
483    }
484
485    /// Compute reward for reinforcement learning
486    pub fn compute_reward(&self, baseline_time: f64) -> f64 {
487        // Reward based on performance improvement
488        let time_improvement = (baseline_time - self.executiontime) / baseline_time;
489        let efficiency_score =
490            (self.cache_efficiency + self.simd_utilization + self.parallel_efficiency) / 3.0;
491
492        // Combined reward considering both time improvement and efficiency
493        time_improvement * 10.0 + efficiency_score * 5.0
494    }
495
496    /// Get overall performance score
497    pub fn performance_score(&self) -> f64 {
498        let time_score = 1.0 / (1.0 + self.executiontime); // Lower time is better
499        let efficiency_score = (self.cache_efficiency
500            + self.simd_utilization
501            + self.parallel_efficiency
502            + self.memory_bandwidth)
503            / 4.0;
504
505        (time_score + efficiency_score) / 2.0
506    }
507}