1use crate::error::{Result, SklearsError};
73use crate::traits::{Fit, Predict};
74use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Dimension};
76use serde::{Deserialize, Serialize};
77use std::collections::HashMap;
78use std::fmt;
79
80pub mod sklearn {
82 use super::*;
83 use crate::traits::{Estimator, Score};
84
85 pub trait SklearnCompatible {
87 fn set_param(&mut self, param: &str, value: impl Into<ParamValue>) -> Result<()>;
89
90 fn get_param(&self, param: &str) -> Result<ParamValue>;
92
93 fn get_params(&self, deep: bool) -> HashMap<String, ParamValue>;
95
96 fn set_params(&mut self, params: HashMap<String, ParamValue>) -> Result<()>;
98 }
99
100 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
102 pub enum ParamValue {
103 Bool(bool),
104 Int(i64),
105 Float(f64),
106 String(String),
107 Array(Vec<f64>),
108 None,
109 }
110
111 impl From<bool> for ParamValue {
112 fn from(value: bool) -> Self {
113 ParamValue::Bool(value)
114 }
115 }
116
117 impl From<i64> for ParamValue {
118 fn from(value: i64) -> Self {
119 ParamValue::Int(value)
120 }
121 }
122
123 impl From<f64> for ParamValue {
124 fn from(value: f64) -> Self {
125 ParamValue::Float(value)
126 }
127 }
128
129 impl From<String> for ParamValue {
130 fn from(value: String) -> Self {
131 ParamValue::String(value)
132 }
133 }
134
135 impl From<&str> for ParamValue {
136 fn from(value: &str) -> Self {
137 ParamValue::String(value.to_string())
138 }
139 }
140
141 #[derive(Debug, Clone)]
143 pub struct ScikitLearnModel {
144 model_type: String,
145 parameters: HashMap<String, ParamValue>,
146 fitted: bool,
147 }
148
149 impl ScikitLearnModel {
150 pub fn linear_regression() -> Self {
152 let mut params = HashMap::new();
153 params.insert("fit_intercept".to_string(), ParamValue::Bool(true));
154 params.insert("normalize".to_string(), ParamValue::Bool(false));
155 params.insert("copy_X".to_string(), ParamValue::Bool(true));
156 params.insert("n_jobs".to_string(), ParamValue::None);
157
158 Self {
159 model_type: "LinearRegression".to_string(),
160 parameters: params,
161 fitted: false,
162 }
163 }
164
165 pub fn random_forest_classifier() -> Self {
167 let mut params = HashMap::new();
168 params.insert("n_estimators".to_string(), ParamValue::Int(100));
169 params.insert(
170 "criterion".to_string(),
171 ParamValue::String("gini".to_string()),
172 );
173 params.insert("max_depth".to_string(), ParamValue::None);
174 params.insert("min_samples_split".to_string(), ParamValue::Int(2));
175 params.insert("min_samples_leaf".to_string(), ParamValue::Int(1));
176 params.insert(
177 "max_features".to_string(),
178 ParamValue::String("auto".to_string()),
179 );
180 params.insert("bootstrap".to_string(), ParamValue::Bool(true));
181 params.insert("oob_score".to_string(), ParamValue::Bool(false));
182 params.insert("n_jobs".to_string(), ParamValue::None);
183 params.insert("random_state".to_string(), ParamValue::None);
184
185 Self {
186 model_type: "RandomForestClassifier".to_string(),
187 parameters: params,
188 fitted: false,
189 }
190 }
191
192 pub fn svm_classifier() -> Self {
194 let mut params = HashMap::new();
195 params.insert("C".to_string(), ParamValue::Float(1.0));
196 params.insert("kernel".to_string(), ParamValue::String("rbf".to_string()));
197 params.insert("degree".to_string(), ParamValue::Int(3));
198 params.insert("gamma".to_string(), ParamValue::String("scale".to_string()));
199 params.insert("coef0".to_string(), ParamValue::Float(0.0));
200 params.insert("shrinking".to_string(), ParamValue::Bool(true));
201 params.insert("probability".to_string(), ParamValue::Bool(false));
202 params.insert("tol".to_string(), ParamValue::Float(1e-3));
203 params.insert("cache_size".to_string(), ParamValue::Float(200.0));
204 params.insert("class_weight".to_string(), ParamValue::None);
205 params.insert("verbose".to_string(), ParamValue::Bool(false));
206 params.insert("max_iter".to_string(), ParamValue::Int(-1));
207 params.insert(
208 "decision_function_shape".to_string(),
209 ParamValue::String("ovr".to_string()),
210 );
211 params.insert("break_ties".to_string(), ParamValue::Bool(false));
212 params.insert("random_state".to_string(), ParamValue::None);
213
214 Self {
215 model_type: "SVC".to_string(),
216 parameters: params,
217 fitted: false,
218 }
219 }
220 }
221
222 impl SklearnCompatible for ScikitLearnModel {
223 fn set_param(&mut self, param: &str, value: impl Into<ParamValue>) -> Result<()> {
224 self.parameters.insert(param.to_string(), value.into());
225 Ok(())
226 }
227
228 fn get_param(&self, param: &str) -> Result<ParamValue> {
229 self.parameters
230 .get(param)
231 .cloned()
232 .ok_or_else(|| SklearsError::InvalidInput(format!("Parameter '{param}' not found")))
233 }
234
235 fn get_params(&self, deep: bool) -> HashMap<String, ParamValue> {
236 if deep {
237 self.parameters.clone()
240 } else {
241 self.parameters.clone()
242 }
243 }
244
245 fn set_params(&mut self, params: HashMap<String, ParamValue>) -> Result<()> {
246 for (key, value) in params {
247 self.parameters.insert(key, value);
248 }
249 Ok(())
250 }
251 }
252
253 impl Estimator for ScikitLearnModel {
254 type Config = HashMap<String, ParamValue>;
255 type Error = SklearsError;
256 type Float = f64;
257
258 fn config(&self) -> &Self::Config {
259 &self.parameters
260 }
261 }
262
263 impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for ScikitLearnModel {
264 type Fitted = FittedScikitLearnModel;
265
266 fn fit(mut self, x: &ArrayView2<'a, f64>, y: &ArrayView1<'a, f64>) -> Result<Self::Fitted> {
267 if x.nrows() != y.len() {
269 return Err(SklearsError::ShapeMismatch {
270 expected: format!("({}, n_features)", y.len()),
271 actual: format!("({}, {})", x.nrows(), x.ncols()),
272 });
273 }
274
275 self.fitted = true;
276
277 Ok(FittedScikitLearnModel {
278 model: self,
279 training_shape: (x.nrows(), x.ncols()),
280 feature_importances: vec![0.1; x.ncols()], classes: get_unique_classes(y),
282 })
283 }
284 }
285
286 #[derive(Debug, Clone)]
288 pub struct FittedScikitLearnModel {
289 model: ScikitLearnModel,
290 training_shape: (usize, usize),
291 feature_importances: Vec<f64>,
292 classes: Vec<f64>,
293 }
294
295 impl FittedScikitLearnModel {
296 pub fn feature_importances(&self) -> &[f64] {
298 &self.feature_importances
299 }
300
301 pub fn classes(&self) -> &[f64] {
303 &self.classes
304 }
305
306 pub fn n_features_in(&self) -> usize {
308 self.training_shape.1
309 }
310 }
311
312 impl<'a> Predict<ArrayView2<'a, f64>, Array1<f64>> for FittedScikitLearnModel {
313 fn predict(&self, x: &ArrayView2<'a, f64>) -> Result<Array1<f64>> {
314 if x.ncols() != self.training_shape.1 {
315 return Err(SklearsError::FeatureMismatch {
316 expected: self.training_shape.1,
317 actual: x.ncols(),
318 });
319 }
320
321 let predictions = match self.model.model_type.as_str() {
323 "LinearRegression" => {
324 Array1::from_iter(x.rows().into_iter().map(|row| row.sum() * 0.1))
326 }
327 "RandomForestClassifier" | "SVC" => {
328 let most_common_class = self.classes.first().copied().unwrap_or(0.0);
330 Array1::from_elem(x.nrows(), most_common_class)
331 }
332 _ => Array1::zeros(x.nrows()),
333 };
334
335 Ok(predictions)
336 }
337 }
338
339 impl<'a> Score<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for FittedScikitLearnModel {
340 type Float = f64;
341
342 fn score(&self, x: &ArrayView2<'a, f64>, y: &ArrayView1<'a, f64>) -> Result<f64> {
343 let predictions = self.predict(x)?;
344
345 match self.model.model_type.as_str() {
346 "LinearRegression" => {
347 let y_mean = y.mean().unwrap_or(0.0);
349 let ss_res = predictions
350 .iter()
351 .zip(y.iter())
352 .map(|(pred, actual)| (actual - pred).powi(2))
353 .sum::<f64>();
354 let ss_tot = y
355 .iter()
356 .map(|actual| (actual - y_mean).powi(2))
357 .sum::<f64>();
358
359 if ss_tot == 0.0 {
360 Ok(1.0)
361 } else {
362 Ok(1.0 - (ss_res / ss_tot))
363 }
364 }
365 _ => {
366 let correct = predictions
368 .iter()
369 .zip(y.iter())
370 .map(|(pred, actual)| {
371 if (pred - actual).abs() < 0.5 {
372 1.0
373 } else {
374 0.0
375 }
376 })
377 .sum::<f64>();
378 Ok(correct / y.len() as f64)
379 }
380 }
381 }
382 }
383
384 fn get_unique_classes(y: &ArrayView1<f64>) -> Vec<f64> {
386 let mut classes: Vec<f64> = y.iter().copied().collect();
387 classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
388 classes.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
389 classes
390 }
391}
392
393pub mod numpy {
395 use super::*;
396 use bytemuck::{Pod, Zeroable};
397
398 #[derive(Debug, Clone)]
400 pub struct NumpyArray<T: Pod + Zeroable> {
401 data: Vec<T>,
402 shape: Vec<usize>,
403 strides: Vec<usize>,
404 dtype: String,
405 fortran_order: bool,
406 }
407
408 impl<T: Pod + Zeroable + fmt::Debug> NumpyArray<T> {
409 pub fn from_ndarray<D: Dimension>(
411 array: &scirs2_core::ndarray::ArrayBase<scirs2_core::ndarray::OwnedRepr<T>, D>,
412 ) -> Result<Self> {
413 let shape = array.shape().to_vec();
414 let strides = array.strides().iter().map(|&s| s as usize).collect();
415 let data = array.iter().cloned().collect();
416 let dtype = Self::get_dtype_string();
417
418 Ok(Self {
419 data,
420 shape,
421 strides,
422 dtype,
423 fortran_order: false,
424 })
425 }
426
427 pub fn from_raw(data: Vec<T>, shape: Vec<usize>) -> Result<Self> {
429 let expected_size = shape.iter().product::<usize>();
430 if data.len() != expected_size {
431 return Err(SklearsError::ShapeMismatch {
432 expected: format!("{expected_size} elements"),
433 actual: format!("{} elements", data.len()),
434 });
435 }
436
437 let strides = Self::calculate_strides(&shape, false);
438 let dtype = Self::get_dtype_string();
439
440 Ok(Self {
441 data,
442 shape,
443 strides,
444 dtype,
445 fortran_order: false,
446 })
447 }
448
449 pub fn to_bytes(&self) -> Result<Vec<u8>> {
451 let header = self.create_numpy_header()?;
452 let data_bytes = bytemuck::cast_slice(&self.data);
453
454 let mut result = Vec::new();
455 result.extend_from_slice(&header);
456 result.extend_from_slice(data_bytes);
457
458 Ok(result)
459 }
460
461 pub fn shape(&self) -> &[usize] {
463 &self.shape
464 }
465
466 pub fn strides(&self) -> &[usize] {
468 &self.strides
469 }
470
471 pub fn dtype(&self) -> &str {
473 &self.dtype
474 }
475
476 pub fn data(&self) -> &[T] {
478 &self.data
479 }
480
481 pub fn to_ndarray(&self) -> Result<Array2<T>> {
483 if self.shape.len() != 2 {
484 return Err(SklearsError::InvalidInput(
485 "Only 2D arrays are currently supported for conversion back to ndarray"
486 .to_string(),
487 ));
488 }
489
490 Array2::from_shape_vec((self.shape[0], self.shape[1]), self.data.clone())
491 .map_err(|e| SklearsError::InvalidInput(format!("Failed to create ndarray: {e}")))
492 }
493
494 fn get_dtype_string() -> String {
495 if std::mem::size_of::<T>() == 8 {
496 "<f8".to_string() } else if std::mem::size_of::<T>() == 4 {
498 "<f4".to_string() } else {
500 "<i8".to_string() }
502 }
503
504 fn calculate_strides(shape: &[usize], fortran_order: bool) -> Vec<usize> {
505 let mut strides = vec![0; shape.len()];
506 let item_size = std::mem::size_of::<T>();
507
508 if fortran_order {
509 let mut stride = item_size;
511 for i in 0..shape.len() {
512 strides[i] = stride;
513 stride *= shape[i];
514 }
515 } else {
516 let mut stride = item_size;
518 for i in (0..shape.len()).rev() {
519 strides[i] = stride;
520 stride *= shape[i];
521 }
522 }
523
524 strides
525 }
526
527 fn create_numpy_header(&self) -> Result<Vec<u8>> {
528 let header_dict = format!(
530 "{{'descr': '{}', 'fortran_order': {}, 'shape': ({},)}}",
531 self.dtype,
532 self.fortran_order,
533 self.shape
534 .iter()
535 .map(|x| x.to_string())
536 .collect::<Vec<_>>()
537 .join(", ")
538 );
539
540 let mut header = header_dict.into_bytes();
541
542 while header.len() % 64 != 0 {
544 header.push(b' ');
545 }
546 header.push(b'\n');
547
548 Ok(header)
549 }
550 }
551
552 }
555
556pub mod pandas {
558 use super::*;
559 use std::collections::BTreeMap;
560
561 #[derive(Debug, Clone, Serialize, Deserialize)]
563 pub struct DataFrame {
564 columns: Vec<String>,
565 data: BTreeMap<String, Vec<DataValue>>,
566 index: Vec<String>,
567 }
568
569 #[derive(Debug, Clone, Serialize, Deserialize)]
571 pub enum DataValue {
572 Float(f64),
573 Int(i64),
574 String(String),
575 Bool(bool),
576 None,
577 }
578
579 impl DataFrame {
580 pub fn new() -> Self {
582 Self {
583 columns: Vec::new(),
584 data: BTreeMap::new(),
585 index: Vec::new(),
586 }
587 }
588
589 pub fn from_ndarray(array: &Array2<f64>, columns: Option<Vec<String>>) -> Result<Self> {
591 let n_cols = array.ncols();
592 let n_rows = array.nrows();
593
594 let columns =
595 columns.unwrap_or_else(|| (0..n_cols).map(|i| format!("col_{i}")).collect());
596
597 if columns.len() != n_cols {
598 return Err(SklearsError::ShapeMismatch {
599 expected: format!("{n_cols} columns"),
600 actual: format!("{} column names", columns.len()),
601 });
602 }
603
604 let mut data = BTreeMap::new();
605 for (col_idx, col_name) in columns.iter().enumerate() {
606 let column_data: Vec<DataValue> = (0..n_rows)
607 .map(|row_idx| DataValue::Float(array[[row_idx, col_idx]]))
608 .collect();
609 data.insert(col_name.clone(), column_data);
610 }
611
612 let index: Vec<String> = (0..n_rows).map(|i| i.to_string()).collect();
613
614 Ok(Self {
615 columns,
616 data,
617 index,
618 })
619 }
620
621 pub fn add_column(&mut self, name: String, values: Vec<DataValue>) -> Result<()> {
623 if !self.data.is_empty() && values.len() != self.index.len() {
624 return Err(SklearsError::ShapeMismatch {
625 expected: format!("{} rows", self.index.len()),
626 actual: format!("{} values", values.len()),
627 });
628 }
629
630 if self.data.is_empty() {
631 self.index = (0..values.len()).map(|i| i.to_string()).collect();
632 }
633
634 self.columns.push(name.clone());
635 self.data.insert(name, values);
636 Ok(())
637 }
638
639 pub fn columns(&self) -> &[String] {
641 &self.columns
642 }
643
644 pub fn get_column(&self, name: &str) -> Option<&Vec<DataValue>> {
646 self.data.get(name)
647 }
648
649 pub fn shape(&self) -> (usize, usize) {
651 (self.index.len(), self.columns.len())
652 }
653
654 pub fn to_ndarray(&self) -> Result<Array2<f64>> {
656 let (n_rows, n_cols) = self.shape();
657 let mut array = Array2::zeros((n_rows, n_cols));
658
659 for (col_idx, col_name) in self.columns.iter().enumerate() {
660 if let Some(column) = self.data.get(col_name) {
661 for (row_idx, value) in column.iter().enumerate() {
662 match value {
663 DataValue::Float(f) => array[[row_idx, col_idx]] = *f,
664 DataValue::Int(i) => array[[row_idx, col_idx]] = *i as f64,
665 DataValue::Bool(b) => {
666 array[[row_idx, col_idx]] = if *b { 1.0 } else { 0.0 }
667 }
668 _ => {
669 return Err(SklearsError::InvalidInput(format!(
670 "Non-numeric value in column '{col_name}' at row {row_idx}"
671 )))
672 }
673 }
674 }
675 }
676 }
677
678 Ok(array)
679 }
680
681 pub fn describe(&self) -> Result<DataFrame> {
683 let mut stats_df = DataFrame::new();
684 let stats = ["count", "mean", "std", "min", "25%", "50%", "75%", "max"];
685
686 for stat in &stats {
687 stats_df.add_column(stat.to_string(), Vec::new())?;
688 }
689
690 for col_name in &self.columns {
691 if let Some(column) = self.data.get(col_name) {
692 let numeric_values: Vec<f64> = column
693 .iter()
694 .filter_map(|v| match v {
695 DataValue::Float(f) => Some(*f),
696 DataValue::Int(i) => Some(*i as f64),
697 _ => None,
698 })
699 .collect();
700
701 if !numeric_values.is_empty() {
702 let count = numeric_values.len() as f64;
703 let mean = numeric_values.iter().sum::<f64>() / count;
704 let variance = numeric_values
705 .iter()
706 .map(|x| (x - mean).powi(2))
707 .sum::<f64>()
708 / count;
709 let _std = variance.sqrt();
710
711 let mut sorted = numeric_values.clone();
712 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
713
714 let _min = sorted[0];
715 let _max = sorted[sorted.len() - 1];
716 let _q25 = sorted[sorted.len() / 4];
717 let _q50 = sorted[sorted.len() / 2];
718 let _q75 = sorted[3 * sorted.len() / 4];
719
720 }
723 }
724 }
725
726 Ok(stats_df)
727 }
728 }
729
730 impl Default for DataFrame {
731 fn default() -> Self {
732 Self::new()
733 }
734 }
735}
736
737pub mod pytorch {
739 use super::*;
740 use bytemuck::{Pod, Zeroable};
741
742 #[derive(Debug, Clone, Serialize, Deserialize)]
744 pub struct TensorMetadata {
745 pub shape: Vec<usize>,
746 pub dtype: String,
747 pub requires_grad: bool,
748 pub device: String,
749 }
750
751 pub fn ndarray_to_pytorch_tensor<T: Pod + Zeroable>(
753 array: &Array2<T>,
754 requires_grad: bool,
755 ) -> Result<(Vec<u8>, TensorMetadata)> {
756 let shape = array.shape().to_vec();
757 let data_bytes = bytemuck::cast_slice(array.as_slice().unwrap());
758 let dtype = if std::mem::size_of::<T>() == 8 {
759 "float64"
760 } else {
761 "float32"
762 }
763 .to_string();
764
765 let metadata = TensorMetadata {
766 shape,
767 dtype,
768 requires_grad,
769 device: "cpu".to_string(),
770 };
771
772 Ok((data_bytes.to_vec(), metadata))
773 }
774}
775
776pub mod serialization {
778 use super::*;
779
780 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
782 pub enum ModelFormat {
783 SklearnPickle,
784 XGBoostJson,
785 LightGBMText,
786 TensorFlowSavedModel,
787 PyTorchStateDict,
788 OnnxProtobuf,
789 }
790
791 pub trait ModelSerialization {
793 fn serialize(&self, format: ModelFormat) -> Result<Vec<u8>>;
795
796 fn deserialize(data: &[u8], format: ModelFormat) -> Result<Self>
798 where
799 Self: Sized;
800
801 fn supported_formats() -> Vec<ModelFormat>;
803 }
804
805 #[derive(Debug, Serialize, Deserialize)]
807 pub struct CrossPlatformModel {
808 pub model_type: String,
809 pub version: String,
810 pub parameters: HashMap<String, serde_json::Value>,
811 pub weights: Vec<f64>,
812 pub metadata: HashMap<String, String>,
813 }
814
815 impl CrossPlatformModel {
816 pub fn to_sklearn_metadata(&self) -> Result<HashMap<String, serde_json::Value>> {
818 let mut sklearn_meta = HashMap::new();
819 sklearn_meta.insert(
820 "__class__".to_string(),
821 serde_json::Value::String(self.model_type.clone()),
822 );
823 sklearn_meta.insert(
824 "__version__".to_string(),
825 serde_json::Value::String(self.version.clone()),
826 );
827 sklearn_meta.extend(self.parameters.clone());
828 Ok(sklearn_meta)
829 }
830
831 pub fn from_sklearn_metadata(metadata: HashMap<String, serde_json::Value>) -> Result<Self> {
833 let model_type = metadata
834 .get("__class__")
835 .and_then(|v| v.as_str())
836 .unwrap_or("unknown")
837 .to_string();
838
839 let version = metadata
840 .get("__version__")
841 .and_then(|v| v.as_str())
842 .unwrap_or("unknown")
843 .to_string();
844
845 let mut parameters = metadata;
846 parameters.remove("__class__");
847 parameters.remove("__version__");
848
849 Ok(Self {
850 model_type,
851 version,
852 parameters,
853 weights: Vec::new(),
854 metadata: HashMap::new(),
855 })
856 }
857 }
858}
859
860#[allow(non_snake_case)]
861#[cfg(test)]
862mod tests {
863 use super::numpy::*;
864 use super::pandas::*;
865 use super::sklearn::*;
866 use super::*;
867 use crate::traits::Fit;
868
869 #[test]
870 fn test_sklearn_linear_regression() {
871 let mut model = ScikitLearnModel::linear_regression();
872 assert!(model.set_param("fit_intercept", false).is_ok());
873 assert_eq!(
874 model.get_param("fit_intercept").unwrap(),
875 ParamValue::Bool(false)
876 );
877 }
878
879 #[test]
880 fn test_sklearn_parameter_management() {
881 let mut model = ScikitLearnModel::random_forest_classifier();
882
883 assert!(model.set_param("n_estimators", 200).is_ok());
885 assert!(model.set_param("max_depth", 10).is_ok());
886
887 assert_eq!(
889 model.get_param("n_estimators").unwrap(),
890 ParamValue::Int(200)
891 );
892 assert_eq!(model.get_param("max_depth").unwrap(), ParamValue::Int(10));
893
894 let params = model.get_params(false);
896 assert!(params.contains_key("n_estimators"));
897 assert!(params.contains_key("max_depth"));
898 }
899
900 #[test]
901 fn test_numpy_array_conversion() {
902 let array = Array2::<f64>::zeros((10, 5));
903 let numpy_array = NumpyArray::from_ndarray(&array);
904 assert!(numpy_array.is_ok());
905
906 let numpy_array = numpy_array.unwrap();
907 assert_eq!(numpy_array.shape(), &[10, 5]);
908 assert_eq!(numpy_array.data().len(), 50);
909 }
910
911 #[test]
912 fn test_pandas_dataframe() {
913 let mut df = DataFrame::new();
914
915 let values = vec![
916 DataValue::Float(1.0),
917 DataValue::Float(2.0),
918 DataValue::Float(3.0),
919 ];
920
921 assert!(df.add_column("test_col".to_string(), values).is_ok());
922 assert_eq!(df.shape(), (3, 1));
923 assert_eq!(df.columns(), &["test_col"]);
924 }
925
926 #[test]
927 fn test_dataframe_to_ndarray() {
928 let array = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
929 let df = DataFrame::from_ndarray(&array, None).unwrap();
930
931 let converted = df.to_ndarray().unwrap();
932 assert_eq!(converted.shape(), [2, 2]);
933 assert_eq!(converted[[0, 0]], 1.0);
934 assert_eq!(converted[[1, 1]], 4.0);
935 }
936
937 #[test]
938 fn test_sklearn_model_fitting() {
939 let model = ScikitLearnModel::linear_regression();
940 let features = Array2::zeros((10, 3));
941 let targets = Array1::zeros(10);
942
943 let fitted = model.fit(&features.view(), &targets.view());
944 assert!(fitted.is_ok());
945
946 let fitted = fitted.unwrap();
947 assert_eq!(fitted.n_features_in(), 3);
948 }
949
950 #[test]
951 fn test_cross_platform_model() {
952 use serialization::CrossPlatformModel;
953
954 let model = CrossPlatformModel {
955 model_type: "LinearRegression".to_string(),
956 version: "1.0".to_string(),
957 parameters: HashMap::new(),
958 weights: vec![1.0, 2.0, 3.0],
959 metadata: HashMap::new(),
960 };
961
962 let sklearn_meta = model.to_sklearn_metadata();
963 assert!(sklearn_meta.is_ok());
964
965 let meta = sklearn_meta.unwrap();
966 assert_eq!(
967 meta.get("__class__").unwrap().as_str().unwrap(),
968 "LinearRegression"
969 );
970 }
971}