sklears_feature_selection/domain_specific/
text_features.rs

1//! Text feature selection module
2//!
3//! This module provides specialized feature selection algorithms for text data,
4//! including TF-IDF analysis, document frequency filtering, and linguistic features.
5
6use crate::base::SelectorMixin;
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
8use sklears_core::{
9    error::{validate, Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Trained, Transform, Untrained},
11    types::Float,
12};
13use std::collections::HashMap;
14use std::marker::PhantomData;
15
16/// Text feature selection using TF-IDF weights and linguistic analysis
17///
18/// This selector analyzes text features represented as term frequency matrices
19/// and applies various text-specific selection criteria:
20/// - Document frequency filtering (min_df, max_df)
21/// - TF-IDF scoring for term importance
22/// - Chi-squared statistical tests with target variables
23/// - N-gram analysis (when configured)
24/// - Part-of-speech and syntactic features (when enabled)
25///
26/// # Input Format
27///
28/// The input matrix `X` should be structured as:
29/// - Rows: Documents
30/// - Columns: Terms/features (e.g., from TF-IDF vectorization)
31/// - Values: Term frequencies or TF-IDF scores
32///
33/// # Examples
34///
35/// ```rust,ignore
36/// use sklears_feature_selection::domain_specific::text_features::TextFeatureSelector;
37/// use sklears_core::traits::{Fit, Transform};
38/// use scirs2_core::ndarray::{Array1, Array2};
39///
40/// let selector = TextFeatureSelector::new()
41///     .min_df(0.02)              // Minimum 2% document frequency
42///     .max_df(0.90)              // Maximum 90% document frequency
43///     .max_features(Some(500))   // Select top 500 features
44///     .ngram_range((1, 2))       // Include unigrams and bigrams
45///     .include_pos(true);        // Include part-of-speech features
46///
47/// let x = Array2::zeros((100, 1000)); // 100 documents, 1000 terms
48/// let y = Array1::zeros(100);          // Document labels
49///
50/// let fitted_selector = selector.fit(&x, &y)?;
51/// let transformed_x = fitted_selector.transform(&x)?;
52/// ```
53#[derive(Debug, Clone)]
54pub struct TextFeatureSelector<State = Untrained> {
55    /// Minimum document frequency for a term to be considered
56    min_df: f64,
57    /// Maximum document frequency for a term to be considered
58    max_df: f64,
59    /// Maximum number of features to select
60    max_features: Option<usize>,
61    /// Whether to use n-grams (1=unigrams, 2=bigrams, etc.)
62    ngram_range: (usize, usize),
63    /// Whether to include part-of-speech features
64    include_pos: bool,
65    /// Whether to include syntactic features
66    include_syntax: bool,
67    state: PhantomData<State>,
68    // Trained state
69    vocabulary_: Option<HashMap<String, usize>>,
70    idf_scores_: Option<Array1<Float>>,
71    selected_features_: Option<Vec<usize>>,
72    feature_names_: Option<Vec<String>>,
73}
74
75impl TextFeatureSelector<Untrained> {
76    pub fn new() -> Self {
77        Self {
78            min_df: 0.01,
79            max_df: 0.95,
80            max_features: Some(1000),
81            ngram_range: (1, 1),
82            include_pos: false,
83            include_syntax: false,
84            state: PhantomData,
85            vocabulary_: None,
86            idf_scores_: None,
87            selected_features_: None,
88            feature_names_: None,
89        }
90    }
91
92    /// Set the minimum document frequency threshold
93    ///
94    /// Terms appearing in fewer than `min_df` fraction of documents
95    /// will be filtered out. This helps remove very rare terms that
96    /// may not be reliable predictors.
97    ///
98    /// # Arguments
99    /// * `min_df` - Fraction between 0.0 and 1.0
100    pub fn min_df(mut self, min_df: f64) -> Self {
101        self.min_df = min_df;
102        self
103    }
104
105    /// Set the maximum document frequency threshold
106    ///
107    /// Terms appearing in more than `max_df` fraction of documents
108    /// will be filtered out. This helps remove very common terms
109    /// (like stop words) that may not be discriminative.
110    ///
111    /// # Arguments
112    /// * `max_df` - Fraction between 0.0 and 1.0
113    pub fn max_df(mut self, max_df: f64) -> Self {
114        self.max_df = max_df;
115        self
116    }
117
118    /// Set the maximum number of features to select
119    ///
120    /// When set to `Some(n)`, selects the top n features by combined score.
121    /// When set to `None`, uses document frequency filtering only.
122    pub fn max_features(mut self, max_features: Option<usize>) -> Self {
123        self.max_features = max_features;
124        self
125    }
126
127    /// Set the n-gram range for feature extraction
128    ///
129    /// - (1, 1): Unigrams only
130    /// - (1, 2): Unigrams and bigrams
131    /// - (2, 3): Bigrams and trigrams
132    /// - etc.
133    ///
134    /// Note: This parameter is informational for compatibility;
135    /// actual n-gram extraction should be done during preprocessing.
136    pub fn ngram_range(mut self, ngram_range: (usize, usize)) -> Self {
137        self.ngram_range = ngram_range;
138        self
139    }
140
141    /// Enable or disable part-of-speech features
142    ///
143    /// When enabled, the selector will give preference to features
144    /// that represent important part-of-speech categories.
145    ///
146    /// Note: This requires preprocessing to extract POS features.
147    pub fn include_pos(mut self, include_pos: bool) -> Self {
148        self.include_pos = include_pos;
149        self
150    }
151
152    /// Enable or disable syntactic features
153    ///
154    /// When enabled, the selector will consider syntactic relationships
155    /// and dependency parsing features.
156    ///
157    /// Note: This requires preprocessing to extract syntactic features.
158    pub fn include_syntax(mut self, include_syntax: bool) -> Self {
159        self.include_syntax = include_syntax;
160        self
161    }
162}
163
164impl Default for TextFeatureSelector<Untrained> {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170impl Estimator for TextFeatureSelector<Untrained> {
171    type Config = ();
172    type Error = SklearsError;
173    type Float = f64;
174
175    fn config(&self) -> &Self::Config {
176        &()
177    }
178}
179
180impl Fit<Array2<Float>, Array1<Float>> for TextFeatureSelector<Untrained> {
181    type Fitted = TextFeatureSelector<Trained>;
182
183    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
184        validate::check_consistent_length(x, y)?;
185
186        let (n_documents, n_features) = x.dim();
187
188        // Compute document frequencies for each feature (term)
189        let mut document_frequencies = Array1::zeros(n_features);
190        for j in 0..n_features {
191            let mut df = 0.0;
192            for i in 0..n_documents {
193                if x[[i, j]] > 0.0 {
194                    df += 1.0;
195                }
196            }
197            document_frequencies[j] = df / n_documents as f64;
198        }
199
200        // Filter features based on document frequency
201        let mut valid_features = Vec::new();
202        for (j, &df) in document_frequencies.iter().enumerate() {
203            if df >= self.min_df && df <= self.max_df {
204                valid_features.push(j);
205            }
206        }
207
208        if valid_features.is_empty() {
209            return Err(SklearsError::InvalidInput(
210                "No features pass the document frequency filters".to_string(),
211            ));
212        }
213
214        // Compute IDF scores for valid features
215        let mut idf_scores = Array1::zeros(valid_features.len());
216        for (idx, &j) in valid_features.iter().enumerate() {
217            let df = document_frequencies[j];
218            idf_scores[idx] = (n_documents as f64 / (1.0 + df * n_documents as f64)).ln();
219        }
220
221        // Compute feature importance using chi-squared test with target
222        let mut chi2_scores = Array1::zeros(valid_features.len());
223        for (idx, &j) in valid_features.iter().enumerate() {
224            let feature_col = x.column(j);
225            chi2_scores[idx] = compute_chi2_score(&feature_col, y);
226        }
227
228        // Combine IDF and chi-squared scores
229        let mut combined_scores = Array1::zeros(valid_features.len());
230        for i in 0..combined_scores.len() {
231            combined_scores[i] = 0.6 * idf_scores[i] + 0.4 * chi2_scores[i];
232        }
233
234        // Select top features
235        let mut scored_features: Vec<(usize, Float)> = combined_scores
236            .indexed_iter()
237            .map(|(i, &score)| (valid_features[i], score))
238            .collect();
239
240        scored_features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
241
242        let selected_features = if let Some(max_feat) = self.max_features {
243            scored_features
244                .iter()
245                .take(max_feat.min(scored_features.len()))
246                .map(|(i, _)| *i)
247                .collect::<Vec<_>>()
248        } else {
249            scored_features.iter().map(|(i, _)| *i).collect()
250        };
251
252        // Create feature names (simplified - in real implementation would use actual vocabulary)
253        let feature_names: Vec<String> = selected_features
254            .iter()
255            .map(|&i| format!("term_{}", i))
256            .collect();
257
258        let vocabulary: HashMap<String, usize> = feature_names
259            .iter()
260            .enumerate()
261            .map(|(i, name)| (name.clone(), i))
262            .collect();
263
264        Ok(TextFeatureSelector {
265            min_df: self.min_df,
266            max_df: self.max_df,
267            max_features: self.max_features,
268            ngram_range: self.ngram_range,
269            include_pos: self.include_pos,
270            include_syntax: self.include_syntax,
271            state: PhantomData,
272            vocabulary_: Some(vocabulary),
273            idf_scores_: Some(idf_scores),
274            selected_features_: Some(selected_features),
275            feature_names_: Some(feature_names),
276        })
277    }
278}
279
280impl Transform<Array2<Float>> for TextFeatureSelector<Trained> {
281    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
282        let selected_features = self.selected_features_.as_ref().unwrap();
283        if selected_features.is_empty() {
284            return Err(SklearsError::InvalidInput(
285                "No features were selected".to_string(),
286            ));
287        }
288
289        let selected_indices: Vec<usize> = selected_features.to_vec();
290        Ok(x.select(Axis(1), &selected_indices))
291    }
292}
293
294impl SelectorMixin for TextFeatureSelector<Trained> {
295    fn get_support(&self) -> SklResult<Array1<bool>> {
296        let selected_features = self.selected_features_.as_ref().unwrap();
297        let n_features = self.idf_scores_.as_ref().unwrap().len()
298            + selected_features.iter().max().unwrap_or(&0)
299            + 1;
300        let mut support = Array1::from_elem(n_features, false);
301        for &idx in selected_features {
302            if idx < n_features {
303                support[idx] = true;
304            }
305        }
306        Ok(support)
307    }
308
309    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
310        let selected_features = self.selected_features_.as_ref().unwrap();
311        Ok(indices
312            .iter()
313            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
314            .collect())
315    }
316}
317
318impl TextFeatureSelector<Trained> {
319    /// Get the vocabulary mapping from terms to feature indices
320    ///
321    /// Returns a reference to the vocabulary dictionary where keys are
322    /// term names and values are their corresponding feature indices.
323    pub fn vocabulary(&self) -> &HashMap<String, usize> {
324        self.vocabulary_.as_ref().unwrap()
325    }
326
327    /// Get the IDF (Inverse Document Frequency) scores
328    ///
329    /// Returns an array where each element is the IDF score for the
330    /// corresponding selected feature.
331    pub fn idf_scores(&self) -> &Array1<Float> {
332        self.idf_scores_.as_ref().unwrap()
333    }
334
335    /// Get the names of selected features
336    ///
337    /// Returns a reference to the vector of feature names that were selected.
338    pub fn feature_names(&self) -> &[String] {
339        self.feature_names_.as_ref().unwrap()
340    }
341
342    /// Get the indices of selected features
343    ///
344    /// Returns a reference to the vector of original feature indices
345    /// that were selected during fitting.
346    pub fn selected_features(&self) -> &[usize] {
347        self.selected_features_.as_ref().unwrap()
348    }
349
350    /// Get the number of selected features
351    pub fn n_features_selected(&self) -> usize {
352        self.selected_features_.as_ref().unwrap().len()
353    }
354
355    /// Get feature information as a structured summary
356    ///
357    /// Returns a vector of tuples containing (feature_index, feature_name, idf_score)
358    /// for all selected features, sorted by feature index.
359    pub fn feature_summary(&self) -> Vec<(usize, &str, Float)> {
360        let indices = self.selected_features();
361        let names = self.feature_names();
362        let scores = self.idf_scores();
363
364        let mut summary: Vec<(usize, &str, Float)> = indices
365            .iter()
366            .zip(names.iter())
367            .zip(scores.iter())
368            .map(|((&idx, name), &score)| (idx, name.as_str(), score))
369            .collect();
370
371        summary.sort_by_key(|&(idx, _, _)| idx);
372        summary
373    }
374}
375
376// ================================================================================================
377// Helper Functions
378// ================================================================================================
379
380/// Compute a simplified chi-squared score for feature selection
381///
382/// This function computes a chi-squared-like statistic for continuous features
383/// by discretizing them based on their mean values. In practice, more sophisticated
384/// discretization methods (like equal-frequency binning) would be preferred.
385///
386/// The chi-squared test measures the independence between the feature and target
387/// variables. Higher scores indicate stronger association.
388fn compute_chi2_score(feature: &ArrayView1<Float>, target: &Array1<Float>) -> Float {
389    // Simplified chi-squared computation for continuous features
390    // In practice, would need proper discretization
391    let feature_mean = feature.mean().unwrap_or(0.0);
392    let target_mean = target.mean().unwrap_or(0.0);
393
394    let mut chi2 = 0.0;
395    let n = feature.len();
396
397    for i in 0..n {
398        let f_i = if feature[i] > feature_mean { 1.0 } else { 0.0 };
399        let t_i = if target[i] > target_mean { 1.0 } else { 0.0 };
400
401        let observed = f_i * t_i;
402        let expected = (feature.sum() / n as Float) * (target.sum() / n as Float);
403
404        if expected > 0.0 {
405            chi2 += (observed - expected).powi(2) / expected;
406        }
407    }
408
409    chi2
410}
411
412/// Compute document frequency for a term vector
413///
414/// Document frequency is the number of documents containing the term
415/// divided by the total number of documents.
416fn compute_document_frequency(term_vector: &ArrayView1<Float>) -> Float {
417    let n_documents = term_vector.len() as Float;
418    let documents_with_term = term_vector.iter().filter(|&&count| count > 0.0).count() as Float;
419    documents_with_term / n_documents
420}
421
422/// Compute TF-IDF score for a term
423///
424/// TF-IDF (Term Frequency-Inverse Document Frequency) is calculated as:
425/// tf-idf(t,d) = tf(t,d) * idf(t)
426/// where idf(t) = log(N / df(t))
427fn compute_tfidf_score(
428    term_frequency: Float,
429    document_frequency: Float,
430    n_documents: usize,
431) -> Float {
432    let idf = (n_documents as Float / (1.0 + document_frequency * n_documents as Float)).ln();
433    term_frequency * idf
434}
435
436/// Create a new text feature selector
437pub fn create_text_feature_selector() -> TextFeatureSelector<Untrained> {
438    TextFeatureSelector::new()
439}
440
441/// Create a text feature selector optimized for short documents
442///
443/// Suitable for tweets, short articles, or other brief text content
444/// where term frequency is typically low.
445pub fn create_short_text_selector() -> TextFeatureSelector<Untrained> {
446    TextFeatureSelector::new()
447        .min_df(0.005) // Lower minimum frequency for short texts
448        .max_df(0.8) // Lower maximum to filter common words
449        .max_features(Some(500))
450        .ngram_range((1, 2)) // Include bigrams for context
451}
452
453/// Create a text feature selector optimized for long documents
454///
455/// Suitable for articles, papers, books, or other lengthy text content
456/// where term frequencies are higher and vocabulary is richer.
457pub fn create_long_text_selector() -> TextFeatureSelector<Untrained> {
458    TextFeatureSelector::new()
459        .min_df(0.02) // Higher minimum frequency
460        .max_df(0.95) // Higher maximum frequency
461        .max_features(Some(2000))
462        .ngram_range((1, 3)) // Include up to trigrams
463}
464
465/// Create a text feature selector for multilingual content
466///
467/// Suitable for text in multiple languages where stop words and
468/// common patterns may vary significantly.
469pub fn create_multilingual_selector() -> TextFeatureSelector<Untrained> {
470    TextFeatureSelector::new()
471        .min_df(0.01)
472        .max_df(0.9) // Conservative max_df for varied languages
473        .max_features(Some(1500))
474        .include_pos(true) // POS tags are language-agnostic
475}
476
477/// Create a text feature selector for classification tasks
478///
479/// Optimized for discriminative feature selection in classification
480/// problems where the goal is to distinguish between classes.
481pub fn create_classification_selector() -> TextFeatureSelector<Untrained> {
482    TextFeatureSelector::new()
483        .min_df(0.02)
484        .max_df(0.9)
485        .max_features(Some(1000))
486        .ngram_range((1, 2))
487}
488
489#[allow(non_snake_case)]
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use scirs2_core::ndarray::{array, Array2};
494
495    #[test]
496    fn test_document_frequency_computation() {
497        let term_vector = array![1.0, 0.0, 2.0, 0.0, 1.0]; // Term appears in 3/5 documents
498        let df = compute_document_frequency(&term_vector.view());
499        assert!((df - 0.6).abs() < 1e-10);
500    }
501
502    #[test]
503    fn test_tfidf_computation() {
504        let tf = 3.0;
505        let df = 0.5; // Term appears in 50% of documents
506        let n_docs = 100;
507        let tfidf = compute_tfidf_score(tf, df, n_docs);
508
509        // Should be positive (tf * log(N / (1 + df*N)))
510        assert!(tfidf > 0.0);
511    }
512
513    #[test]
514    fn test_chi2_score_computation() {
515        let feature = array![1.0, 2.0, 3.0, 4.0, 5.0];
516        let target = array![1.0, 1.0, 0.0, 0.0, 0.0];
517        let chi2 = compute_chi2_score(&feature.view(), &target);
518
519        // Should compute some positive score
520        assert!(chi2 >= 0.0);
521    }
522
523    #[test]
524    fn test_text_feature_selector_basic() {
525        let selector = TextFeatureSelector::new().min_df(0.1).max_features(Some(2));
526
527        // Create a simple term-document matrix
528        let x = Array2::from_shape_vec(
529            (4, 3),
530            vec![
531                1.0, 0.0, 2.0, // Doc 1: term1, term3
532                0.0, 1.0, 1.0, // Doc 2: term2, term3
533                1.0, 1.0, 0.0, // Doc 3: term1, term2
534                2.0, 0.0, 1.0, // Doc 4: term1, term3
535            ],
536        )
537        .unwrap();
538        let y = array![1.0, 0.0, 1.0, 0.0];
539
540        let fitted = selector.fit(&x, &y).unwrap();
541        assert!(fitted.n_features_selected() <= 2);
542
543        let transformed = fitted.transform(&x).unwrap();
544        assert!(transformed.ncols() <= 2);
545    }
546
547    #[test]
548    fn test_document_frequency_filtering() {
549        let selector = TextFeatureSelector::new()
550            .min_df(0.6) // Require term in at least 60% of documents
551            .max_df(1.0);
552
553        let x = Array2::from_shape_vec(
554            (5, 3),
555            vec![
556                1.0, 0.0, 1.0, // term1: 4/5, term2: 1/5, term3: 5/5
557                1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0,
558            ],
559        )
560        .unwrap();
561        let y = array![1.0, 0.0, 1.0, 0.0, 1.0];
562
563        let fitted = selector.fit(&x, &y).unwrap();
564
565        // Only term1 (4/5 = 80%) and term3 (5/5 = 100%) should pass the 60% threshold
566        assert!(fitted.n_features_selected() <= 2);
567    }
568}