scirs2_sparse/neural_adaptive_sparse/
reinforcement_learning.rs1use super::neural_network::NeuralNetwork;
7use super::pattern_memory::OptimizationStrategy;
8use crate::error::SparseResult;
9use scirs2_core::random::Rng;
10use std::collections::VecDeque;
11
12#[derive(Debug, Clone, Copy)]
14pub enum RLAlgorithm {
15 DQN,
17 PolicyGradient,
19 ActorCritic,
21 PPO,
23 SAC,
25}
26
27#[derive(Debug)]
29#[allow(dead_code)]
30pub(crate) struct RLAgent {
31 pub q_network: NeuralNetwork,
32 pub target_network: Option<NeuralNetwork>,
33 pub policy_network: Option<NeuralNetwork>,
34 pub value_network: Option<NeuralNetwork>,
35 pub algorithm: RLAlgorithm,
36 pub epsilon: f64,
37 pub learningrate: f64,
38}
39
40#[derive(Debug, Clone)]
42#[allow(dead_code)]
43pub(crate) struct Experience {
44 pub state: Vec<f64>,
45 pub action: OptimizationStrategy,
46 pub reward: f64,
47 pub next_state: Vec<f64>,
48 pub done: bool,
49 pub timestamp: u64,
50}
51
52#[derive(Debug)]
54pub(crate) struct ExperienceBuffer {
55 pub buffer: VecDeque<Experience>,
56 pub capacity: usize,
57 pub priority_weights: Vec<f64>,
58}
59
60#[derive(Debug, Clone)]
62pub struct PerformanceMetrics {
63 #[allow(dead_code)]
64 pub executiontime: f64,
65 #[allow(dead_code)]
66 pub cache_efficiency: f64,
67 #[allow(dead_code)]
68 pub simd_utilization: f64,
69 #[allow(dead_code)]
70 pub parallel_efficiency: f64,
71 #[allow(dead_code)]
72 pub memory_bandwidth: f64,
73 pub strategy_used: OptimizationStrategy,
74}
75
76impl RLAgent {
77 pub fn new(
79 state_size: usize,
80 action_size: usize,
81 algorithm: RLAlgorithm,
82 learning_rate: f64,
83 epsilon: f64,
84 ) -> Self {
85 let q_network = NeuralNetwork::new(state_size, 3, 64, action_size, 4);
86
87 let target_network = match algorithm {
88 RLAlgorithm::DQN => Some(q_network.clone()),
89 _ => None,
90 };
91
92 let (policy_network, value_network) = match algorithm {
93 RLAlgorithm::ActorCritic | RLAlgorithm::PPO | RLAlgorithm::SAC => {
94 let policy = NeuralNetwork::new(state_size, 2, 32, action_size, 4);
95 let value = NeuralNetwork::new(state_size, 2, 32, 1, 4);
96 (Some(policy), Some(value))
97 }
98 _ => (None, None),
99 };
100
101 Self {
102 q_network,
103 target_network,
104 policy_network,
105 value_network,
106 algorithm,
107 epsilon,
108 learningrate: learning_rate,
109 }
110 }
111
112 pub fn select_action(&self, state: &[f64]) -> OptimizationStrategy {
114 let mut rng = scirs2_core::random::thread_rng();
115
116 if matches!(self.algorithm, RLAlgorithm::DQN) && rng.gen::<f64>() < self.epsilon {
118 self.random_action()
120 } else {
121 self.greedy_action(state)
123 }
124 }
125
126 fn greedy_action(&self, state: &[f64]) -> OptimizationStrategy {
128 match self.algorithm {
129 RLAlgorithm::DQN => {
130 let q_values = self.q_network.forward(state);
131 let best_action_idx = q_values
132 .iter()
133 .enumerate()
134 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
135 .map(|(idx, _)| idx)
136 .unwrap_or(0);
137 self.idx_to_strategy(best_action_idx)
138 }
139 RLAlgorithm::PolicyGradient
140 | RLAlgorithm::ActorCritic
141 | RLAlgorithm::PPO
142 | RLAlgorithm::SAC => {
143 if let Some(ref policy_network) = self.policy_network {
144 let action_probs = policy_network.forward(state);
145 let best_action_idx = action_probs
146 .iter()
147 .enumerate()
148 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
149 .map(|(idx, _)| idx)
150 .unwrap_or(0);
151 self.idx_to_strategy(best_action_idx)
152 } else {
153 self.random_action()
154 }
155 }
156 }
157 }
158
159 fn random_action(&self) -> OptimizationStrategy {
161 let mut rng = scirs2_core::random::thread_rng();
162 let strategies = [
163 OptimizationStrategy::RowWiseCache,
164 OptimizationStrategy::ColumnWiseLocality,
165 OptimizationStrategy::BlockStructured,
166 OptimizationStrategy::DiagonalOptimized,
167 OptimizationStrategy::Hierarchical,
168 OptimizationStrategy::StreamingCompute,
169 OptimizationStrategy::SIMDVectorized,
170 OptimizationStrategy::ParallelWorkStealing,
171 OptimizationStrategy::AdaptiveHybrid,
172 ];
173
174 strategies[rng.gen_range(0..strategies.len())]
175 }
176
177 fn idx_to_strategy(&self, idx: usize) -> OptimizationStrategy {
179 match idx % 9 {
180 0 => OptimizationStrategy::RowWiseCache,
181 1 => OptimizationStrategy::ColumnWiseLocality,
182 2 => OptimizationStrategy::BlockStructured,
183 3 => OptimizationStrategy::DiagonalOptimized,
184 4 => OptimizationStrategy::Hierarchical,
185 5 => OptimizationStrategy::StreamingCompute,
186 6 => OptimizationStrategy::SIMDVectorized,
187 7 => OptimizationStrategy::ParallelWorkStealing,
188 _ => OptimizationStrategy::AdaptiveHybrid,
189 }
190 }
191
192 fn strategy_to_idx(&self, strategy: OptimizationStrategy) -> usize {
194 Self::strategy_to_idx_static(strategy)
195 }
196
197 fn strategy_to_idx_static(strategy: OptimizationStrategy) -> usize {
199 match strategy {
200 OptimizationStrategy::RowWiseCache => 0,
201 OptimizationStrategy::ColumnWiseLocality => 1,
202 OptimizationStrategy::BlockStructured => 2,
203 OptimizationStrategy::DiagonalOptimized => 3,
204 OptimizationStrategy::Hierarchical => 4,
205 OptimizationStrategy::StreamingCompute => 5,
206 OptimizationStrategy::SIMDVectorized => 6,
207 OptimizationStrategy::ParallelWorkStealing => 7,
208 OptimizationStrategy::AdaptiveHybrid => 8,
209 }
210 }
211
212 pub fn train(&mut self, experiences: &[Experience]) -> SparseResult<()> {
214 if experiences.is_empty() {
215 return Ok(());
216 }
217
218 match self.algorithm {
219 RLAlgorithm::DQN => self.train_dqn(experiences),
220 RLAlgorithm::PolicyGradient => self.train_policy_gradient(experiences),
221 RLAlgorithm::ActorCritic => self.train_actor_critic(experiences),
222 RLAlgorithm::PPO => self.train_ppo(experiences),
223 RLAlgorithm::SAC => self.train_sac(experiences),
224 }
225 }
226
227 fn train_dqn(&mut self, experiences: &[Experience]) -> SparseResult<()> {
229 for experience in experiences {
230 let current_q_values = self.q_network.forward(&experience.state);
231 let action_idx = self.strategy_to_idx(experience.action);
232
233 let target = if experience.done {
234 experience.reward
235 } else if let Some(ref target_network) = self.target_network {
236 let next_q_values = target_network.forward(&experience.next_state);
237 let max_next_q = next_q_values
238 .iter()
239 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
240 experience.reward + 0.99 * max_next_q } else {
242 experience.reward
243 };
244
245 let mut target_q_values = current_q_values;
248 if action_idx < target_q_values.len() {
249 target_q_values[action_idx] = target;
250 }
251
252 let (_, cache) = self.q_network.forward_with_cache(&experience.state);
254 let gradients =
255 self.q_network
256 .compute_gradients(&experience.state, &target_q_values, &cache);
257 self.q_network.update_weights(&gradients, self.learningrate);
258 }
259
260 Ok(())
261 }
262
263 fn train_policy_gradient(&mut self, experiences: &[Experience]) -> SparseResult<()> {
265 let learning_rate = self.learningrate;
267 if let Some(ref mut policy_network) = self.policy_network {
268 for experience in experiences {
269 let action_probs = policy_network.forward(&experience.state);
270 let action_idx = Self::strategy_to_idx_static(experience.action);
271
272 let mut target_probs = action_probs;
274 if action_idx < target_probs.len() {
275 target_probs[action_idx] += learning_rate * experience.reward;
276 }
277
278 let (_, cache) = policy_network.forward_with_cache(&experience.state);
280 let gradients =
281 policy_network.compute_gradients(&experience.state, &target_probs, &cache);
282 policy_network.update_weights(&gradients, learning_rate);
283 }
284 }
285
286 Ok(())
287 }
288
289 fn train_actor_critic(&mut self, experiences: &[Experience]) -> SparseResult<()> {
291 let learning_rate = self.learningrate;
293 for experience in experiences {
294 if let Some(ref mut value_network) = self.value_network {
296 let current_value = value_network.forward(&experience.state)[0];
297 let target_value = if experience.done {
298 experience.reward
299 } else {
300 let next_value = value_network.forward(&experience.next_state)[0];
301 experience.reward + 0.99 * next_value
302 };
303
304 let (_, cache) = value_network.forward_with_cache(&experience.state);
305 let gradients =
306 value_network.compute_gradients(&experience.state, &[target_value], &cache);
307 value_network.update_weights(&gradients, learning_rate);
308
309 if let Some(ref mut policy_network) = self.policy_network {
311 let advantage = target_value - current_value;
312 let action_probs = policy_network.forward(&experience.state);
313 let action_idx = Self::strategy_to_idx_static(experience.action);
314
315 let mut target_probs = action_probs;
316 if action_idx < target_probs.len() {
317 target_probs[action_idx] += learning_rate * advantage;
318 }
319
320 let (_, cache) = policy_network.forward_with_cache(&experience.state);
321 let gradients =
322 policy_network.compute_gradients(&experience.state, &target_probs, &cache);
323 policy_network.update_weights(&gradients, learning_rate);
324 }
325 }
326 }
327
328 Ok(())
329 }
330
331 fn train_ppo(&mut self, experiences: &[Experience]) -> SparseResult<()> {
333 self.train_actor_critic(experiences) }
336
337 fn train_sac(&mut self, experiences: &[Experience]) -> SparseResult<()> {
339 self.train_actor_critic(experiences) }
342
343 pub fn update_target_network(&mut self) {
345 if let Some(ref mut target_network) = self.target_network {
346 let params = self.q_network.get_parameters();
347 target_network.set_parameters(¶ms);
348 }
349 }
350
351 pub fn decay_epsilon(&mut self, decay_rate: f64) {
353 self.epsilon *= decay_rate;
354 self.epsilon = self.epsilon.max(0.01); }
356
357 pub fn estimate_value(&self, state: &[f64]) -> f64 {
359 match self.algorithm {
360 RLAlgorithm::DQN => {
361 let q_values = self.q_network.forward(state);
362 q_values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
363 }
364 _ => {
365 if let Some(ref value_network) = self.value_network {
366 value_network.forward(state)[0]
367 } else {
368 0.0
369 }
370 }
371 }
372 }
373}
374
375impl ExperienceBuffer {
376 pub fn new(capacity: usize) -> Self {
378 Self {
379 buffer: VecDeque::new(),
380 capacity,
381 priority_weights: Vec::new(),
382 }
383 }
384
385 pub fn add(&mut self, experience: Experience) {
387 if self.buffer.len() >= self.capacity {
388 self.buffer.pop_front();
389 if !self.priority_weights.is_empty() {
390 self.priority_weights.remove(0);
391 }
392 }
393
394 self.buffer.push_back(experience);
395 self.priority_weights.push(1.0); }
397
398 pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
400 let mut rng = scirs2_core::random::thread_rng();
401 let mut batch = Vec::new();
402
403 for _ in 0..batch_size.min(self.buffer.len()) {
404 let idx = rng.gen_range(0..self.buffer.len());
405 if let Some(experience) = self.buffer.get(idx) {
406 batch.push(experience.clone());
407 }
408 }
409
410 batch
411 }
412
413 pub fn sample_prioritized(&self, batch_size: usize) -> Vec<Experience> {
415 if self.priority_weights.is_empty() {
416 return self.sample(batch_size);
417 }
418
419 let mut rng = scirs2_core::random::thread_rng();
420 let mut batch = Vec::new();
421 let total_weight: f64 = self.priority_weights.iter().sum();
422
423 for _ in 0..batch_size.min(self.buffer.len()) {
424 let mut weight_sum = 0.0;
425 let target = rng.gen::<f64>() * total_weight;
426
427 for (idx, &weight) in self.priority_weights.iter().enumerate() {
428 weight_sum += weight;
429 if weight_sum >= target {
430 if let Some(experience) = self.buffer.get(idx) {
431 batch.push(experience.clone());
432 break;
433 }
434 }
435 }
436 }
437
438 batch
439 }
440
441 pub fn update_priority(&mut self, idx: usize, priority: f64) {
443 if idx < self.priority_weights.len() {
444 self.priority_weights[idx] = priority.max(0.01); }
446 }
447
448 pub fn len(&self) -> usize {
450 self.buffer.len()
451 }
452
453 pub fn is_empty(&self) -> bool {
455 self.buffer.is_empty()
456 }
457
458 pub fn clear(&mut self) {
460 self.buffer.clear();
461 self.priority_weights.clear();
462 }
463}
464
465impl PerformanceMetrics {
466 pub fn new(
468 execution_time: f64,
469 cache_efficiency: f64,
470 simd_utilization: f64,
471 parallel_efficiency: f64,
472 memory_bandwidth: f64,
473 strategy_used: OptimizationStrategy,
474 ) -> Self {
475 Self {
476 executiontime: execution_time,
477 cache_efficiency,
478 simd_utilization,
479 parallel_efficiency,
480 memory_bandwidth,
481 strategy_used,
482 }
483 }
484
485 pub fn compute_reward(&self, baseline_time: f64) -> f64 {
487 let time_improvement = (baseline_time - self.executiontime) / baseline_time;
489 let efficiency_score =
490 (self.cache_efficiency + self.simd_utilization + self.parallel_efficiency) / 3.0;
491
492 time_improvement * 10.0 + efficiency_score * 5.0
494 }
495
496 pub fn performance_score(&self) -> f64 {
498 let time_score = 1.0 / (1.0 + self.executiontime); let efficiency_score = (self.cache_efficiency
500 + self.simd_utilization
501 + self.parallel_efficiency
502 + self.memory_bandwidth)
503 / 4.0;
504
505 (time_score + efficiency_score) / 2.0
506 }
507}