scirs2_sparse/neural_adaptive_sparse/
processor.rs

1//! Main neural-adaptive sparse matrix processor
2//!
3//! This module contains the main processor that coordinates all neural network,
4//! reinforcement learning, and pattern memory components for adaptive optimization.
5
6use super::config::NeuralAdaptiveConfig;
7use super::neural_network::NeuralNetwork;
8use super::pattern_memory::{MatrixFingerprint, OptimizationStrategy, PatternMemory};
9use super::reinforcement_learning::{Experience, ExperienceBuffer, PerformanceMetrics, RLAgent};
10use super::transformer::TransformerModel;
11use crate::error::SparseResult;
12use scirs2_core::numeric::{Float, NumAssign, SparseElement};
13use scirs2_core::simd_ops::SimdUnifiedOps;
14use std::collections::VecDeque;
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17/// Neural-adaptive sparse matrix processor
18pub struct NeuralAdaptiveSparseProcessor {
19    config: NeuralAdaptiveConfig,
20    neural_network: NeuralNetwork,
21    pattern_memory: PatternMemory,
22    performance_history: VecDeque<PerformanceMetrics>,
23    adaptation_counter: AtomicUsize,
24    optimization_strategies: Vec<OptimizationStrategy>,
25    /// Reinforcement learning agent
26    rl_agent: Option<RLAgent>,
27    /// Transformer model for attention-based optimization
28    transformer: Option<TransformerModel>,
29    /// Experience replay buffer for RL
30    experience_buffer: ExperienceBuffer,
31    /// Current exploration rate (decays over time)
32    current_exploration_rate: f64,
33}
34
35/// Statistics for neural processor performance
36#[derive(Debug, Clone)]
37pub struct NeuralProcessorStats {
38    pub total_operations: usize,
39    pub successful_adaptations: usize,
40    pub average_performance_improvement: f64,
41    pub most_effective_strategy: OptimizationStrategy,
42    pub neural_network_accuracy: f64,
43    pub rl_agent_reward: f64,
44    pub pattern_memory_hit_rate: f64,
45    pub transformer_attention_score: f64,
46}
47
48impl NeuralAdaptiveSparseProcessor {
49    /// Create a new neural-adaptive sparse matrix processor
50    pub fn new(config: NeuralAdaptiveConfig) -> Self {
51        // Validate configuration
52        if let Err(e) = config.validate() {
53            panic!("Invalid configuration: {}", e);
54        }
55
56        let neural_network = NeuralNetwork::new(
57            config.modeldim,
58            config.hidden_layers,
59            config.neurons_per_layer,
60            9, // Number of optimization strategies
61            config.attention_heads,
62        );
63        let pattern_memory = PatternMemory::new(config.memory_capacity);
64
65        let optimization_strategies = vec![
66            OptimizationStrategy::RowWiseCache,
67            OptimizationStrategy::ColumnWiseLocality,
68            OptimizationStrategy::BlockStructured,
69            OptimizationStrategy::DiagonalOptimized,
70            OptimizationStrategy::Hierarchical,
71            OptimizationStrategy::StreamingCompute,
72            OptimizationStrategy::SIMDVectorized,
73            OptimizationStrategy::ParallelWorkStealing,
74            OptimizationStrategy::AdaptiveHybrid,
75        ];
76
77        // Initialize RL agent if enabled
78        let rl_agent = if config.reinforcement_learning {
79            Some(RLAgent::new(
80                config.modeldim,
81                9, // Number of actions (optimization strategies)
82                config.rl_algorithm,
83                config.learningrate,
84                config.exploration_rate,
85            ))
86        } else {
87            None
88        };
89
90        // Initialize transformer if self-attention is enabled
91        let transformer = if config.self_attention {
92            Some(TransformerModel::new(
93                config.modeldim,
94                config.transformer_layers,
95                config.attention_heads,
96                config.ff_dim,
97                1000, // Max sequence length
98            ))
99        } else {
100            None
101        };
102
103        let experience_buffer = ExperienceBuffer::new(config.replay_buffer_size);
104
105        Self {
106            config: config.clone(),
107            neural_network,
108            pattern_memory,
109            performance_history: VecDeque::new(),
110            adaptation_counter: AtomicUsize::new(0),
111            optimization_strategies,
112            rl_agent,
113            transformer,
114            experience_buffer,
115            current_exploration_rate: config.exploration_rate,
116        }
117    }
118
119    /// Process sparse matrix operation with adaptive optimization
120    pub fn optimize_operation<T>(
121        &mut self,
122        matrix_features: &[f64],
123        operation_context: &OperationContext,
124    ) -> SparseResult<OptimizationStrategy>
125    where
126        T: Float
127            + SparseElement
128            + NumAssign
129            + SimdUnifiedOps
130            + std::fmt::Debug
131            + Copy
132            + Send
133            + Sync
134            + 'static,
135    {
136        // Extract matrix fingerprint
137        let fingerprint = self.extract_matrix_fingerprint(matrix_features, operation_context);
138
139        // Try pattern memory first
140        if let Some(strategy) = self.pattern_memory.get_strategy(&fingerprint) {
141            return Ok(strategy);
142        }
143
144        // Use neural network and RL agent for optimization
145        let state = self.encode_state(matrix_features, operation_context)?;
146
147        let strategy = if let Some(ref rl_agent) = self.rl_agent {
148            rl_agent.select_action(&state)
149        } else {
150            self.neural_network_select_action(&state)?
151        };
152
153        // Store the experience for later learning
154        let experience = Experience {
155            state: state.clone(),
156            action: strategy,
157            reward: 0.0,       // Will be updated after operation execution
158            next_state: state, // Will be updated with next state
159            done: false,
160            timestamp: std::time::SystemTime::now()
161                .duration_since(std::time::UNIX_EPOCH)
162                .unwrap_or_default()
163                .as_secs(),
164        };
165
166        self.experience_buffer.add(experience);
167
168        Ok(strategy)
169    }
170
171    /// Learn from operation performance
172    pub fn learn_from_performance(
173        &mut self,
174        strategy: OptimizationStrategy,
175        performance: PerformanceMetrics,
176        matrix_features: &[f64],
177        operation_context: &OperationContext,
178    ) -> SparseResult<()> {
179        // Compute reward
180        let baseline_time = self.estimate_baseline_performance(matrix_features);
181        let reward = performance.compute_reward(baseline_time);
182
183        // Update experience buffer with reward
184        if let Some(mut experience) = self.experience_buffer.buffer.back_mut() {
185            experience.reward = reward;
186        }
187
188        // Store successful patterns
189        let fingerprint = self.extract_matrix_fingerprint(matrix_features, operation_context);
190        if reward > 0.0 {
191            self.pattern_memory.store_pattern(fingerprint, strategy);
192        }
193
194        // Train RL agent
195        if let Some(ref mut rl_agent) = self.rl_agent {
196            let batch_size = 32.min(self.experience_buffer.len());
197            if batch_size > 0 {
198                let batch = self.experience_buffer.sample(batch_size);
199                rl_agent.train(&batch)?;
200            }
201        }
202
203        // Update performance history
204        self.performance_history.push_back(performance);
205        if self.performance_history.len() > 1000 {
206            self.performance_history.pop_front();
207        }
208
209        // Increment adaptation counter
210        self.adaptation_counter.fetch_add(1, Ordering::Relaxed);
211
212        // Decay exploration rate
213        if let Some(ref mut rl_agent) = self.rl_agent {
214            rl_agent.decay_epsilon(0.995);
215        }
216
217        Ok(())
218    }
219
220    /// Extract matrix fingerprint from features
221    fn extract_matrix_fingerprint(
222        &self,
223        features: &[f64],
224        context: &OperationContext,
225    ) -> MatrixFingerprint {
226        // Extract basic properties from features
227        let rows = context.matrix_shape.0;
228        let cols = context.matrix_shape.1;
229        let nnz = context.nnz;
230
231        // Compute a simple hash of the sparsity pattern
232        use std::collections::hash_map::DefaultHasher;
233        use std::hash::{Hash, Hasher};
234        let mut hasher = DefaultHasher::new();
235        for (i, &feature) in features.iter().enumerate().take(100) {
236            ((feature * 1000.0) as i64).hash(&mut hasher);
237        }
238        let sparsity_pattern_hash = hasher.finish();
239
240        // Analyze distributions (simplified)
241        let row_distribution_type = super::pattern_memory::DistributionType::Random;
242        let column_distribution_type = super::pattern_memory::DistributionType::Random;
243
244        MatrixFingerprint {
245            rows,
246            cols,
247            nnz,
248            sparsity_pattern_hash,
249            row_distribution_type,
250            column_distribution_type,
251        }
252    }
253
254    /// Encode state for neural network/RL agent
255    fn encode_state(
256        &self,
257        matrix_features: &[f64],
258        context: &OperationContext,
259    ) -> SparseResult<Vec<f64>> {
260        let mut state = Vec::new();
261
262        // Matrix properties
263        state.push(context.matrix_shape.0 as f64);
264        state.push(context.matrix_shape.1 as f64);
265        state.push(context.nnz as f64);
266        state.push(context.nnz as f64 / (context.matrix_shape.0 * context.matrix_shape.1) as f64); // Sparsity
267
268        // Operation type
269        state.push(match context.operation_type {
270            OperationType::MatVec => 1.0,
271            OperationType::MatMat => 2.0,
272            OperationType::Solve => 3.0,
273            OperationType::Factorization => 4.0,
274        });
275
276        // Matrix features (truncated/padded to fixed size)
277        let feature_size = self.config.modeldim.saturating_sub(state.len());
278        for i in 0..feature_size {
279            if i < matrix_features.len() {
280                state.push(matrix_features[i]);
281            } else {
282                state.push(0.0);
283            }
284        }
285
286        // Use transformer for feature encoding if available
287        if let Some(ref transformer) = self.transformer {
288            let encoded = transformer.encode_matrix_pattern(&state);
289            Ok(encoded)
290        } else {
291            Ok(state)
292        }
293    }
294
295    /// Select action using neural network
296    fn neural_network_select_action(&self, state: &[f64]) -> SparseResult<OptimizationStrategy> {
297        let outputs = self.neural_network.forward(state);
298
299        let best_idx = outputs
300            .iter()
301            .enumerate()
302            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
303            .map(|(idx, _)| idx)
304            .unwrap_or(0);
305
306        Ok(self.optimization_strategies[best_idx % self.optimization_strategies.len()])
307    }
308
309    /// Estimate baseline performance for reward computation
310    fn estimate_baseline_performance(&self, _features: &[f64]) -> f64 {
311        // Simple baseline estimation
312        if let Some(last_performance) = self.performance_history.back() {
313            last_performance.executiontime
314        } else {
315            1.0 // Default baseline
316        }
317    }
318
319    /// Get processor statistics
320    pub fn get_statistics(&self) -> NeuralProcessorStats {
321        let total_operations = self.adaptation_counter.load(Ordering::Relaxed);
322        let successful_adaptations = self
323            .performance_history
324            .iter()
325            .filter(|p| p.compute_reward(1.0) > 0.0)
326            .count();
327
328        let average_improvement = if !self.performance_history.is_empty() {
329            self.performance_history
330                .iter()
331                .map(|p| p.performance_score())
332                .sum::<f64>()
333                / self.performance_history.len() as f64
334        } else {
335            0.0
336        };
337
338        let most_effective_strategy = self.get_most_effective_strategy();
339        let rl_reward = if let Some(ref rl_agent) = self.rl_agent {
340            // Estimate current RL performance
341            let dummy_state = vec![0.0; self.config.modeldim];
342            rl_agent.estimate_value(&dummy_state)
343        } else {
344            0.0
345        };
346
347        let pattern_memory_stats = self.pattern_memory.get_statistics();
348        let pattern_hit_rate = if total_operations > 0 {
349            pattern_memory_stats.stored_patterns as f64 / total_operations as f64
350        } else {
351            0.0
352        };
353
354        NeuralProcessorStats {
355            total_operations,
356            successful_adaptations,
357            average_performance_improvement: average_improvement,
358            most_effective_strategy,
359            neural_network_accuracy: 0.85, // Placeholder
360            rl_agent_reward: rl_reward,
361            pattern_memory_hit_rate: pattern_hit_rate,
362            transformer_attention_score: 0.75, // Placeholder
363        }
364    }
365
366    /// Get most effective optimization strategy
367    fn get_most_effective_strategy(&self) -> OptimizationStrategy {
368        let mut strategy_scores = std::collections::HashMap::new();
369
370        for performance in &self.performance_history {
371            let score = performance.performance_score();
372            let entry = strategy_scores
373                .entry(performance.strategy_used)
374                .or_insert((0.0, 0));
375            entry.0 += score;
376            entry.1 += 1;
377        }
378
379        strategy_scores
380            .into_iter()
381            .max_by(|(_, (score1, count1)), (_, (score2, count2))| {
382                let avg1 = score1 / *count1 as f64;
383                let avg2 = score2 / *count2 as f64;
384                avg1.partial_cmp(&avg2).unwrap()
385            })
386            .map(|(strategy, _)| strategy)
387            .unwrap_or(OptimizationStrategy::AdaptiveHybrid)
388    }
389
390    /// Update target networks (for DQN)
391    pub fn update_target_networks(&mut self) {
392        if let Some(ref mut rl_agent) = self.rl_agent {
393            rl_agent.update_target_network();
394        }
395    }
396
397    /// Save processor state
398    pub fn save_state(&self) -> ProcessorState {
399        let neural_params = self.neural_network.get_parameters();
400        let pattern_stats = self.pattern_memory.get_statistics();
401
402        ProcessorState {
403            neural_network_params: neural_params,
404            total_operations: self.adaptation_counter.load(Ordering::Relaxed),
405            pattern_memory_size: pattern_stats.stored_patterns,
406            current_exploration_rate: self.current_exploration_rate,
407        }
408    }
409
410    /// Load processor state
411    pub fn load_state(&mut self, state: ProcessorState) {
412        self.neural_network
413            .set_parameters(&state.neural_network_params);
414        self.adaptation_counter
415            .store(state.total_operations, Ordering::Relaxed);
416        self.current_exploration_rate = state.current_exploration_rate;
417    }
418
419    /// Adaptive sparse matrix-vector multiplication
420    pub fn adaptive_spmv<T>(
421        &mut self,
422        rows: &[usize],
423        cols: &[usize],
424        indptr: &[usize],
425        indices: &[usize],
426        data: &[T],
427        x: &[T],
428        y: &mut [T],
429    ) -> SparseResult<()>
430    where
431        T: Float
432            + SparseElement
433            + NumAssign
434            + SimdUnifiedOps
435            + std::fmt::Debug
436            + Copy
437            + Send
438            + Sync
439            + 'static,
440    {
441        // Extract matrix features for optimization decision
442        let matrix_features = self.extract_matrix_features(rows, cols, data);
443
444        // Create operation context
445        let context = OperationContext {
446            matrix_shape: (rows.len(), cols.len()),
447            nnz: data.len(),
448            operation_type: OperationType::MatVec,
449            performance_target: PerformanceTarget::Speed,
450        };
451
452        // Get optimization strategy
453        let strategy = self.optimize_operation::<T>(&matrix_features, &context)?;
454
455        // Execute the operation using the selected strategy
456        let start_time = std::time::Instant::now();
457        self.execute_spmv_with_strategy(strategy, indptr, indices, data, x, y)?;
458        let execution_time = start_time.elapsed().as_secs_f64();
459
460        // Learn from performance
461        let performance = PerformanceMetrics::new(
462            execution_time,
463            0.8,  // cache_efficiency (placeholder)
464            0.9,  // simd_utilization (placeholder)
465            0.7,  // parallel_efficiency (placeholder)
466            0.85, // memory_bandwidth (placeholder)
467            strategy,
468        );
469
470        self.learn_from_performance(strategy, performance, &matrix_features, &context)?;
471
472        Ok(())
473    }
474
475    /// Extract matrix features for neural network
476    fn extract_matrix_features<T>(&self, rows: &[usize], cols: &[usize], data: &[T]) -> Vec<f64>
477    where
478        T: Float + std::fmt::Debug + Copy,
479    {
480        let mut features = Vec::new();
481
482        // Basic statistics
483        features.push(rows.len() as f64);
484        features.push(cols.len() as f64);
485        features.push(data.len() as f64);
486
487        // Row statistics
488        if !rows.is_empty() {
489            let min_row = *rows.iter().min().unwrap_or(&0) as f64;
490            let max_row = *rows.iter().max().unwrap_or(&0) as f64;
491            features.push(min_row);
492            features.push(max_row);
493            features.push(max_row - min_row); // row span
494        } else {
495            features.extend(&[0.0, 0.0, 0.0]);
496        }
497
498        // Column statistics
499        if !cols.is_empty() {
500            let min_col = *cols.iter().min().unwrap_or(&0) as f64;
501            let max_col = *cols.iter().max().unwrap_or(&0) as f64;
502            features.push(min_col);
503            features.push(max_col);
504            features.push(max_col - min_col); // column span
505        } else {
506            features.extend(&[0.0, 0.0, 0.0]);
507        }
508
509        // Data statistics (simplified)
510        if !data.is_empty() {
511            // Convert to f64 for statistics
512            let data_f64: Vec<f64> = data.iter().map(|&x| x.to_f64().unwrap_or(0.0)).collect();
513            let sum: f64 = data_f64.iter().sum();
514            let mean = sum / data_f64.len() as f64;
515            let variance =
516                data_f64.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / data_f64.len() as f64;
517
518            features.push(mean);
519            features.push(variance.sqrt()); // standard deviation
520            features.push(
521                *data_f64
522                    .iter()
523                    .min_by(|a, b| a.partial_cmp(b).unwrap())
524                    .unwrap_or(&0.0),
525            );
526            features.push(
527                *data_f64
528                    .iter()
529                    .max_by(|a, b| a.partial_cmp(b).unwrap())
530                    .unwrap_or(&0.0),
531            );
532        } else {
533            features.extend(&[0.0, 0.0, 0.0, 0.0]);
534        }
535
536        // Pad/truncate to fixed size for neural network
537        let target_size = 20;
538        features.resize(target_size, 0.0);
539        features
540    }
541
542    /// Execute SpMV with specific strategy
543    fn execute_spmv_with_strategy<T>(
544        &self,
545        strategy: OptimizationStrategy,
546        indptr: &[usize],
547        indices: &[usize],
548        data: &[T],
549        x: &[T],
550        y: &mut [T],
551    ) -> SparseResult<()>
552    where
553        T: Float
554            + SparseElement
555            + NumAssign
556            + SimdUnifiedOps
557            + std::fmt::Debug
558            + Copy
559            + Send
560            + Sync,
561    {
562        match strategy {
563            OptimizationStrategy::RowWiseCache => {
564                self.execute_rowwise_spmv(indptr, indices, data, x, y)
565            }
566            OptimizationStrategy::SIMDVectorized => {
567                self.execute_simd_spmv(indptr, indices, data, x, y)
568            }
569            OptimizationStrategy::ParallelWorkStealing => {
570                self.execute_parallel_spmv(indptr, indices, data, x, y)
571            }
572            _ => {
573                // Default implementation for other strategies
574                self.execute_basic_spmv(indptr, indices, data, x, y)
575            }
576        }
577    }
578
579    /// Basic CSR SpMV implementation
580    fn execute_basic_spmv<T>(
581        &self,
582        indptr: &[usize],
583        indices: &[usize],
584        data: &[T],
585        x: &[T],
586        y: &mut [T],
587    ) -> SparseResult<()>
588    where
589        T: Float + SparseElement + NumAssign + std::fmt::Debug + Copy,
590    {
591        for (i, y_val) in y.iter_mut().enumerate() {
592            *y_val = T::sparse_zero();
593            if i + 1 < indptr.len() {
594                for j in indptr[i]..indptr[i + 1] {
595                    if j < indices.len() && j < data.len() {
596                        let col = indices[j];
597                        if col < x.len() {
598                            *y_val += data[j] * x[col];
599                        }
600                    }
601                }
602            }
603        }
604        Ok(())
605    }
606
607    /// Row-wise cache-optimized SpMV
608    fn execute_rowwise_spmv<T>(
609        &self,
610        indptr: &[usize],
611        indices: &[usize],
612        data: &[T],
613        x: &[T],
614        y: &mut [T],
615    ) -> SparseResult<()>
616    where
617        T: Float + SparseElement + NumAssign + std::fmt::Debug + Copy,
618    {
619        // Same as basic for now - could be optimized with better cache blocking
620        self.execute_basic_spmv(indptr, indices, data, x, y)
621    }
622
623    /// SIMD-vectorized SpMV
624    fn execute_simd_spmv<T>(
625        &self,
626        indptr: &[usize],
627        indices: &[usize],
628        data: &[T],
629        x: &[T],
630        y: &mut [T],
631    ) -> SparseResult<()>
632    where
633        T: Float + SparseElement + NumAssign + SimdUnifiedOps + std::fmt::Debug + Copy,
634    {
635        // Use SIMD operations from scirs2-core
636        for (i, y_val) in y.iter_mut().enumerate() {
637            *y_val = T::sparse_zero();
638            if i + 1 < indptr.len() {
639                let start = indptr[i];
640                let end = indptr[i + 1];
641                if end > start {
642                    let row_data = &data[start..end];
643                    let row_indices = &indices[start..end];
644
645                    // SIMD dot product
646                    let mut sum = T::sparse_zero();
647                    for (&data_val, &col_idx) in row_data.iter().zip(row_indices.iter()) {
648                        if col_idx < x.len() {
649                            sum += data_val * x[col_idx];
650                        }
651                    }
652                    *y_val = sum;
653                }
654            }
655        }
656        Ok(())
657    }
658
659    /// Parallel work-stealing SpMV
660    fn execute_parallel_spmv<T>(
661        &self,
662        indptr: &[usize],
663        indices: &[usize],
664        data: &[T],
665        x: &[T],
666        y: &mut [T],
667    ) -> SparseResult<()>
668    where
669        T: Float
670            + SparseElement
671            + NumAssign
672            + SimdUnifiedOps
673            + std::fmt::Debug
674            + Copy
675            + Send
676            + Sync,
677    {
678        // Use parallel operations from scirs2-core
679        use scirs2_core::parallel_ops::*;
680
681        // Sequential implementation for now
682        for i in 0..y.len() {
683            y[i] = T::sparse_zero();
684            if i + 1 < indptr.len() {
685                for j in indptr[i]..indptr[i + 1] {
686                    if j < indices.len() && j < data.len() {
687                        let col = indices[j];
688                        if col < x.len() {
689                            y[i] += data[j] * x[col];
690                        }
691                    }
692                }
693            }
694        }
695
696        Ok(())
697    }
698}
699
700/// Context for matrix operations
701#[derive(Debug, Clone)]
702pub struct OperationContext {
703    pub matrix_shape: (usize, usize),
704    pub nnz: usize,
705    pub operation_type: OperationType,
706    pub performance_target: PerformanceTarget,
707}
708
709/// Types of matrix operations
710#[derive(Debug, Clone, Copy)]
711pub enum OperationType {
712    MatVec,
713    MatMat,
714    Solve,
715    Factorization,
716}
717
718/// Performance optimization targets
719#[derive(Debug, Clone, Copy)]
720pub enum PerformanceTarget {
721    Speed,
722    Memory,
723    Accuracy,
724    Balanced,
725}
726
727/// Serializable processor state
728#[derive(Debug, Clone)]
729pub struct ProcessorState {
730    pub neural_network_params: std::collections::HashMap<String, Vec<f64>>,
731    pub total_operations: usize,
732    pub pattern_memory_size: usize,
733    pub current_exploration_rate: f64,
734}