1use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
7use num_traits::{Float, NumCast};
8
9use crate::error::{Result, TransformError};
10
11pub struct VarianceThreshold {
16 threshold: f64,
18 variances_: Option<Array1<f64>>,
20 selected_features_: Option<Vec<usize>>,
22}
23
24impl VarianceThreshold {
25 pub fn new(threshold: f64) -> Result<Self> {
41 if threshold < 0.0 {
42 return Err(TransformError::InvalidInput(
43 "Threshold must be non-negative".to_string(),
44 ));
45 }
46
47 Ok(VarianceThreshold {
48 threshold,
49 variances_: None,
50 selected_features_: None,
51 })
52 }
53
54 pub fn with_defaults() -> Self {
58 Self::new(0.0).unwrap()
59 }
60
61 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
69 where
70 S: Data,
71 S::Elem: Float + NumCast,
72 {
73 let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
74
75 let n_samples = x_f64.shape()[0];
76 let n_features = x_f64.shape()[1];
77
78 if n_samples == 0 || n_features == 0 {
79 return Err(TransformError::InvalidInput("Empty input data".to_string()));
80 }
81
82 if n_samples < 2 {
83 return Err(TransformError::InvalidInput(
84 "At least 2 samples required to compute variance".to_string(),
85 ));
86 }
87
88 let mut variances = Array1::zeros(n_features);
90 let mut selected_features = Vec::new();
91
92 for j in 0..n_features {
93 let feature_data = x_f64.column(j);
94
95 let mean = feature_data.iter().sum::<f64>() / n_samples as f64;
97
98 let variance = feature_data
100 .iter()
101 .map(|&x| (x - mean).powi(2))
102 .sum::<f64>()
103 / n_samples as f64;
104
105 variances[j] = variance;
106
107 if variance > self.threshold {
109 selected_features.push(j);
110 }
111 }
112
113 self.variances_ = Some(variances);
114 self.selected_features_ = Some(selected_features);
115
116 Ok(())
117 }
118
119 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
127 where
128 S: Data,
129 S::Elem: Float + NumCast,
130 {
131 let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
132
133 let n_samples = x_f64.shape()[0];
134 let n_features = x_f64.shape()[1];
135
136 if self.selected_features_.is_none() {
137 return Err(TransformError::TransformationError(
138 "VarianceThreshold has not been fitted".to_string(),
139 ));
140 }
141
142 let selected_features = self.selected_features_.as_ref().unwrap();
143
144 if let Some(ref variances) = self.variances_ {
146 if n_features != variances.len() {
147 return Err(TransformError::InvalidInput(format!(
148 "x has {} features, but VarianceThreshold was fitted with {} features",
149 n_features,
150 variances.len()
151 )));
152 }
153 }
154
155 let n_selected = selected_features.len();
156 let mut transformed = Array2::zeros((n_samples, n_selected));
157
158 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
160 for i in 0..n_samples {
161 transformed[[i, new_idx]] = x_f64[[i, old_idx]];
162 }
163 }
164
165 Ok(transformed)
166 }
167
168 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
176 where
177 S: Data,
178 S::Elem: Float + NumCast,
179 {
180 self.fit(x)?;
181 self.transform(x)
182 }
183
184 pub fn variances(&self) -> Option<&Array1<f64>> {
189 self.variances_.as_ref()
190 }
191
192 pub fn get_support(&self) -> Option<&Vec<usize>> {
197 self.selected_features_.as_ref()
198 }
199
200 pub fn get_support_mask(&self) -> Option<Array1<bool>> {
205 if let (Some(ref variances), Some(ref selected)) =
206 (&self.variances_, &self.selected_features_)
207 {
208 let n_features = variances.len();
209 let mut mask = Array1::from_elem(n_features, false);
210
211 for &idx in selected {
212 mask[idx] = true;
213 }
214
215 Some(mask)
216 } else {
217 None
218 }
219 }
220
221 pub fn n_features_selected(&self) -> Option<usize> {
226 self.selected_features_.as_ref().map(|s| s.len())
227 }
228
229 pub fn inverse_transform<S>(&self, _x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
234 where
235 S: Data,
236 S::Elem: Float + NumCast,
237 {
238 Err(TransformError::TransformationError(
239 "inverse_transform is not supported for feature selection".to_string(),
240 ))
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use approx::assert_abs_diff_eq;
248 use ndarray::Array;
249
250 #[test]
251 fn test_variance_threshold_basic() {
252 let data = Array::from_shape_vec(
258 (3, 4),
259 vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
260 )
261 .unwrap();
262
263 let mut selector = VarianceThreshold::with_defaults();
264 let transformed = selector.fit_transform(&data).unwrap();
265
266 assert_eq!(transformed.shape(), &[3, 2]);
268
269 let selected = selector.get_support().unwrap();
271 assert_eq!(selected, &[1, 3]);
272
273 assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[0, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 1]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[2, 1]], 5.0, epsilon = 1e-10); }
282
283 #[test]
284 fn test_variance_threshold_custom() {
285 let data = Array::from_shape_vec(
287 (4, 3),
288 vec![
289 1.0, 1.0, 1.0, 2.0, 1.1, 2.0, 3.0, 1.0, 3.0, 4.0, 1.1, 4.0, ],
294 )
295 .unwrap();
296
297 let mut selector = VarianceThreshold::new(0.1).unwrap();
299 let transformed = selector.fit_transform(&data).unwrap();
300
301 assert_eq!(transformed.shape(), &[4, 2]);
304
305 let selected = selector.get_support().unwrap();
306 assert_eq!(selected, &[0, 2]);
307
308 let variances = selector.variances().unwrap();
310 assert!(variances[0] > 0.1); assert!(variances[1] <= 0.1); assert!(variances[2] > 0.1); }
314
315 #[test]
316 fn test_variance_threshold_support_mask() {
317 let data = Array::from_shape_vec(
318 (3, 4),
319 vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
320 )
321 .unwrap();
322
323 let mut selector = VarianceThreshold::with_defaults();
324 selector.fit(&data).unwrap();
325
326 let mask = selector.get_support_mask().unwrap();
327 assert_eq!(mask.len(), 4);
328 assert!(!mask[0]); assert!(mask[1]); assert!(!mask[2]); assert!(mask[3]); assert_eq!(selector.n_features_selected().unwrap(), 2);
334 }
335
336 #[test]
337 fn test_variance_threshold_all_removed() {
338 let data = Array::from_shape_vec((3, 2), vec![5.0, 10.0, 5.0, 10.0, 5.0, 10.0]).unwrap();
340
341 let mut selector = VarianceThreshold::with_defaults();
342 let transformed = selector.fit_transform(&data).unwrap();
343
344 assert_eq!(transformed.shape(), &[3, 0]);
346 assert_eq!(selector.n_features_selected().unwrap(), 0);
347 }
348
349 #[test]
350 fn test_variance_threshold_errors() {
351 assert!(VarianceThreshold::new(-0.1).is_err());
353
354 let small_data = Array::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
356 let mut selector = VarianceThreshold::with_defaults();
357 assert!(selector.fit(&small_data).is_err());
358
359 let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
361 let selector_unfitted = VarianceThreshold::with_defaults();
362 assert!(selector_unfitted.transform(&data).is_err());
363
364 let mut selector = VarianceThreshold::with_defaults();
366 selector.fit(&data).unwrap();
367 assert!(selector.inverse_transform(&data).is_err());
368 }
369
370 #[test]
371 fn test_variance_threshold_feature_mismatch() {
372 let train_data =
373 Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
374 .unwrap();
375 let test_data = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let mut selector = VarianceThreshold::with_defaults();
378 selector.fit(&train_data).unwrap();
379 assert!(selector.transform(&test_data).is_err());
380 }
381
382 #[test]
383 fn test_variance_calculation() {
384 let data = Array::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
387
388 let mut selector = VarianceThreshold::with_defaults();
389 selector.fit(&data).unwrap();
390
391 let variances = selector.variances().unwrap();
392 let expected_variance = 2.0 / 3.0;
393 assert_abs_diff_eq!(variances[0], expected_variance, epsilon = 1e-10);
394 }
395}