sklears_feature_selection/domain_specific/
text_features.rs1use crate::base::SelectorMixin;
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
8use sklears_core::{
9 error::{validate, Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Trained, Transform, Untrained},
11 types::Float,
12};
13use std::collections::HashMap;
14use std::marker::PhantomData;
15
16#[derive(Debug, Clone)]
54pub struct TextFeatureSelector<State = Untrained> {
55 min_df: f64,
57 max_df: f64,
59 max_features: Option<usize>,
61 ngram_range: (usize, usize),
63 include_pos: bool,
65 include_syntax: bool,
67 state: PhantomData<State>,
68 vocabulary_: Option<HashMap<String, usize>>,
70 idf_scores_: Option<Array1<Float>>,
71 selected_features_: Option<Vec<usize>>,
72 feature_names_: Option<Vec<String>>,
73}
74
75impl TextFeatureSelector<Untrained> {
76 pub fn new() -> Self {
77 Self {
78 min_df: 0.01,
79 max_df: 0.95,
80 max_features: Some(1000),
81 ngram_range: (1, 1),
82 include_pos: false,
83 include_syntax: false,
84 state: PhantomData,
85 vocabulary_: None,
86 idf_scores_: None,
87 selected_features_: None,
88 feature_names_: None,
89 }
90 }
91
92 pub fn min_df(mut self, min_df: f64) -> Self {
101 self.min_df = min_df;
102 self
103 }
104
105 pub fn max_df(mut self, max_df: f64) -> Self {
114 self.max_df = max_df;
115 self
116 }
117
118 pub fn max_features(mut self, max_features: Option<usize>) -> Self {
123 self.max_features = max_features;
124 self
125 }
126
127 pub fn ngram_range(mut self, ngram_range: (usize, usize)) -> Self {
137 self.ngram_range = ngram_range;
138 self
139 }
140
141 pub fn include_pos(mut self, include_pos: bool) -> Self {
148 self.include_pos = include_pos;
149 self
150 }
151
152 pub fn include_syntax(mut self, include_syntax: bool) -> Self {
159 self.include_syntax = include_syntax;
160 self
161 }
162}
163
164impl Default for TextFeatureSelector<Untrained> {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl Estimator for TextFeatureSelector<Untrained> {
171 type Config = ();
172 type Error = SklearsError;
173 type Float = f64;
174
175 fn config(&self) -> &Self::Config {
176 &()
177 }
178}
179
180impl Fit<Array2<Float>, Array1<Float>> for TextFeatureSelector<Untrained> {
181 type Fitted = TextFeatureSelector<Trained>;
182
183 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
184 validate::check_consistent_length(x, y)?;
185
186 let (n_documents, n_features) = x.dim();
187
188 let mut document_frequencies = Array1::zeros(n_features);
190 for j in 0..n_features {
191 let mut df = 0.0;
192 for i in 0..n_documents {
193 if x[[i, j]] > 0.0 {
194 df += 1.0;
195 }
196 }
197 document_frequencies[j] = df / n_documents as f64;
198 }
199
200 let mut valid_features = Vec::new();
202 for (j, &df) in document_frequencies.iter().enumerate() {
203 if df >= self.min_df && df <= self.max_df {
204 valid_features.push(j);
205 }
206 }
207
208 if valid_features.is_empty() {
209 return Err(SklearsError::InvalidInput(
210 "No features pass the document frequency filters".to_string(),
211 ));
212 }
213
214 let mut idf_scores = Array1::zeros(valid_features.len());
216 for (idx, &j) in valid_features.iter().enumerate() {
217 let df = document_frequencies[j];
218 idf_scores[idx] = (n_documents as f64 / (1.0 + df * n_documents as f64)).ln();
219 }
220
221 let mut chi2_scores = Array1::zeros(valid_features.len());
223 for (idx, &j) in valid_features.iter().enumerate() {
224 let feature_col = x.column(j);
225 chi2_scores[idx] = compute_chi2_score(&feature_col, y);
226 }
227
228 let mut combined_scores = Array1::zeros(valid_features.len());
230 for i in 0..combined_scores.len() {
231 combined_scores[i] = 0.6 * idf_scores[i] + 0.4 * chi2_scores[i];
232 }
233
234 let mut scored_features: Vec<(usize, Float)> = combined_scores
236 .indexed_iter()
237 .map(|(i, &score)| (valid_features[i], score))
238 .collect();
239
240 scored_features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
241
242 let selected_features = if let Some(max_feat) = self.max_features {
243 scored_features
244 .iter()
245 .take(max_feat.min(scored_features.len()))
246 .map(|(i, _)| *i)
247 .collect::<Vec<_>>()
248 } else {
249 scored_features.iter().map(|(i, _)| *i).collect()
250 };
251
252 let feature_names: Vec<String> = selected_features
254 .iter()
255 .map(|&i| format!("term_{}", i))
256 .collect();
257
258 let vocabulary: HashMap<String, usize> = feature_names
259 .iter()
260 .enumerate()
261 .map(|(i, name)| (name.clone(), i))
262 .collect();
263
264 Ok(TextFeatureSelector {
265 min_df: self.min_df,
266 max_df: self.max_df,
267 max_features: self.max_features,
268 ngram_range: self.ngram_range,
269 include_pos: self.include_pos,
270 include_syntax: self.include_syntax,
271 state: PhantomData,
272 vocabulary_: Some(vocabulary),
273 idf_scores_: Some(idf_scores),
274 selected_features_: Some(selected_features),
275 feature_names_: Some(feature_names),
276 })
277 }
278}
279
280impl Transform<Array2<Float>> for TextFeatureSelector<Trained> {
281 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
282 let selected_features = self.selected_features_.as_ref().unwrap();
283 if selected_features.is_empty() {
284 return Err(SklearsError::InvalidInput(
285 "No features were selected".to_string(),
286 ));
287 }
288
289 let selected_indices: Vec<usize> = selected_features.to_vec();
290 Ok(x.select(Axis(1), &selected_indices))
291 }
292}
293
294impl SelectorMixin for TextFeatureSelector<Trained> {
295 fn get_support(&self) -> SklResult<Array1<bool>> {
296 let selected_features = self.selected_features_.as_ref().unwrap();
297 let n_features = self.idf_scores_.as_ref().unwrap().len()
298 + selected_features.iter().max().unwrap_or(&0)
299 + 1;
300 let mut support = Array1::from_elem(n_features, false);
301 for &idx in selected_features {
302 if idx < n_features {
303 support[idx] = true;
304 }
305 }
306 Ok(support)
307 }
308
309 fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
310 let selected_features = self.selected_features_.as_ref().unwrap();
311 Ok(indices
312 .iter()
313 .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
314 .collect())
315 }
316}
317
318impl TextFeatureSelector<Trained> {
319 pub fn vocabulary(&self) -> &HashMap<String, usize> {
324 self.vocabulary_.as_ref().unwrap()
325 }
326
327 pub fn idf_scores(&self) -> &Array1<Float> {
332 self.idf_scores_.as_ref().unwrap()
333 }
334
335 pub fn feature_names(&self) -> &[String] {
339 self.feature_names_.as_ref().unwrap()
340 }
341
342 pub fn selected_features(&self) -> &[usize] {
347 self.selected_features_.as_ref().unwrap()
348 }
349
350 pub fn n_features_selected(&self) -> usize {
352 self.selected_features_.as_ref().unwrap().len()
353 }
354
355 pub fn feature_summary(&self) -> Vec<(usize, &str, Float)> {
360 let indices = self.selected_features();
361 let names = self.feature_names();
362 let scores = self.idf_scores();
363
364 let mut summary: Vec<(usize, &str, Float)> = indices
365 .iter()
366 .zip(names.iter())
367 .zip(scores.iter())
368 .map(|((&idx, name), &score)| (idx, name.as_str(), score))
369 .collect();
370
371 summary.sort_by_key(|&(idx, _, _)| idx);
372 summary
373 }
374}
375
376fn compute_chi2_score(feature: &ArrayView1<Float>, target: &Array1<Float>) -> Float {
389 let feature_mean = feature.mean().unwrap_or(0.0);
392 let target_mean = target.mean().unwrap_or(0.0);
393
394 let mut chi2 = 0.0;
395 let n = feature.len();
396
397 for i in 0..n {
398 let f_i = if feature[i] > feature_mean { 1.0 } else { 0.0 };
399 let t_i = if target[i] > target_mean { 1.0 } else { 0.0 };
400
401 let observed = f_i * t_i;
402 let expected = (feature.sum() / n as Float) * (target.sum() / n as Float);
403
404 if expected > 0.0 {
405 chi2 += (observed - expected).powi(2) / expected;
406 }
407 }
408
409 chi2
410}
411
412fn compute_document_frequency(term_vector: &ArrayView1<Float>) -> Float {
417 let n_documents = term_vector.len() as Float;
418 let documents_with_term = term_vector.iter().filter(|&&count| count > 0.0).count() as Float;
419 documents_with_term / n_documents
420}
421
422fn compute_tfidf_score(
428 term_frequency: Float,
429 document_frequency: Float,
430 n_documents: usize,
431) -> Float {
432 let idf = (n_documents as Float / (1.0 + document_frequency * n_documents as Float)).ln();
433 term_frequency * idf
434}
435
436pub fn create_text_feature_selector() -> TextFeatureSelector<Untrained> {
438 TextFeatureSelector::new()
439}
440
441pub fn create_short_text_selector() -> TextFeatureSelector<Untrained> {
446 TextFeatureSelector::new()
447 .min_df(0.005) .max_df(0.8) .max_features(Some(500))
450 .ngram_range((1, 2)) }
452
453pub fn create_long_text_selector() -> TextFeatureSelector<Untrained> {
458 TextFeatureSelector::new()
459 .min_df(0.02) .max_df(0.95) .max_features(Some(2000))
462 .ngram_range((1, 3)) }
464
465pub fn create_multilingual_selector() -> TextFeatureSelector<Untrained> {
470 TextFeatureSelector::new()
471 .min_df(0.01)
472 .max_df(0.9) .max_features(Some(1500))
474 .include_pos(true) }
476
477pub fn create_classification_selector() -> TextFeatureSelector<Untrained> {
482 TextFeatureSelector::new()
483 .min_df(0.02)
484 .max_df(0.9)
485 .max_features(Some(1000))
486 .ngram_range((1, 2))
487}
488
489#[allow(non_snake_case)]
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use scirs2_core::ndarray::{array, Array2};
494
495 #[test]
496 fn test_document_frequency_computation() {
497 let term_vector = array![1.0, 0.0, 2.0, 0.0, 1.0]; let df = compute_document_frequency(&term_vector.view());
499 assert!((df - 0.6).abs() < 1e-10);
500 }
501
502 #[test]
503 fn test_tfidf_computation() {
504 let tf = 3.0;
505 let df = 0.5; let n_docs = 100;
507 let tfidf = compute_tfidf_score(tf, df, n_docs);
508
509 assert!(tfidf > 0.0);
511 }
512
513 #[test]
514 fn test_chi2_score_computation() {
515 let feature = array![1.0, 2.0, 3.0, 4.0, 5.0];
516 let target = array![1.0, 1.0, 0.0, 0.0, 0.0];
517 let chi2 = compute_chi2_score(&feature.view(), &target);
518
519 assert!(chi2 >= 0.0);
521 }
522
523 #[test]
524 fn test_text_feature_selector_basic() {
525 let selector = TextFeatureSelector::new().min_df(0.1).max_features(Some(2));
526
527 let x = Array2::from_shape_vec(
529 (4, 3),
530 vec![
531 1.0, 0.0, 2.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 2.0, 0.0, 1.0, ],
536 )
537 .unwrap();
538 let y = array![1.0, 0.0, 1.0, 0.0];
539
540 let fitted = selector.fit(&x, &y).unwrap();
541 assert!(fitted.n_features_selected() <= 2);
542
543 let transformed = fitted.transform(&x).unwrap();
544 assert!(transformed.ncols() <= 2);
545 }
546
547 #[test]
548 fn test_document_frequency_filtering() {
549 let selector = TextFeatureSelector::new()
550 .min_df(0.6) .max_df(1.0);
552
553 let x = Array2::from_shape_vec(
554 (5, 3),
555 vec![
556 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
558 ],
559 )
560 .unwrap();
561 let y = array![1.0, 0.0, 1.0, 0.0, 1.0];
562
563 let fitted = selector.fit(&x, &y).unwrap();
564
565 assert!(fitted.n_features_selected() <= 2);
567 }
568}