scirs2_transform/
encoding.rs

1//! Categorical data encoding utilities
2//!
3//! This module provides methods for encoding categorical data into numerical
4//! formats suitable for machine learning algorithms.
5
6use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
7use scirs2_core::numeric::{Float, NumCast};
8use std::collections::HashMap;
9
10use crate::error::{Result, TransformError};
11
12/// Simple sparse matrix representation in COO (Coordinate) format
13#[derive(Debug, Clone)]
14pub struct SparseMatrix {
15    /// Shape of the matrix (rows, cols)
16    pub shape: (usize, usize),
17    /// Row indices of non-zero values
18    pub row_indices: Vec<usize>,
19    /// Column indices of non-zero values
20    pub col_indices: Vec<usize>,
21    /// Non-zero values
22    pub values: Vec<f64>,
23}
24
25impl SparseMatrix {
26    /// Create a new empty sparse matrix
27    pub fn new(shape: (usize, usize)) -> Self {
28        SparseMatrix {
29            shape,
30            row_indices: Vec::new(),
31            col_indices: Vec::new(),
32            values: Vec::new(),
33        }
34    }
35
36    /// Add a non-zero value at (row, col)
37    pub fn push(&mut self, row: usize, col: usize, value: f64) {
38        if row < self.shape.0 && col < self.shape.1 && value != 0.0 {
39            self.row_indices.push(row);
40            self.col_indices.push(col);
41            self.values.push(value);
42        }
43    }
44
45    /// Convert to dense Array2
46    pub fn to_dense(&self) -> Array2<f64> {
47        let mut dense = Array2::zeros(self.shape);
48        for ((&row, &col), &val) in self
49            .row_indices
50            .iter()
51            .zip(self.col_indices.iter())
52            .zip(self.values.iter())
53        {
54            dense[[row, col]] = val;
55        }
56        dense
57    }
58
59    /// Get number of non-zero elements
60    pub fn nnz(&self) -> usize {
61        self.values.len()
62    }
63}
64
65/// Output format for encoded data
66#[derive(Debug, Clone)]
67pub enum EncodedOutput {
68    /// Dense matrix representation
69    Dense(Array2<f64>),
70    /// Sparse matrix representation
71    Sparse(SparseMatrix),
72}
73
74impl EncodedOutput {
75    /// Convert to dense matrix (creates copy if sparse)
76    pub fn to_dense(&self) -> Array2<f64> {
77        match self {
78            EncodedOutput::Dense(arr) => arr.clone(),
79            EncodedOutput::Sparse(sparse) => sparse.to_dense(),
80        }
81    }
82
83    /// Get shape of the output
84    pub fn shape(&self) -> (usize, usize) {
85        match self {
86            EncodedOutput::Dense(arr) => (arr.nrows(), arr.ncols()),
87            EncodedOutput::Sparse(sparse) => sparse.shape,
88        }
89    }
90}
91
92/// OneHotEncoder for converting categorical features to binary features
93///
94/// This transformer converts categorical features into a one-hot encoded representation,
95/// where each category is represented by a binary feature.
96pub struct OneHotEncoder {
97    /// Categories for each feature (learned during fit)
98    categories_: Option<Vec<Vec<u64>>>,
99    /// Whether to drop one category per feature to avoid collinearity
100    drop: Option<String>,
101    /// Whether to handle unknown categories
102    handleunknown: String,
103    /// Whether to return sparse matrix output
104    sparse: bool,
105}
106
107impl OneHotEncoder {
108    /// Creates a new OneHotEncoder
109    ///
110    /// # Arguments
111    /// * `drop` - Strategy for dropping categories ('first', 'if_binary', or None)
112    /// * `handleunknown` - How to handle unknown categories ('error' or 'ignore')
113    /// * `sparse` - Whether to return sparse arrays
114    ///
115    /// # Returns
116    /// * A new OneHotEncoder instance
117    pub fn new(_drop: Option<String>, handleunknown: &str, sparse: bool) -> Result<Self> {
118        if let Some(ref drop_strategy) = _drop {
119            if drop_strategy != "first" && drop_strategy != "if_binary" {
120                return Err(TransformError::InvalidInput(
121                    "_drop must be 'first', 'if_binary', or None".to_string(),
122                ));
123            }
124        }
125
126        if handleunknown != "error" && handleunknown != "ignore" {
127            return Err(TransformError::InvalidInput(
128                "handleunknown must be 'error' or 'ignore'".to_string(),
129            ));
130        }
131
132        Ok(OneHotEncoder {
133            categories_: None,
134            drop: _drop,
135            handleunknown: handleunknown.to_string(),
136            sparse,
137        })
138    }
139
140    /// Creates a OneHotEncoder with default settings
141    pub fn with_defaults() -> Self {
142        Self::new(None, "error", false).unwrap()
143    }
144
145    /// Fits the OneHotEncoder to the input data
146    ///
147    /// # Arguments
148    /// * `x` - The input categorical data, shape (n_samples, n_features)
149    ///
150    /// # Returns
151    /// * `Result<()>` - Ok if successful, Err otherwise
152    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
153    where
154        S: Data,
155        S::Elem: Float + NumCast,
156    {
157        let x_u64 = x.mapv(|x| {
158            let val_f64 = NumCast::from(x).unwrap_or(0.0);
159            val_f64 as u64
160        });
161
162        let n_samples = x_u64.shape()[0];
163        let n_features = x_u64.shape()[1];
164
165        if n_samples == 0 || n_features == 0 {
166            return Err(TransformError::InvalidInput("Empty input data".to_string()));
167        }
168
169        let mut categories = Vec::with_capacity(n_features);
170
171        for j in 0..n_features {
172            // Collect unique values for this feature
173            let mut unique_values: Vec<u64> = x_u64.column(j).to_vec();
174            unique_values.sort_unstable();
175            unique_values.dedup();
176
177            categories.push(unique_values);
178        }
179
180        self.categories_ = Some(categories);
181        Ok(())
182    }
183
184    /// Transforms the input data using the fitted OneHotEncoder
185    ///
186    /// # Arguments
187    /// * `x` - The input categorical data, shape (n_samples, n_features)
188    ///
189    /// # Returns
190    /// * `Result<EncodedOutput>` - The one-hot encoded data (dense or sparse)
191    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<EncodedOutput>
192    where
193        S: Data,
194        S::Elem: Float + NumCast,
195    {
196        let x_u64 = x.mapv(|x| {
197            let val_f64 = NumCast::from(x).unwrap_or(0.0);
198            val_f64 as u64
199        });
200
201        let n_samples = x_u64.shape()[0];
202        let n_features = x_u64.shape()[1];
203
204        if self.categories_.is_none() {
205            return Err(TransformError::TransformationError(
206                "OneHotEncoder has not been fitted".to_string(),
207            ));
208        }
209
210        let categories = self.categories_.as_ref().unwrap();
211
212        if n_features != categories.len() {
213            return Err(TransformError::InvalidInput(format!(
214                "x has {} features, but OneHotEncoder was fitted with {} features",
215                n_features,
216                categories.len()
217            )));
218        }
219
220        // Calculate total number of output features
221        let mut total_features = 0;
222        for (j, feature_categories) in categories.iter().enumerate() {
223            let n_cats = feature_categories.len();
224
225            // Apply drop strategy
226            let n_output_cats = match &self.drop {
227                Some(strategy) if strategy == "first" => n_cats.saturating_sub(1),
228                Some(strategy) if strategy == "if_binary" && n_cats == 2 => 1,
229                _ => n_cats,
230            };
231
232            if n_output_cats == 0 {
233                return Err(TransformError::InvalidInput(format!(
234                    "Feature {j} has only one category after dropping"
235                )));
236            }
237
238            total_features += n_output_cats;
239        }
240
241        // Create mappings from category values to column indices
242        let mut category_mappings = Vec::new();
243        let mut current_col = 0;
244
245        for feature_categories in categories.iter() {
246            let mut mapping = HashMap::new();
247            let n_cats = feature_categories.len();
248
249            // Determine how many categories to keep
250            let (start_idx, n_output_cats) = match &self.drop {
251                Some(strategy) if strategy == "first" => (1, n_cats.saturating_sub(1)),
252                Some(strategy) if strategy == "if_binary" && n_cats == 2 => (0, 1),
253                _ => (0, n_cats),
254            };
255
256            for (cat_idx, &category) in feature_categories.iter().enumerate() {
257                if cat_idx >= start_idx && cat_idx < start_idx + n_output_cats {
258                    mapping.insert(category, current_col + cat_idx - start_idx);
259                }
260            }
261
262            category_mappings.push(mapping);
263            current_col += n_output_cats;
264        }
265
266        // Create output based on sparse setting
267        if self.sparse {
268            // Sparse output
269            let mut sparse_matrix = SparseMatrix::new((n_samples, total_features));
270
271            for i in 0..n_samples {
272                for j in 0..n_features {
273                    let value = x_u64[[i, j]];
274
275                    if let Some(&col_idx) = category_mappings[j].get(&value) {
276                        sparse_matrix.push(i, col_idx, 1.0);
277                    } else {
278                        // Check if this is a dropped category
279                        let feature_categories = &categories[j];
280                        let is_dropped_category = match &self.drop {
281                            Some(strategy) if strategy == "first" => {
282                                !feature_categories.is_empty() && value == feature_categories[0]
283                            }
284                            Some(strategy)
285                                if strategy == "if_binary" && feature_categories.len() == 2 =>
286                            {
287                                feature_categories.len() == 2 && value == feature_categories[1]
288                            }
289                            _ => false,
290                        };
291
292                        if !is_dropped_category && self.handleunknown == "error" {
293                            return Err(TransformError::InvalidInput(format!(
294                                "Found unknown category {value} in feature {j}"
295                            )));
296                        }
297                        // If it's a dropped category or handleunknown == "ignore", we don't add anything (sparse)
298                    }
299                }
300            }
301
302            Ok(EncodedOutput::Sparse(sparse_matrix))
303        } else {
304            // Dense output
305            let mut transformed = Array2::zeros((n_samples, total_features));
306
307            for i in 0..n_samples {
308                for j in 0..n_features {
309                    let value = x_u64[[i, j]];
310
311                    if let Some(&col_idx) = category_mappings[j].get(&value) {
312                        transformed[[i, col_idx]] = 1.0;
313                    } else {
314                        // Check if this is a dropped category
315                        let feature_categories = &categories[j];
316                        let is_dropped_category = match &self.drop {
317                            Some(strategy) if strategy == "first" => {
318                                !feature_categories.is_empty() && value == feature_categories[0]
319                            }
320                            Some(strategy)
321                                if strategy == "if_binary" && feature_categories.len() == 2 =>
322                            {
323                                feature_categories.len() == 2 && value == feature_categories[1]
324                            }
325                            _ => false,
326                        };
327
328                        if !is_dropped_category && self.handleunknown == "error" {
329                            return Err(TransformError::InvalidInput(format!(
330                                "Found unknown category {value} in feature {j}"
331                            )));
332                        }
333                        // If it's a dropped category or handleunknown == "ignore", we just leave it as 0
334                    }
335                }
336            }
337
338            Ok(EncodedOutput::Dense(transformed))
339        }
340    }
341
342    /// Fits the OneHotEncoder to the input data and transforms it
343    ///
344    /// # Arguments
345    /// * `x` - The input categorical data, shape (n_samples, n_features)
346    ///
347    /// # Returns
348    /// * `Result<EncodedOutput>` - The one-hot encoded data (dense or sparse)
349    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<EncodedOutput>
350    where
351        S: Data,
352        S::Elem: Float + NumCast,
353    {
354        self.fit(x)?;
355        self.transform(x)
356    }
357
358    /// Convenience method that always returns dense output for backward compatibility
359    ///
360    /// # Arguments
361    /// * `x` - The input categorical data, shape (n_samples, n_features)
362    ///
363    /// # Returns
364    /// * `Result<Array2<f64>>` - The one-hot encoded data as dense matrix
365    pub fn transform_dense<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
366    where
367        S: Data,
368        S::Elem: Float + NumCast,
369    {
370        Ok(self.transform(x)?.to_dense())
371    }
372
373    /// Convenience method that fits and transforms returning dense output
374    ///
375    /// # Arguments
376    /// * `x` - The input categorical data, shape (n_samples, n_features)
377    ///
378    /// # Returns
379    /// * `Result<Array2<f64>>` - The one-hot encoded data as dense matrix
380    pub fn fit_transform_dense<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
381    where
382        S: Data,
383        S::Elem: Float + NumCast,
384    {
385        Ok(self.fit_transform(x)?.to_dense())
386    }
387
388    /// Returns the categories for each feature
389    ///
390    /// # Returns
391    /// * `Option<&Vec<Vec<u64>>>` - The categories for each feature
392    pub fn categories(&self) -> Option<&Vec<Vec<u64>>> {
393        self.categories_.as_ref()
394    }
395
396    /// Gets the feature names for the transformed output
397    ///
398    /// # Arguments
399    /// * `inputfeatures` - Names of input features
400    ///
401    /// # Returns
402    /// * `Result<Vec<String>>` - Names of output features
403    pub fn get_feature_names(&self, inputfeatures: Option<&[String]>) -> Result<Vec<String>> {
404        if self.categories_.is_none() {
405            return Err(TransformError::TransformationError(
406                "OneHotEncoder has not been fitted".to_string(),
407            ));
408        }
409
410        let categories = self.categories_.as_ref().unwrap();
411        let mut feature_names = Vec::new();
412
413        for (j, feature_categories) in categories.iter().enumerate() {
414            let feature_name = if let Some(names) = inputfeatures {
415                if j < names.len() {
416                    names[j].clone()
417                } else {
418                    format!("x{j}")
419                }
420            } else {
421                format!("x{j}")
422            };
423
424            let n_cats = feature_categories.len();
425
426            // Determine which categories to include based on drop strategy
427            let (start_idx, n_output_cats) = match &self.drop {
428                Some(strategy) if strategy == "first" => (1, n_cats.saturating_sub(1)),
429                Some(strategy) if strategy == "if_binary" && n_cats == 2 => (0, 1),
430                _ => (0, n_cats),
431            };
432
433            for &category in feature_categories
434                .iter()
435                .skip(start_idx)
436                .take(n_output_cats)
437            {
438                feature_names.push(format!("{feature_name}_cat_{category}"));
439            }
440        }
441
442        Ok(feature_names)
443    }
444}
445
446/// OrdinalEncoder for converting categorical features to ordinal integers
447///
448/// This transformer converts categorical features into ordinal integers,
449/// where each category is assigned a unique integer.
450pub struct OrdinalEncoder {
451    /// Categories for each feature (learned during fit)
452    categories_: Option<Vec<Vec<u64>>>,
453    /// How to handle unknown categories
454    handleunknown: String,
455    /// Value to use for unknown categories
456    unknownvalue: Option<f64>,
457}
458
459impl OrdinalEncoder {
460    /// Creates a new OrdinalEncoder
461    ///
462    /// # Arguments
463    /// * `handleunknown` - How to handle unknown categories ('error' or 'use_encoded_value')
464    /// * `unknownvalue` - Value to use for unknown categories (when handleunknown='use_encoded_value')
465    ///
466    /// # Returns
467    /// * A new OrdinalEncoder instance
468    pub fn new(handleunknown: &str, unknownvalue: Option<f64>) -> Result<Self> {
469        if handleunknown != "error" && handleunknown != "use_encoded_value" {
470            return Err(TransformError::InvalidInput(
471                "handleunknown must be 'error' or 'use_encoded_value'".to_string(),
472            ));
473        }
474
475        if handleunknown == "use_encoded_value" && unknownvalue.is_none() {
476            return Err(TransformError::InvalidInput(
477                "unknownvalue must be specified when handleunknown='use_encoded_value'".to_string(),
478            ));
479        }
480
481        Ok(OrdinalEncoder {
482            categories_: None,
483            handleunknown: handleunknown.to_string(),
484            unknownvalue,
485        })
486    }
487
488    /// Creates an OrdinalEncoder with default settings
489    pub fn with_defaults() -> Self {
490        Self::new("error", None).unwrap()
491    }
492
493    /// Fits the OrdinalEncoder to the input data
494    ///
495    /// # Arguments
496    /// * `x` - The input categorical data, shape (n_samples, n_features)
497    ///
498    /// # Returns
499    /// * `Result<()>` - Ok if successful, Err otherwise
500    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
501    where
502        S: Data,
503        S::Elem: Float + NumCast,
504    {
505        let x_u64 = x.mapv(|x| {
506            let val_f64 = NumCast::from(x).unwrap_or(0.0);
507            val_f64 as u64
508        });
509
510        let n_samples = x_u64.shape()[0];
511        let n_features = x_u64.shape()[1];
512
513        if n_samples == 0 || n_features == 0 {
514            return Err(TransformError::InvalidInput("Empty input data".to_string()));
515        }
516
517        let mut categories = Vec::with_capacity(n_features);
518
519        for j in 0..n_features {
520            // Collect unique values for this feature
521            let mut unique_values: Vec<u64> = x_u64.column(j).to_vec();
522            unique_values.sort_unstable();
523            unique_values.dedup();
524
525            categories.push(unique_values);
526        }
527
528        self.categories_ = Some(categories);
529        Ok(())
530    }
531
532    /// Transforms the input data using the fitted OrdinalEncoder
533    ///
534    /// # Arguments
535    /// * `x` - The input categorical data, shape (n_samples, n_features)
536    ///
537    /// # Returns
538    /// * `Result<Array2<f64>>` - The ordinally encoded data
539    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
540    where
541        S: Data,
542        S::Elem: Float + NumCast,
543    {
544        let x_u64 = x.mapv(|x| {
545            let val_f64 = NumCast::from(x).unwrap_or(0.0);
546            val_f64 as u64
547        });
548
549        let n_samples = x_u64.shape()[0];
550        let n_features = x_u64.shape()[1];
551
552        if self.categories_.is_none() {
553            return Err(TransformError::TransformationError(
554                "OrdinalEncoder has not been fitted".to_string(),
555            ));
556        }
557
558        let categories = self.categories_.as_ref().unwrap();
559
560        if n_features != categories.len() {
561            return Err(TransformError::InvalidInput(format!(
562                "x has {} features, but OrdinalEncoder was fitted with {} features",
563                n_features,
564                categories.len()
565            )));
566        }
567
568        let mut transformed = Array2::zeros((n_samples, n_features));
569
570        // Create mappings from category values to ordinal values
571        let mut category_mappings = Vec::new();
572        for feature_categories in categories {
573            let mut mapping = HashMap::new();
574            for (ordinal, &category) in feature_categories.iter().enumerate() {
575                mapping.insert(category, ordinal as f64);
576            }
577            category_mappings.push(mapping);
578        }
579
580        // Fill the transformed array
581        for i in 0..n_samples {
582            for j in 0..n_features {
583                let value = x_u64[[i, j]];
584
585                if let Some(&ordinal_value) = category_mappings[j].get(&value) {
586                    transformed[[i, j]] = ordinal_value;
587                } else if self.handleunknown == "error" {
588                    return Err(TransformError::InvalidInput(format!(
589                        "Found unknown category {value} in feature {j}"
590                    )));
591                } else {
592                    // handleunknown == "use_encoded_value"
593                    transformed[[i, j]] = self.unknownvalue.unwrap();
594                }
595            }
596        }
597
598        Ok(transformed)
599    }
600
601    /// Fits the OrdinalEncoder to the input data and transforms it
602    ///
603    /// # Arguments
604    /// * `x` - The input categorical data, shape (n_samples, n_features)
605    ///
606    /// # Returns
607    /// * `Result<Array2<f64>>` - The ordinally encoded data
608    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
609    where
610        S: Data,
611        S::Elem: Float + NumCast,
612    {
613        self.fit(x)?;
614        self.transform(x)
615    }
616
617    /// Returns the categories for each feature
618    ///
619    /// # Returns
620    /// * `Option<&Vec<Vec<u64>>>` - The categories for each feature
621    pub fn categories(&self) -> Option<&Vec<Vec<u64>>> {
622        self.categories_.as_ref()
623    }
624}
625
626/// TargetEncoder for supervised categorical encoding
627///
628/// This encoder transforms categorical features using the target variable values,
629/// encoding each category with a statistic (mean, median, etc.) of the target values
630/// for that category. This is useful for high-cardinality categorical features.
631///
632/// # Examples
633/// ```
634/// use scirs2_core::ndarray::Array2;
635/// use scirs2_transform::encoding::TargetEncoder;
636///
637/// let x = Array2::from_shape_vec((6, 1), vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0]).unwrap();
638/// let y = vec![1.0, 2.0, 3.0, 1.5, 2.5, 3.5];
639///
640/// let mut encoder = TargetEncoder::new("mean", 1.0, 0.0).unwrap();
641/// let encoded = encoder.fit_transform(&x, &y).unwrap();
642/// ```
643#[derive(Debug, Clone)]
644pub struct TargetEncoder {
645    /// Encoding strategy ('mean', 'median', 'count', 'sum')
646    strategy: String,
647    /// Smoothing parameter for regularization (higher = more smoothing toward global mean)
648    smoothing: f64,
649    /// Global statistic to use for smoothing and unknown categories
650    globalstat: f64,
651    /// Mappings from categories to encoded values for each feature
652    encodings_: Option<Vec<HashMap<u64, f64>>>,
653    /// Whether the encoder has been fitted
654    is_fitted: bool,
655    /// Global mean of target values (computed during fit)
656    global_mean_: f64,
657}
658
659impl TargetEncoder {
660    /// Creates a new TargetEncoder
661    ///
662    /// # Arguments
663    /// * `strategy` - Encoding strategy ('mean', 'median', 'count', 'sum')
664    /// * `smoothing` - Smoothing parameter (0.0 = no smoothing, higher = more smoothing)
665    /// * `globalstat` - Global statistic fallback for unknown categories
666    ///
667    /// # Returns
668    /// * A new TargetEncoder instance
669    pub fn new(_strategy: &str, smoothing: f64, globalstat: f64) -> Result<Self> {
670        if !["mean", "median", "count", "sum"].contains(&_strategy) {
671            return Err(TransformError::InvalidInput(
672                "_strategy must be 'mean', 'median', 'count', or 'sum'".to_string(),
673            ));
674        }
675
676        if smoothing < 0.0 {
677            return Err(TransformError::InvalidInput(
678                "smoothing parameter must be non-negative".to_string(),
679            ));
680        }
681
682        Ok(TargetEncoder {
683            strategy: _strategy.to_string(),
684            smoothing,
685            globalstat,
686            encodings_: None,
687            is_fitted: false,
688            global_mean_: 0.0,
689        })
690    }
691
692    /// Creates a TargetEncoder with mean strategy and default smoothing
693    pub fn with_mean(smoothing: f64) -> Self {
694        TargetEncoder {
695            strategy: "mean".to_string(),
696            smoothing,
697            globalstat: 0.0,
698            encodings_: None,
699            is_fitted: false,
700            global_mean_: 0.0,
701        }
702    }
703
704    /// Creates a TargetEncoder with median strategy
705    pub fn with_median(smoothing: f64) -> Self {
706        TargetEncoder {
707            strategy: "median".to_string(),
708            smoothing,
709            globalstat: 0.0,
710            encodings_: None,
711            is_fitted: false,
712            global_mean_: 0.0,
713        }
714    }
715
716    /// Fits the TargetEncoder to the input data and target values
717    ///
718    /// # Arguments
719    /// * `x` - The input categorical data, shape (n_samples, n_features)
720    /// * `y` - The target values, length n_samples
721    ///
722    /// # Returns
723    /// * `Result<()>` - Ok if successful, Err otherwise
724    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<()>
725    where
726        S: Data,
727        S::Elem: Float + NumCast,
728    {
729        let x_u64 = x.mapv(|x| {
730            let val_f64 = NumCast::from(x).unwrap_or(0.0);
731            val_f64 as u64
732        });
733
734        let n_samples = x_u64.shape()[0];
735        let n_features = x_u64.shape()[1];
736
737        if n_samples == 0 || n_features == 0 {
738            return Err(TransformError::InvalidInput("Empty input data".to_string()));
739        }
740
741        if y.len() != n_samples {
742            return Err(TransformError::InvalidInput(
743                "Number of target values must match number of samples".to_string(),
744            ));
745        }
746
747        // Compute global mean for smoothing
748        self.global_mean_ = y.iter().sum::<f64>() / y.len() as f64;
749
750        let mut encodings = Vec::with_capacity(n_features);
751
752        for j in 0..n_features {
753            // Group target values by category for this feature
754            let mut category_targets: HashMap<u64, Vec<f64>> = HashMap::new();
755
756            for i in 0..n_samples {
757                let category = x_u64[[i, j]];
758                category_targets.entry(category).or_default().push(y[i]);
759            }
760
761            // Compute encoding for each category
762            let mut category_encoding = HashMap::new();
763
764            for (category, targets) in category_targets.iter() {
765                let encoded_value = match self.strategy.as_str() {
766                    "mean" => {
767                        let category_mean = targets.iter().sum::<f64>() / targets.len() as f64;
768                        let count = targets.len() as f64;
769
770                        // Apply smoothing: (count * category_mean + smoothing * global_mean) / (count + smoothing)
771                        if self.smoothing > 0.0 {
772                            (count * category_mean + self.smoothing * self.global_mean_)
773                                / (count + self.smoothing)
774                        } else {
775                            category_mean
776                        }
777                    }
778                    "median" => {
779                        let mut sorted_targets = targets.clone();
780                        sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap());
781
782                        let median = if sorted_targets.len() % 2 == 0 {
783                            let mid = sorted_targets.len() / 2;
784                            (sorted_targets[mid - 1] + sorted_targets[mid]) / 2.0
785                        } else {
786                            sorted_targets[sorted_targets.len() / 2]
787                        };
788
789                        // Apply smoothing toward global mean
790                        if self.smoothing > 0.0 {
791                            let count = targets.len() as f64;
792                            (count * median + self.smoothing * self.global_mean_)
793                                / (count + self.smoothing)
794                        } else {
795                            median
796                        }
797                    }
798                    "count" => targets.len() as f64,
799                    "sum" => targets.iter().sum::<f64>(),
800                    _ => unreachable!(),
801                };
802
803                category_encoding.insert(*category, encoded_value);
804            }
805
806            encodings.push(category_encoding);
807        }
808
809        self.encodings_ = Some(encodings);
810        self.is_fitted = true;
811        Ok(())
812    }
813
814    /// Transforms the input data using the fitted TargetEncoder
815    ///
816    /// # Arguments
817    /// * `x` - The input categorical data, shape (n_samples, n_features)
818    ///
819    /// # Returns
820    /// * `Result<Array2<f64>>` - The target-encoded data
821    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
822    where
823        S: Data,
824        S::Elem: Float + NumCast,
825    {
826        if !self.is_fitted {
827            return Err(TransformError::TransformationError(
828                "TargetEncoder has not been fitted".to_string(),
829            ));
830        }
831
832        let x_u64 = x.mapv(|x| {
833            let val_f64 = NumCast::from(x).unwrap_or(0.0);
834            val_f64 as u64
835        });
836
837        let n_samples = x_u64.shape()[0];
838        let n_features = x_u64.shape()[1];
839
840        let encodings = self.encodings_.as_ref().unwrap();
841
842        if n_features != encodings.len() {
843            return Err(TransformError::InvalidInput(format!(
844                "x has {} features, but TargetEncoder was fitted with {} features",
845                n_features,
846                encodings.len()
847            )));
848        }
849
850        let mut transformed = Array2::zeros((n_samples, n_features));
851
852        for i in 0..n_samples {
853            for j in 0..n_features {
854                let category = x_u64[[i, j]];
855
856                if let Some(&encoded_value) = encodings[j].get(&category) {
857                    transformed[[i, j]] = encoded_value;
858                } else {
859                    // Use global statistic for unknown categories
860                    transformed[[i, j]] = if self.globalstat != 0.0 {
861                        self.globalstat
862                    } else {
863                        self.global_mean_
864                    };
865                }
866            }
867        }
868
869        Ok(transformed)
870    }
871
872    /// Fits the TargetEncoder and transforms the data in one step
873    ///
874    /// # Arguments
875    /// * `x` - The input categorical data, shape (n_samples, n_features)
876    /// * `y` - The target values, length n_samples
877    ///
878    /// # Returns
879    /// * `Result<Array2<f64>>` - The target-encoded data
880    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<Array2<f64>>
881    where
882        S: Data,
883        S::Elem: Float + NumCast,
884    {
885        self.fit(x, y)?;
886        self.transform(x)
887    }
888
889    /// Returns the learned encodings for each feature
890    ///
891    /// # Returns
892    /// * `Option<&Vec<HashMap<u64, f64>>>` - The category encodings for each feature
893    pub fn encodings(&self) -> Option<&Vec<HashMap<u64, f64>>> {
894        self.encodings_.as_ref()
895    }
896
897    /// Returns whether the encoder has been fitted
898    pub fn is_fitted(&self) -> bool {
899        self.is_fitted
900    }
901
902    /// Returns the global mean computed during fitting
903    pub fn global_mean(&self) -> f64 {
904        self.global_mean_
905    }
906
907    /// Applies cross-validation target encoding to prevent overfitting
908    ///
909    /// This method performs k-fold cross-validation to compute target encodings,
910    /// which helps prevent overfitting when the same data is used for both
911    /// fitting and transforming.
912    ///
913    /// # Arguments
914    /// * `x` - The input categorical data, shape (n_samples, n_features)
915    /// * `y` - The target values, length n_samples
916    /// * `cv_folds` - Number of cross-validation folds (default: 5)
917    ///
918    /// # Returns
919    /// * `Result<Array2<f64>>` - The cross-validated target-encoded data
920    pub fn fit_transform_cv<S>(
921        &mut self,
922        x: &ArrayBase<S, Ix2>,
923        y: &[f64],
924        cv_folds: usize,
925    ) -> Result<Array2<f64>>
926    where
927        S: Data,
928        S::Elem: Float + NumCast,
929    {
930        let x_u64 = x.mapv(|x| {
931            let val_f64 = NumCast::from(x).unwrap_or(0.0);
932            val_f64 as u64
933        });
934
935        let n_samples = x_u64.shape()[0];
936        let n_features = x_u64.shape()[1];
937
938        if n_samples == 0 || n_features == 0 {
939            return Err(TransformError::InvalidInput("Empty input data".to_string()));
940        }
941
942        if y.len() != n_samples {
943            return Err(TransformError::InvalidInput(
944                "Number of target values must match number of samples".to_string(),
945            ));
946        }
947
948        if cv_folds < 2 {
949            return Err(TransformError::InvalidInput(
950                "cv_folds must be at least 2".to_string(),
951            ));
952        }
953
954        let mut transformed = Array2::zeros((n_samples, n_features));
955
956        // Compute global mean
957        self.global_mean_ = y.iter().sum::<f64>() / y.len() as f64;
958
959        // Create fold indices
960        let fold_size = n_samples / cv_folds;
961        let mut fold_indices = Vec::new();
962        for fold in 0..cv_folds {
963            let start = fold * fold_size;
964            let end = if fold == cv_folds - 1 {
965                n_samples
966            } else {
967                (fold + 1) * fold_size
968            };
969            fold_indices.push((start, end));
970        }
971
972        // For each fold, train on other _folds and predict on this fold
973        for fold in 0..cv_folds {
974            let (val_start, val_end) = fold_indices[fold];
975
976            // Collect training data (all _folds except current)
977            let mut train_indices = Vec::new();
978            for (other_fold, &(start, end)) in fold_indices.iter().enumerate().take(cv_folds) {
979                if other_fold != fold {
980                    train_indices.extend(start..end);
981                }
982            }
983
984            // For each feature, compute encodings on training data
985            for j in 0..n_features {
986                let mut category_targets: HashMap<u64, Vec<f64>> = HashMap::new();
987
988                // Collect target values by category for training data
989                for &train_idx in &train_indices {
990                    let category = x_u64[[train_idx, j]];
991                    category_targets
992                        .entry(category)
993                        .or_default()
994                        .push(y[train_idx]);
995                }
996
997                // Compute encodings for this fold
998                let mut category_encoding = HashMap::new();
999                for (category, targets) in category_targets.iter() {
1000                    let encoded_value = match self.strategy.as_str() {
1001                        "mean" => {
1002                            let category_mean = targets.iter().sum::<f64>() / targets.len() as f64;
1003                            let count = targets.len() as f64;
1004
1005                            if self.smoothing > 0.0 {
1006                                (count * category_mean + self.smoothing * self.global_mean_)
1007                                    / (count + self.smoothing)
1008                            } else {
1009                                category_mean
1010                            }
1011                        }
1012                        "median" => {
1013                            let mut sorted_targets = targets.clone();
1014                            sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap());
1015
1016                            let median = if sorted_targets.len() % 2 == 0 {
1017                                let mid = sorted_targets.len() / 2;
1018                                (sorted_targets[mid - 1] + sorted_targets[mid]) / 2.0
1019                            } else {
1020                                sorted_targets[sorted_targets.len() / 2]
1021                            };
1022
1023                            if self.smoothing > 0.0 {
1024                                let count = targets.len() as f64;
1025                                (count * median + self.smoothing * self.global_mean_)
1026                                    / (count + self.smoothing)
1027                            } else {
1028                                median
1029                            }
1030                        }
1031                        "count" => targets.len() as f64,
1032                        "sum" => targets.iter().sum::<f64>(),
1033                        _ => unreachable!(),
1034                    };
1035
1036                    category_encoding.insert(*category, encoded_value);
1037                }
1038
1039                // Apply encodings to validation fold
1040                for val_idx in val_start..val_end {
1041                    let category = x_u64[[val_idx, j]];
1042
1043                    if let Some(&encoded_value) = category_encoding.get(&category) {
1044                        transformed[[val_idx, j]] = encoded_value;
1045                    } else {
1046                        // Use global mean for unknown categories
1047                        transformed[[val_idx, j]] = self.global_mean_;
1048                    }
1049                }
1050            }
1051        }
1052
1053        // Now fit on the full data for future transforms
1054        self.fit(x, y)?;
1055
1056        Ok(transformed)
1057    }
1058}
1059
1060/// BinaryEncoder for converting categorical features to binary representations
1061///
1062/// This transformer converts categorical features into binary representations,
1063/// where each category is encoded as a unique binary number. This is more
1064/// memory-efficient than one-hot encoding for high-cardinality categorical features.
1065///
1066/// For n unique categories, ceil(log2(n)) binary features are created.
1067#[derive(Debug, Clone)]
1068pub struct BinaryEncoder {
1069    /// Mappings from categories to binary codes for each feature
1070    categories_: Option<Vec<HashMap<u64, Vec<u8>>>>,
1071    /// Number of binary features per original feature
1072    n_binary_features_: Option<Vec<usize>>,
1073    /// Whether to handle unknown categories
1074    handleunknown: String,
1075    /// Whether the encoder has been fitted
1076    is_fitted: bool,
1077}
1078
1079impl BinaryEncoder {
1080    /// Creates a new BinaryEncoder
1081    ///
1082    /// # Arguments
1083    /// * `handleunknown` - How to handle unknown categories ('error' or 'ignore')
1084    ///   - 'error': Raise an error if unknown categories are encountered
1085    ///   - 'ignore': Encode unknown categories as all zeros
1086    ///
1087    /// # Returns
1088    /// * `Result<BinaryEncoder>` - The new encoder instance
1089    pub fn new(handleunknown: &str) -> Result<Self> {
1090        if handleunknown != "error" && handleunknown != "ignore" {
1091            return Err(TransformError::InvalidInput(
1092                "handleunknown must be 'error' or 'ignore'".to_string(),
1093            ));
1094        }
1095
1096        Ok(BinaryEncoder {
1097            categories_: None,
1098            n_binary_features_: None,
1099            handleunknown: handleunknown.to_string(),
1100            is_fitted: false,
1101        })
1102    }
1103
1104    /// Creates a BinaryEncoder with default settings (handleunknown='error')
1105    pub fn with_defaults() -> Self {
1106        Self::new("error").unwrap()
1107    }
1108
1109    /// Fits the BinaryEncoder to the input data
1110    ///
1111    /// # Arguments
1112    /// * `x` - The input categorical data, shape (n_samples, n_features)
1113    ///
1114    /// # Returns
1115    /// * `Result<()>` - Ok if successful, Err otherwise
1116    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
1117    where
1118        S: Data,
1119        S::Elem: Float + NumCast,
1120    {
1121        let x_u64 = x.mapv(|x| {
1122            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1123            val_f64 as u64
1124        });
1125
1126        let n_samples = x_u64.shape()[0];
1127        let n_features = x_u64.shape()[1];
1128
1129        if n_samples == 0 || n_features == 0 {
1130            return Err(TransformError::InvalidInput("Empty input data".to_string()));
1131        }
1132
1133        let mut categories = Vec::with_capacity(n_features);
1134        let mut n_binary_features = Vec::with_capacity(n_features);
1135
1136        for j in 0..n_features {
1137            // Collect unique categories for this feature
1138            let mut unique_categories: Vec<u64> = x_u64.column(j).to_vec();
1139            unique_categories.sort_unstable();
1140            unique_categories.dedup();
1141
1142            if unique_categories.is_empty() {
1143                return Err(TransformError::InvalidInput(
1144                    "Feature has no valid categories".to_string(),
1145                ));
1146            }
1147
1148            // Calculate number of binary features needed
1149            let n_cats = unique_categories.len();
1150            let nbits = if n_cats <= 1 {
1151                1
1152            } else {
1153                (n_cats as f64).log2().ceil() as usize
1154            };
1155
1156            // Create binary mappings
1157            let mut category_map = HashMap::new();
1158            for (idx, &category) in unique_categories.iter().enumerate() {
1159                let binary_code = Self::int_to_binary(idx, nbits);
1160                category_map.insert(category, binary_code);
1161            }
1162
1163            categories.push(category_map);
1164            n_binary_features.push(nbits);
1165        }
1166
1167        self.categories_ = Some(categories);
1168        self.n_binary_features_ = Some(n_binary_features);
1169        self.is_fitted = true;
1170
1171        Ok(())
1172    }
1173
1174    /// Transforms the input data using the fitted encoder
1175    ///
1176    /// # Arguments
1177    /// * `x` - The input categorical data to transform
1178    ///
1179    /// # Returns
1180    /// * `Result<Array2<f64>>` - The binary-encoded data
1181    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1182    where
1183        S: Data,
1184        S::Elem: Float + NumCast,
1185    {
1186        if !self.is_fitted {
1187            return Err(TransformError::InvalidInput(
1188                "Encoder has not been fitted yet".to_string(),
1189            ));
1190        }
1191
1192        let categories = self.categories_.as_ref().unwrap();
1193        let n_binary_features = self.n_binary_features_.as_ref().unwrap();
1194
1195        let x_u64 = x.mapv(|x| {
1196            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1197            val_f64 as u64
1198        });
1199
1200        let n_samples = x_u64.shape()[0];
1201        let n_features = x_u64.shape()[1];
1202
1203        if n_features != categories.len() {
1204            return Err(TransformError::InvalidInput(format!(
1205                "Number of features ({}) does not match fitted features ({})",
1206                n_features,
1207                categories.len()
1208            )));
1209        }
1210
1211        // Calculate total number of output features
1212        let total_binary_features: usize = n_binary_features.iter().sum();
1213        let mut result = Array2::<f64>::zeros((n_samples, total_binary_features));
1214
1215        let mut output_col = 0;
1216        for j in 0..n_features {
1217            let category_map = &categories[j];
1218            let nbits = n_binary_features[j];
1219
1220            for i in 0..n_samples {
1221                let category = x_u64[[i, j]];
1222
1223                if let Some(binary_code) = category_map.get(&category) {
1224                    // Known category: use binary code
1225                    for (bit_idx, &bit_val) in binary_code.iter().enumerate() {
1226                        result[[i, output_col + bit_idx]] = bit_val as f64;
1227                    }
1228                } else {
1229                    // Unknown category
1230                    match self.handleunknown.as_str() {
1231                        "error" => {
1232                            return Err(TransformError::InvalidInput(format!(
1233                                "Unknown category {category} in feature {j}"
1234                            )));
1235                        }
1236                        "ignore" => {
1237                            // Set all bits to zero (already initialized)
1238                        }
1239                        _ => unreachable!(),
1240                    }
1241                }
1242            }
1243
1244            output_col += nbits;
1245        }
1246
1247        Ok(result)
1248    }
1249
1250    /// Fits the encoder and transforms the data in one step
1251    ///
1252    /// # Arguments
1253    /// * `x` - The input categorical data
1254    ///
1255    /// # Returns
1256    /// * `Result<Array2<f64>>` - The binary-encoded data
1257    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1258    where
1259        S: Data,
1260        S::Elem: Float + NumCast,
1261    {
1262        self.fit(x)?;
1263        self.transform(x)
1264    }
1265
1266    /// Returns whether the encoder has been fitted
1267    pub fn is_fitted(&self) -> bool {
1268        self.is_fitted
1269    }
1270
1271    /// Returns the category mappings if fitted
1272    pub fn categories(&self) -> Option<&Vec<HashMap<u64, Vec<u8>>>> {
1273        self.categories_.as_ref()
1274    }
1275
1276    /// Returns the number of binary features per original feature
1277    pub fn n_binary_features(&self) -> Option<&Vec<usize>> {
1278        self.n_binary_features_.as_ref()
1279    }
1280
1281    /// Returns the total number of output features
1282    pub fn n_output_features(&self) -> Option<usize> {
1283        self.n_binary_features_.as_ref().map(|v| v.iter().sum())
1284    }
1285
1286    /// Converts an integer to binary representation
1287    fn int_to_binary(_value: usize, nbits: usize) -> Vec<u8> {
1288        let mut binary = Vec::with_capacity(nbits);
1289        let mut val = _value;
1290
1291        for _ in 0..nbits {
1292            binary.push((val & 1) as u8);
1293            val >>= 1;
1294        }
1295
1296        binary.reverse(); // Most significant bit first
1297        binary
1298    }
1299}
1300
1301/// FrequencyEncoder for converting categorical features to frequency counts
1302///
1303/// This transformer converts categorical features into their frequency of occurrence
1304/// in the training data. High-frequency categories get higher values, which can be
1305/// useful for models that can leverage frequency information.
1306#[derive(Debug, Clone)]
1307pub struct FrequencyEncoder {
1308    /// Frequency mappings for each feature
1309    frequency_maps_: Option<Vec<HashMap<u64, f64>>>,
1310    /// Whether to normalize frequencies to [0, 1]
1311    normalize: bool,
1312    /// How to handle unknown categories
1313    handleunknown: String,
1314    /// Value to use for unknown categories (when handleunknown="use_encoded_value")
1315    unknownvalue: f64,
1316    /// Whether the encoder has been fitted
1317    is_fitted: bool,
1318}
1319
1320impl FrequencyEncoder {
1321    /// Creates a new FrequencyEncoder
1322    ///
1323    /// # Arguments
1324    /// * `normalize` - Whether to normalize frequencies to [0, 1] range
1325    /// * `handleunknown` - How to handle unknown categories ('error', 'ignore', or 'use_encoded_value')
1326    /// * `unknownvalue` - Value to use for unknown categories (when handleunknown="use_encoded_value")
1327    ///
1328    /// # Returns
1329    /// * `Result<FrequencyEncoder>` - The new encoder instance
1330    pub fn new(normalize: bool, handleunknown: &str, unknownvalue: f64) -> Result<Self> {
1331        if !["error", "ignore", "use_encoded_value"].contains(&handleunknown) {
1332            return Err(TransformError::InvalidInput(
1333                "handleunknown must be 'error', 'ignore', or 'use_encoded_value'".to_string(),
1334            ));
1335        }
1336
1337        Ok(FrequencyEncoder {
1338            frequency_maps_: None,
1339            normalize,
1340            handleunknown: handleunknown.to_string(),
1341            unknownvalue,
1342            is_fitted: false,
1343        })
1344    }
1345
1346    /// Creates a FrequencyEncoder with default settings
1347    pub fn with_defaults() -> Self {
1348        Self::new(false, "error", 0.0).unwrap()
1349    }
1350
1351    /// Creates a FrequencyEncoder with normalized frequencies
1352    pub fn with_normalization() -> Self {
1353        Self::new(true, "error", 0.0).unwrap()
1354    }
1355
1356    /// Fits the FrequencyEncoder to the input data
1357    ///
1358    /// # Arguments
1359    /// * `x` - The input categorical data, shape (n_samples, n_features)
1360    ///
1361    /// # Returns
1362    /// * `Result<()>` - Ok if successful, Err otherwise
1363    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
1364    where
1365        S: Data,
1366        S::Elem: Float + NumCast,
1367    {
1368        let x_u64 = x.mapv(|x| {
1369            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1370            val_f64 as u64
1371        });
1372
1373        let n_samples = x_u64.shape()[0];
1374        let n_features = x_u64.shape()[1];
1375
1376        if n_samples == 0 || n_features == 0 {
1377            return Err(TransformError::InvalidInput("Empty input data".to_string()));
1378        }
1379
1380        let mut frequency_maps = Vec::with_capacity(n_features);
1381
1382        for j in 0..n_features {
1383            // Count frequency of each category
1384            let mut category_counts: HashMap<u64, usize> = HashMap::new();
1385            for i in 0..n_samples {
1386                let category = x_u64[[i, j]];
1387                *category_counts.entry(category).or_insert(0) += 1;
1388            }
1389
1390            // Convert counts to frequencies
1391            let mut frequency_map = HashMap::new();
1392            for (category, count) in category_counts {
1393                let frequency = if self.normalize {
1394                    count as f64 / n_samples as f64
1395                } else {
1396                    count as f64
1397                };
1398                frequency_map.insert(category, frequency);
1399            }
1400
1401            frequency_maps.push(frequency_map);
1402        }
1403
1404        self.frequency_maps_ = Some(frequency_maps);
1405        self.is_fitted = true;
1406        Ok(())
1407    }
1408
1409    /// Transforms the input data using the fitted FrequencyEncoder
1410    ///
1411    /// # Arguments
1412    /// * `x` - The input categorical data, shape (n_samples, n_features)
1413    ///
1414    /// # Returns
1415    /// * `Result<Array2<f64>>` - The frequency-encoded data
1416    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1417    where
1418        S: Data,
1419        S::Elem: Float + NumCast,
1420    {
1421        if !self.is_fitted {
1422            return Err(TransformError::TransformationError(
1423                "FrequencyEncoder has not been fitted".to_string(),
1424            ));
1425        }
1426
1427        let frequency_maps = self.frequency_maps_.as_ref().unwrap();
1428
1429        let x_u64 = x.mapv(|x| {
1430            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1431            val_f64 as u64
1432        });
1433
1434        let n_samples = x_u64.shape()[0];
1435        let n_features = x_u64.shape()[1];
1436
1437        if n_features != frequency_maps.len() {
1438            return Err(TransformError::InvalidInput(format!(
1439                "x has {} features, but FrequencyEncoder was fitted with {} features",
1440                n_features,
1441                frequency_maps.len()
1442            )));
1443        }
1444
1445        let mut transformed = Array2::zeros((n_samples, n_features));
1446
1447        for i in 0..n_samples {
1448            for j in 0..n_features {
1449                let category = x_u64[[i, j]];
1450
1451                if let Some(&frequency) = frequency_maps[j].get(&category) {
1452                    transformed[[i, j]] = frequency;
1453                } else {
1454                    // Handle unknown category
1455                    match self.handleunknown.as_str() {
1456                        "error" => {
1457                            return Err(TransformError::InvalidInput(format!(
1458                                "Unknown category {category} in feature {j}"
1459                            )));
1460                        }
1461                        "ignore" => {
1462                            transformed[[i, j]] = 0.0;
1463                        }
1464                        "use_encoded_value" => {
1465                            transformed[[i, j]] = self.unknownvalue;
1466                        }
1467                        _ => unreachable!(),
1468                    }
1469                }
1470            }
1471        }
1472
1473        Ok(transformed)
1474    }
1475
1476    /// Fits the encoder and transforms the data in one step
1477    ///
1478    /// # Arguments
1479    /// * `x` - The input categorical data
1480    ///
1481    /// # Returns
1482    /// * `Result<Array2<f64>>` - The frequency-encoded data
1483    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1484    where
1485        S: Data,
1486        S::Elem: Float + NumCast,
1487    {
1488        self.fit(x)?;
1489        self.transform(x)
1490    }
1491
1492    /// Returns whether the encoder has been fitted
1493    pub fn is_fitted(&self) -> bool {
1494        self.is_fitted
1495    }
1496
1497    /// Returns the learned frequency mappings if fitted
1498    pub fn frequency_maps(&self) -> Option<&Vec<HashMap<u64, f64>>> {
1499        self.frequency_maps_.as_ref()
1500    }
1501}
1502
1503/// Weight of Evidence (WOE) Encoder for converting categorical features using target information
1504///
1505/// WOE encoding transforms categorical features based on the relationship between
1506/// each category and a binary target variable. It's particularly useful for credit
1507/// scoring and other binary classification tasks.
1508///
1509/// WOE = ln(P(target=1|category) / P(target=0|category))
1510#[derive(Debug, Clone)]
1511pub struct WOEEncoder {
1512    /// WOE mappings for each feature
1513    woe_maps_: Option<Vec<HashMap<u64, f64>>>,
1514    /// Information Value (IV) for each feature
1515    information_values_: Option<Vec<f64>>,
1516    /// Regularization parameter to handle categories with zero events/non-events
1517    regularization: f64,
1518    /// How to handle unknown categories
1519    handleunknown: String,
1520    /// Value to use for unknown categories (when handleunknown="use_encoded_value")
1521    unknownvalue: f64,
1522    /// Global WOE value for unknown categories (computed as overall log-odds)
1523    global_woe_: f64,
1524    /// Whether the encoder has been fitted
1525    is_fitted: bool,
1526}
1527
1528impl WOEEncoder {
1529    /// Creates a new WOEEncoder
1530    ///
1531    /// # Arguments
1532    /// * `regularization` - Small value added to prevent division by zero (default: 0.5)
1533    /// * `handleunknown` - How to handle unknown categories ('error', 'global_woe', or 'use_encoded_value')
1534    /// * `unknownvalue` - Value to use for unknown categories (when handleunknown="use_encoded_value")
1535    ///
1536    /// # Returns
1537    /// * `Result<WOEEncoder>` - The new encoder instance
1538    pub fn new(regularization: f64, handleunknown: &str, unknownvalue: f64) -> Result<Self> {
1539        if regularization < 0.0 {
1540            return Err(TransformError::InvalidInput(
1541                "regularization must be non-negative".to_string(),
1542            ));
1543        }
1544
1545        if !["error", "global_woe", "use_encoded_value"].contains(&handleunknown) {
1546            return Err(TransformError::InvalidInput(
1547                "handleunknown must be 'error', 'global_woe', or 'use_encoded_value'".to_string(),
1548            ));
1549        }
1550
1551        Ok(WOEEncoder {
1552            woe_maps_: None,
1553            information_values_: None,
1554            regularization,
1555            handleunknown: handleunknown.to_string(),
1556            unknownvalue,
1557            global_woe_: 0.0,
1558            is_fitted: false,
1559        })
1560    }
1561
1562    /// Creates a WOEEncoder with default settings
1563    pub fn with_defaults() -> Self {
1564        Self::new(0.5, "global_woe", 0.0).unwrap()
1565    }
1566
1567    /// Creates a WOEEncoder with custom regularization
1568    pub fn with_regularization(regularization: f64) -> Result<Self> {
1569        Self::new(regularization, "global_woe", 0.0)
1570    }
1571
1572    /// Fits the WOEEncoder to the input data
1573    ///
1574    /// # Arguments
1575    /// * `x` - The input categorical data, shape (n_samples, n_features)
1576    /// * `y` - The binary target values (0 or 1), length n_samples
1577    ///
1578    /// # Returns
1579    /// * `Result<()>` - Ok if successful, Err otherwise
1580    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<()>
1581    where
1582        S: Data,
1583        S::Elem: Float + NumCast,
1584    {
1585        let x_u64 = x.mapv(|x| {
1586            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1587            val_f64 as u64
1588        });
1589
1590        let n_samples = x_u64.shape()[0];
1591        let n_features = x_u64.shape()[1];
1592
1593        if n_samples == 0 || n_features == 0 {
1594            return Err(TransformError::InvalidInput("Empty input data".to_string()));
1595        }
1596
1597        if y.len() != n_samples {
1598            return Err(TransformError::InvalidInput(
1599                "Number of target values must match number of samples".to_string(),
1600            ));
1601        }
1602
1603        // Validate that target is binary
1604        for &target in y {
1605            if target != 0.0 && target != 1.0 {
1606                return Err(TransformError::InvalidInput(
1607                    "Target values must be binary (0 or 1)".to_string(),
1608                ));
1609            }
1610        }
1611
1612        // Calculate global statistics
1613        let total_events: f64 = y.iter().sum();
1614        let total_non_events = n_samples as f64 - total_events;
1615
1616        if total_events == 0.0 || total_non_events == 0.0 {
1617            return Err(TransformError::InvalidInput(
1618                "Target must contain both 0 and 1 values".to_string(),
1619            ));
1620        }
1621
1622        // Global WOE (overall log-odds)
1623        self.global_woe_ = (total_events / total_non_events).ln();
1624
1625        let mut woe_maps = Vec::with_capacity(n_features);
1626        let mut information_values = Vec::with_capacity(n_features);
1627
1628        for j in 0..n_features {
1629            // Collect target values by category
1630            let mut category_stats: HashMap<u64, (f64, f64)> = HashMap::new(); // (events, non_events)
1631
1632            for i in 0..n_samples {
1633                let category = x_u64[[i, j]];
1634                let target = y[i];
1635
1636                let (events, non_events) = category_stats.entry(category).or_insert((0.0, 0.0));
1637                if target == 1.0 {
1638                    *events += 1.0;
1639                } else {
1640                    *non_events += 1.0;
1641                }
1642            }
1643
1644            // Calculate WOE and IV for each category
1645            let mut woe_map = HashMap::new();
1646            let mut feature_iv = 0.0;
1647
1648            for (category, (events, non_events)) in category_stats.iter() {
1649                // Add regularization to handle zero counts
1650                let reg_events = events + self.regularization;
1651                let reg_non_events = non_events + self.regularization;
1652                let reg_total_events =
1653                    total_events + self.regularization * category_stats.len() as f64;
1654                let reg_total_non_events =
1655                    total_non_events + self.regularization * category_stats.len() as f64;
1656
1657                // Calculate distribution percentages
1658                let event_rate = reg_events / reg_total_events;
1659                let non_event_rate = reg_non_events / reg_total_non_events;
1660
1661                // Calculate WOE
1662                let woe = (event_rate / non_event_rate).ln();
1663                woe_map.insert(*category, woe);
1664
1665                // Calculate Information Value contribution
1666                let iv_contribution = (event_rate - non_event_rate) * woe;
1667                feature_iv += iv_contribution;
1668            }
1669
1670            woe_maps.push(woe_map);
1671            information_values.push(feature_iv);
1672        }
1673
1674        self.woe_maps_ = Some(woe_maps);
1675        self.information_values_ = Some(information_values);
1676        self.is_fitted = true;
1677        Ok(())
1678    }
1679
1680    /// Transforms the input data using the fitted WOEEncoder
1681    ///
1682    /// # Arguments
1683    /// * `x` - The input categorical data, shape (n_samples, n_features)
1684    ///
1685    /// # Returns
1686    /// * `Result<Array2<f64>>` - The WOE-encoded data
1687    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1688    where
1689        S: Data,
1690        S::Elem: Float + NumCast,
1691    {
1692        if !self.is_fitted {
1693            return Err(TransformError::TransformationError(
1694                "WOEEncoder has not been fitted".to_string(),
1695            ));
1696        }
1697
1698        let woe_maps = self.woe_maps_.as_ref().unwrap();
1699
1700        let x_u64 = x.mapv(|x| {
1701            let val_f64 = NumCast::from(x).unwrap_or(0.0);
1702            val_f64 as u64
1703        });
1704
1705        let n_samples = x_u64.shape()[0];
1706        let n_features = x_u64.shape()[1];
1707
1708        if n_features != woe_maps.len() {
1709            return Err(TransformError::InvalidInput(format!(
1710                "x has {} features, but WOEEncoder was fitted with {} features",
1711                n_features,
1712                woe_maps.len()
1713            )));
1714        }
1715
1716        let mut transformed = Array2::zeros((n_samples, n_features));
1717
1718        for i in 0..n_samples {
1719            for j in 0..n_features {
1720                let category = x_u64[[i, j]];
1721
1722                if let Some(&woe_value) = woe_maps[j].get(&category) {
1723                    transformed[[i, j]] = woe_value;
1724                } else {
1725                    // Handle unknown category
1726                    match self.handleunknown.as_str() {
1727                        "error" => {
1728                            return Err(TransformError::InvalidInput(format!(
1729                                "Unknown category {category} in feature {j}"
1730                            )));
1731                        }
1732                        "global_woe" => {
1733                            transformed[[i, j]] = self.global_woe_;
1734                        }
1735                        "use_encoded_value" => {
1736                            transformed[[i, j]] = self.unknownvalue;
1737                        }
1738                        _ => unreachable!(),
1739                    }
1740                }
1741            }
1742        }
1743
1744        Ok(transformed)
1745    }
1746
1747    /// Fits the encoder and transforms the data in one step
1748    ///
1749    /// # Arguments
1750    /// * `x` - The input categorical data
1751    /// * `y` - The binary target values
1752    ///
1753    /// # Returns
1754    /// * `Result<Array2<f64>>` - The WOE-encoded data
1755    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<Array2<f64>>
1756    where
1757        S: Data,
1758        S::Elem: Float + NumCast,
1759    {
1760        self.fit(x, y)?;
1761        self.transform(x)
1762    }
1763
1764    /// Returns whether the encoder has been fitted
1765    pub fn is_fitted(&self) -> bool {
1766        self.is_fitted
1767    }
1768
1769    /// Returns the learned WOE mappings if fitted
1770    pub fn woe_maps(&self) -> Option<&Vec<HashMap<u64, f64>>> {
1771        self.woe_maps_.as_ref()
1772    }
1773
1774    /// Returns the Information Values for each feature if fitted
1775    ///
1776    /// Information Value interpretation:
1777    /// - < 0.02: Not useful for prediction
1778    /// - 0.02 - 0.1: Weak predictive power
1779    /// - 0.1 - 0.3: Medium predictive power  
1780    /// - 0.3 - 0.5: Strong predictive power
1781    /// - > 0.5: Suspicious, too good to be true
1782    pub fn information_values(&self) -> Option<&Vec<f64>> {
1783        self.information_values_.as_ref()
1784    }
1785
1786    /// Returns the global WOE value (overall log-odds)
1787    pub fn global_woe(&self) -> f64 {
1788        self.global_woe_
1789    }
1790
1791    /// Returns features ranked by Information Value (descending order)
1792    ///
1793    /// # Returns
1794    /// * `Option<Vec<(usize, f64)>>` - Vector of (feature_index, information_value) pairs
1795    pub fn feature_importance_ranking(&self) -> Option<Vec<(usize, f64)>> {
1796        self.information_values_.as_ref().map(|ivs| {
1797            let mut ranking: Vec<(usize, f64)> =
1798                ivs.iter().enumerate().map(|(idx, &iv)| (idx, iv)).collect();
1799            ranking.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1800            ranking
1801        })
1802    }
1803}
1804
1805#[cfg(test)]
1806mod tests {
1807    use super::*;
1808    use approx::assert_abs_diff_eq;
1809    use scirs2_core::ndarray::Array;
1810
1811    #[test]
1812    fn test_one_hot_encoder_basic() {
1813        // Create test data with categorical values
1814        let data = Array::from_shape_vec(
1815            (4, 2),
1816            vec![
1817                0.0, 1.0, // categories: [0, 1, 2] and [1, 2, 3]
1818                1.0, 2.0, 2.0, 3.0, 0.0, 1.0,
1819            ],
1820        )
1821        .unwrap();
1822
1823        let mut encoder = OneHotEncoder::with_defaults();
1824        let encoded = encoder.fit_transform(&data).unwrap();
1825
1826        // Should have 3 + 3 = 6 output features
1827        assert_eq!(encoded.shape(), (4, 6));
1828
1829        // Convert to dense for indexing
1830        let encoded_dense = encoded.to_dense();
1831
1832        // Check first row: category 0 in feature 0, category 1 in feature 1
1833        assert_abs_diff_eq!(encoded_dense[[0, 0]], 1.0, epsilon = 1e-10); // cat 0, feature 0
1834        assert_abs_diff_eq!(encoded_dense[[0, 1]], 0.0, epsilon = 1e-10); // cat 1, feature 0
1835        assert_abs_diff_eq!(encoded_dense[[0, 2]], 0.0, epsilon = 1e-10); // cat 2, feature 0
1836        assert_abs_diff_eq!(encoded_dense[[0, 3]], 1.0, epsilon = 1e-10); // cat 1, feature 1
1837        assert_abs_diff_eq!(encoded_dense[[0, 4]], 0.0, epsilon = 1e-10); // cat 2, feature 1
1838        assert_abs_diff_eq!(encoded_dense[[0, 5]], 0.0, epsilon = 1e-10); // cat 3, feature 1
1839
1840        // Check second row: category 1 in feature 0, category 2 in feature 1
1841        assert_abs_diff_eq!(encoded_dense[[1, 0]], 0.0, epsilon = 1e-10); // cat 0, feature 0
1842        assert_abs_diff_eq!(encoded_dense[[1, 1]], 1.0, epsilon = 1e-10); // cat 1, feature 0
1843        assert_abs_diff_eq!(encoded_dense[[1, 2]], 0.0, epsilon = 1e-10); // cat 2, feature 0
1844        assert_abs_diff_eq!(encoded_dense[[1, 3]], 0.0, epsilon = 1e-10); // cat 1, feature 1
1845        assert_abs_diff_eq!(encoded_dense[[1, 4]], 1.0, epsilon = 1e-10); // cat 2, feature 1
1846        assert_abs_diff_eq!(encoded_dense[[1, 5]], 0.0, epsilon = 1e-10); // cat 3, feature 1
1847    }
1848
1849    #[test]
1850    fn test_one_hot_encoder_drop_first() {
1851        // Create test data with categorical values
1852        let data = Array::from_shape_vec((3, 2), vec![0.0, 1.0, 1.0, 2.0, 2.0, 1.0]).unwrap();
1853
1854        let mut encoder = OneHotEncoder::new(Some("first".to_string()), "error", false).unwrap();
1855        let encoded = encoder.fit_transform(&data).unwrap();
1856
1857        // Should have (3-1) + (2-1) = 3 output features (dropped first category of each)
1858        assert_eq!(encoded.shape(), (3, 3));
1859
1860        // Categories: feature 0: [0, 1, 2] -> keep [1, 2]
1861        //            feature 1: [1, 2] -> keep [2]
1862        let encoded_dense = encoded.to_dense();
1863
1864        // First row: category 0 (dropped), category 1 (dropped)
1865        assert_abs_diff_eq!(encoded_dense[[0, 0]], 0.0, epsilon = 1e-10); // cat 1, feature 0
1866        assert_abs_diff_eq!(encoded_dense[[0, 1]], 0.0, epsilon = 1e-10); // cat 2, feature 0
1867        assert_abs_diff_eq!(encoded_dense[[0, 2]], 0.0, epsilon = 1e-10); // cat 2, feature 1
1868
1869        // Second row: category 1, category 2
1870        assert_abs_diff_eq!(encoded_dense[[1, 0]], 1.0, epsilon = 1e-10); // cat 1, feature 0
1871        assert_abs_diff_eq!(encoded_dense[[1, 1]], 0.0, epsilon = 1e-10); // cat 2, feature 0
1872        assert_abs_diff_eq!(encoded_dense[[1, 2]], 1.0, epsilon = 1e-10); // cat 2, feature 1
1873    }
1874
1875    #[test]
1876    fn test_ordinal_encoder() {
1877        // Create test data with categorical values
1878        let data = Array::from_shape_vec(
1879            (4, 2),
1880            vec![
1881                2.0, 10.0, // categories will be mapped to ordinals
1882                1.0, 20.0, 3.0, 10.0, 2.0, 30.0,
1883            ],
1884        )
1885        .unwrap();
1886
1887        let mut encoder = OrdinalEncoder::with_defaults();
1888        let encoded = encoder.fit_transform(&data).unwrap();
1889
1890        // Should preserve shape
1891        assert_eq!(encoded.shape(), &[4, 2]);
1892
1893        // Categories for feature 0: [1, 2, 3] -> ordinals [0, 1, 2]
1894        // Categories for feature 1: [10, 20, 30] -> ordinals [0, 1, 2]
1895
1896        // Check mappings
1897        assert_abs_diff_eq!(encoded[[0, 0]], 1.0, epsilon = 1e-10); // 2 -> ordinal 1
1898        assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10); // 10 -> ordinal 0
1899        assert_abs_diff_eq!(encoded[[1, 0]], 0.0, epsilon = 1e-10); // 1 -> ordinal 0
1900        assert_abs_diff_eq!(encoded[[1, 1]], 1.0, epsilon = 1e-10); // 20 -> ordinal 1
1901        assert_abs_diff_eq!(encoded[[2, 0]], 2.0, epsilon = 1e-10); // 3 -> ordinal 2
1902        assert_abs_diff_eq!(encoded[[2, 1]], 0.0, epsilon = 1e-10); // 10 -> ordinal 0
1903        assert_abs_diff_eq!(encoded[[3, 0]], 1.0, epsilon = 1e-10); // 2 -> ordinal 1
1904        assert_abs_diff_eq!(encoded[[3, 1]], 2.0, epsilon = 1e-10); // 30 -> ordinal 2
1905    }
1906
1907    #[test]
1908    fn test_unknown_category_handling() {
1909        let train_data = Array::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1910
1911        let test_data = Array::from_shape_vec(
1912            (1, 1),
1913            vec![3.0], // Unknown category
1914        )
1915        .unwrap();
1916
1917        // Test error handling
1918        let mut encoder = OneHotEncoder::with_defaults(); // with_defaults is handleunknown="error"
1919        encoder.fit(&train_data).unwrap();
1920        assert!(encoder.transform(&test_data).is_err());
1921
1922        // Test ignore handling
1923        let mut encoder = OneHotEncoder::new(None, "ignore", false).unwrap();
1924        encoder.fit(&train_data).unwrap();
1925        let encoded = encoder.transform(&test_data).unwrap();
1926
1927        // Should be all zeros (ignored unknown category)
1928        assert_eq!(encoded.shape(), (1, 2));
1929        let encoded_dense = encoded.to_dense();
1930        assert_abs_diff_eq!(encoded_dense[[0, 0]], 0.0, epsilon = 1e-10);
1931        assert_abs_diff_eq!(encoded_dense[[0, 1]], 0.0, epsilon = 1e-10);
1932    }
1933
1934    #[test]
1935    fn test_ordinal_encoder_unknown_value() {
1936        let train_data = Array::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1937
1938        let test_data = Array::from_shape_vec(
1939            (1, 1),
1940            vec![3.0], // Unknown category
1941        )
1942        .unwrap();
1943
1944        let mut encoder = OrdinalEncoder::new("use_encoded_value", Some(-1.0)).unwrap();
1945        encoder.fit(&train_data).unwrap();
1946        let encoded = encoder.transform(&test_data).unwrap();
1947
1948        // Should use the specified unknown value
1949        assert_eq!(encoded.shape(), &[1, 1]);
1950        assert_abs_diff_eq!(encoded[[0, 0]], -1.0, epsilon = 1e-10);
1951    }
1952
1953    #[test]
1954    fn test_get_feature_names() {
1955        let data = Array::from_shape_vec((2, 2), vec![1.0, 10.0, 2.0, 20.0]).unwrap();
1956
1957        let mut encoder = OneHotEncoder::with_defaults();
1958        encoder.fit(&data).unwrap();
1959
1960        let feature_names = encoder.get_feature_names(None).unwrap();
1961        assert_eq!(feature_names.len(), 4); // 2 cats per feature * 2 features
1962
1963        let custom_names = vec!["feat_a".to_string(), "feat_b".to_string()];
1964        let feature_names = encoder.get_feature_names(Some(&custom_names)).unwrap();
1965        assert!(feature_names[0].starts_with("feat_a_cat_"));
1966        assert!(feature_names[2].starts_with("feat_b_cat_"));
1967    }
1968
1969    #[test]
1970    fn test_target_encoder_mean_strategy() {
1971        // Create test data
1972        let x = Array::from_shape_vec((6, 1), vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0]).unwrap();
1973        let y = vec![1.0, 2.0, 3.0, 1.5, 2.5, 3.5];
1974
1975        let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1976        let encoded = encoder.fit_transform(&x, &y).unwrap();
1977
1978        // Should preserve shape
1979        assert_eq!(encoded.shape(), &[6, 1]);
1980
1981        // Check category encodings:
1982        // Category 0: targets [1.0, 1.5] -> mean = 1.25
1983        // Category 1: targets [2.0, 2.5] -> mean = 2.25
1984        // Category 2: targets [3.0, 3.5] -> mean = 3.25
1985
1986        assert_abs_diff_eq!(encoded[[0, 0]], 1.25, epsilon = 1e-10);
1987        assert_abs_diff_eq!(encoded[[1, 0]], 2.25, epsilon = 1e-10);
1988        assert_abs_diff_eq!(encoded[[2, 0]], 3.25, epsilon = 1e-10);
1989        assert_abs_diff_eq!(encoded[[3, 0]], 1.25, epsilon = 1e-10);
1990        assert_abs_diff_eq!(encoded[[4, 0]], 2.25, epsilon = 1e-10);
1991        assert_abs_diff_eq!(encoded[[5, 0]], 3.25, epsilon = 1e-10);
1992
1993        // Check global mean
1994        assert_abs_diff_eq!(encoder.global_mean(), 2.25, epsilon = 1e-10);
1995    }
1996
1997    #[test]
1998    fn test_target_encoder_median_strategy() {
1999        let x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
2000        let y = vec![1.0, 2.0, 3.0, 4.0];
2001
2002        let mut encoder = TargetEncoder::new("median", 0.0, 0.0).unwrap();
2003        let encoded = encoder.fit_transform(&x, &y).unwrap();
2004
2005        // Category 0: targets [1.0, 3.0] -> median = 2.0
2006        // Category 1: targets [2.0, 4.0] -> median = 3.0
2007
2008        assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10);
2009        assert_abs_diff_eq!(encoded[[1, 0]], 3.0, epsilon = 1e-10);
2010        assert_abs_diff_eq!(encoded[[2, 0]], 2.0, epsilon = 1e-10);
2011        assert_abs_diff_eq!(encoded[[3, 0]], 3.0, epsilon = 1e-10);
2012    }
2013
2014    #[test]
2015    fn test_target_encoder_count_strategy() {
2016        let x = Array::from_shape_vec((5, 1), vec![0.0, 1.0, 0.0, 2.0, 1.0]).unwrap();
2017        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2018
2019        let mut encoder = TargetEncoder::new("count", 0.0, 0.0).unwrap();
2020        let encoded = encoder.fit_transform(&x, &y).unwrap();
2021
2022        // Category 0: appears 2 times
2023        // Category 1: appears 2 times
2024        // Category 2: appears 1 time
2025
2026        assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10);
2027        assert_abs_diff_eq!(encoded[[1, 0]], 2.0, epsilon = 1e-10);
2028        assert_abs_diff_eq!(encoded[[2, 0]], 2.0, epsilon = 1e-10);
2029        assert_abs_diff_eq!(encoded[[3, 0]], 1.0, epsilon = 1e-10);
2030        assert_abs_diff_eq!(encoded[[4, 0]], 2.0, epsilon = 1e-10);
2031    }
2032
2033    #[test]
2034    fn test_target_encoder_sum_strategy() {
2035        let x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
2036        let y = vec![1.0, 2.0, 3.0, 4.0];
2037
2038        let mut encoder = TargetEncoder::new("sum", 0.0, 0.0).unwrap();
2039        let encoded = encoder.fit_transform(&x, &y).unwrap();
2040
2041        // Category 0: targets [1.0, 3.0] -> sum = 4.0
2042        // Category 1: targets [2.0, 4.0] -> sum = 6.0
2043
2044        assert_abs_diff_eq!(encoded[[0, 0]], 4.0, epsilon = 1e-10);
2045        assert_abs_diff_eq!(encoded[[1, 0]], 6.0, epsilon = 1e-10);
2046        assert_abs_diff_eq!(encoded[[2, 0]], 4.0, epsilon = 1e-10);
2047        assert_abs_diff_eq!(encoded[[3, 0]], 6.0, epsilon = 1e-10);
2048    }
2049
2050    #[test]
2051    fn test_target_encoder_smoothing() {
2052        let x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2053        let y = vec![1.0, 2.0, 3.0];
2054
2055        let mut encoder = TargetEncoder::new("mean", 1.0, 0.0).unwrap();
2056        let encoded = encoder.fit_transform(&x, &y).unwrap();
2057
2058        // Global mean = (1+2+3)/3 = 2.0
2059        // Category 0: count=1, mean=1.0 -> smoothed = (1*1.0 + 1.0*2.0)/(1+1) = 1.5
2060        // Category 1: count=1, mean=2.0 -> smoothed = (1*2.0 + 1.0*2.0)/(1+1) = 2.0
2061        // Category 2: count=1, mean=3.0 -> smoothed = (1*3.0 + 1.0*2.0)/(1+1) = 2.5
2062
2063        assert_abs_diff_eq!(encoded[[0, 0]], 1.5, epsilon = 1e-10);
2064        assert_abs_diff_eq!(encoded[[1, 0]], 2.0, epsilon = 1e-10);
2065        assert_abs_diff_eq!(encoded[[2, 0]], 2.5, epsilon = 1e-10);
2066    }
2067
2068    #[test]
2069    fn test_target_encoder_unknown_categories() {
2070        let train_x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2071        let train_y = vec![1.0, 2.0, 3.0];
2072
2073        let test_x = Array::from_shape_vec((2, 1), vec![3.0, 4.0]).unwrap(); // Unknown categories
2074
2075        let mut encoder = TargetEncoder::new("mean", 0.0, -1.0).unwrap();
2076        encoder.fit(&train_x, &train_y).unwrap();
2077        let encoded = encoder.transform(&test_x).unwrap();
2078
2079        // Should use globalstat for unknown categories
2080        assert_abs_diff_eq!(encoded[[0, 0]], -1.0, epsilon = 1e-10);
2081        assert_abs_diff_eq!(encoded[[1, 0]], -1.0, epsilon = 1e-10);
2082    }
2083
2084    #[test]
2085    fn test_target_encoder_unknown_categories_global_mean() {
2086        let train_x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2087        let train_y = vec![1.0, 2.0, 3.0];
2088
2089        let test_x = Array::from_shape_vec((1, 1), vec![3.0]).unwrap(); // Unknown category
2090
2091        let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap(); // globalstat = 0.0
2092        encoder.fit(&train_x, &train_y).unwrap();
2093        let encoded = encoder.transform(&test_x).unwrap();
2094
2095        // Should use global_mean for unknown categories when globalstat == 0.0
2096        assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10); // Global mean = 2.0
2097    }
2098
2099    #[test]
2100    fn test_target_encoder_multi_feature() {
2101        let x =
2102            Array::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0]).unwrap();
2103        let y = vec![1.0, 2.0, 3.0, 4.0];
2104
2105        let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2106        let encoded = encoder.fit_transform(&x, &y).unwrap();
2107
2108        assert_eq!(encoded.shape(), &[4, 2]);
2109
2110        // Feature 0: Category 0 -> targets [1.0, 3.0] -> mean = 2.0
2111        //           Category 1 -> targets [2.0, 4.0] -> mean = 3.0
2112        // Feature 1: Category 0 -> targets [1.0, 4.0] -> mean = 2.5
2113        //           Category 1 -> targets [2.0, 3.0] -> mean = 2.5
2114
2115        assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10);
2116        assert_abs_diff_eq!(encoded[[0, 1]], 2.5, epsilon = 1e-10);
2117        assert_abs_diff_eq!(encoded[[1, 0]], 3.0, epsilon = 1e-10);
2118        assert_abs_diff_eq!(encoded[[1, 1]], 2.5, epsilon = 1e-10);
2119    }
2120
2121    #[test]
2122    fn test_target_encoder_cross_validation() {
2123        let x = Array::from_shape_vec(
2124            (10, 1),
2125            vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
2126        )
2127        .unwrap();
2128        let y = vec![1.0, 2.0, 1.5, 2.5, 1.2, 2.2, 1.3, 2.3, 1.1, 2.1];
2129
2130        let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2131        let encoded = encoder.fit_transform_cv(&x, &y, 5).unwrap();
2132
2133        // Should have same shape
2134        assert_eq!(encoded.shape(), &[10, 1]);
2135
2136        // Results should be reasonable (not exact due to CV)
2137        // All category 0 samples should get similar values
2138        // All category 1 samples should get similar values
2139        assert!(encoded[[0, 0]] < encoded[[1, 0]]); // Category 0 < Category 1
2140        assert!(encoded[[2, 0]] < encoded[[3, 0]]);
2141    }
2142
2143    #[test]
2144    fn test_target_encoder_convenience_methods() {
2145        let _x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
2146        let _y = [1.0, 2.0, 3.0, 4.0];
2147
2148        let encoder1 = TargetEncoder::with_mean(1.0);
2149        assert_eq!(encoder1.strategy, "mean");
2150        assert_abs_diff_eq!(encoder1.smoothing, 1.0, epsilon = 1e-10);
2151
2152        let encoder2 = TargetEncoder::with_median(0.5);
2153        assert_eq!(encoder2.strategy, "median");
2154        assert_abs_diff_eq!(encoder2.smoothing, 0.5, epsilon = 1e-10);
2155    }
2156
2157    #[test]
2158    fn test_target_encoder_validation_errors() {
2159        // Invalid strategy
2160        assert!(TargetEncoder::new("invalid", 0.0, 0.0).is_err());
2161
2162        // Negative smoothing
2163        assert!(TargetEncoder::new("mean", -1.0, 0.0).is_err());
2164
2165        // Mismatched target length
2166        let x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2167        let y = vec![1.0, 2.0]; // Wrong length
2168
2169        let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2170        assert!(encoder.fit(&x, &y).is_err());
2171
2172        // Transform before fit
2173        let encoder2 = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2174        assert!(encoder2.transform(&x).is_err());
2175
2176        // Wrong number of features
2177        let train_x = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2178        let test_x = Array::from_shape_vec((2, 2), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
2179        let train_y = vec![1.0, 2.0];
2180
2181        let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2182        encoder.fit(&train_x, &train_y).unwrap();
2183        assert!(encoder.transform(&test_x).is_err());
2184
2185        // Invalid CV folds
2186        let x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
2187        let y = vec![1.0, 2.0, 3.0, 4.0];
2188        let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2189        assert!(encoder.fit_transform_cv(&x, &y, 1).is_err()); // cv_folds < 2
2190    }
2191
2192    #[test]
2193    fn test_target_encoder_accessors() {
2194        let x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2195        let y = vec![1.0, 2.0, 3.0];
2196
2197        let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2198
2199        assert!(!encoder.is_fitted());
2200        assert!(encoder.encodings().is_none());
2201
2202        encoder.fit(&x, &y).unwrap();
2203
2204        assert!(encoder.is_fitted());
2205        assert!(encoder.encodings().is_some());
2206        assert_abs_diff_eq!(encoder.global_mean(), 2.0, epsilon = 1e-10);
2207
2208        let encodings = encoder.encodings().unwrap();
2209        assert_eq!(encodings.len(), 1); // 1 feature
2210        assert_eq!(encodings[0].len(), 3); // 3 categories
2211    }
2212
2213    #[test]
2214    fn test_target_encoder_empty_data() {
2215        let empty_x = Array2::<f64>::zeros((0, 1));
2216        let empty_y = vec![];
2217
2218        let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2219        assert!(encoder.fit(&empty_x, &empty_y).is_err());
2220    }
2221
2222    // ===== BinaryEncoder Tests =====
2223
2224    #[test]
2225    fn test_binary_encoder_basic() {
2226        // Test basic binary encoding with 4 categories (needs 2 bits)
2227        let data = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 2.0, 3.0]).unwrap();
2228
2229        let mut encoder = BinaryEncoder::with_defaults();
2230        let encoded = encoder.fit_transform(&data).unwrap();
2231
2232        // Should have 2 binary features (ceil(log2(4)) = 2)
2233        assert_eq!(encoded.shape(), &[4, 2]);
2234
2235        // Check binary codes: 0=00, 1=01, 2=10, 3=11
2236        assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10); // 0 -> 00
2237        assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10);
2238        assert_abs_diff_eq!(encoded[[1, 0]], 0.0, epsilon = 1e-10); // 1 -> 01
2239        assert_abs_diff_eq!(encoded[[1, 1]], 1.0, epsilon = 1e-10);
2240        assert_abs_diff_eq!(encoded[[2, 0]], 1.0, epsilon = 1e-10); // 2 -> 10
2241        assert_abs_diff_eq!(encoded[[2, 1]], 0.0, epsilon = 1e-10);
2242        assert_abs_diff_eq!(encoded[[3, 0]], 1.0, epsilon = 1e-10); // 3 -> 11
2243        assert_abs_diff_eq!(encoded[[3, 1]], 1.0, epsilon = 1e-10);
2244    }
2245
2246    #[test]
2247    fn test_binary_encoder_power_of_two() {
2248        // Test with exactly 8 categories (power of 2)
2249        let data =
2250            Array::from_shape_vec((8, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).unwrap();
2251
2252        let mut encoder = BinaryEncoder::with_defaults();
2253        let encoded = encoder.fit_transform(&data).unwrap();
2254
2255        // Should have 3 binary features (log2(8) = 3)
2256        assert_eq!(encoded.shape(), &[8, 3]);
2257
2258        // Check some specific encodings
2259        assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10); // 0 -> 000
2260        assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10);
2261        assert_abs_diff_eq!(encoded[[0, 2]], 0.0, epsilon = 1e-10);
2262
2263        assert_abs_diff_eq!(encoded[[7, 0]], 1.0, epsilon = 1e-10); // 7 -> 111
2264        assert_abs_diff_eq!(encoded[[7, 1]], 1.0, epsilon = 1e-10);
2265        assert_abs_diff_eq!(encoded[[7, 2]], 1.0, epsilon = 1e-10);
2266    }
2267
2268    #[test]
2269    fn test_binary_encoder_non_power_of_two() {
2270        // Test with 5 categories (not power of 2, needs 3 bits)
2271        let data = Array::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0]).unwrap();
2272
2273        let mut encoder = BinaryEncoder::with_defaults();
2274        let encoded = encoder.fit_transform(&data).unwrap();
2275
2276        // Should have 3 binary features (ceil(log2(5)) = 3)
2277        assert_eq!(encoded.shape(), &[5, 3]);
2278        assert_eq!(encoder.n_output_features().unwrap(), 3);
2279    }
2280
2281    #[test]
2282    fn test_binary_encoder_single_category() {
2283        // Test edge case with only 1 category
2284        let data = Array::from_shape_vec((3, 1), vec![5.0, 5.0, 5.0]).unwrap();
2285
2286        let mut encoder = BinaryEncoder::with_defaults();
2287        let encoded = encoder.fit_transform(&data).unwrap();
2288
2289        // Should have 1 binary feature for single category
2290        assert_eq!(encoded.shape(), &[3, 1]);
2291        assert_eq!(encoder.n_output_features().unwrap(), 1);
2292
2293        // All values should be encoded as 0
2294        for i in 0..3 {
2295            assert_abs_diff_eq!(encoded[[i, 0]], 0.0, epsilon = 1e-10);
2296        }
2297    }
2298
2299    #[test]
2300    fn test_binary_encoder_multi_feature() {
2301        // Test with multiple features
2302        let data = Array::from_shape_vec(
2303            (4, 2),
2304            vec![
2305                0.0, 10.0, // Feature 0: [0,1,2] (2 bits), Feature 1: [10,11] (1 bit)
2306                1.0, 11.0, 2.0, 10.0, 0.0, 11.0,
2307            ],
2308        )
2309        .unwrap();
2310
2311        let mut encoder = BinaryEncoder::with_defaults();
2312        let encoded = encoder.fit_transform(&data).unwrap();
2313
2314        // Feature 0: 3 categories need 2 bits, Feature 1: 2 categories need 1 bit
2315        // Total: 2 + 1 = 3 features
2316        assert_eq!(encoded.shape(), &[4, 3]);
2317        assert_eq!(encoder.n_output_features().unwrap(), 3);
2318
2319        let n_binary_features = encoder.n_binary_features().unwrap();
2320        assert_eq!(n_binary_features, &[2, 1]);
2321    }
2322
2323    #[test]
2324    fn test_binary_encoder_separate_fit_transform() {
2325        let train_data = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2326        let test_data = Array::from_shape_vec((2, 1), vec![1.0, 0.0]).unwrap();
2327
2328        let mut encoder = BinaryEncoder::with_defaults();
2329
2330        // Fit on training data
2331        encoder.fit(&train_data).unwrap();
2332        assert!(encoder.is_fitted());
2333
2334        // Transform test data
2335        let encoded = encoder.transform(&test_data).unwrap();
2336        assert_eq!(encoded.shape(), &[2, 2]); // 3 categories need 2 bits
2337
2338        // Check that mappings are consistent
2339        let train_encoded = encoder.transform(&train_data).unwrap();
2340        assert_abs_diff_eq!(encoded[[0, 0]], train_encoded[[1, 0]], epsilon = 1e-10); // Same category 1
2341        assert_abs_diff_eq!(encoded[[0, 1]], train_encoded[[1, 1]], epsilon = 1e-10);
2342    }
2343
2344    #[test]
2345    fn test_binary_encoder_unknown_categories_error() {
2346        let train_data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2347        let test_data = Array::from_shape_vec((1, 1), vec![2.0]).unwrap(); // Unknown category
2348
2349        let mut encoder = BinaryEncoder::new("error").unwrap();
2350        encoder.fit(&train_data).unwrap();
2351
2352        // Should error on unknown category
2353        assert!(encoder.transform(&test_data).is_err());
2354    }
2355
2356    #[test]
2357    fn test_binary_encoder_unknown_categories_ignore() {
2358        let train_data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2359        let test_data = Array::from_shape_vec((1, 1), vec![2.0]).unwrap(); // Unknown category
2360
2361        let mut encoder = BinaryEncoder::new("ignore").unwrap();
2362        encoder.fit(&train_data).unwrap();
2363        let encoded = encoder.transform(&test_data).unwrap();
2364
2365        // Unknown category should be encoded as all zeros
2366        assert_eq!(encoded.shape(), &[1, 1]); // 2 categories need 1 bit
2367        assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10);
2368    }
2369
2370    #[test]
2371    fn test_binary_encoder_categories_accessor() {
2372        let data = Array::from_shape_vec((3, 1), vec![10.0, 20.0, 30.0]).unwrap();
2373
2374        let mut encoder = BinaryEncoder::with_defaults();
2375
2376        // Before fitting
2377        assert!(!encoder.is_fitted());
2378        assert!(encoder.categories().is_none());
2379        assert!(encoder.n_binary_features().is_none());
2380        assert!(encoder.n_output_features().is_none());
2381
2382        encoder.fit(&data).unwrap();
2383
2384        // After fitting
2385        assert!(encoder.is_fitted());
2386        assert!(encoder.categories().is_some());
2387        assert!(encoder.n_binary_features().is_some());
2388        assert!(encoder.n_output_features().is_some());
2389
2390        let categories = encoder.categories().unwrap();
2391        assert_eq!(categories.len(), 1); // 1 feature
2392        assert_eq!(categories[0].len(), 3); // 3 categories
2393
2394        // Check that categories are mapped correctly
2395        let category_map = &categories[0];
2396        assert!(category_map.contains_key(&10));
2397        assert!(category_map.contains_key(&20));
2398        assert!(category_map.contains_key(&30));
2399    }
2400
2401    #[test]
2402    fn test_binary_encoder_int_to_binary() {
2403        // Test binary conversion utility function
2404        assert_eq!(BinaryEncoder::int_to_binary(0, 3), vec![0, 0, 0]);
2405        assert_eq!(BinaryEncoder::int_to_binary(1, 3), vec![0, 0, 1]);
2406        assert_eq!(BinaryEncoder::int_to_binary(2, 3), vec![0, 1, 0]);
2407        assert_eq!(BinaryEncoder::int_to_binary(3, 3), vec![0, 1, 1]);
2408        assert_eq!(BinaryEncoder::int_to_binary(7, 3), vec![1, 1, 1]);
2409
2410        // Test with different bit lengths
2411        assert_eq!(BinaryEncoder::int_to_binary(5, 4), vec![0, 1, 0, 1]);
2412        assert_eq!(BinaryEncoder::int_to_binary(1, 1), vec![1]);
2413    }
2414
2415    #[test]
2416    fn test_binary_encoder_validation_errors() {
2417        // Invalid handleunknown parameter
2418        assert!(BinaryEncoder::new("invalid").is_err());
2419
2420        // Empty data
2421        let empty_data = Array2::<f64>::zeros((0, 1));
2422        let mut encoder = BinaryEncoder::with_defaults();
2423        assert!(encoder.fit(&empty_data).is_err());
2424
2425        // Transform before fit
2426        let data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2427        let encoder = BinaryEncoder::with_defaults();
2428        assert!(encoder.transform(&data).is_err());
2429
2430        // Wrong number of features in transform
2431        let train_data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2432        let test_data = Array::from_shape_vec((2, 2), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
2433
2434        let mut encoder = BinaryEncoder::with_defaults();
2435        encoder.fit(&train_data).unwrap();
2436        assert!(encoder.transform(&test_data).is_err());
2437    }
2438
2439    #[test]
2440    fn test_binary_encoder_consistency() {
2441        // Test that encoding is consistent across multiple calls
2442        let data = Array::from_shape_vec((4, 1), vec![3.0, 1.0, 4.0, 1.0]).unwrap();
2443
2444        let mut encoder = BinaryEncoder::with_defaults();
2445        let encoded1 = encoder.fit_transform(&data).unwrap();
2446        let encoded2 = encoder.transform(&data).unwrap();
2447
2448        // Both should be identical
2449        for i in 0..encoded1.shape()[0] {
2450            for j in 0..encoded1.shape()[1] {
2451                assert_abs_diff_eq!(encoded1[[i, j]], encoded2[[i, j]], epsilon = 1e-10);
2452            }
2453        }
2454
2455        // Same categories should have same encoding
2456        assert_abs_diff_eq!(encoded1[[1, 0]], encoded1[[3, 0]], epsilon = 1e-10); // Both category 1
2457        assert_abs_diff_eq!(encoded1[[1, 1]], encoded1[[3, 1]], epsilon = 1e-10);
2458    }
2459
2460    #[test]
2461    fn test_binary_encoder_memory_efficiency() {
2462        // Test that binary encoding is more memory efficient than one-hot
2463        // For 10 categories: one-hot needs 10 features, binary needs 4 features
2464        let data = Array::from_shape_vec(
2465            (10, 1),
2466            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
2467        )
2468        .unwrap();
2469
2470        let mut binary_encoder = BinaryEncoder::with_defaults();
2471        let binary_encoded = binary_encoder.fit_transform(&data).unwrap();
2472
2473        let mut onehot_encoder = OneHotEncoder::with_defaults();
2474        let onehot_encoded = onehot_encoder.fit_transform(&data).unwrap();
2475
2476        // Binary should use fewer features
2477        assert_eq!(binary_encoded.shape()[1], 4); // ceil(log2(10)) = 4
2478        assert_eq!(onehot_encoded.shape().1, 10); // 10 categories = 10 features
2479        assert!(binary_encoded.shape()[1] < onehot_encoded.shape().1);
2480    }
2481
2482    #[test]
2483    fn test_sparse_matrix_basic() {
2484        let mut sparse = SparseMatrix::new((3, 4));
2485        sparse.push(0, 1, 1.0);
2486        sparse.push(1, 2, 1.0);
2487        sparse.push(2, 0, 1.0);
2488
2489        assert_eq!(sparse.shape, (3, 4));
2490        assert_eq!(sparse.nnz(), 3);
2491
2492        let dense = sparse.to_dense();
2493        assert_eq!(dense.shape(), &[3, 4]);
2494        assert_eq!(dense[[0, 1]], 1.0);
2495        assert_eq!(dense[[1, 2]], 1.0);
2496        assert_eq!(dense[[2, 0]], 1.0);
2497        assert_eq!(dense[[0, 0]], 0.0); // Verify zeros
2498    }
2499
2500    #[test]
2501    fn test_onehot_sparse_output() {
2502        let data =
2503            Array::from_shape_vec((4, 2), vec![0.0, 1.0, 1.0, 2.0, 2.0, 0.0, 0.0, 1.0]).unwrap();
2504
2505        // Test sparse output
2506        let mut encoder_sparse = OneHotEncoder::new(None, "error", true).unwrap();
2507        let result_sparse = encoder_sparse.fit_transform(&data).unwrap();
2508
2509        match &result_sparse {
2510            EncodedOutput::Sparse(sparse) => {
2511                assert_eq!(sparse.shape, (4, 6)); // 3 categories + 3 categories = 6 features
2512                assert_eq!(sparse.nnz(), 8); // 4 samples * 2 features = 8 non-zeros
2513
2514                // Convert to dense for comparison
2515                let dense = sparse.to_dense();
2516
2517                // First sample [0, 1] should have [1,0,0,0,1,0] (category 0 in col0, category 1 in col1)
2518                assert_eq!(dense[[0, 0]], 1.0); // category 0 in feature 0
2519                assert_eq!(dense[[0, 4]], 1.0); // category 1 in feature 1
2520                assert_eq!(dense[[0, 1]], 0.0); // not category 1 in feature 0
2521            }
2522            EncodedOutput::Dense(_) => assert!(false, "Expected sparse output, got dense"),
2523        }
2524
2525        // Test dense output for comparison
2526        let mut encoder_dense = OneHotEncoder::new(None, "error", false).unwrap();
2527        let result_dense = encoder_dense.fit_transform(&data).unwrap();
2528
2529        match result_dense {
2530            EncodedOutput::Dense(dense) => {
2531                assert_eq!(dense.shape(), &[4, 6]);
2532                // Verify dense and sparse produce same results
2533                let sparse_as_dense = result_sparse.to_dense();
2534                for i in 0..4 {
2535                    for j in 0..6 {
2536                        assert_abs_diff_eq!(
2537                            dense[[i, j]],
2538                            sparse_as_dense[[i, j]],
2539                            epsilon = 1e-10
2540                        );
2541                    }
2542                }
2543            }
2544            EncodedOutput::Sparse(_) => assert!(false, "Expected dense output, got sparse"),
2545        }
2546    }
2547
2548    #[test]
2549    fn test_onehot_sparse_with_drop() {
2550        let data = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2551
2552        let mut encoder = OneHotEncoder::new(Some("first".to_string()), "error", true).unwrap();
2553        let result = encoder.fit_transform(&data).unwrap();
2554
2555        match result {
2556            EncodedOutput::Sparse(sparse) => {
2557                assert_eq!(sparse.shape, (3, 2)); // 3 categories - 1 dropped = 2 features
2558                assert_eq!(sparse.nnz(), 2); // Only categories 1 and 2 are encoded
2559
2560                let dense = sparse.to_dense();
2561                assert_eq!(dense[[0, 0]], 0.0); // Category 0 dropped, all zeros
2562                assert_eq!(dense[[0, 1]], 0.0);
2563                assert_eq!(dense[[1, 0]], 1.0); // Category 1 maps to first output
2564                assert_eq!(dense[[2, 1]], 1.0); // Category 2 maps to second output
2565            }
2566            EncodedOutput::Dense(_) => assert!(false, "Expected sparse output, got dense"),
2567        }
2568    }
2569
2570    #[test]
2571    fn test_onehot_sparse_backward_compatibility() {
2572        let data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2573
2574        let mut encoder = OneHotEncoder::new(None, "error", true).unwrap();
2575        encoder.fit(&data).unwrap();
2576
2577        // Test that the convenience methods work
2578        let dense_result = encoder.transform_dense(&data).unwrap();
2579        assert_eq!(dense_result.shape(), &[2, 2]);
2580        assert_eq!(dense_result[[0, 0]], 1.0);
2581        assert_eq!(dense_result[[1, 1]], 1.0);
2582
2583        let mut encoder2 = OneHotEncoder::new(None, "error", true).unwrap();
2584        let dense_result2 = encoder2.fit_transform_dense(&data).unwrap();
2585        assert_eq!(dense_result2.shape(), &[2, 2]);
2586
2587        // Results should be identical
2588        for i in 0..2 {
2589            for j in 0..2 {
2590                assert_abs_diff_eq!(dense_result[[i, j]], dense_result2[[i, j]], epsilon = 1e-10);
2591            }
2592        }
2593    }
2594
2595    #[test]
2596    fn test_encoded_output_methods() {
2597        let dense_array =
2598            Array::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap();
2599        let dense_output = EncodedOutput::Dense(dense_array);
2600
2601        let mut sparse_matrix = SparseMatrix::new((2, 3));
2602        sparse_matrix.push(0, 0, 1.0);
2603        sparse_matrix.push(1, 1, 1.0);
2604        let sparse_output = EncodedOutput::Sparse(sparse_matrix);
2605
2606        // Test shape method
2607        assert_eq!(dense_output.shape(), (2, 3));
2608        assert_eq!(sparse_output.shape(), (2, 3));
2609
2610        // Test to_dense method
2611        let dense_from_dense = dense_output.to_dense();
2612        let dense_from_sparse = sparse_output.to_dense();
2613
2614        assert_eq!(dense_from_dense.shape(), &[2, 3]);
2615        assert_eq!(dense_from_sparse.shape(), &[2, 3]);
2616
2617        // Verify values are equivalent
2618        assert_eq!(dense_from_dense[[0, 0]], 1.0);
2619        assert_eq!(dense_from_sparse[[0, 0]], 1.0);
2620        assert_eq!(dense_from_dense[[1, 1]], 1.0);
2621        assert_eq!(dense_from_sparse[[1, 1]], 1.0);
2622    }
2623}