scirs2_sparse/neural_adaptive_sparse/
processor.rs

1//! Main neural-adaptive sparse matrix processor
2//!
3//! This module contains the main processor that coordinates all neural network,
4//! reinforcement learning, and pattern memory components for adaptive optimization.
5
6use super::config::NeuralAdaptiveConfig;
7use super::neural_network::NeuralNetwork;
8use super::pattern_memory::{MatrixFingerprint, OptimizationStrategy, PatternMemory};
9use super::reinforcement_learning::{Experience, ExperienceBuffer, PerformanceMetrics, RLAgent};
10use super::transformer::TransformerModel;
11use crate::error::SparseResult;
12use num_traits::{Float, NumAssign};
13use scirs2_core::simd_ops::SimdUnifiedOps;
14use std::collections::VecDeque;
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17/// Neural-adaptive sparse matrix processor
18pub struct NeuralAdaptiveSparseProcessor {
19    config: NeuralAdaptiveConfig,
20    neural_network: NeuralNetwork,
21    pattern_memory: PatternMemory,
22    performance_history: VecDeque<PerformanceMetrics>,
23    adaptation_counter: AtomicUsize,
24    optimization_strategies: Vec<OptimizationStrategy>,
25    /// Reinforcement learning agent
26    rl_agent: Option<RLAgent>,
27    /// Transformer model for attention-based optimization
28    transformer: Option<TransformerModel>,
29    /// Experience replay buffer for RL
30    experience_buffer: ExperienceBuffer,
31    /// Current exploration rate (decays over time)
32    current_exploration_rate: f64,
33}
34
35/// Statistics for neural processor performance
36#[derive(Debug, Clone)]
37pub struct NeuralProcessorStats {
38    pub total_operations: usize,
39    pub successful_adaptations: usize,
40    pub average_performance_improvement: f64,
41    pub most_effective_strategy: OptimizationStrategy,
42    pub neural_network_accuracy: f64,
43    pub rl_agent_reward: f64,
44    pub pattern_memory_hit_rate: f64,
45    pub transformer_attention_score: f64,
46}
47
48impl NeuralAdaptiveSparseProcessor {
49    /// Create a new neural-adaptive sparse matrix processor
50    pub fn new(config: NeuralAdaptiveConfig) -> Self {
51        // Validate configuration
52        if let Err(e) = config.validate() {
53            panic!("Invalid configuration: {}", e);
54        }
55
56        let neural_network = NeuralNetwork::new(
57            config.modeldim,
58            config.hidden_layers,
59            config.neurons_per_layer,
60            9, // Number of optimization strategies
61            config.attention_heads,
62        );
63        let pattern_memory = PatternMemory::new(config.memory_capacity);
64
65        let optimization_strategies = vec![
66            OptimizationStrategy::RowWiseCache,
67            OptimizationStrategy::ColumnWiseLocality,
68            OptimizationStrategy::BlockStructured,
69            OptimizationStrategy::DiagonalOptimized,
70            OptimizationStrategy::Hierarchical,
71            OptimizationStrategy::StreamingCompute,
72            OptimizationStrategy::SIMDVectorized,
73            OptimizationStrategy::ParallelWorkStealing,
74            OptimizationStrategy::AdaptiveHybrid,
75        ];
76
77        // Initialize RL agent if enabled
78        let rl_agent = if config.reinforcement_learning {
79            Some(RLAgent::new(
80                config.modeldim,
81                9, // Number of actions (optimization strategies)
82                config.rl_algorithm,
83                config.learningrate,
84                config.exploration_rate,
85            ))
86        } else {
87            None
88        };
89
90        // Initialize transformer if self-attention is enabled
91        let transformer = if config.self_attention {
92            Some(TransformerModel::new(
93                config.modeldim,
94                config.transformer_layers,
95                config.attention_heads,
96                config.ff_dim,
97                1000, // Max sequence length
98            ))
99        } else {
100            None
101        };
102
103        let experience_buffer = ExperienceBuffer::new(config.replay_buffer_size);
104
105        Self {
106            config: config.clone(),
107            neural_network,
108            pattern_memory,
109            performance_history: VecDeque::new(),
110            adaptation_counter: AtomicUsize::new(0),
111            optimization_strategies,
112            rl_agent,
113            transformer,
114            experience_buffer,
115            current_exploration_rate: config.exploration_rate,
116        }
117    }
118
119    /// Process sparse matrix operation with adaptive optimization
120    pub fn optimize_operation<T>(
121        &mut self,
122        matrix_features: &[f64],
123        operation_context: &OperationContext,
124    ) -> SparseResult<OptimizationStrategy>
125    where
126        T: Float + NumAssign + SimdUnifiedOps + std::fmt::Debug + Copy + Send + Sync + 'static,
127    {
128        // Extract matrix fingerprint
129        let fingerprint = self.extract_matrix_fingerprint(matrix_features, operation_context);
130
131        // Try pattern memory first
132        if let Some(strategy) = self.pattern_memory.get_strategy(&fingerprint) {
133            return Ok(strategy);
134        }
135
136        // Use neural network and RL agent for optimization
137        let state = self.encode_state(matrix_features, operation_context)?;
138
139        let strategy = if let Some(ref rl_agent) = self.rl_agent {
140            rl_agent.select_action(&state)
141        } else {
142            self.neural_network_select_action(&state)?
143        };
144
145        // Store the experience for later learning
146        let experience = Experience {
147            state: state.clone(),
148            action: strategy,
149            reward: 0.0,       // Will be updated after operation execution
150            next_state: state, // Will be updated with next state
151            done: false,
152            timestamp: std::time::SystemTime::now()
153                .duration_since(std::time::UNIX_EPOCH)
154                .unwrap_or_default()
155                .as_secs(),
156        };
157
158        self.experience_buffer.add(experience);
159
160        Ok(strategy)
161    }
162
163    /// Learn from operation performance
164    pub fn learn_from_performance(
165        &mut self,
166        strategy: OptimizationStrategy,
167        performance: PerformanceMetrics,
168        matrix_features: &[f64],
169        operation_context: &OperationContext,
170    ) -> SparseResult<()> {
171        // Compute reward
172        let baseline_time = self.estimate_baseline_performance(matrix_features);
173        let reward = performance.compute_reward(baseline_time);
174
175        // Update experience buffer with reward
176        if let Some(mut experience) = self.experience_buffer.buffer.back_mut() {
177            experience.reward = reward;
178        }
179
180        // Store successful patterns
181        let fingerprint = self.extract_matrix_fingerprint(matrix_features, operation_context);
182        if reward > 0.0 {
183            self.pattern_memory.store_pattern(fingerprint, strategy);
184        }
185
186        // Train RL agent
187        if let Some(ref mut rl_agent) = self.rl_agent {
188            let batch_size = 32.min(self.experience_buffer.len());
189            if batch_size > 0 {
190                let batch = self.experience_buffer.sample(batch_size);
191                rl_agent.train(&batch)?;
192            }
193        }
194
195        // Update performance history
196        self.performance_history.push_back(performance);
197        if self.performance_history.len() > 1000 {
198            self.performance_history.pop_front();
199        }
200
201        // Increment adaptation counter
202        self.adaptation_counter.fetch_add(1, Ordering::Relaxed);
203
204        // Decay exploration rate
205        if let Some(ref mut rl_agent) = self.rl_agent {
206            rl_agent.decay_epsilon(0.995);
207        }
208
209        Ok(())
210    }
211
212    /// Extract matrix fingerprint from features
213    fn extract_matrix_fingerprint(
214        &self,
215        features: &[f64],
216        context: &OperationContext,
217    ) -> MatrixFingerprint {
218        // Extract basic properties from features
219        let rows = context.matrix_shape.0;
220        let cols = context.matrix_shape.1;
221        let nnz = context.nnz;
222
223        // Compute a simple hash of the sparsity pattern
224        use std::collections::hash_map::DefaultHasher;
225        use std::hash::{Hash, Hasher};
226        let mut hasher = DefaultHasher::new();
227        for (i, &feature) in features.iter().enumerate().take(100) {
228            ((feature * 1000.0) as i64).hash(&mut hasher);
229        }
230        let sparsity_pattern_hash = hasher.finish();
231
232        // Analyze distributions (simplified)
233        let row_distribution_type = super::pattern_memory::DistributionType::Random;
234        let column_distribution_type = super::pattern_memory::DistributionType::Random;
235
236        MatrixFingerprint {
237            rows,
238            cols,
239            nnz,
240            sparsity_pattern_hash,
241            row_distribution_type,
242            column_distribution_type,
243        }
244    }
245
246    /// Encode state for neural network/RL agent
247    fn encode_state(
248        &self,
249        matrix_features: &[f64],
250        context: &OperationContext,
251    ) -> SparseResult<Vec<f64>> {
252        let mut state = Vec::new();
253
254        // Matrix properties
255        state.push(context.matrix_shape.0 as f64);
256        state.push(context.matrix_shape.1 as f64);
257        state.push(context.nnz as f64);
258        state.push(context.nnz as f64 / (context.matrix_shape.0 * context.matrix_shape.1) as f64); // Sparsity
259
260        // Operation type
261        state.push(match context.operation_type {
262            OperationType::MatVec => 1.0,
263            OperationType::MatMat => 2.0,
264            OperationType::Solve => 3.0,
265            OperationType::Factorization => 4.0,
266        });
267
268        // Matrix features (truncated/padded to fixed size)
269        let feature_size = self.config.modeldim.saturating_sub(state.len());
270        for i in 0..feature_size {
271            if i < matrix_features.len() {
272                state.push(matrix_features[i]);
273            } else {
274                state.push(0.0);
275            }
276        }
277
278        // Use transformer for feature encoding if available
279        if let Some(ref transformer) = self.transformer {
280            let encoded = transformer.encode_matrix_pattern(&state);
281            Ok(encoded)
282        } else {
283            Ok(state)
284        }
285    }
286
287    /// Select action using neural network
288    fn neural_network_select_action(&self, state: &[f64]) -> SparseResult<OptimizationStrategy> {
289        let outputs = self.neural_network.forward(state);
290
291        let best_idx = outputs
292            .iter()
293            .enumerate()
294            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
295            .map(|(idx, _)| idx)
296            .unwrap_or(0);
297
298        Ok(self.optimization_strategies[best_idx % self.optimization_strategies.len()])
299    }
300
301    /// Estimate baseline performance for reward computation
302    fn estimate_baseline_performance(&self, _features: &[f64]) -> f64 {
303        // Simple baseline estimation
304        if let Some(last_performance) = self.performance_history.back() {
305            last_performance.executiontime
306        } else {
307            1.0 // Default baseline
308        }
309    }
310
311    /// Get processor statistics
312    pub fn get_statistics(&self) -> NeuralProcessorStats {
313        let total_operations = self.adaptation_counter.load(Ordering::Relaxed);
314        let successful_adaptations = self
315            .performance_history
316            .iter()
317            .filter(|p| p.compute_reward(1.0) > 0.0)
318            .count();
319
320        let average_improvement = if !self.performance_history.is_empty() {
321            self.performance_history
322                .iter()
323                .map(|p| p.performance_score())
324                .sum::<f64>()
325                / self.performance_history.len() as f64
326        } else {
327            0.0
328        };
329
330        let most_effective_strategy = self.get_most_effective_strategy();
331        let rl_reward = if let Some(ref rl_agent) = self.rl_agent {
332            // Estimate current RL performance
333            let dummy_state = vec![0.0; self.config.modeldim];
334            rl_agent.estimate_value(&dummy_state)
335        } else {
336            0.0
337        };
338
339        let pattern_memory_stats = self.pattern_memory.get_statistics();
340        let pattern_hit_rate = if total_operations > 0 {
341            pattern_memory_stats.stored_patterns as f64 / total_operations as f64
342        } else {
343            0.0
344        };
345
346        NeuralProcessorStats {
347            total_operations,
348            successful_adaptations,
349            average_performance_improvement: average_improvement,
350            most_effective_strategy,
351            neural_network_accuracy: 0.85, // Placeholder
352            rl_agent_reward: rl_reward,
353            pattern_memory_hit_rate: pattern_hit_rate,
354            transformer_attention_score: 0.75, // Placeholder
355        }
356    }
357
358    /// Get most effective optimization strategy
359    fn get_most_effective_strategy(&self) -> OptimizationStrategy {
360        let mut strategy_scores = std::collections::HashMap::new();
361
362        for performance in &self.performance_history {
363            let score = performance.performance_score();
364            let entry = strategy_scores
365                .entry(performance.strategy_used)
366                .or_insert((0.0, 0));
367            entry.0 += score;
368            entry.1 += 1;
369        }
370
371        strategy_scores
372            .into_iter()
373            .max_by(|(_, (score1, count1)), (_, (score2, count2))| {
374                let avg1 = score1 / *count1 as f64;
375                let avg2 = score2 / *count2 as f64;
376                avg1.partial_cmp(&avg2).unwrap()
377            })
378            .map(|(strategy, _)| strategy)
379            .unwrap_or(OptimizationStrategy::AdaptiveHybrid)
380    }
381
382    /// Update target networks (for DQN)
383    pub fn update_target_networks(&mut self) {
384        if let Some(ref mut rl_agent) = self.rl_agent {
385            rl_agent.update_target_network();
386        }
387    }
388
389    /// Save processor state
390    pub fn save_state(&self) -> ProcessorState {
391        let neural_params = self.neural_network.get_parameters();
392        let pattern_stats = self.pattern_memory.get_statistics();
393
394        ProcessorState {
395            neural_network_params: neural_params,
396            total_operations: self.adaptation_counter.load(Ordering::Relaxed),
397            pattern_memory_size: pattern_stats.stored_patterns,
398            current_exploration_rate: self.current_exploration_rate,
399        }
400    }
401
402    /// Load processor state
403    pub fn load_state(&mut self, state: ProcessorState) {
404        self.neural_network
405            .set_parameters(&state.neural_network_params);
406        self.adaptation_counter
407            .store(state.total_operations, Ordering::Relaxed);
408        self.current_exploration_rate = state.current_exploration_rate;
409    }
410
411    /// Adaptive sparse matrix-vector multiplication
412    pub fn adaptive_spmv<T>(
413        &mut self,
414        rows: &[usize],
415        cols: &[usize],
416        indptr: &[usize],
417        indices: &[usize],
418        data: &[T],
419        x: &[T],
420        y: &mut [T],
421    ) -> SparseResult<()>
422    where
423        T: Float + NumAssign + SimdUnifiedOps + std::fmt::Debug + Copy + Send + Sync + 'static,
424    {
425        // Extract matrix features for optimization decision
426        let matrix_features = self.extract_matrix_features(rows, cols, data);
427
428        // Create operation context
429        let context = OperationContext {
430            matrix_shape: (rows.len(), cols.len()),
431            nnz: data.len(),
432            operation_type: OperationType::MatVec,
433            performance_target: PerformanceTarget::Speed,
434        };
435
436        // Get optimization strategy
437        let strategy = self.optimize_operation::<T>(&matrix_features, &context)?;
438
439        // Execute the operation using the selected strategy
440        let start_time = std::time::Instant::now();
441        self.execute_spmv_with_strategy(strategy, indptr, indices, data, x, y)?;
442        let execution_time = start_time.elapsed().as_secs_f64();
443
444        // Learn from performance
445        let performance = PerformanceMetrics::new(
446            execution_time,
447            0.8,  // cache_efficiency (placeholder)
448            0.9,  // simd_utilization (placeholder)
449            0.7,  // parallel_efficiency (placeholder)
450            0.85, // memory_bandwidth (placeholder)
451            strategy,
452        );
453
454        self.learn_from_performance(strategy, performance, &matrix_features, &context)?;
455
456        Ok(())
457    }
458
459    /// Extract matrix features for neural network
460    fn extract_matrix_features<T>(&self, rows: &[usize], cols: &[usize], data: &[T]) -> Vec<f64>
461    where
462        T: Float + std::fmt::Debug + Copy,
463    {
464        let mut features = Vec::new();
465
466        // Basic statistics
467        features.push(rows.len() as f64);
468        features.push(cols.len() as f64);
469        features.push(data.len() as f64);
470
471        // Row statistics
472        if !rows.is_empty() {
473            let min_row = *rows.iter().min().unwrap_or(&0) as f64;
474            let max_row = *rows.iter().max().unwrap_or(&0) as f64;
475            features.push(min_row);
476            features.push(max_row);
477            features.push(max_row - min_row); // row span
478        } else {
479            features.extend(&[0.0, 0.0, 0.0]);
480        }
481
482        // Column statistics
483        if !cols.is_empty() {
484            let min_col = *cols.iter().min().unwrap_or(&0) as f64;
485            let max_col = *cols.iter().max().unwrap_or(&0) as f64;
486            features.push(min_col);
487            features.push(max_col);
488            features.push(max_col - min_col); // column span
489        } else {
490            features.extend(&[0.0, 0.0, 0.0]);
491        }
492
493        // Data statistics (simplified)
494        if !data.is_empty() {
495            // Convert to f64 for statistics
496            let data_f64: Vec<f64> = data.iter().map(|&x| x.to_f64().unwrap_or(0.0)).collect();
497            let sum: f64 = data_f64.iter().sum();
498            let mean = sum / data_f64.len() as f64;
499            let variance =
500                data_f64.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / data_f64.len() as f64;
501
502            features.push(mean);
503            features.push(variance.sqrt()); // standard deviation
504            features.push(
505                *data_f64
506                    .iter()
507                    .min_by(|a, b| a.partial_cmp(b).unwrap())
508                    .unwrap_or(&0.0),
509            );
510            features.push(
511                *data_f64
512                    .iter()
513                    .max_by(|a, b| a.partial_cmp(b).unwrap())
514                    .unwrap_or(&0.0),
515            );
516        } else {
517            features.extend(&[0.0, 0.0, 0.0, 0.0]);
518        }
519
520        // Pad/truncate to fixed size for neural network
521        let target_size = 20;
522        features.resize(target_size, 0.0);
523        features
524    }
525
526    /// Execute SpMV with specific strategy
527    fn execute_spmv_with_strategy<T>(
528        &self,
529        strategy: OptimizationStrategy,
530        indptr: &[usize],
531        indices: &[usize],
532        data: &[T],
533        x: &[T],
534        y: &mut [T],
535    ) -> SparseResult<()>
536    where
537        T: Float + NumAssign + SimdUnifiedOps + std::fmt::Debug + Copy + Send + Sync,
538    {
539        match strategy {
540            OptimizationStrategy::RowWiseCache => {
541                self.execute_rowwise_spmv(indptr, indices, data, x, y)
542            }
543            OptimizationStrategy::SIMDVectorized => {
544                self.execute_simd_spmv(indptr, indices, data, x, y)
545            }
546            OptimizationStrategy::ParallelWorkStealing => {
547                self.execute_parallel_spmv(indptr, indices, data, x, y)
548            }
549            _ => {
550                // Default implementation for other strategies
551                self.execute_basic_spmv(indptr, indices, data, x, y)
552            }
553        }
554    }
555
556    /// Basic CSR SpMV implementation
557    fn execute_basic_spmv<T>(
558        &self,
559        indptr: &[usize],
560        indices: &[usize],
561        data: &[T],
562        x: &[T],
563        y: &mut [T],
564    ) -> SparseResult<()>
565    where
566        T: Float + NumAssign + std::fmt::Debug + Copy,
567    {
568        for (i, y_val) in y.iter_mut().enumerate() {
569            *y_val = T::zero();
570            if i + 1 < indptr.len() {
571                for j in indptr[i]..indptr[i + 1] {
572                    if j < indices.len() && j < data.len() {
573                        let col = indices[j];
574                        if col < x.len() {
575                            *y_val += data[j] * x[col];
576                        }
577                    }
578                }
579            }
580        }
581        Ok(())
582    }
583
584    /// Row-wise cache-optimized SpMV
585    fn execute_rowwise_spmv<T>(
586        &self,
587        indptr: &[usize],
588        indices: &[usize],
589        data: &[T],
590        x: &[T],
591        y: &mut [T],
592    ) -> SparseResult<()>
593    where
594        T: Float + NumAssign + std::fmt::Debug + Copy,
595    {
596        // Same as basic for now - could be optimized with better cache blocking
597        self.execute_basic_spmv(indptr, indices, data, x, y)
598    }
599
600    /// SIMD-vectorized SpMV
601    fn execute_simd_spmv<T>(
602        &self,
603        indptr: &[usize],
604        indices: &[usize],
605        data: &[T],
606        x: &[T],
607        y: &mut [T],
608    ) -> SparseResult<()>
609    where
610        T: Float + NumAssign + SimdUnifiedOps + std::fmt::Debug + Copy,
611    {
612        // Use SIMD operations from scirs2-core
613        for (i, y_val) in y.iter_mut().enumerate() {
614            *y_val = T::zero();
615            if i + 1 < indptr.len() {
616                let start = indptr[i];
617                let end = indptr[i + 1];
618                if end > start {
619                    let row_data = &data[start..end];
620                    let row_indices = &indices[start..end];
621
622                    // SIMD dot product
623                    let mut sum = T::zero();
624                    for (&data_val, &col_idx) in row_data.iter().zip(row_indices.iter()) {
625                        if col_idx < x.len() {
626                            sum += data_val * x[col_idx];
627                        }
628                    }
629                    *y_val = sum;
630                }
631            }
632        }
633        Ok(())
634    }
635
636    /// Parallel work-stealing SpMV
637    fn execute_parallel_spmv<T>(
638        &self,
639        indptr: &[usize],
640        indices: &[usize],
641        data: &[T],
642        x: &[T],
643        y: &mut [T],
644    ) -> SparseResult<()>
645    where
646        T: Float + NumAssign + SimdUnifiedOps + std::fmt::Debug + Copy + Send + Sync,
647    {
648        // Use parallel operations from scirs2-core
649        use scirs2_core::parallel_ops::*;
650
651        // Sequential implementation for now
652        for i in 0..y.len() {
653            y[i] = T::zero();
654            if i + 1 < indptr.len() {
655                for j in indptr[i]..indptr[i + 1] {
656                    if j < indices.len() && j < data.len() {
657                        let col = indices[j];
658                        if col < x.len() {
659                            y[i] += data[j] * x[col];
660                        }
661                    }
662                }
663            }
664        }
665
666        Ok(())
667    }
668}
669
670/// Context for matrix operations
671#[derive(Debug, Clone)]
672pub struct OperationContext {
673    pub matrix_shape: (usize, usize),
674    pub nnz: usize,
675    pub operation_type: OperationType,
676    pub performance_target: PerformanceTarget,
677}
678
679/// Types of matrix operations
680#[derive(Debug, Clone, Copy)]
681pub enum OperationType {
682    MatVec,
683    MatMat,
684    Solve,
685    Factorization,
686}
687
688/// Performance optimization targets
689#[derive(Debug, Clone, Copy)]
690pub enum PerformanceTarget {
691    Speed,
692    Memory,
693    Accuracy,
694    Balanced,
695}
696
697/// Serializable processor state
698#[derive(Debug, Clone)]
699pub struct ProcessorState {
700    pub neural_network_params: std::collections::HashMap<String, Vec<f64>>,
701    pub total_operations: usize,
702    pub pattern_memory_size: usize,
703    pub current_exploration_rate: f64,
704}