sklears_python/linear/
common.rs

1//! Common functionality for linear model Python bindings
2//!
3//! This module contains shared imports, types, and utilities used
4//! across all linear model implementations.
5
6// Re-export commonly used types and traits - Using SciRS2-Core for improved performance
7pub use numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2};
8pub use pyo3::exceptions::PyValueError;
9pub use pyo3::prelude::*;
10pub use scirs2_core::ndarray::{Array1, Array2};
11
12// Performance optimization imports
13#[cfg(feature = "parallel")]
14pub use rayon::prelude::*;
15
16/// Common error type for linear model operations
17pub type LinearModelResult<T> = Result<T, PyValueError>;
18
19/// Enhanced error handling for sklears-python
20#[derive(Debug)]
21pub enum SklearsPythonError {
22    /// Input validation errors
23    ValidationError(String),
24    /// Model fitting errors
25    FittingError(String),
26    /// Prediction errors
27    PredictionError(String),
28    /// Memory allocation errors
29    MemoryError(String),
30    /// Numerical computation errors
31    NumericalError(String),
32    /// Configuration errors
33    ConfigurationError(String),
34}
35
36impl std::fmt::Display for SklearsPythonError {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            SklearsPythonError::ValidationError(msg) => write!(f, "Validation Error: {}", msg),
40            SklearsPythonError::FittingError(msg) => write!(f, "Model Fitting Error: {}", msg),
41            SklearsPythonError::PredictionError(msg) => write!(f, "Prediction Error: {}", msg),
42            SklearsPythonError::MemoryError(msg) => write!(f, "Memory Error: {}", msg),
43            SklearsPythonError::NumericalError(msg) => write!(f, "Numerical Error: {}", msg),
44            SklearsPythonError::ConfigurationError(msg) => {
45                write!(f, "Configuration Error: {}", msg)
46            }
47        }
48    }
49}
50
51impl std::error::Error for SklearsPythonError {}
52
53impl From<SklearsPythonError> for PyErr {
54    fn from(err: SklearsPythonError) -> Self {
55        match err {
56            SklearsPythonError::ValidationError(msg) => PyValueError::new_err(msg),
57            SklearsPythonError::FittingError(msg) => PyRuntimeError::new_err(msg),
58            SklearsPythonError::PredictionError(msg) => PyRuntimeError::new_err(msg),
59            SklearsPythonError::MemoryError(msg) => {
60                use pyo3::exceptions::PyMemoryError;
61                PyMemoryError::new_err(msg)
62            }
63            SklearsPythonError::NumericalError(msg) => PyArithmeticError::new_err(msg),
64            SklearsPythonError::ConfigurationError(msg) => PyValueError::new_err(msg),
65        }
66    }
67}
68
69// Import additional exception types
70use pyo3::exceptions::{PyArithmeticError, PyRuntimeError};
71
72/// Calculate R² score with optimized array operations
73pub fn calculate_r2_score(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> f64 {
74    let y_mean = y_true.mean().unwrap_or(0.0);
75
76    // Use optimized array operations
77    let y_centered: Array1<f64> = y_true.mapv(|y| y - y_mean);
78    let residuals: Array1<f64> = y_true - y_pred;
79    let ss_tot = y_centered.dot(&y_centered);
80    let ss_res = residuals.dot(&residuals);
81
82    1.0 - (ss_res / ss_tot)
83}
84
85/// Validate input arrays for model fitting
86pub fn validate_fit_arrays(x: &Array2<f64>, y: &Array1<f64>) -> PyResult<()> {
87    if x.nrows() != y.len() {
88        return Err(PyValueError::new_err(format!(
89            "X and y have incompatible shapes: X has {} samples, y has {} samples",
90            x.nrows(),
91            y.len()
92        )));
93    }
94
95    if x.nrows() == 0 {
96        return Err(PyValueError::new_err("X and y must not be empty"));
97    }
98
99    if x.ncols() == 0 {
100        return Err(PyValueError::new_err("X must have at least one feature"));
101    }
102
103    Ok(())
104}
105
106/// Validate input arrays for prediction
107pub fn validate_predict_array(x: &Array2<f64>) -> PyResult<()> {
108    if x.nrows() == 0 {
109        return Err(SklearsPythonError::ValidationError("X must not be empty".to_string()).into());
110    }
111
112    if x.ncols() == 0 {
113        return Err(SklearsPythonError::ValidationError(
114            "X must have at least one feature".to_string(),
115        )
116        .into());
117    }
118
119    // Check for invalid values
120    validate_finite_values(x)?;
121
122    Ok(())
123}
124
125/// Enhanced validation functions with better error handling
126pub fn validate_fit_arrays_enhanced(
127    x: &Array2<f64>,
128    y: &Array1<f64>,
129) -> Result<(), SklearsPythonError> {
130    if x.nrows() != y.len() {
131        return Err(SklearsPythonError::ValidationError(format!(
132            "X and y have incompatible shapes: X has {} samples, y has {} samples",
133            x.nrows(),
134            y.len()
135        )));
136    }
137
138    if x.nrows() == 0 {
139        return Err(SklearsPythonError::ValidationError(
140            "X and y must not be empty".to_string(),
141        ));
142    }
143
144    if x.ncols() == 0 {
145        return Err(SklearsPythonError::ValidationError(
146            "X must have at least one feature".to_string(),
147        ));
148    }
149
150    // Check for infinite or NaN values
151    validate_finite_values(x)?;
152    validate_finite_values_1d(y)?;
153
154    // Memory usage validation (warn if arrays are very large)
155    check_memory_usage(x, y)?;
156
157    Ok(())
158}
159
160/// Validate that array contains only finite values
161pub fn validate_finite_values(arr: &Array2<f64>) -> Result<(), SklearsPythonError> {
162    for value in arr.iter() {
163        if !value.is_finite() {
164            return Err(SklearsPythonError::NumericalError(
165                "Input array contains non-finite values (NaN or infinite)".to_string(),
166            ));
167        }
168    }
169    Ok(())
170}
171
172/// Validate that 1D array contains only finite values
173pub fn validate_finite_values_1d(arr: &Array1<f64>) -> Result<(), SklearsPythonError> {
174    for value in arr.iter() {
175        if !value.is_finite() {
176            return Err(SklearsPythonError::NumericalError(
177                "Target array contains non-finite values (NaN or infinite)".to_string(),
178            ));
179        }
180    }
181    Ok(())
182}
183
184/// Check memory usage and warn if arrays are very large
185pub fn check_memory_usage(x: &Array2<f64>, y: &Array1<f64>) -> Result<(), SklearsPythonError> {
186    let x_memory_mb = (x.len() * std::mem::size_of::<f64>()) as f64 / (1024.0 * 1024.0);
187    let y_memory_mb = (y.len() * std::mem::size_of::<f64>()) as f64 / (1024.0 * 1024.0);
188    let total_memory_mb = x_memory_mb + y_memory_mb;
189
190    // Warn if using more than 1GB of memory
191    if total_memory_mb > 1024.0 {
192        eprintln!("Warning: Large dataset detected ({:.2} MB). Consider using batch processing or data preprocessing to reduce memory usage.", total_memory_mb);
193    }
194
195    // Error if using more than 4GB (likely will cause issues)
196    if total_memory_mb > 4096.0 {
197        return Err(SklearsPythonError::MemoryError(format!(
198            "Dataset is too large ({:.2} MB). Consider using data preprocessing to reduce memory usage.",
199            total_memory_mb
200        )));
201    }
202
203    Ok(())
204}
205
206/// Get system memory information for better memory management
207pub fn get_available_memory_mb() -> f64 {
208    // This is a simplified implementation
209    // In a real implementation, you'd use system APIs to get actual available memory
210    // For now, we assume 8GB as a reasonable default
211    8192.0
212}
213
214/// Performance monitoring structure
215#[derive(Debug, Clone, Default)]
216pub struct PerformanceStats {
217    pub training_time_ms: Option<f64>,
218    pub prediction_time_ms: Option<f64>,
219    pub memory_usage_mb: Option<f64>,
220    pub cache_hits: usize,
221    pub cache_misses: usize,
222}