1use scirs2_core::ndarray::{s, Array2, ArrayView1, ArrayView2};
6use sklears_core::{
7 error::{Result as SklResult, SklearsError},
8 traits::{Estimator, Fit, Transform, Untrained},
9 types::Float,
10};
11use std::collections::HashMap;
14
15use crate::{MockTransformer, PipelineStep};
16
17#[derive(Debug, Clone)]
44pub struct ColumnTransformer<S = Untrained> {
45 state: S,
46 transformer_names: Vec<String>,
47 transformer_columns: Vec<Vec<usize>>,
48 remainder: String,
49 sparse_threshold: f64,
50 n_jobs: Option<i32>,
51 transformer_weights: Option<HashMap<String, f64>>,
52}
53
54#[derive(Debug)]
56pub struct ColumnTransformerTrained {
57 fitted_transformers: Vec<(String, Box<dyn PipelineStep>, Vec<usize>)>,
58 output_indices: Vec<Vec<usize>>,
59 n_features_in: usize,
60 feature_names_in: Option<Vec<String>>,
61 sparse_output: bool,
62}
63
64impl ColumnTransformer<Untrained> {
65 #[must_use]
67 pub fn new() -> Self {
68 Self {
69 state: Untrained,
70 transformer_names: Vec::new(),
71 transformer_columns: Vec::new(),
72 remainder: "drop".to_string(),
73 sparse_threshold: 0.3,
74 n_jobs: None,
75 transformer_weights: None,
76 }
77 }
78
79 #[must_use]
81 pub fn builder() -> ColumnTransformerBuilder {
82 ColumnTransformerBuilder::new()
83 }
84
85 pub fn add_transformer(&mut self, name: String, columns: Vec<usize>) {
87 self.transformer_names.push(name);
88 self.transformer_columns.push(columns);
89 }
90
91 pub fn add_transformer_step(
93 &mut self,
94 name: String,
95 transformer: Box<dyn PipelineStep>,
96 columns: Vec<usize>,
97 ) {
98 self.transformer_names.push(name);
99 self.transformer_columns.push(columns);
100 }
101
102 #[must_use]
104 pub fn remainder(mut self, remainder: String) -> Self {
105 self.remainder = remainder;
106 self
107 }
108
109 #[must_use]
111 pub fn sparse_threshold(mut self, threshold: f64) -> Self {
112 self.sparse_threshold = threshold;
113 self
114 }
115
116 #[must_use]
118 pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
119 self.n_jobs = n_jobs;
120 self
121 }
122
123 #[must_use]
125 pub fn transformer_weights(mut self, weights: HashMap<String, f64>) -> Self {
126 self.transformer_weights = Some(weights);
127 self
128 }
129
130 fn extract_columns(
132 &self,
133 x: &ArrayView2<'_, Float>,
134 columns: &[usize],
135 ) -> SklResult<Array2<Float>> {
136 if columns.is_empty() {
137 return Ok(Array2::zeros((x.nrows(), 0)));
138 }
139
140 let mut result = Array2::zeros((x.nrows(), columns.len()));
141 for (col_idx, &original_col) in columns.iter().enumerate() {
142 if original_col >= x.ncols() {
143 return Err(SklearsError::InvalidInput(format!(
144 "Column index {original_col} out of bounds"
145 )));
146 }
147 result.column_mut(col_idx).assign(&x.column(original_col));
148 }
149 Ok(result)
150 }
151
152 fn should_output_sparse(&self, x: &ArrayView2<'_, Float>) -> bool {
154 let total_elements = x.nrows() * x.ncols();
155 if total_elements == 0 {
156 return false;
157 }
158
159 let zero_count = x.iter().filter(|&&val| val == 0.0).count();
160 let sparsity = zero_count as f64 / total_elements as f64;
161
162 sparsity >= self.sparse_threshold
163 }
164}
165
166impl Default for ColumnTransformer<Untrained> {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct ColumnTransformerConfig {
175 pub remainder: String,
176 pub sparse_threshold: f64,
177 pub n_jobs: Option<i32>,
178 pub transformer_weights: Option<HashMap<String, f64>>,
179}
180
181impl Default for ColumnTransformerConfig {
182 fn default() -> Self {
183 Self {
184 remainder: "drop".to_string(),
185 sparse_threshold: 0.3,
186 n_jobs: None,
187 transformer_weights: None,
188 }
189 }
190}
191
192impl Estimator for ColumnTransformer<Untrained> {
193 type Config = ColumnTransformerConfig;
194 type Error = SklearsError;
195 type Float = Float;
196
197 fn config(&self) -> &Self::Config {
198 static DEFAULT_CONFIG: ColumnTransformerConfig = ColumnTransformerConfig {
201 remainder: String::new(),
202 sparse_threshold: 0.3,
203 n_jobs: None,
204 transformer_weights: None,
205 };
206 &DEFAULT_CONFIG
207 }
208}
209
210impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for ColumnTransformer<Untrained> {
211 type Fitted = ColumnTransformer<ColumnTransformerTrained>;
212
213 fn fit(
214 self,
215 x: &ArrayView2<'_, Float>,
216 y: &Option<&ArrayView1<'_, Float>>,
217 ) -> SklResult<Self::Fitted> {
218 let n_features_in = x.ncols();
219 let mut fitted_transformers = Vec::new();
220 let mut output_indices = Vec::new();
221 let mut used_columns = vec![false; n_features_in];
222
223 for (name, columns) in self
225 .transformer_names
226 .iter()
227 .zip(self.transformer_columns.iter())
228 {
229 for &col in columns {
231 if col >= n_features_in {
232 return Err(SklearsError::InvalidInput(format!(
233 "Column index {col} out of bounds for {n_features_in} features"
234 )));
235 }
236 used_columns[col] = true;
237 }
238
239 let x_subset = self.extract_columns(x, columns)?;
241
242 let mut transformer = Box::new(MockTransformer::new()) as Box<dyn PipelineStep>;
244 transformer.fit(&x_subset.view(), y.as_ref().copied())?;
245
246 fitted_transformers.push((name.clone(), transformer, columns.clone()));
247 output_indices.push((0..columns.len()).collect()); }
249
250 let remainder_columns: Vec<usize> =
252 (0..n_features_in).filter(|&i| !used_columns[i]).collect();
253
254 if !remainder_columns.is_empty() && self.remainder == "passthrough" {
255 let x_remainder = self.extract_columns(x, &remainder_columns)?;
256 let mut remainder_transformer =
257 Box::new(MockTransformer::new()) as Box<dyn PipelineStep>;
258 remainder_transformer.fit(&x_remainder.view(), y.as_ref().copied())?;
259 fitted_transformers.push((
260 "remainder".to_string(),
261 remainder_transformer,
262 remainder_columns.clone(),
263 ));
264 output_indices.push((0..remainder_columns.len()).collect());
265 }
266
267 let sparse_output = self.should_output_sparse(x);
269
270 Ok(ColumnTransformer {
271 state: ColumnTransformerTrained {
272 fitted_transformers,
273 output_indices,
274 n_features_in,
275 feature_names_in: None,
276 sparse_output,
277 },
278 transformer_names: self.transformer_names,
279 transformer_columns: self.transformer_columns,
280 remainder: self.remainder,
281 sparse_threshold: self.sparse_threshold,
282 n_jobs: self.n_jobs,
283 transformer_weights: self.transformer_weights,
284 })
285 }
286}
287
288impl Transform<ArrayView2<'_, Float>, Array2<f64>> for ColumnTransformer<ColumnTransformerTrained> {
290 fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
291 if x.ncols() != self.state.n_features_in {
292 return Err(SklearsError::InvalidInput(format!(
293 "Input has {} features, expected {}",
294 x.ncols(),
295 self.state.n_features_in
296 )));
297 }
298
299 if self.state.fitted_transformers.is_empty() {
300 return Ok(x.mapv(|v| v));
301 }
302
303 let mut transformed_results = Vec::new();
304
305 for (name, transformer, columns) in &self.state.fitted_transformers {
307 let x_subset = self.extract_columns(x, columns)?;
308 let mut transformed = transformer.transform(&x_subset.view())?;
309
310 if let Some(ref weights) = self.transformer_weights {
312 if let Some(&weight) = weights.get(name) {
313 transformed.mapv_inplace(|v| v * weight);
314 }
315 }
316
317 transformed_results.push(transformed);
318 }
319
320 if transformed_results.is_empty() {
321 return Ok(Array2::zeros((x.nrows(), 0)));
322 }
323
324 if transformed_results.len() == 1 {
326 Ok(transformed_results.into_iter().next().unwrap())
327 } else {
328 self.concatenate_results(transformed_results)
329 }
330 }
331}
332
333#[derive(Debug, Clone)]
335pub enum ColumnTransformerOutput {
336 Dense(Array2<f64>),
338 }
341
342impl ColumnTransformer<ColumnTransformerTrained> {
343 pub fn transform_output(
345 &self,
346 x: &ArrayView2<'_, Float>,
347 ) -> SklResult<ColumnTransformerOutput> {
348 let dense_result = self.transform(x)?;
349
350 Ok(ColumnTransformerOutput::Dense(dense_result))
357 }
359
360 fn concatenate_results(&self, results: Vec<Array2<f64>>) -> SklResult<Array2<f64>> {
378 let n_samples = results[0].nrows();
379 let total_features: usize = results
380 .iter()
381 .map(scirs2_core::ndarray::ArrayBase::ncols)
382 .sum();
383
384 let mut concatenated = Array2::zeros((n_samples, total_features));
385 let mut col_idx = 0;
386
387 for result in results {
388 if result.nrows() != n_samples {
389 return Err(SklearsError::InvalidInput(
390 "All transformer outputs must have the same number of samples".to_string(),
391 ));
392 }
393
394 let end_idx = col_idx + result.ncols();
395 concatenated
396 .slice_mut(s![.., col_idx..end_idx])
397 .assign(&result);
398 col_idx = end_idx;
399 }
400
401 Ok(concatenated)
402 }
403
404 fn extract_columns(
406 &self,
407 x: &ArrayView2<'_, Float>,
408 columns: &[usize],
409 ) -> SklResult<Array2<Float>> {
410 if columns.is_empty() {
411 return Ok(Array2::zeros((x.nrows(), 0)));
412 }
413
414 let mut result = Array2::zeros((x.nrows(), columns.len()));
415 for (col_idx, &original_col) in columns.iter().enumerate() {
416 if original_col >= x.ncols() {
417 return Err(SklearsError::InvalidInput(format!(
418 "Column index {original_col} out of bounds"
419 )));
420 }
421 result.column_mut(col_idx).assign(&x.column(original_col));
422 }
423 Ok(result)
424 }
425
426 #[must_use]
428 pub fn get_transformer_info(&self) -> Vec<(String, Vec<usize>)> {
429 self.state
430 .fitted_transformers
431 .iter()
432 .map(|(name, _, columns)| (name.clone(), columns.clone()))
433 .collect()
434 }
435
436 #[must_use]
438 pub fn n_features_out(&self) -> usize {
439 self.state
440 .output_indices
441 .iter()
442 .map(std::vec::Vec::len)
443 .sum()
444 }
445}
446
447#[derive(Debug, Clone)]
449pub struct ColumnTransformerBuilder {
450 transformer_names: Vec<String>,
451 transformer_columns: Vec<Vec<usize>>,
452 remainder: String,
453 sparse_threshold: f64,
454 n_jobs: Option<i32>,
455 transformer_weights: Option<HashMap<String, f64>>,
456}
457
458impl ColumnTransformerBuilder {
459 #[must_use]
461 pub fn new() -> Self {
462 Self {
463 transformer_names: Vec::new(),
464 transformer_columns: Vec::new(),
465 remainder: "drop".to_string(),
466 sparse_threshold: 0.3,
467 n_jobs: None,
468 transformer_weights: None,
469 }
470 }
471
472 #[must_use]
474 pub fn transformer(mut self, name: String, columns: Vec<usize>) -> Self {
475 self.transformer_names.push(name);
476 self.transformer_columns.push(columns);
477 self
478 }
479
480 #[must_use]
482 pub fn remainder(mut self, remainder: String) -> Self {
483 self.remainder = remainder;
484 self
485 }
486
487 #[must_use]
489 pub fn sparse_threshold(mut self, threshold: f64) -> Self {
490 self.sparse_threshold = threshold;
491 self
492 }
493
494 #[must_use]
496 pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
497 self.n_jobs = n_jobs;
498 self
499 }
500
501 #[must_use]
503 pub fn transformer_weights(mut self, weights: HashMap<String, f64>) -> Self {
504 self.transformer_weights = Some(weights);
505 self
506 }
507
508 #[must_use]
510 pub fn build(self) -> ColumnTransformer<Untrained> {
511 ColumnTransformer {
513 state: Untrained,
514 transformer_names: self.transformer_names,
515 transformer_columns: self.transformer_columns,
516 remainder: self.remainder,
517 sparse_threshold: self.sparse_threshold,
518 n_jobs: self.n_jobs,
519 transformer_weights: self.transformer_weights,
520 }
521 }
522}
523
524impl Default for ColumnTransformerBuilder {
525 fn default() -> Self {
526 Self::new()
527 }
528}