scirs2_transform/
encoding.rs

1//! Categorical data encoding utilities
2//!
3//! This module provides methods for encoding categorical data into numerical
4//! formats suitable for machine learning algorithms.
5
6use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
7use scirs2_core::numeric::{Float, NumCast};
8use std::collections::HashMap;
9
10use crate::error::{Result, TransformError};
11
12/// Simple sparse matrix representation in COO (Coordinate) format
13#[derive(Debug, Clone)]
14pub struct SparseMatrix {
15    /// Shape of the matrix (rows, cols)
16    pub shape: (usize, usize),
17    /// Row indices of non-zero values
18    pub row_indices: Vec<usize>,
19    /// Column indices of non-zero values
20    pub col_indices: Vec<usize>,
21    /// Non-zero values
22    pub values: Vec<f64>,
23}
24
25impl SparseMatrix {
26    /// Create a new empty sparse matrix
27    pub fn new(shape: (usize, usize)) -> Self {
28        SparseMatrix {
29            shape,
30            row_indices: Vec::new(),
31            col_indices: Vec::new(),
32            values: Vec::new(),
33        }
34    }
35
36    /// Add a non-zero value at (row, col)
37    pub fn push(&mut self, row: usize, col: usize, value: f64) {
38        if row < self.shape.0 && col < self.shape.1 && value != 0.0 {
39            self.row_indices.push(row);
40            self.col_indices.push(col);
41            self.values.push(value);
42        }
43    }
44
45    /// Convert to dense Array2
46    pub fn to_dense(&self) -> Array2<f64> {
47        let mut dense = Array2::zeros(self.shape);
48        for ((&row, &col), &val) in self
49            .row_indices
50            .iter()
51            .zip(self.col_indices.iter())
52            .zip(self.values.iter())
53        {
54            dense[[row, col]] = val;
55        }
56        dense
57    }
58
59    /// Get number of non-zero elements
60    pub fn nnz(&self) -> usize {
61        self.values.len()
62    }
63}
64
65/// Output format for encoded data
66#[derive(Debug, Clone)]
67pub enum EncodedOutput {
68    /// Dense matrix representation
69    Dense(Array2<f64>),
70    /// Sparse matrix representation
71    Sparse(SparseMatrix),
72}
73
74impl EncodedOutput {
75    /// Convert to dense matrix (creates copy if sparse)
76    pub fn to_dense(&self) -> Array2<f64> {
77        match self {
78            EncodedOutput::Dense(arr) => arr.clone(),
79            EncodedOutput::Sparse(sparse) => sparse.to_dense(),
80        }
81    }
82
83    /// Get shape of the output
84    pub fn shape(&self) -> (usize, usize) {
85        match self {
86            EncodedOutput::Dense(arr) => (arr.nrows(), arr.ncols()),
87            EncodedOutput::Sparse(sparse) => sparse.shape,
88        }
89    }
90}
91
92/// OneHotEncoder for converting categorical features to binary features
93///
94/// This transformer converts categorical features into a one-hot encoded representation,
95/// where each category is represented by a binary feature.
96pub struct OneHotEncoder {
97    /// Categories for each feature (learned during fit)
98    categories_: Option<Vec<Vec<u64>>>,
99    /// Whether to drop one category per feature to avoid collinearity
100    drop: Option<String>,
101    /// Whether to handle unknown categories
102    handleunknown: String,
103    /// Whether to return sparse matrix output
104    sparse: bool,
105}
106
107impl OneHotEncoder {
108    /// Creates a new OneHotEncoder
109    ///
110    /// # Arguments
111    /// * `drop` - Strategy for dropping categories ('first', 'if_binary', or None)
112    /// * `handleunknown` - How to handle unknown categories ('error' or 'ignore')
113    /// * `sparse` - Whether to return sparse arrays
114    ///
115    /// # Returns
116    /// * A new OneHotEncoder instance
117    pub fn new(_drop: Option<String>, handleunknown: &str, sparse: bool) -> Result<Self> {
118        if let Some(ref drop_strategy) = _drop {
119            if drop_strategy != "first" && drop_strategy != "if_binary" {
120                return Err(TransformError::InvalidInput(
121                    "_drop must be 'first', 'if_binary', or None".to_string(),
122                ));
123            }
124        }
125
126        if handleunknown != "error" && handleunknown != "ignore" {
127            return Err(TransformError::InvalidInput(
128                "handleunknown must be 'error' or 'ignore'".to_string(),
129            ));
130        }
131
132        Ok(OneHotEncoder {
133            categories_: None,
134            drop: _drop,
135            handleunknown: handleunknown.to_string(),
136            sparse,
137        })
138    }
139
140    /// Creates a OneHotEncoder with default settings
141    pub fn with_defaults() -> Self {
142        Self::new(None, "error", false).unwrap()
143    }
144
145    /// Fits the OneHotEncoder to the input data
146    ///
147    /// # Arguments
148    /// * `x` - The input categorical data, shape (n_samples, n_features)
149    ///
150    /// # Returns
151    /// * `Result<()>` - Ok if successful, Err otherwise
152    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
153    where
154        S: Data,
155        S::Elem: Float + NumCast,
156    {
157        let x_u64 = x.mapv(|x| {
158            let val_f64 = NumCast::from(x).unwrap_or(0.0);
159            val_f64 as u64
160        });
161
162        let n_samples = x_u64.shape()[0];
163        let n_features = x_u64.shape()[1];
164
165        if n_samples == 0 || n_features == 0 {
166            return Err(TransformError::InvalidInput("Empty input data".to_string()));
167        }
168
169        let mut categories = Vec::with_capacity(n_features);
170
171        for j in 0..n_features {
172            // Collect unique values for this feature
173            let mut unique_values: Vec<u64> = x_u64.column(j).to_vec();
174            unique_values.sort_unstable();
175            unique_values.dedup();
176
177            categories.push(unique_values);
178        }
179
180        self.categories_ = Some(categories);
181        Ok(())
182    }
183
184    /// Transforms the input data using the fitted OneHotEncoder
185    ///
186    /// # Arguments
187    /// * `x` - The input categorical data, shape (n_samples, n_features)
188    ///
189    /// # Returns
190    /// * `Result<EncodedOutput>` - The one-hot encoded data (dense or sparse)
191    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<EncodedOutput>
192    where
193        S: Data,
194        S::Elem: Float + NumCast,
195    {
196        let x_u64 = x.mapv(|x| {
197            let val_f64 = NumCast::from(x).unwrap_or(0.0);
198            val_f64 as u64
199        });
200
201        let n_samples = x_u64.shape()[0];
202        let n_features = x_u64.shape()[1];
203
204        if self.categories_.is_none() {
205            return Err(TransformError::TransformationError(
206                "OneHotEncoder has not been fitted".to_string(),
207            ));
208        }
209
210        let categories = self.categories_.as_ref().unwrap();
211
212        if n_features != categories.len() {
213            return Err(TransformError::InvalidInput(format!(
214                "x has {} features, but OneHotEncoder was fitted with {} features",
215                n_features,
216                categories.len()
217            )));
218        }
219
220        // Calculate total number of output features
221        let mut total_features = 0;
222        for (j, feature_categories) in categories.iter().enumerate() {
223            let n_cats = feature_categories.len();
224
225            // Apply drop strategy
226            let n_output_cats = match &self.drop {
227                Some(strategy) if strategy == "first" => n_cats.saturating_sub(1),
228                Some(strategy) if strategy == "if_binary" && n_cats == 2 => 1,
229                _ => n_cats,
230            };
231
232            if n_output_cats == 0 {
233                return Err(TransformError::InvalidInput(format!(
234                    "Feature {j} has only one category after dropping"
235                )));
236            }
237
238            total_features += n_output_cats;
239        }
240
241        // Create mappings from category values to column indices
242        let mut category_mappings = Vec::new();
243        let mut current_col = 0;
244
245        for feature_categories in categories.iter() {
246            let mut mapping = HashMap::new();
247            let n_cats = feature_categories.len();
248
249            // Determine how many categories to keep
250            let (start_idx, n_output_cats) = match &self.drop {
251                Some(strategy) if strategy == "first" => (1, n_cats.saturating_sub(1)),
252                Some(strategy) if strategy == "if_binary" && n_cats == 2 => (0, 1),
253                _ => (0, n_cats),
254            };
255
256            for (cat_idx, &category) in feature_categories.iter().enumerate() {
257                if cat_idx >= start_idx && cat_idx < start_idx + n_output_cats {
258                    mapping.insert(category, current_col + cat_idx - start_idx);
259                }
260            }
261
262            category_mappings.push(mapping);
263            current_col += n_output_cats;
264        }
265
266        // Create output based on sparse setting
267        if self.sparse {
268            // Sparse output
269            let mut sparse_matrix = SparseMatrix::new((n_samples, total_features));
270
271            for i in 0..n_samples {
272                for j in 0..n_features {
273                    let value = x_u64[[i, j]];
274
275                    if let Some(&col_idx) = category_mappings[j].get(&value) {
276                        sparse_matrix.push(i, col_idx, 1.0);
277                    } else {
278                        // Check if this is a dropped category
279                        let feature_categories = &categories[j];
280                        let is_dropped_category = match &self.drop {
281                            Some(strategy) if strategy == "first" => {
282                                !feature_categories.is_empty() && value == feature_categories[0]
283                            }
284                            Some(strategy)
285                                if strategy == "if_binary" && feature_categories.len() == 2 =>
286                            {
287                                feature_categories.len() == 2 && value == feature_categories[1]
288                            }
289                            _ => false,
290                        };
291
292                        if !is_dropped_category && self.handleunknown == "error" {
293                            return Err(TransformError::InvalidInput(format!(
294                                "Found unknown category {value} in feature {j}"
295                            )));
296                        }
297                        // If it's a dropped category or handleunknown == "ignore", we don't add anything (sparse)
298                    }
299                }
300            }
301
302            Ok(EncodedOutput::Sparse(sparse_matrix))
303        } else {
304            // Dense output
305            let mut transformed = Array2::zeros((n_samples, total_features));
306
307            for i in 0..n_samples {
308                for j in 0..n_features {
309                    let value = x_u64[[i, j]];
310
311                    if let Some(&col_idx) = category_mappings[j].get(&value) {
312                        transformed[[i, col_idx]] = 1.0;
313                    } else {
314                        // Check if this is a dropped category
315                        let feature_categories = &categories[j];
316                        let is_dropped_category = match &self.drop {
317                            Some(strategy) if strategy == "first" => {
318                                !feature_categories.is_empty() && value == feature_categories[0]
319                            }
320                            Some(strategy)
321                                if strategy == "if_binary" && feature_categories.len() == 2 =>
322                            {
323                                feature_categories.len() == 2 && value == feature_categories[1]
324                            }
325                            _ => false,
326                        };
327
328                        if !is_dropped_category && self.handleunknown == "error" {
329                            return Err(TransformError::InvalidInput(format!(
330                                "Found unknown category {value} in feature {j}"
331                            )));
332                        }
333                        // If it's a dropped category or handleunknown == "ignore", we just leave it as 0
334                    }
335                }
336            }
337
338            Ok(EncodedOutput::Dense(transformed))
339        }
340    }
341
342    /// Fits the OneHotEncoder to the input data and transforms it
343    ///
344    /// # Arguments
345    /// * `x` - The input categorical data, shape (n_samples, n_features)
346    ///
347    /// # Returns
348    /// * `Result<EncodedOutput>` - The one-hot encoded data (dense or sparse)
349    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<EncodedOutput>
350    where
351        S: Data,
352        S::Elem: Float + NumCast,
353    {
354        self.fit(x)?;
355        self.transform(x)
356    }
357
358    /// Convenience method that always returns dense output for backward compatibility
359    ///
360    /// # Arguments
361    /// * `x` - The input categorical data, shape (n_samples, n_features)
362    ///
363    /// # Returns
364    /// * `Result<Array2<f64>>` - The one-hot encoded data as dense matrix
365    pub fn transform_dense<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
366    where
367        S: Data,
368        S::Elem: Float + NumCast,
369    {
370        Ok(self.transform(x)?.to_dense())
371    }
372
373    /// Convenience method that fits and transforms returning dense output
374    ///
375    /// # Arguments
376    /// * `x` - The input categorical data, shape (n_samples, n_features)
377    ///
378    /// # Returns
379    /// * `Result<Array2<f64>>` - The one-hot encoded data as dense matrix
380    pub fn fit_transform_dense<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
381    where
382        S: Data,
383        S::Elem: Float + NumCast,
384    {
385        Ok(self.fit_transform(x)?.to_dense())
386    }
387
388    /// Returns the categories for each feature
389    ///
390    /// # Returns
391    /// * `Option<&Vec<Vec<u64>>>` - The categories for each feature
392    pub fn categories(&self) -> Option<&Vec<Vec<u64>>> {
393        self.categories_.as_ref()
394    }
395
396    /// Gets the feature names for the transformed output
397    ///
398    /// # Arguments
399    /// * `inputfeatures` - Names of input features
400    ///
401    /// # Returns
402    /// * `Result<Vec<String>>` - Names of output features
403    pub fn get_feature_names(&self, inputfeatures: Option<&[String]>) -> Result<Vec<String>> {
404        if self.categories_.is_none() {
405            return Err(TransformError::TransformationError(
406                "OneHotEncoder has not been fitted".to_string(),
407            ));
408        }
409
410        let categories = self.categories_.as_ref().unwrap();
411        let mut feature_names = Vec::new();
412
413        for (j, feature_categories) in categories.iter().enumerate() {
414            let feature_name = if let Some(names) = inputfeatures {
415                if j < names.len() {
416                    names[j].clone()
417                } else {
418                    format!("x{j}")
419                }
420            } else {
421                format!("x{j}")
422            };
423
424            let n_cats = feature_categories.len();
425
426            // Determine which categories to include based on drop strategy
427            let (start_idx, n_output_cats) = match &self.drop {
428                Some(strategy) if strategy == "first" => (1, n_cats.saturating_sub(1)),
429                Some(strategy) if strategy == "if_binary" && n_cats == 2 => (0, 1),
430                _ => (0, n_cats),
431            };
432
433            for &category in feature_categories
434                .iter()
435                .skip(start_idx)
436                .take(n_output_cats)
437            {
438                feature_names.push(format!("{feature_name}_cat_{category}"));
439            }
440        }
441
442        Ok(feature_names)
443    }
444}
445
446/// OrdinalEncoder for converting categorical features to ordinal integers
447///
448/// This transformer converts categorical features into ordinal integers,
449/// where each category is assigned a unique integer.
450pub struct OrdinalEncoder {
451    /// Categories for each feature (learned during fit)
452    categories_: Option<Vec<Vec<u64>>>,
453    /// How to handle unknown categories
454    handleunknown: String,
455    /// Value to use for unknown categories
456    unknownvalue: Option<f64>,
457}
458
459impl OrdinalEncoder {
460    /// Creates a new OrdinalEncoder
461    ///
462    /// # Arguments
463    /// * `handleunknown` - How to handle unknown categories ('error' or 'use_encoded_value')
464    /// * `unknownvalue` - Value to use for unknown categories (when handleunknown='use_encoded_value')
465    ///
466    /// # Returns
467    /// * A new OrdinalEncoder instance
468    pub fn new(handleunknown: &str, unknownvalue: Option<f64>) -> Result<Self> {
469        if handleunknown != "error" && handleunknown != "use_encoded_value" {
470            return Err(TransformError::InvalidInput(
471                "handleunknown must be 'error' or 'use_encoded_value'".to_string(),
472            ));
473        }
474
475        if handleunknown == "use_encoded_value" && unknownvalue.is_none() {
476            return Err(TransformError::InvalidInput(
477                "unknownvalue must be specified when handleunknown='use_encoded_value'".to_string(),
478            ));
479        }
480
481        Ok(OrdinalEncoder {
482            categories_: None,
483            handleunknown: handleunknown.to_string(),
484            unknownvalue,
485        })
486    }
487
488    /// Creates an OrdinalEncoder with default settings
489    pub fn with_defaults() -> Self {
490        Self::new("error", None).unwrap()
491    }
492
493    /// Fits the OrdinalEncoder to the input data
494    ///
495    /// # Arguments
496    /// * `x` - The input categorical data, shape (n_samples, n_features)
497    ///
498    /// # Returns
499    /// * `Result<()>` - Ok if successful, Err otherwise
500    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
501    where
502        S: Data,
503        S::Elem: Float + NumCast,
504    {
505        let x_u64 = x.mapv(|x| {
506            let val_f64 = NumCast::from(x).unwrap_or(0.0);
507            val_f64 as u64
508        });
509
510        let n_samples = x_u64.shape()[0];
511        let n_features = x_u64.shape()[1];
512
513        if n_samples == 0 || n_features == 0 {
514            return Err(TransformError::InvalidInput("Empty input data".to_string()));
515        }
516
517        let mut categories = Vec::with_capacity(n_features);
518
519        for j in 0..n_features {
520            // Collect unique values for this feature
521            let mut unique_values: Vec<u64> = x_u64.column(j).to_vec();
522            unique_values.sort_unstable();
523            unique_values.dedup();
524
525            categories.push(unique_values);
526        }
527
528        self.categories_ = Some(categories);
529        Ok(())
530    }
531
532    /// Transforms the input data using the fitted OrdinalEncoder
533    ///
534    /// # Arguments
535    /// * `x` - The input categorical data, shape (n_samples, n_features)
536    ///
537    /// # Returns
538    /// * `Result<Array2<f64>>` - The ordinally encoded data
539    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
540    where
541        S: Data,
542        S::Elem: Float + NumCast,
543    {
544        let x_u64 = x.mapv(|x| {
545            let val_f64 = NumCast::from(x).unwrap_or(0.0);
546            val_f64 as u64
547        });
548
549        let n_samples = x_u64.shape()[0];
550        let n_features = x_u64.shape()[1];
551
552        if self.categories_.is_none() {
553            return Err(TransformError::TransformationError(
554                "OrdinalEncoder has not been fitted".to_string(),
555            ));
556        }
557
558        let categories = self.categories_.as_ref().unwrap();
559
560        if n_features != categories.len() {
561            return Err(TransformError::InvalidInput(format!(
562                "x has {} features, but OrdinalEncoder was fitted with {} features",
563                n_features,
564                categories.len()
565            )));
566        }
567
568        let mut transformed = Array2::zeros((n_samples, n_features));
569
570        // Create mappings from category values to ordinal values
571        let mut category_mappings = Vec::new();
572        for feature_categories in categories {
573            let mut mapping = HashMap::new();
574            for (ordinal, &category) in feature_categories.iter().enumerate() {
575                mapping.insert(category, ordinal as f64);
576            }
577            category_mappings.push(mapping);
578        }
579
580        // Fill the transformed array
581        for i in 0..n_samples {
582            for j in 0..n_features {
583                let value = x_u64[[i, j]];
584
585                if let Some(&ordinal_value) = category_mappings[j].get(&value) {
586                    transformed[[i, j]] = ordinal_value;
587                } else if self.handleunknown == "error" {
588                    return Err(TransformError::InvalidInput(format!(
589                        "Found unknown category {value} in feature {j}"
590                    )));
591                } else {
592                    // handleunknown == "use_encoded_value"
593                    transformed[[i, j]] = self.unknownvalue.unwrap();
594                }
595            }
596        }
597
598        Ok(transformed)
599    }
600
601    /// Fits the OrdinalEncoder to the input data and transforms it
602    ///
603    /// # Arguments
604    /// * `x` - The input categorical data, shape (n_samples, n_features)
605    ///
606    /// # Returns
607    /// * `Result<Array2<f64>>` - The ordinally encoded data
608    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
609    where
610        S: Data,
611        S::Elem: Float + NumCast,
612    {
613        self.fit(x)?;
614        self.transform(x)
615    }
616
617    /// Returns the categories for each feature
618    ///
619    /// # Returns
620    /// * `Option<&Vec<Vec<u64>>>` - The categories for each feature
621    pub fn categories(&self) -> Option<&Vec<Vec<u64>>> {
622        self.categories_.as_ref()
623    }
624}
625
626/// TargetEncoder for supervised categorical encoding
627///
628/// This encoder transforms categorical features using the target variable values,
629/// encoding each category with a statistic (mean, median, etc.) of the target values
630/// for that category. This is useful for high-cardinality categorical features.
631///
632/// # Examples
633/// ```
634/// use scirs2_core::ndarray::Array2;
635/// use scirs2_transform::encoding::TargetEncoder;
636///
637/// let x = Array2::from_shape_vec((6, 1), vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0]).unwrap();
638/// let y = vec![1.0, 2.0, 3.0, 1.5, 2.5, 3.5];
639///
640/// let mut encoder = TargetEncoder::new("mean", 1.0, 0.0).unwrap();
641/// let encoded = encoder.fit_transform(&x, &y).unwrap();
642/// ```
643#[derive(Debug, Clone)]
644pub struct TargetEncoder {
645    /// Encoding strategy ('mean', 'median', 'count', 'sum')
646    strategy: String,
647    /// Smoothing parameter for regularization (higher = more smoothing toward global mean)
648    smoothing: f64,
649    /// Global statistic to use for smoothing and unknown categories
650    globalstat: f64,
651    /// Mappings from categories to encoded values for each feature
652    encodings_: Option<Vec<HashMap<u64, f64>>>,
653    /// Whether the encoder has been fitted
654    is_fitted: bool,
655    /// Global mean of target values (computed during fit)
656    global_mean_: f64,
657}
658
659impl TargetEncoder {
660    /// Creates a new TargetEncoder
661    ///
662    /// # Arguments
663    /// * `strategy` - Encoding strategy ('mean', 'median', 'count', 'sum')
664    /// * `smoothing` - Smoothing parameter (0.0 = no smoothing, higher = more smoothing)
665    /// * `globalstat` - Global statistic fallback for unknown categories
666    ///
667    /// # Returns
668    /// * A new TargetEncoder instance
669    pub fn new(_strategy: &str, smoothing: f64, globalstat: f64) -> Result<Self> {
670        if !["mean", "median", "count", "sum"].contains(&_strategy) {
671            return Err(TransformError::InvalidInput(
672                "_strategy must be 'mean', 'median', 'count', or 'sum'".to_string(),
673            ));
674        }
675
676        if smoothing < 0.0 {
677            return Err(TransformError::InvalidInput(
678                "smoothing parameter must be non-negative".to_string(),
679            ));
680        }
681
682        Ok(TargetEncoder {
683            strategy: _strategy.to_string(),
684            smoothing,
685            globalstat,
686            encodings_: None,
687            is_fitted: false,
688            global_mean_: 0.0,
689        })
690    }
691
692    /// Creates a TargetEncoder with mean strategy and default smoothing
693    pub fn with_mean(smoothing: f64) -> Self {
694        TargetEncoder {
695            strategy: "mean".to_string(),
696            smoothing,
697            globalstat: 0.0,
698            encodings_: None,
699            is_fitted: false,
700            global_mean_: 0.0,
701        }
702    }
703
704    /// Creates a TargetEncoder with median strategy
705    pub fn with_median(smoothing: f64) -> Self {
706        TargetEncoder {
707            strategy: "median".to_string(),
708            smoothing,
709            globalstat: 0.0,
710            encodings_: None,
711            is_fitted: false,
712            global_mean_: 0.0,
713        }
714    }
715
716    /// Fits the TargetEncoder to the input data and target values
717    ///
718    /// # Arguments
719    /// * `x` - The input categorical data, shape (n_samples, n_features)
720    /// * `y` - The target values, length n_samples
721    ///
722    /// # Returns
723    /// * `Result<()>` - Ok if successful, Err otherwise
724    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<()>
725    where
726        S: Data,
727        S::Elem: Float + NumCast,
728    {
729        let x_u64 = x.mapv(|x| {
730            let val_f64 = NumCast::from(x).unwrap_or(0.0);
731            val_f64 as u64
732        });
733
734        let n_samples = x_u64.shape()[0];
735        let n_features = x_u64.shape()[1];
736
737        if n_samples == 0 || n_features == 0 {
738            return Err(TransformError::InvalidInput("Empty input data".to_string()));
739        }
740
741        if y.len() != n_samples {
742            return Err(TransformError::InvalidInput(
743                "Number of target values must match number of samples".to_string(),
744            ));
745        }
746
747        // Compute global mean for smoothing
748        self.global_mean_ = y.iter().sum::<f64>() / y.len() as f64;
749
750        let mut encodings = Vec::with_capacity(n_features);
751
752        for j in 0..n_features {
753            // Group target values by category for this feature
754            let mut category_targets: HashMap<u64, Vec<f64>> = HashMap::new();
755
756            for i in 0..n_samples {
757                let category = x_u64[[i, j]];
758                category_targets.entry(category).or_default().push(y[i]);
759            }
760
761            // Compute encoding for each category
762            let mut category_encoding = HashMap::new();
763
764            for (category, targets) in category_targets.iter() {
765                let encoded_value = match self.strategy.as_str() {
766                    "mean" => {
767                        let category_mean = targets.iter().sum::<f64>() / targets.len() as f64;
768                        let count = targets.len() as f64;
769
770                        // Apply smoothing: (count * category_mean + smoothing * global_mean) / (count + smoothing)
771                        if self.smoothing > 0.0 {
772                            (count * category_mean + self.smoothing * self.global_mean_)
773                                / (count + self.smoothing)
774                        } else {
775                            category_mean
776                        }
777                    }
778                    "median" => {
779                        let mut sorted_targets = targets.clone();
780                        sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap());
781
782                        let median = if sorted_targets.len() % 2 == 0 {
783                            let mid = sorted_targets.len() / 2;
784                            (sorted_targets[mid - 1] + sorted_targets[mid]) / 2.0
785                        } else {
786                            sorted_targets[sorted_targets.len() / 2]
787                        };
788
789                        // Apply smoothing toward global mean
790                        if self.smoothing > 0.0 {
791                            let count = targets.len() as f64;
792                            (count * median + self.smoothing * self.global_mean_)
793                                / (count + self.smoothing)
794                        } else {
795                            median
796                        }
797                    }
798                    "count" => targets.len() as f64,
799                    "sum" => targets.iter().sum::<f64>(),
800                    _ => unreachable!(),
801                };
802
803                category_encoding.insert(*category, encoded_value);
804            }
805
806            encodings.push(category_encoding);
807        }
808
809        self.encodings_ = Some(encodings);
810        self.is_fitted = true;
811        Ok(())
812    }
813
814    /// Transforms the input data using the fitted TargetEncoder
815    ///
816    /// # Arguments
817    /// * `x` - The input categorical data, shape (n_samples, n_features)
818    ///
819    /// # Returns
820    /// * `Result<Array2<f64>>` - The target-encoded data
821    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
822    where
823        S: Data,
824        S::Elem: Float + NumCast,
825    {
826        if !self.is_fitted {
827            return Err(TransformError::TransformationError(
828                "TargetEncoder has not been fitted".to_string(),
829            ));
830        }
831
832        let x_u64 = x.mapv(|x| {
833            let val_f64 = NumCast::from(x).unwrap_or(0.0);
834            val_f64 as u64
835        });
836
837        let n_samples = x_u64.shape()[0];
838        let n_features = x_u64.shape()[1];
839
840        let encodings = self.encodings_.as_ref().unwrap();
841
842        if n_features != encodings.len() {
843            return Err(TransformError::InvalidInput(format!(
844                "x has {} features, but TargetEncoder was fitted with {} features",
845                n_features,
846                encodings.len()
847            )));
848        }
849
850        let mut transformed = Array2::zeros((n_samples, n_features));
851
852        for i in 0..n_samples {
853            for j in 0..n_features {
854                let category = x_u64[[i, j]];
855
856                if let Some(&encoded_value) = encodings[j].get(&category) {
857                    transformed[[i, j]] = encoded_value;
858                } else {
859                    // Use global statistic for unknown categories
860                    transformed[[i, j]] = if self.globalstat != 0.0 {
861                        self.globalstat
862                    } else {
863                        self.global_mean_
864                    };
865                }
866            }
867        }
868
869        Ok(transformed)
870    }
871
872    /// Fits the TargetEncoder and transforms the data in one step
873    ///
874    /// # Arguments
875    /// * `x` - The input categorical data, shape (n_samples, n_features)
876    /// * `y` - The target values, length n_samples
877    ///
878    /// # Returns
879    /// * `Result<Array2<f64>>` - The target-encoded data
880    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<Array2<f64>>
881    where
882        S: Data,
883        S::Elem: Float + NumCast,
884    {
885        self.fit(x, y)?;
886        self.transform(x)
887    }
888
889    /// Returns the learned encodings for each feature
890    ///
891    /// # Returns
892    /// * `Option<&Vec<HashMap<u64, f64>>>` - The category encodings for each feature
893    pub fn encodings(&self) -> Option<&Vec<HashMap<u64, f64>>> {
894        self.encodings_.as_ref()
895    }
896
897    /// Returns whether the encoder has been fitted
898    pub fn is_fitted(&self) -> bool {
899        self.is_fitted
900    }
901
902    /// Returns the global mean computed during fitting
903    pub fn global_mean(&self) -> f64 {
904        self.global_mean_
905    }
906
907    /// Applies cross-validation target encoding to prevent overfitting
908    ///
909    /// This method performs k-fold cross-validation to compute target encodings,
910    /// which helps prevent overfitting when the same data is used for both
911    /// fitting and transforming.
912    ///
913    /// # Arguments
914    /// * `x` - The input categorical data, shape (n_samples, n_features)
915    /// * `y` - The target values, length n_samples
916    /// * `cv_folds` - Number of cross-validation folds (default: 5)
917    ///
918    /// # Returns
919    /// * `Result<Array2<f64>>` - The cross-validated target-encoded data
920    pub fn fit_transform_cv<S>(
921        &mut self,
922        x: &ArrayBase<S, Ix2>,
923        y: &[f64],
924        cv_folds: usize,
925    ) -> Result<Array2<f64>>
926    where
927        S: Data,
928        S::Elem: Float + NumCast,
929    {
930        let x_u64 = x.mapv(|x| {
931            let val_f64 = NumCast::from(x).unwrap_or(0.0);
932            val_f64 as u64
933        });
934
935        let n_samples = x_u64.shape()[0];
936        let n_features = x_u64.shape()[1];
937
938        if n_samples == 0 || n_features == 0 {
939            return Err(TransformError::InvalidInput("Empty input data".to_string()));
940        }
941
942        if y.len() != n_samples {
943            return Err(TransformError::InvalidInput(
944                "Number of target values must match number of samples".to_string(),
945            ));
946        }
947
948        if cv_folds < 2 {
949            return Err(TransformError::InvalidInput(
950                "cv_folds must be at least 2".to_string(),
951            ));
952        }
953
954        let mut transformed = Array2::zeros((n_samples, n_features));
955
956        // Compute global mean
957        self.global_mean_ = y.iter().sum::<f64>() / y.len() as f64;
958
959        // Create fold indices
960        let fold_size = n_samples / cv_folds;
961        let mut fold_indices = Vec::new();
962        for fold in 0..cv_folds {
963            let start = fold * fold_size;
964            let end = if fold == cv_folds - 1 {
965                n_samples
966            } else {
967                (fold + 1) * fold_size
968            };
969            fold_indices.push((start, end));
970        }
971
972        // For each fold, train on other _folds and predict on this fold
973        for fold in 0..cv_folds {
974            let (val_start, val_end) = fold_indices[fold];
975
976            // Collect training data (all _folds except current)
977            let mut train_indices = Vec::new();
978            for (other_fold, &(start, end)) in fold_indices.iter().enumerate().take(cv_folds) {
979                if other_fold != fold {
980                    train_indices.extend(start..end);
981                }
982            }
983
984            // For each feature, compute encodings on training data
985            for j in 0..n_features {
986                let mut category_targets: HashMap<u64, Vec<f64>> = HashMap::new();
987
988                // Collect target values by category for training data
989                for &train_idx in &train_indices {
990                    let category = x_u64[[train_idx, j]];
991                    category_targets
992                        .entry(category)
993                        .or_default()
994                        .push(y[train_idx]);
995                }
996
997                // Compute encodings for this fold
998                let mut category_encoding = HashMap::new();
999                for (category, targets) in category_targets.iter() {
1000                    let encoded_value = match self.strategy.as_str() {
1001                        "mean" => {
1002                            let category_mean = targets.iter().sum::<f64>() / targets.len() as f64;
1003                            let count = targets.len() as f64;
1004
1005                            if self.smoothing > 0.0 {
1006                                (count * category_mean + self.smoothing * self.global_mean_)
1007                                    / (count + self.smoothing)
1008                            } else {
1009                                category_mean
1010                            }
1011                        }
1012                        "median" => {
1013                            let mut sorted_targets = targets.clone();
1014                            sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap());
1015
1016                            let median = if sorted_targets.len() % 2 == 0 {
1017                                let mid = sorted_targets.len() / 2;
1018                                (sorted_targets[mid - 1] + sorted_targets[mid]) / 2.0
1019                            } else {
1020                                sorted_targets[sorted_targets.len() / 2]
1021                            };
1022
1023                            if self.smoothing > 0.0 {
1024                                let count = targets.len() as f64;
1025                                (count * median + self.smoothing * self.global_mean_)
1026                                    / (count + self.smoothing)
1027                            } else {
1028                                median
1029                            }
1030                        }
1031                        "count" => targets.len() as f64,
1032                        "sum" => targets.iter().sum::<f64>(),
1033                        _ => unreachable!(),
1034                    };
1035
1036                    category_encoding.insert(*category, encoded_value);
1037                }
1038
1039                // Apply encodings to validation fold
1040                for val_idx in val_start..val_end {
1041                    let category = x_u64[[val_idx, j]];
1042
1043                    if let Some(&encoded_value) = category_encoding.get(&category) {
1044                        transformed[[val_idx, j]] = encoded_value;
1045                    } else {
1046                        // Use global mean for unknown categories
1047                        transformed[[val_idx, j]] = self.global_mean_;
1048                    }
1049                }
1050            }
1051        }
1052
1053        // Now fit on the full data for future transforms
1054        self.fit(x, y)?;
1055
1056        Ok(transformed)
1057    }
1058}
1059
1060/// BinaryEncoder for converting categorical features to binary representations
1061///
1062/// This transformer converts categorical features into binary representations,
1063/// where each category is encoded as a unique binary number. This is more
1064/// memory-efficient than one-hot encoding for high-cardinality categorical features.
1065///
1066/// For n unique categories, ceil(log2(n)) binary features are created.
1067#[derive(Debug, Clone)]
1068pub struct BinaryEncoder {
1069    /// Mappings from categories to binary codes for each feature
1070    categories_: Option<Vec<HashMap<u64, Vec<u8>>>>,
1071    /// Number of binary features per original feature
1072    n_binary_features_: Option<Vec<usize>>,
1073    /// Whether to handle unknown categories
1074    handleunknown: String,
1075    /// Whether the encoder has been fitted
1076    is_fitted: bool,
1077}
1078
1079impl BinaryEncoder {
1080    /// Creates a new BinaryEncoder
1081    ///
1082    /// # Arguments
1083    /// * `handleunknown` - How to handle unknown categories ('error' or 'ignore')
1084    ///   - 'error': Raise an error if unknown categories are encountered
1085    ///   - 'ignore': Encode unknown categories as all zeros
1086    ///
1087    /// # Returns
1088    /// * `Result<BinaryEncoder>` - The new encoder instance
1089    pub fn new(handleunknown: &str) -> Result<Self> {
1090        if handleunknown != "error" && handleunknown != "ignore" {
1091            return Err(TransformError::InvalidInput(
1092                "handleunknown must be 'error' or 'ignore'".to_string(),
1093            ));
1094        }
1095
1096        Ok(BinaryEncoder {
1097            categories_: None,
1098            n_binary_features_: None,
1099            handleunknown: handleunknown.to_string(),
1100            is_fitted: false,
1101        })
1102    }
1103
1104    /// Creates a BinaryEncoder with default settings (handleunknown='error')
1105    pub fn with_defaults() -> Self {
1106        Self::new("error").unwrap()
1107    }
1108
1109    /// Fits the BinaryEncoder to the input data
1110    ///
1111    /// # Arguments
1112    /// * `x` - The input categorical data, shape (n_samples, n_features)
1113    ///
1114    /// # Returns
1115    /// * `Result<()>` - Ok if successful, Err otherwise
1116    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
1117    where
1118        S: Data,
1119        S::Elem: Float + NumCast,
1120    {
1121        let x_u64 = x.mapv(|x| {
1122            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1123            val_f64 as u64
1124        });
1125
1126        let n_samples = x_u64.shape()[0];
1127        let n_features = x_u64.shape()[1];
1128
1129        if n_samples == 0 || n_features == 0 {
1130            return Err(TransformError::InvalidInput("Empty input data".to_string()));
1131        }
1132
1133        let mut categories = Vec::with_capacity(n_features);
1134        let mut n_binary_features = Vec::with_capacity(n_features);
1135
1136        for j in 0..n_features {
1137            // Collect unique categories for this feature
1138            let mut unique_categories: Vec<u64> = x_u64.column(j).to_vec();
1139            unique_categories.sort_unstable();
1140            unique_categories.dedup();
1141
1142            if unique_categories.is_empty() {
1143                return Err(TransformError::InvalidInput(
1144                    "Feature has no valid categories".to_string(),
1145                ));
1146            }
1147
1148            // Calculate number of binary features needed
1149            let n_cats = unique_categories.len();
1150            let nbits = if n_cats <= 1 {
1151                1
1152            } else {
1153                (n_cats as f64).log2().ceil() as usize
1154            };
1155
1156            // Create binary mappings
1157            let mut category_map = HashMap::new();
1158            for (idx, &category) in unique_categories.iter().enumerate() {
1159                let binary_code = Self::int_to_binary(idx, nbits);
1160                category_map.insert(category, binary_code);
1161            }
1162
1163            categories.push(category_map);
1164            n_binary_features.push(nbits);
1165        }
1166
1167        self.categories_ = Some(categories);
1168        self.n_binary_features_ = Some(n_binary_features);
1169        self.is_fitted = true;
1170
1171        Ok(())
1172    }
1173
1174    /// Transforms the input data using the fitted encoder
1175    ///
1176    /// # Arguments
1177    /// * `x` - The input categorical data to transform
1178    ///
1179    /// # Returns
1180    /// * `Result<Array2<f64>>` - The binary-encoded data
1181    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1182    where
1183        S: Data,
1184        S::Elem: Float + NumCast,
1185    {
1186        if !self.is_fitted {
1187            return Err(TransformError::InvalidInput(
1188                "Encoder has not been fitted yet".to_string(),
1189            ));
1190        }
1191
1192        let categories = self.categories_.as_ref().unwrap();
1193        let n_binary_features = self.n_binary_features_.as_ref().unwrap();
1194
1195        let x_u64 = x.mapv(|x| {
1196            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1197            val_f64 as u64
1198        });
1199
1200        let n_samples = x_u64.shape()[0];
1201        let n_features = x_u64.shape()[1];
1202
1203        if n_features != categories.len() {
1204            return Err(TransformError::InvalidInput(format!(
1205                "Number of features ({}) does not match fitted features ({})",
1206                n_features,
1207                categories.len()
1208            )));
1209        }
1210
1211        // Calculate total number of output features
1212        let total_binary_features: usize = n_binary_features.iter().sum();
1213        let mut result = Array2::<f64>::zeros((n_samples, total_binary_features));
1214
1215        let mut output_col = 0;
1216        for j in 0..n_features {
1217            let category_map = &categories[j];
1218            let nbits = n_binary_features[j];
1219
1220            for i in 0..n_samples {
1221                let category = x_u64[[i, j]];
1222
1223                if let Some(binary_code) = category_map.get(&category) {
1224                    // Known category: use binary code
1225                    for (bit_idx, &bit_val) in binary_code.iter().enumerate() {
1226                        result[[i, output_col + bit_idx]] = bit_val as f64;
1227                    }
1228                } else {
1229                    // Unknown category
1230                    match self.handleunknown.as_str() {
1231                        "error" => {
1232                            return Err(TransformError::InvalidInput(format!(
1233                                "Unknown category {category} in feature {j}"
1234                            )));
1235                        }
1236                        "ignore" => {
1237                            // Set all bits to zero (already initialized)
1238                        }
1239                        _ => unreachable!(),
1240                    }
1241                }
1242            }
1243
1244            output_col += nbits;
1245        }
1246
1247        Ok(result)
1248    }
1249
1250    /// Fits the encoder and transforms the data in one step
1251    ///
1252    /// # Arguments
1253    /// * `x` - The input categorical data
1254    ///
1255    /// # Returns
1256    /// * `Result<Array2<f64>>` - The binary-encoded data
1257    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1258    where
1259        S: Data,
1260        S::Elem: Float + NumCast,
1261    {
1262        self.fit(x)?;
1263        self.transform(x)
1264    }
1265
1266    /// Returns whether the encoder has been fitted
1267    pub fn is_fitted(&self) -> bool {
1268        self.is_fitted
1269    }
1270
1271    /// Returns the category mappings if fitted
1272    pub fn categories(&self) -> Option<&Vec<HashMap<u64, Vec<u8>>>> {
1273        self.categories_.as_ref()
1274    }
1275
1276    /// Returns the number of binary features per original feature
1277    pub fn n_binary_features(&self) -> Option<&Vec<usize>> {
1278        self.n_binary_features_.as_ref()
1279    }
1280
1281    /// Returns the total number of output features
1282    pub fn n_output_features(&self) -> Option<usize> {
1283        self.n_binary_features_.as_ref().map(|v| v.iter().sum())
1284    }
1285
1286    /// Converts an integer to binary representation
1287    fn int_to_binary(_value: usize, nbits: usize) -> Vec<u8> {
1288        let mut binary = Vec::with_capacity(nbits);
1289        let mut val = _value;
1290
1291        for _ in 0..nbits {
1292            binary.push((val & 1) as u8);
1293            val >>= 1;
1294        }
1295
1296        binary.reverse(); // Most significant bit first
1297        binary
1298    }
1299}
1300
1301/// FrequencyEncoder for converting categorical features to frequency counts
1302///
1303/// This transformer converts categorical features into their frequency of occurrence
1304/// in the training data. High-frequency categories get higher values, which can be
1305/// useful for models that can leverage frequency information.
1306#[derive(Debug, Clone)]
1307pub struct FrequencyEncoder {
1308    /// Frequency mappings for each feature
1309    frequency_maps_: Option<Vec<HashMap<u64, f64>>>,
1310    /// Whether to normalize frequencies to [0, 1]
1311    normalize: bool,
1312    /// How to handle unknown categories
1313    handleunknown: String,
1314    /// Value to use for unknown categories (when handleunknown="use_encoded_value")
1315    unknownvalue: f64,
1316    /// Whether the encoder has been fitted
1317    is_fitted: bool,
1318}
1319
1320impl FrequencyEncoder {
1321    /// Creates a new FrequencyEncoder
1322    ///
1323    /// # Arguments
1324    /// * `normalize` - Whether to normalize frequencies to [0, 1] range
1325    /// * `handleunknown` - How to handle unknown categories ('error', 'ignore', or 'use_encoded_value')
1326    /// * `unknownvalue` - Value to use for unknown categories (when handleunknown="use_encoded_value")
1327    ///
1328    /// # Returns
1329    /// * `Result<FrequencyEncoder>` - The new encoder instance
1330    pub fn new(normalize: bool, handleunknown: &str, unknownvalue: f64) -> Result<Self> {
1331        if !["error", "ignore", "use_encoded_value"].contains(&handleunknown) {
1332            return Err(TransformError::InvalidInput(
1333                "handleunknown must be 'error', 'ignore', or 'use_encoded_value'".to_string(),
1334            ));
1335        }
1336
1337        Ok(FrequencyEncoder {
1338            frequency_maps_: None,
1339            normalize,
1340            handleunknown: handleunknown.to_string(),
1341            unknownvalue,
1342            is_fitted: false,
1343        })
1344    }
1345
1346    /// Creates a FrequencyEncoder with default settings
1347    pub fn with_defaults() -> Self {
1348        Self::new(false, "error", 0.0).unwrap()
1349    }
1350
1351    /// Creates a FrequencyEncoder with normalized frequencies
1352    pub fn with_normalization() -> Self {
1353        Self::new(true, "error", 0.0).unwrap()
1354    }
1355
1356    /// Fits the FrequencyEncoder to the input data
1357    ///
1358    /// # Arguments
1359    /// * `x` - The input categorical data, shape (n_samples, n_features)
1360    ///
1361    /// # Returns
1362    /// * `Result<()>` - Ok if successful, Err otherwise
1363    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
1364    where
1365        S: Data,
1366        S::Elem: Float + NumCast,
1367    {
1368        let x_u64 = x.mapv(|x| {
1369            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1370            val_f64 as u64
1371        });
1372
1373        let n_samples = x_u64.shape()[0];
1374        let n_features = x_u64.shape()[1];
1375
1376        if n_samples == 0 || n_features == 0 {
1377            return Err(TransformError::InvalidInput("Empty input data".to_string()));
1378        }
1379
1380        let mut frequency_maps = Vec::with_capacity(n_features);
1381
1382        for j in 0..n_features {
1383            // Count frequency of each category
1384            let mut category_counts: HashMap<u64, usize> = HashMap::new();
1385            for i in 0..n_samples {
1386                let category = x_u64[[i, j]];
1387                *category_counts.entry(category).or_insert(0) += 1;
1388            }
1389
1390            // Convert counts to frequencies
1391            let mut frequency_map = HashMap::new();
1392            for (category, count) in category_counts {
1393                let frequency = if self.normalize {
1394                    count as f64 / n_samples as f64
1395                } else {
1396                    count as f64
1397                };
1398                frequency_map.insert(category, frequency);
1399            }
1400
1401            frequency_maps.push(frequency_map);
1402        }
1403
1404        self.frequency_maps_ = Some(frequency_maps);
1405        self.is_fitted = true;
1406        Ok(())
1407    }
1408
1409    /// Transforms the input data using the fitted FrequencyEncoder
1410    ///
1411    /// # Arguments
1412    /// * `x` - The input categorical data, shape (n_samples, n_features)
1413    ///
1414    /// # Returns
1415    /// * `Result<Array2<f64>>` - The frequency-encoded data
1416    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1417    where
1418        S: Data,
1419        S::Elem: Float + NumCast,
1420    {
1421        if !self.is_fitted {
1422            return Err(TransformError::TransformationError(
1423                "FrequencyEncoder has not been fitted".to_string(),
1424            ));
1425        }
1426
1427        let frequency_maps = self.frequency_maps_.as_ref().unwrap();
1428
1429        let x_u64 = x.mapv(|x| {
1430            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1431            val_f64 as u64
1432        });
1433
1434        let n_samples = x_u64.shape()[0];
1435        let n_features = x_u64.shape()[1];
1436
1437        if n_features != frequency_maps.len() {
1438            return Err(TransformError::InvalidInput(format!(
1439                "x has {} features, but FrequencyEncoder was fitted with {} features",
1440                n_features,
1441                frequency_maps.len()
1442            )));
1443        }
1444
1445        let mut transformed = Array2::zeros((n_samples, n_features));
1446
1447        for i in 0..n_samples {
1448            for j in 0..n_features {
1449                let category = x_u64[[i, j]];
1450
1451                if let Some(&frequency) = frequency_maps[j].get(&category) {
1452                    transformed[[i, j]] = frequency;
1453                } else {
1454                    // Handle unknown category
1455                    match self.handleunknown.as_str() {
1456                        "error" => {
1457                            return Err(TransformError::InvalidInput(format!(
1458                                "Unknown category {category} in feature {j}"
1459                            )));
1460                        }
1461                        "ignore" => {
1462                            transformed[[i, j]] = 0.0;
1463                        }
1464                        "use_encoded_value" => {
1465                            transformed[[i, j]] = self.unknownvalue;
1466                        }
1467                        _ => unreachable!(),
1468                    }
1469                }
1470            }
1471        }
1472
1473        Ok(transformed)
1474    }
1475
1476    /// Fits the encoder and transforms the data in one step
1477    ///
1478    /// # Arguments
1479    /// * `x` - The input categorical data
1480    ///
1481    /// # Returns
1482    /// * `Result<Array2<f64>>` - The frequency-encoded data
1483    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1484    where
1485        S: Data,
1486        S::Elem: Float + NumCast,
1487    {
1488        self.fit(x)?;
1489        self.transform(x)
1490    }
1491
1492    /// Returns whether the encoder has been fitted
1493    pub fn is_fitted(&self) -> bool {
1494        self.is_fitted
1495    }
1496
1497    /// Returns the learned frequency mappings if fitted
1498    pub fn frequency_maps(&self) -> Option<&Vec<HashMap<u64, f64>>> {
1499        self.frequency_maps_.as_ref()
1500    }
1501}
1502
1503/// Weight of Evidence (WOE) Encoder for converting categorical features using target information
1504///
1505/// WOE encoding transforms categorical features based on the relationship between
1506/// each category and a binary target variable. It's particularly useful for credit
1507/// scoring and other binary classification tasks.
1508///
1509/// WOE = ln(P(target=1|category) / P(target=0|category))
1510#[derive(Debug, Clone)]
1511pub struct WOEEncoder {
1512    /// WOE mappings for each feature
1513    woe_maps_: Option<Vec<HashMap<u64, f64>>>,
1514    /// Information Value (IV) for each feature
1515    information_values_: Option<Vec<f64>>,
1516    /// Regularization parameter to handle categories with zero events/non-events
1517    regularization: f64,
1518    /// How to handle unknown categories
1519    handleunknown: String,
1520    /// Value to use for unknown categories (when handleunknown="use_encoded_value")
1521    unknownvalue: f64,
1522    /// Global WOE value for unknown categories (computed as overall log-odds)
1523    global_woe_: f64,
1524    /// Whether the encoder has been fitted
1525    is_fitted: bool,
1526}
1527
1528impl WOEEncoder {
1529    /// Creates a new WOEEncoder
1530    ///
1531    /// # Arguments
1532    /// * `regularization` - Small value added to prevent division by zero (default: 0.5)
1533    /// * `handleunknown` - How to handle unknown categories ('error', 'global_woe', or 'use_encoded_value')
1534    /// * `unknownvalue` - Value to use for unknown categories (when handleunknown="use_encoded_value")
1535    ///
1536    /// # Returns
1537    /// * `Result<WOEEncoder>` - The new encoder instance
1538    pub fn new(regularization: f64, handleunknown: &str, unknownvalue: f64) -> Result<Self> {
1539        if regularization < 0.0 {
1540            return Err(TransformError::InvalidInput(
1541                "regularization must be non-negative".to_string(),
1542            ));
1543        }
1544
1545        if !["error", "global_woe", "use_encoded_value"].contains(&handleunknown) {
1546            return Err(TransformError::InvalidInput(
1547                "handleunknown must be 'error', 'global_woe', or 'use_encoded_value'".to_string(),
1548            ));
1549        }
1550
1551        Ok(WOEEncoder {
1552            woe_maps_: None,
1553            information_values_: None,
1554            regularization,
1555            handleunknown: handleunknown.to_string(),
1556            unknownvalue,
1557            global_woe_: 0.0,
1558            is_fitted: false,
1559        })
1560    }
1561
1562    /// Creates a WOEEncoder with default settings
1563    pub fn with_defaults() -> Self {
1564        Self::new(0.5, "global_woe", 0.0).unwrap()
1565    }
1566
1567    /// Creates a WOEEncoder with custom regularization
1568    pub fn with_regularization(regularization: f64) -> Result<Self> {
1569        Self::new(regularization, "global_woe", 0.0)
1570    }
1571
1572    /// Fits the WOEEncoder to the input data
1573    ///
1574    /// # Arguments
1575    /// * `x` - The input categorical data, shape (n_samples, n_features)
1576    /// * `y` - The binary target values (0 or 1), length n_samples
1577    ///
1578    /// # Returns
1579    /// * `Result<()>` - Ok if successful, Err otherwise
1580    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<()>
1581    where
1582        S: Data,
1583        S::Elem: Float + NumCast,
1584    {
1585        let x_u64 = x.mapv(|x| {
1586            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1587            val_f64 as u64
1588        });
1589
1590        let n_samples = x_u64.shape()[0];
1591        let n_features = x_u64.shape()[1];
1592
1593        if n_samples == 0 || n_features == 0 {
1594            return Err(TransformError::InvalidInput("Empty input data".to_string()));
1595        }
1596
1597        if y.len() != n_samples {
1598            return Err(TransformError::InvalidInput(
1599                "Number of target values must match number of samples".to_string(),
1600            ));
1601        }
1602
1603        // Validate that target is binary
1604        for &target in y {
1605            if target != 0.0 && target != 1.0 {
1606                return Err(TransformError::InvalidInput(
1607                    "Target values must be binary (0 or 1)".to_string(),
1608                ));
1609            }
1610        }
1611
1612        // Calculate global statistics
1613        let total_events: f64 = y.iter().sum();
1614        let total_non_events = n_samples as f64 - total_events;
1615
1616        if total_events == 0.0 || total_non_events == 0.0 {
1617            return Err(TransformError::InvalidInput(
1618                "Target must contain both 0 and 1 values".to_string(),
1619            ));
1620        }
1621
1622        // Global WOE (overall log-odds)
1623        self.global_woe_ = (total_events / total_non_events).ln();
1624
1625        let mut woe_maps = Vec::with_capacity(n_features);
1626        let mut information_values = Vec::with_capacity(n_features);
1627
1628        for j in 0..n_features {
1629            // Collect target values by category
1630            let mut category_stats: HashMap<u64, (f64, f64)> = HashMap::new(); // (events, non_events)
1631
1632            for i in 0..n_samples {
1633                let category = x_u64[[i, j]];
1634                let target = y[i];
1635
1636                let (events, non_events) = category_stats.entry(category).or_insert((0.0, 0.0));
1637                if target == 1.0 {
1638                    *events += 1.0;
1639                } else {
1640                    *non_events += 1.0;
1641                }
1642            }
1643
1644            // Calculate WOE and IV for each category
1645            let mut woe_map = HashMap::new();
1646            let mut feature_iv = 0.0;
1647
1648            for (category, (events, non_events)) in category_stats.iter() {
1649                // Add regularization to handle zero counts
1650                let reg_events = events + self.regularization;
1651                let reg_non_events = non_events + self.regularization;
1652                let reg_total_events =
1653                    total_events + self.regularization * category_stats.len() as f64;
1654                let reg_total_non_events =
1655                    total_non_events + self.regularization * category_stats.len() as f64;
1656
1657                // Calculate distribution percentages
1658                let event_rate = reg_events / reg_total_events;
1659                let non_event_rate = reg_non_events / reg_total_non_events;
1660
1661                // Calculate WOE
1662                let woe = (event_rate / non_event_rate).ln();
1663                woe_map.insert(*category, woe);
1664
1665                // Calculate Information Value contribution
1666                let iv_contribution = (event_rate - non_event_rate) * woe;
1667                feature_iv += iv_contribution;
1668            }
1669
1670            woe_maps.push(woe_map);
1671            information_values.push(feature_iv);
1672        }
1673
1674        self.woe_maps_ = Some(woe_maps);
1675        self.information_values_ = Some(information_values);
1676        self.is_fitted = true;
1677        Ok(())
1678    }
1679
1680    /// Transforms the input data using the fitted WOEEncoder
1681    ///
1682    /// # Arguments
1683    /// * `x` - The input categorical data, shape (n_samples, n_features)
1684    ///
1685    /// # Returns
1686    /// * `Result<Array2<f64>>` - The WOE-encoded data
1687    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1688    where
1689        S: Data,
1690        S::Elem: Float + NumCast,
1691    {
1692        if !self.is_fitted {
1693            return Err(TransformError::TransformationError(
1694                "WOEEncoder has not been fitted".to_string(),
1695            ));
1696        }
1697
1698        let woe_maps = self.woe_maps_.as_ref().unwrap();
1699
1700        let x_u64 = x.mapv(|x| {
1701            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1702            val_f64 as u64
1703        });
1704
1705        let n_samples = x_u64.shape()[0];
1706        let n_features = x_u64.shape()[1];
1707
1708        if n_features != woe_maps.len() {
1709            return Err(TransformError::InvalidInput(format!(
1710                "x has {} features, but WOEEncoder was fitted with {} features",
1711                n_features,
1712                woe_maps.len()
1713            )));
1714        }
1715
1716        let mut transformed = Array2::zeros((n_samples, n_features));
1717
1718        for i in 0..n_samples {
1719            for j in 0..n_features {
1720                let category = x_u64[[i, j]];
1721
1722                if let Some(&woe_value) = woe_maps[j].get(&category) {
1723                    transformed[[i, j]] = woe_value;
1724                } else {
1725                    // Handle unknown category
1726                    match self.handleunknown.as_str() {
1727                        "error" => {
1728                            return Err(TransformError::InvalidInput(format!(
1729                                "Unknown category {category} in feature {j}"
1730                            )));
1731                        }
1732                        "global_woe" => {
1733                            transformed[[i, j]] = self.global_woe_;
1734                        }
1735                        "use_encoded_value" => {
1736                            transformed[[i, j]] = self.unknownvalue;
1737                        }
1738                        _ => unreachable!(),
1739                    }
1740                }
1741            }
1742        }
1743
1744        Ok(transformed)
1745    }
1746
1747    /// Fits the encoder and transforms the data in one step
1748    ///
1749    /// # Arguments
1750    /// * `x` - The input categorical data
1751    /// * `y` - The binary target values
1752    ///
1753    /// # Returns
1754    /// * `Result<Array2<f64>>` - The WOE-encoded data
1755    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<Array2<f64>>
1756    where
1757        S: Data,
1758        S::Elem: Float + NumCast,
1759    {
1760        self.fit(x, y)?;
1761        self.transform(x)
1762    }
1763
1764    /// Returns whether the encoder has been fitted
1765    pub fn is_fitted(&self) -> bool {
1766        self.is_fitted
1767    }
1768
1769    /// Returns the learned WOE mappings if fitted
1770    pub fn woe_maps(&self) -> Option<&Vec<HashMap<u64, f64>>> {
1771        self.woe_maps_.as_ref()
1772    }
1773
1774    /// Returns the Information Values for each feature if fitted
1775    ///
1776    /// Information Value interpretation:
1777    /// - < 0.02: Not useful for prediction
1778    /// - 0.02 - 0.1: Weak predictive power
1779    /// - 0.1 - 0.3: Medium predictive power  
1780    /// - 0.3 - 0.5: Strong predictive power
1781    /// - > 0.5: Suspicious, too good to be true
1782    pub fn information_values(&self) -> Option<&Vec<f64>> {
1783        self.information_values_.as_ref()
1784    }
1785
1786    /// Returns the global WOE value (overall log-odds)
1787    pub fn global_woe(&self) -> f64 {
1788        self.global_woe_
1789    }
1790
1791    /// Returns features ranked by Information Value (descending order)
1792    ///
1793    /// # Returns
1794    /// * `Option<Vec<(usize, f64)>>` - Vector of (feature_index, information_value) pairs
1795    pub fn feature_importance_ranking(&self) -> Option<Vec<(usize, f64)>> {
1796        self.information_values_.as_ref().map(|ivs| {
1797            let mut ranking: Vec<(usize, f64)> =
1798                ivs.iter().enumerate().map(|(idx, &iv)| (idx, iv)).collect();
1799            ranking.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1800            ranking
1801        })
1802    }
1803}
1804
1805#[cfg(test)]
1806#[path = "encoding_tests.rs"]
1807mod tests;