sklears_utils/
external_integration.rs

1//! External integration utilities for machine learning interoperability
2//!
3//! This module provides utilities for integrating with external systems,
4//! including Python interoperability, WASM compilation support, and
5//! foreign function interface (FFI) utilities.
6
7use crate::{UtilsError, UtilsResult};
8use scirs2_core::ndarray::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::ffi::{CStr, CString};
12use std::fmt;
13use std::os::raw::c_char;
14
15/// Python interoperability utilities
16pub struct PythonInterop;
17
18impl PythonInterop {
19    /// Convert Rust array to Python-compatible format
20    pub fn array_to_python_buffer(array: &Array1<f64>) -> PyArrayBuffer {
21        PyArrayBuffer {
22            data: array.as_slice().unwrap().to_vec(),
23            shape: vec![array.len()],
24            dtype: "float64".to_string(),
25            order: "C".to_string(),
26        }
27    }
28
29    /// Convert 2D Rust array to Python-compatible format
30    pub fn array2_to_python_buffer(array: &Array2<f64>) -> PyArrayBuffer {
31        let (rows, cols) = array.dim();
32        PyArrayBuffer {
33            data: array.as_slice().unwrap().to_vec(),
34            shape: vec![rows, cols],
35            dtype: "float64".to_string(),
36            order: "C".to_string(),
37        }
38    }
39
40    /// Create Rust array from Python buffer
41    pub fn python_buffer_to_array(buffer: &PyArrayBuffer) -> UtilsResult<Array1<f64>> {
42        if buffer.shape.len() != 1 {
43            return Err(UtilsError::InvalidParameter(
44                "Expected 1D array".to_string(),
45            ));
46        }
47
48        if buffer.dtype != "float64" {
49            return Err(UtilsError::InvalidParameter(format!(
50                "Unsupported dtype: {}",
51                buffer.dtype
52            )));
53        }
54
55        Array1::from_vec(buffer.data.clone())
56            .into_shape_with_order(buffer.shape[0])
57            .map_err(|e| UtilsError::InvalidParameter(format!("Shape error: {e}")))
58    }
59
60    /// Create 2D Rust array from Python buffer
61    pub fn python_buffer_to_array2(buffer: &PyArrayBuffer) -> UtilsResult<Array2<f64>> {
62        if buffer.shape.len() != 2 {
63            return Err(UtilsError::InvalidParameter(
64                "Expected 2D array".to_string(),
65            ));
66        }
67
68        if buffer.dtype != "float64" {
69            return Err(UtilsError::InvalidParameter(format!(
70                "Unsupported dtype: {}",
71                buffer.dtype
72            )));
73        }
74
75        Array2::from_shape_vec((buffer.shape[0], buffer.shape[1]), buffer.data.clone())
76            .map_err(|e| UtilsError::InvalidParameter(format!("Shape error: {e}")))
77    }
78
79    /// Generate Python numpy import code
80    pub fn generate_numpy_import_code(array_name: &str, buffer: &PyArrayBuffer) -> String {
81        format!(
82            r#"
83import numpy as np
84
85# Data generated by sklears-utils
86{} = np.array({:?}, dtype='{}').reshape({:?})
87"#,
88            array_name, buffer.data, buffer.dtype, buffer.shape
89        )
90    }
91
92    /// Generate Python function call template
93    pub fn generate_function_call_template(
94        function_name: &str,
95        parameters: &[PythonParameter],
96    ) -> String {
97        let param_strings: Vec<String> = parameters
98            .iter()
99            .map(|p| match &p.value {
100                PythonValue::String(s) => format!("{}='{}'", p.name, s),
101                PythonValue::Number(n) => format!("{}={}", p.name, n),
102                PythonValue::Boolean(b) => {
103                    format!("{}={}", p.name, if *b { "True" } else { "False" })
104                }
105                PythonValue::Array(name) => format!("{}={}", p.name, name),
106            })
107            .collect();
108
109        format!("{}({})", function_name, param_strings.join(", "))
110    }
111
112    /// Create Python script for ML model
113    pub fn create_ml_script(
114        model_type: &str,
115        training_data: &PyArrayBuffer,
116        labels: &PyArrayBuffer,
117        hyperparameters: &HashMap<String, f64>,
118    ) -> UtilsResult<String> {
119        let mut script = String::new();
120
121        // Imports
122        script.push_str("import numpy as np\n");
123        script.push_str("from sklearn.model_selection import train_test_split\n");
124
125        match model_type {
126            "linear_regression" => {
127                script.push_str("from sklearn.linear_model import LinearRegression\n")
128            }
129            "random_forest" => {
130                script.push_str("from sklearn.ensemble import RandomForestRegressor\n")
131            }
132            "svm" => script.push_str("from sklearn.svm import SVC\n"),
133            _ => {
134                return Err(UtilsError::InvalidParameter(format!(
135                    "Unsupported model type: {model_type}"
136                )))
137            }
138        }
139
140        script.push_str("\n# Data preparation\n");
141        script.push_str(&Self::generate_numpy_import_code("X", training_data));
142        script.push_str(&Self::generate_numpy_import_code("y", labels));
143
144        script.push_str("\n# Train-test split\n");
145        script.push_str("X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n");
146
147        script.push_str("\n# Model creation and training\n");
148        let model_creation = match model_type {
149            "linear_regression" => "model = LinearRegression()".to_string(),
150            "random_forest" => {
151                let n_estimators = hyperparameters.get("n_estimators").unwrap_or(&100.0);
152                format!(
153                    "model = RandomForestRegressor(n_estimators={})",
154                    *n_estimators as i32
155                )
156            }
157            "svm" => {
158                let c = hyperparameters.get("C").unwrap_or(&1.0);
159                format!("model = SVC(C={c})")
160            }
161            _ => {
162                return Err(UtilsError::InvalidParameter(format!(
163                    "Unsupported model type: {model_type}"
164                )))
165            }
166        };
167
168        script.push_str(&model_creation);
169        script.push_str("\nmodel.fit(X_train, y_train)\n");
170
171        script.push_str("\n# Evaluation\n");
172        script.push_str("score = model.score(X_test, y_test)\n");
173        script.push_str("print(f'Model score: {score:.4f}')\n");
174
175        Ok(script)
176    }
177}
178
179/// WASM compilation utilities
180pub struct WasmUtils;
181
182impl WasmUtils {
183    /// Generate WASM-compatible function signature
184    pub fn generate_wasm_signature(
185        function_name: &str,
186        parameters: &[WasmParameter],
187        return_type: WasmType,
188    ) -> String {
189        let param_strings: Vec<String> = parameters
190            .iter()
191            .map(|p| format!("{}: {}", p.name, p.param_type))
192            .collect();
193
194        format!(
195            "#[wasm_bindgen]\npub fn {}({}) -> {} {{",
196            function_name,
197            param_strings.join(", "),
198            return_type
199        )
200    }
201
202    /// Generate WASM memory management helpers
203    pub fn generate_memory_helpers() -> String {
204        r#"
205use wasm_bindgen::prelude::*;
206
207// Memory management helpers for WASM
208#[wasm_bindgen]
209pub fn alloc(size: usize) -> *mut u8 {
210    let mut buf = Vec::with_capacity(size);
211    let ptr = buf.as_mut_ptr();
212    std::mem::forget(buf);
213    ptr
214}
215
216#[wasm_bindgen]
217pub fn dealloc(ptr: *mut u8, size: usize) {
218    unsafe {
219        let _ = Vec::from_raw_parts(ptr, size, size);
220    }
221}
222
223// Array helpers
224#[wasm_bindgen]
225pub struct Float64Array {
226    data: Vec<f64>,
227}
228
229#[wasm_bindgen]
230impl Float64Array {
231    #[wasm_bindgen(constructor)]
232    pub fn new(size: usize) -> Float64Array {
233        Float64Array {
234            data: vec![0.0; size],
235        }
236    }
237
238    #[wasm_bindgen(getter)]
239    pub fn length(&self) -> usize {
240        self.data.len()
241    }
242
243    #[wasm_bindgen]
244    pub fn get(&self, index: usize) -> f64 {
245        self.data.get(index).copied().unwrap_or(0.0)
246    }
247
248    #[wasm_bindgen]
249    pub fn set(&mut self, index: usize, value: f64) {
250        if index < self.data.len() {
251            self.data[index] = value;
252        }
253    }
254
255    #[wasm_bindgen]
256    pub fn as_ptr(&self) -> *const f64 {
257        self.data.as_ptr()
258    }
259}
260"#
261        .to_string()
262    }
263
264    /// Generate WASM bindings for ML functions
265    pub fn generate_ml_bindings() -> String {
266        r#"
267use wasm_bindgen::prelude::*;
268
269// Linear algebra operations
270#[wasm_bindgen]
271pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
272    if a.len() != b.len() {
273        return 0.0;
274    }
275    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
276}
277
278#[wasm_bindgen]
279pub fn matrix_multiply(
280    a: &[f64], a_rows: usize, a_cols: usize,
281    b: &[f64], b_rows: usize, b_cols: usize,
282    result: &mut [f64]
283) -> bool {
284    if a_cols != b_rows || result.len() != a_rows * b_cols {
285        return false;
286    }
287
288    for i in 0..a_rows {
289        for j in 0..b_cols {
290            let mut sum = 0.0;
291            for k in 0..a_cols {
292                sum += a[i * a_cols + k] * b[k * b_cols + j];
293            }
294            result[i * b_cols + j] = sum;
295        }
296    }
297    true
298}
299
300// Statistical functions
301#[wasm_bindgen]
302pub fn mean(data: &[f64]) -> f64 {
303    if data.is_empty() {
304        return 0.0;
305    }
306    data.iter().sum::<f64>() / data.len() as f64
307}
308
309#[wasm_bindgen]
310pub fn variance(data: &[f64]) -> f64 {
311    if data.len() < 2 {
312        return 0.0;
313    }
314    let mean_val = mean(data);
315    let variance = data.iter()
316        .map(|x| (x - mean_val).powi(2))
317        .sum::<f64>() / (data.len() - 1) as f64;
318    variance
319}
320
321#[wasm_bindgen]
322pub fn standard_deviation(data: &[f64]) -> f64 {
323    variance(data).sqrt()
324}
325"#
326        .to_string()
327    }
328
329    /// Create build configuration for WASM
330    pub fn create_wasm_build_config() -> WasmBuildConfig {
331        WasmBuildConfig {
332            target: "wasm32-unknown-unknown".to_string(),
333            features: vec![
334                "wasm-bindgen".to_string(),
335                "console_error_panic_hook".to_string(),
336            ],
337            optimization: WasmOptimization::Size,
338            debug: false,
339            typescript_bindings: true,
340        }
341    }
342
343    /// Generate package.json for WASM project
344    pub fn generate_package_json(project_name: &str, version: &str) -> String {
345        format!(
346            r#"{{
347  "name": "{project_name}",
348  "version": "{version}",
349  "description": "WASM bindings for sklears ML utilities",
350  "main": "index.js",
351  "types": "index.d.ts",
352  "scripts": {{
353    "build": "wasm-pack build --target web --out-dir pkg",
354    "build:nodejs": "wasm-pack build --target nodejs --out-dir pkg-node",
355    "test": "wasm-pack test --headless --chrome"
356  }},
357  "devDependencies": {{
358    "wasm-pack": "^0.12.0"
359  }},
360  "files": [
361    "pkg/"
362  ],
363  "keywords": [
364    "wasm",
365    "machine-learning",
366    "sklears",
367    "linear-algebra"
368  ]
369}}"#
370        )
371    }
372}
373
374/// R interoperability utilities
375pub struct RInterop;
376
377impl RInterop {
378    /// Convert Rust array to R-compatible format
379    pub fn array_to_r_vector(array: &Array1<f64>) -> RVector {
380        RVector {
381            data: array.as_slice().unwrap().to_vec(),
382            length: array.len(),
383            r_type: RType::Numeric,
384        }
385    }
386
387    /// Convert 2D Rust array to R matrix
388    pub fn array2_to_r_matrix(array: &Array2<f64>) -> RMatrix {
389        let (rows, cols) = array.dim();
390
391        // Convert from row-major (ndarray) to column-major (R) storage
392        let mut col_major_data = vec![0.0; rows * cols];
393        for i in 0..rows {
394            for j in 0..cols {
395                col_major_data[j * rows + i] = array[[i, j]];
396            }
397        }
398
399        RMatrix {
400            data: col_major_data,
401            nrow: rows,
402            ncol: cols,
403            byrow: false, // R defaults to column-major
404            r_type: RType::Numeric,
405        }
406    }
407
408    /// Create Rust array from R vector
409    pub fn r_vector_to_array(vector: &RVector) -> UtilsResult<Array1<f64>> {
410        if vector.r_type != RType::Numeric {
411            return Err(UtilsError::InvalidParameter(format!(
412                "Expected numeric vector, got {:?}",
413                vector.r_type
414            )));
415        }
416
417        Array1::from_vec(vector.data.clone())
418            .into_shape_with_order(vector.length)
419            .map_err(|e| UtilsError::InvalidParameter(format!("Shape error: {e}")))
420    }
421
422    /// Create 2D Rust array from R matrix
423    pub fn r_matrix_to_array2(matrix: &RMatrix) -> UtilsResult<Array2<f64>> {
424        if matrix.r_type != RType::Numeric {
425            return Err(UtilsError::InvalidParameter(format!(
426                "Expected numeric matrix, got {:?}",
427                matrix.r_type
428            )));
429        }
430
431        if matrix.byrow {
432            // Data is already in row-major format, can use directly
433            Array2::from_shape_vec((matrix.nrow, matrix.ncol), matrix.data.clone())
434                .map_err(|e| UtilsError::InvalidParameter(format!("Shape error: {e}")))
435        } else {
436            // Convert from column-major (R) to row-major (ndarray) storage
437            let mut row_major_data = vec![0.0; matrix.data.len()];
438            for i in 0..matrix.nrow {
439                for j in 0..matrix.ncol {
440                    row_major_data[i * matrix.ncol + j] = matrix.data[j * matrix.nrow + i];
441                }
442            }
443            Array2::from_shape_vec((matrix.nrow, matrix.ncol), row_major_data)
444                .map_err(|e| UtilsError::InvalidParameter(format!("Shape error: {e}")))
445        }
446    }
447
448    /// Generate R vector creation code
449    pub fn generate_r_vector_code(vector_name: &str, vector: &RVector) -> String {
450        format!(
451            "{} <- c({})",
452            vector_name,
453            vector
454                .data
455                .iter()
456                .map(|x| x.to_string())
457                .collect::<Vec<_>>()
458                .join(", ")
459        )
460    }
461
462    /// Generate R matrix creation code
463    pub fn generate_r_matrix_code(matrix_name: &str, matrix: &RMatrix) -> String {
464        let data_str = matrix
465            .data
466            .iter()
467            .map(|x| x.to_string())
468            .collect::<Vec<_>>()
469            .join(", ");
470
471        format!(
472            "{} <- matrix(c({}), nrow = {}, ncol = {}, byrow = {})",
473            matrix_name,
474            data_str,
475            matrix.nrow,
476            matrix.ncol,
477            if matrix.byrow { "TRUE" } else { "FALSE" }
478        )
479    }
480
481    /// Generate R data frame creation code
482    pub fn generate_r_dataframe_code(
483        df_name: &str,
484        columns: &HashMap<String, RVector>,
485    ) -> UtilsResult<String> {
486        let mut column_defs = Vec::new();
487
488        // Check that all columns have the same length
489        let first_length = columns.values().next().map(|v| v.length).unwrap_or(0);
490
491        for (name, vector) in columns {
492            if vector.length != first_length {
493                return Err(UtilsError::InvalidParameter(
494                    "All columns must have the same length".to_string(),
495                ));
496            }
497
498            let data_str = vector
499                .data
500                .iter()
501                .map(|x| x.to_string())
502                .collect::<Vec<_>>()
503                .join(", ");
504
505            column_defs.push(format!("{name} = c({data_str})"));
506        }
507
508        Ok(format!(
509            "{} <- data.frame({})",
510            df_name,
511            column_defs.join(", ")
512        ))
513    }
514
515    /// Generate R package loading code
516    pub fn generate_r_package_imports(packages: &[&str]) -> String {
517        let mut imports = String::new();
518
519        for package in packages {
520            imports.push_str(&format!("library({package})\n"));
521        }
522
523        imports
524    }
525
526    /// Generate R function call template
527    pub fn generate_r_function_call(function_name: &str, parameters: &[RParameter]) -> String {
528        let param_strings: Vec<String> = parameters
529            .iter()
530            .map(|p| match &p.value {
531                RValue::String(s) => format!("{} = \"{}\"", p.name, s),
532                RValue::Number(n) => format!("{} = {}", p.name, n),
533                RValue::Boolean(b) => format!("{} = {}", p.name, if *b { "TRUE" } else { "FALSE" }),
534                RValue::Vector(name) => format!("{} = {}", p.name, name),
535                RValue::Matrix(name) => format!("{} = {}", p.name, name),
536                RValue::DataFrame(name) => format!("{} = {}", p.name, name),
537            })
538            .collect();
539
540        format!("{}({})", function_name, param_strings.join(", "))
541    }
542
543    /// Create R script for ML model
544    pub fn create_r_ml_script(
545        model_type: &str,
546        training_data: &RMatrix,
547        response_var: &RVector,
548        hyperparameters: &HashMap<String, f64>,
549    ) -> UtilsResult<String> {
550        let mut script = String::new();
551
552        // Package imports based on model type
553        match model_type {
554            "linear_regression" => {
555                script.push_str("# Linear regression using base R\n");
556            }
557            "random_forest" => {
558                script.push_str(&Self::generate_r_package_imports(&["randomForest"]));
559            }
560            "svm" => {
561                script.push_str(&Self::generate_r_package_imports(&["e1071"]));
562            }
563            "glm" => {
564                script.push_str("# Generalized linear model using base R\n");
565            }
566            "tree" => {
567                script.push_str(&Self::generate_r_package_imports(&["tree"]));
568            }
569            _ => {
570                return Err(UtilsError::InvalidParameter(format!(
571                    "Unsupported R model type: {model_type}"
572                )))
573            }
574        }
575
576        script.push_str("\n# Data preparation\n");
577        script.push_str(&Self::generate_r_matrix_code("X", training_data));
578        script.push('\n');
579        script.push_str(&Self::generate_r_vector_code("y", response_var));
580        script.push('\n');
581
582        // Create data frame for some models
583        if matches!(model_type, "glm" | "tree") {
584            script.push_str("\n# Create data frame\n");
585            script.push_str("df <- data.frame(y = y, X)\n");
586            script.push_str("colnames(df) <- c('response', paste0('X', 1:ncol(X)))\n");
587        }
588
589        script.push_str("\n# Train-test split\n");
590        script.push_str("set.seed(42)\n");
591        script.push_str("train_indices <- sample(1:nrow(X), size = 0.8 * nrow(X))\n");
592        script.push_str("X_train <- X[train_indices, ]\n");
593        script.push_str("X_test <- X[-train_indices, ]\n");
594        script.push_str("y_train <- y[train_indices]\n");
595        script.push_str("y_test <- y[-train_indices]\n");
596
597        script.push_str("\n# Model creation and training\n");
598        let model_creation = match model_type {
599            "linear_regression" => "model <- lm(y_train ~ X_train)".to_string(),
600            "random_forest" => {
601                let ntree = hyperparameters.get("ntree").unwrap_or(&500.0);
602                let mtry = hyperparameters.get("mtry").unwrap_or(&3.0);
603                format!(
604                    "model <- randomForest(x = X_train, y = y_train, ntree = {}, mtry = {})",
605                    *ntree as i32, *mtry as i32
606                )
607            }
608            "svm" => {
609                let cost = hyperparameters.get("cost").unwrap_or(&1.0);
610                let gamma = hyperparameters.get("gamma").unwrap_or(&0.1);
611                format!("model <- svm(x = X_train, y = y_train, cost = {cost}, gamma = {gamma})")
612            }
613            "glm" => {
614                let family = "gaussian"; // Default, could be parameterized
615                format!("df_train <- df[train_indices, ]\nmodel <- glm(response ~ ., data = df_train, family = {family})")
616            }
617            "tree" => {
618                "df_train <- df[train_indices, ]\nmodel <- tree(response ~ ., data = df_train)"
619                    .to_string()
620            }
621            _ => {
622                return Err(UtilsError::InvalidParameter(format!(
623                    "Unsupported R model type: {model_type}"
624                )))
625            }
626        };
627
628        script.push_str(&model_creation);
629        script.push('\n');
630
631        script.push_str("\n# Prediction and evaluation\n");
632        let prediction_code = match model_type {
633            "linear_regression" => {
634                "predictions <- predict(model, data.frame(X_test))\nrmse <- sqrt(mean((predictions - y_test)^2))\ncat('RMSE:', rmse, '\\n')"
635            },
636            "random_forest" => {
637                "predictions <- predict(model, X_test)\nrmse <- sqrt(mean((predictions - y_test)^2))\ncat('RMSE:', rmse, '\\n')"
638            },
639            "svm" => {
640                "predictions <- predict(model, X_test)\nrmse <- sqrt(mean((predictions - y_test)^2))\ncat('RMSE:', rmse, '\\n')"
641            },
642            "glm" | "tree" => {
643                "df_test <- df[-train_indices, ]\npredictions <- predict(model, df_test)\nrmse <- sqrt(mean((predictions - y_test)^2))\ncat('RMSE:', rmse, '\\n')"
644            },
645            _ => {
646                return Err(UtilsError::InvalidParameter(format!(
647                    "Unsupported R model type for prediction: {model_type}"
648                )))
649            }
650        };
651
652        script.push_str(prediction_code);
653        script.push('\n');
654
655        script.push_str("\n# Model summary\n");
656        script.push_str("print(summary(model))\n");
657
658        Ok(script)
659    }
660
661    /// Generate R statistical analysis script
662    pub fn create_r_statistical_analysis(
663        data: &RMatrix,
664        analysis_type: &str,
665    ) -> UtilsResult<String> {
666        let mut script = String::new();
667
668        script.push_str("# Statistical analysis generated by sklears-utils\n");
669        script.push_str(&Self::generate_r_matrix_code("data", data));
670        script.push('\n');
671
672        match analysis_type {
673            "descriptive" => {
674                script.push_str("\n# Descriptive statistics\n");
675                script.push_str("summary(data)\n");
676                script.push_str("apply(data, 2, sd)  # Standard deviations\n");
677                script.push_str("cor(data)  # Correlation matrix\n");
678            }
679            "pca" => {
680                script.push_str("\n# Principal Component Analysis\n");
681                script.push_str("pca_result <- prcomp(data, center = TRUE, scale. = TRUE)\n");
682                script.push_str("summary(pca_result)\n");
683                script
684                    .push_str("plot(pca_result$x[,1:2])  # Scatter plot of first two components\n");
685            }
686            "clustering" => {
687                script.push_str("\n# K-means clustering\n");
688                script.push_str("set.seed(42)\n");
689                script.push_str("kmeans_result <- kmeans(data, centers = 3)\n");
690                script.push_str("print(kmeans_result)\n");
691                script.push_str("plot(data, col = kmeans_result$cluster)\n");
692            }
693            "normality_test" => {
694                script.push_str("\n# Normality tests\n");
695                script.push_str("for(i in 1:ncol(data)) {\n");
696                script.push_str("  cat('Column', i, '\\n')\n");
697                script.push_str("  print(shapiro.test(data[,i]))\n");
698                script.push_str("}\n");
699            }
700            _ => {
701                return Err(UtilsError::InvalidParameter(format!(
702                    "Unsupported analysis type: {analysis_type}"
703                )))
704            }
705        }
706
707        Ok(script)
708    }
709
710    /// Convert R data types to Rust-compatible format
711    pub fn convert_r_output(output: &str, expected_type: ROutputType) -> UtilsResult<ROutputValue> {
712        match expected_type {
713            ROutputType::Vector => {
714                // Parse R vector output: [1] 1.0 2.0 3.0
715                let cleaned = output.replace("[1]", "").trim().to_string();
716                let values: Result<Vec<f64>, _> = cleaned
717                    .split_whitespace()
718                    .map(|s| s.parse::<f64>())
719                    .collect();
720
721                match values {
722                    Ok(v) => Ok(ROutputValue::Vector(v)),
723                    Err(_) => Err(UtilsError::InvalidParameter(
724                        "Failed to parse R vector output".to_string(),
725                    )),
726                }
727            }
728            ROutputType::Scalar => {
729                // Parse single value
730                let cleaned = output.replace("[1]", "");
731                let cleaned = cleaned.trim();
732                match cleaned.parse::<f64>() {
733                    Ok(v) => Ok(ROutputValue::Scalar(v)),
734                    Err(_) => Err(UtilsError::InvalidParameter(
735                        "Failed to parse R scalar output".to_string(),
736                    )),
737                }
738            }
739            ROutputType::String => Ok(ROutputValue::String(output.to_string())),
740        }
741    }
742}
743
744/// Foreign Function Interface (FFI) utilities
745pub struct FFIUtils;
746
747impl FFIUtils {
748    /// Create C-compatible function signature
749    pub fn create_c_signature(
750        function_name: &str,
751        parameters: &[CParameter],
752        return_type: CType,
753    ) -> String {
754        let param_strings: Vec<String> = parameters
755            .iter()
756            .map(|p| format!("{} {}", p.param_type, p.name))
757            .collect();
758
759        format!(
760            "extern \"C\" fn {}({}) -> {}",
761            function_name,
762            param_strings.join(", "),
763            return_type
764        )
765    }
766
767    /// Generate C header file
768    pub fn generate_c_header(library_name: &str, functions: &[CFunctionSignature]) -> String {
769        let mut header = String::new();
770
771        let library_upper = library_name.to_uppercase();
772        header.push_str(&format!("#ifndef {library_upper}_H\n"));
773        header.push_str(&format!("#define {library_upper}_H\n\n"));
774        header.push_str("#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n");
775
776        for func in functions {
777            header.push_str(&format!("{};\n", func.signature));
778        }
779
780        header.push_str("\n#ifdef __cplusplus\n}\n#endif\n\n");
781        header.push_str(&format!("#endif // {library_upper}_H\n"));
782
783        header
784    }
785
786    /// Convert Rust string to C string
787    pub fn rust_string_to_c(s: &str) -> UtilsResult<*mut c_char> {
788        let c_string = CString::new(s)
789            .map_err(|e| UtilsError::InvalidParameter(format!("Invalid C string: {e}")))?;
790        Ok(c_string.into_raw())
791    }
792
793    /// Convert C string to Rust string
794    ///
795    /// # Safety
796    ///
797    /// This function is unsafe because it dereferences a raw pointer. The caller must ensure:
798    /// - The pointer is valid and points to a null-terminated C string
799    /// - The pointer remains valid for the duration of this function call
800    /// - The memory pointed to by the pointer is not accessed by other threads during this call
801    pub unsafe fn c_string_to_rust(ptr: *const c_char) -> UtilsResult<String> {
802        if ptr.is_null() {
803            return Err(UtilsError::InvalidParameter("Null pointer".to_string()));
804        }
805
806        let c_str = CStr::from_ptr(ptr);
807        c_str
808            .to_str()
809            .map(|s| s.to_string())
810            .map_err(|e| UtilsError::InvalidParameter(format!("Invalid UTF-8: {e}")))
811    }
812
813    /// Create array transfer structure for FFI
814    pub fn create_array_transfer(data: &[f64]) -> ArrayTransfer {
815        ArrayTransfer {
816            data: data.as_ptr(),
817            length: data.len(),
818            capacity: data.len(),
819        }
820    }
821
822    /// Generate FFI binding examples
823    pub fn generate_ffi_examples() -> String {
824        r#"
825// Example FFI functions for machine learning operations
826
827use std::os::raw::{c_double, c_int};
828use std::slice;
829
830#[repr(C)]
831pub struct ArrayTransfer {
832    pub data: *const f64,
833    pub length: usize,
834    pub capacity: usize,
835}
836
837// Linear regression example
838#[no_mangle]
839pub extern "C" fn linear_regression_fit(
840    x_data: *const c_double,
841    y_data: *const c_double,
842    n_samples: c_int,
843    coefficients: *mut c_double,
844    intercept: *mut c_double,
845) -> c_int {
846    if x_data.is_null() || y_data.is_null() || coefficients.is_null() || intercept.is_null() {
847        return -1; // Error: null pointer
848    }
849
850    unsafe {
851        let x_slice = slice::from_raw_parts(x_data, n_samples as usize);
852        let y_slice = slice::from_raw_parts(y_data, n_samples as usize);
853        
854        // Simple linear regression calculation
855        let n = n_samples as f64;
856        let sum_x: f64 = x_slice.iter().sum();
857        let sum_y: f64 = y_slice.iter().sum();
858        let sum_xy: f64 = x_slice.iter().zip(y_slice.iter()).map(|(x, y)| x * y).sum();
859        let sum_xx: f64 = x_slice.iter().map(|x| x * x).sum();
860        
861        let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x * sum_x);
862        let intercept_val = (sum_y - slope * sum_x) / n;
863        
864        *coefficients = slope;
865        *intercept = intercept_val;
866        
867        0 // Success
868    }
869}
870
871// Array operations example
872#[no_mangle]
873pub extern "C" fn array_mean(
874    data: *const c_double,
875    length: c_int,
876    result: *mut c_double,
877) -> c_int {
878    if data.is_null() || result.is_null() || length <= 0 {
879        return -1;
880    }
881
882    unsafe {
883        let slice = slice::from_raw_parts(data, length as usize);
884        let mean = slice.iter().sum::<f64>() / length as f64;
885        *result = mean;
886        0
887    }
888}
889"#
890        .to_string()
891    }
892}
893
894// Data structures for external integration
895
896#[derive(Debug, Clone, Serialize, Deserialize)]
897pub struct PyArrayBuffer {
898    pub data: Vec<f64>,
899    pub shape: Vec<usize>,
900    pub dtype: String,
901    pub order: String,
902}
903
904#[derive(Debug, Clone)]
905pub struct PythonParameter {
906    pub name: String,
907    pub value: PythonValue,
908}
909
910#[derive(Debug, Clone)]
911pub enum PythonValue {
912    String(String),
913    Number(f64),
914    Boolean(bool),
915    Array(String), // Reference to array variable name
916}
917
918#[derive(Debug, Clone)]
919pub struct WasmParameter {
920    pub name: String,
921    pub param_type: WasmType,
922}
923
924#[derive(Debug, Clone)]
925pub enum WasmType {
926    F64,
927    F32,
928    I32,
929    U32,
930    Bool,
931    String,
932}
933
934impl fmt::Display for WasmType {
935    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
936        match self {
937            WasmType::F64 => write!(f, "f64"),
938            WasmType::F32 => write!(f, "f32"),
939            WasmType::I32 => write!(f, "i32"),
940            WasmType::U32 => write!(f, "u32"),
941            WasmType::Bool => write!(f, "bool"),
942            WasmType::String => write!(f, "String"),
943        }
944    }
945}
946
947#[derive(Debug, Clone)]
948pub struct WasmBuildConfig {
949    pub target: String,
950    pub features: Vec<String>,
951    pub optimization: WasmOptimization,
952    pub debug: bool,
953    pub typescript_bindings: bool,
954}
955
956#[derive(Debug, Clone)]
957pub enum WasmOptimization {
958    None,
959    Size,
960    Speed,
961}
962
963#[derive(Debug, Clone)]
964pub struct CParameter {
965    pub name: String,
966    pub param_type: CType,
967}
968
969#[derive(Debug, Clone)]
970pub enum CType {
971    Int,
972    Double,
973    Float,
974    CharPtr,
975    VoidPtr,
976    ConstCharPtr,
977    ConstDoublePtr,
978}
979
980impl fmt::Display for CType {
981    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
982        match self {
983            CType::Int => write!(f, "int"),
984            CType::Double => write!(f, "double"),
985            CType::Float => write!(f, "float"),
986            CType::CharPtr => write!(f, "char*"),
987            CType::VoidPtr => write!(f, "void*"),
988            CType::ConstCharPtr => write!(f, "const char*"),
989            CType::ConstDoublePtr => write!(f, "const double*"),
990        }
991    }
992}
993
994#[derive(Debug, Clone)]
995pub struct CFunctionSignature {
996    pub name: String,
997    pub signature: String,
998    pub description: String,
999}
1000
1001#[repr(C)]
1002#[derive(Debug, Clone)]
1003pub struct ArrayTransfer {
1004    pub data: *const f64,
1005    pub length: usize,
1006    pub capacity: usize,
1007}
1008
1009// R integration data structures
1010
1011#[derive(Debug, Clone, Serialize, Deserialize)]
1012pub struct RVector {
1013    pub data: Vec<f64>,
1014    pub length: usize,
1015    pub r_type: RType,
1016}
1017
1018#[derive(Debug, Clone, Serialize, Deserialize)]
1019pub struct RMatrix {
1020    pub data: Vec<f64>,
1021    pub nrow: usize,
1022    pub ncol: usize,
1023    pub byrow: bool,
1024    pub r_type: RType,
1025}
1026
1027#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1028pub enum RType {
1029    Numeric,
1030    Integer,
1031    Character,
1032    Logical,
1033    Factor,
1034}
1035
1036#[derive(Debug, Clone)]
1037pub struct RParameter {
1038    pub name: String,
1039    pub value: RValue,
1040}
1041
1042#[derive(Debug, Clone)]
1043pub enum RValue {
1044    String(String),
1045    Number(f64),
1046    Boolean(bool),
1047    Vector(String),    // Reference to vector variable name
1048    Matrix(String),    // Reference to matrix variable name
1049    DataFrame(String), // Reference to data frame variable name
1050}
1051
1052#[derive(Debug, Clone)]
1053pub enum ROutputType {
1054    Vector,
1055    Scalar,
1056    String,
1057}
1058
1059#[derive(Debug, Clone)]
1060pub enum ROutputValue {
1061    Vector(Vec<f64>),
1062    Scalar(f64),
1063    String(String),
1064}
1065
1066#[allow(non_snake_case)]
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070    use scirs2_core::ndarray::array;
1071
1072    #[test]
1073    fn test_python_array_conversion() {
1074        let arr = array![1.0, 2.0, 3.0, 4.0];
1075        let buffer = PythonInterop::array_to_python_buffer(&arr);
1076
1077        assert_eq!(buffer.data, vec![1.0, 2.0, 3.0, 4.0]);
1078        assert_eq!(buffer.shape, vec![4]);
1079        assert_eq!(buffer.dtype, "float64");
1080        assert_eq!(buffer.order, "C");
1081
1082        // Test conversion back
1083        let converted = PythonInterop::python_buffer_to_array(&buffer).unwrap();
1084        assert_eq!(converted, arr);
1085    }
1086
1087    #[test]
1088    fn test_python_array2_conversion() {
1089        let arr = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1090        let buffer = PythonInterop::array2_to_python_buffer(&arr);
1091
1092        assert_eq!(buffer.data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1093        assert_eq!(buffer.shape, vec![3, 2]);
1094
1095        // Test conversion back
1096        let converted = PythonInterop::python_buffer_to_array2(&buffer).unwrap();
1097        assert_eq!(converted, arr);
1098    }
1099
1100    #[test]
1101    fn test_numpy_code_generation() {
1102        let buffer = PyArrayBuffer {
1103            data: vec![1.0, 2.0, 3.0],
1104            shape: vec![3],
1105            dtype: "float64".to_string(),
1106            order: "C".to_string(),
1107        };
1108
1109        let code = PythonInterop::generate_numpy_import_code("my_array", &buffer);
1110
1111        assert!(code.contains("import numpy as np"));
1112        assert!(code.contains("my_array = np.array"));
1113        assert!(code.contains("[1.0, 2.0, 3.0]"));
1114        assert!(code.contains("dtype='float64'"));
1115        assert!(code.contains("reshape([3])"));
1116    }
1117
1118    #[test]
1119    fn test_python_function_call_template() {
1120        let params = vec![
1121            PythonParameter {
1122                name: "n_estimators".to_string(),
1123                value: PythonValue::Number(100.0),
1124            },
1125            PythonParameter {
1126                name: "random_state".to_string(),
1127                value: PythonValue::Number(42.0),
1128            },
1129            PythonParameter {
1130                name: "verbose".to_string(),
1131                value: PythonValue::Boolean(true),
1132            },
1133        ];
1134
1135        let call = PythonInterop::generate_function_call_template("RandomForestRegressor", &params);
1136
1137        assert!(call.contains("RandomForestRegressor("));
1138        assert!(call.contains("n_estimators=100"));
1139        assert!(call.contains("random_state=42"));
1140        assert!(call.contains("verbose=True"));
1141    }
1142
1143    #[test]
1144    fn test_ml_script_generation() {
1145        let training_data = PyArrayBuffer {
1146            data: vec![1.0, 2.0, 3.0, 4.0],
1147            shape: vec![2, 2],
1148            dtype: "float64".to_string(),
1149            order: "C".to_string(),
1150        };
1151
1152        let labels = PyArrayBuffer {
1153            data: vec![0.0, 1.0],
1154            shape: vec![2],
1155            dtype: "float64".to_string(),
1156            order: "C".to_string(),
1157        };
1158
1159        let mut hyperparams = HashMap::new();
1160        hyperparams.insert("n_estimators".to_string(), 50.0);
1161
1162        let script =
1163            PythonInterop::create_ml_script("random_forest", &training_data, &labels, &hyperparams)
1164                .unwrap();
1165
1166        assert!(script.contains("import numpy as np"));
1167        assert!(script.contains("from sklearn.ensemble import RandomForestRegressor"));
1168        assert!(script.contains("train_test_split"));
1169        assert!(script.contains("RandomForestRegressor(n_estimators=50)"));
1170        assert!(script.contains("model.fit(X_train, y_train)"));
1171        assert!(script.contains("model.score(X_test, y_test)"));
1172    }
1173
1174    #[test]
1175    fn test_wasm_signature_generation() {
1176        let params = vec![
1177            WasmParameter {
1178                name: "a".to_string(),
1179                param_type: WasmType::F64,
1180            },
1181            WasmParameter {
1182                name: "b".to_string(),
1183                param_type: WasmType::F64,
1184            },
1185        ];
1186
1187        let signature = WasmUtils::generate_wasm_signature("add", &params, WasmType::F64);
1188
1189        assert!(signature.contains("#[wasm_bindgen]"));
1190        assert!(signature.contains("pub fn add(a: f64, b: f64) -> f64"));
1191    }
1192
1193    #[test]
1194    fn test_wasm_memory_helpers() {
1195        let helpers = WasmUtils::generate_memory_helpers();
1196
1197        assert!(helpers.contains("pub fn alloc(size: usize)"));
1198        assert!(helpers.contains("pub fn dealloc(ptr: *mut u8, size: usize)"));
1199        assert!(helpers.contains("pub struct Float64Array"));
1200        assert!(helpers.contains("#[wasm_bindgen]"));
1201    }
1202
1203    #[test]
1204    fn test_wasm_ml_bindings() {
1205        let bindings = WasmUtils::generate_ml_bindings();
1206
1207        assert!(bindings.contains("pub fn dot_product"));
1208        assert!(bindings.contains("pub fn matrix_multiply"));
1209        assert!(bindings.contains("pub fn mean"));
1210        assert!(bindings.contains("pub fn variance"));
1211        assert!(bindings.contains("pub fn standard_deviation"));
1212    }
1213
1214    #[test]
1215    fn test_wasm_build_config() {
1216        let config = WasmUtils::create_wasm_build_config();
1217
1218        assert_eq!(config.target, "wasm32-unknown-unknown");
1219        assert!(config.features.contains(&"wasm-bindgen".to_string()));
1220        assert!(config.typescript_bindings);
1221        assert!(!config.debug);
1222    }
1223
1224    #[test]
1225    fn test_package_json_generation() {
1226        let json = WasmUtils::generate_package_json("my-ml-wasm", "0.1.0");
1227
1228        assert!(json.contains("\"name\": \"my-ml-wasm\""));
1229        assert!(json.contains("\"version\": \"0.1.0\""));
1230        assert!(json.contains("wasm-pack"));
1231        assert!(json.contains("\"machine-learning\""));
1232    }
1233
1234    #[test]
1235    fn test_c_signature_creation() {
1236        let params = vec![
1237            CParameter {
1238                name: "data".to_string(),
1239                param_type: CType::ConstDoublePtr,
1240            },
1241            CParameter {
1242                name: "length".to_string(),
1243                param_type: CType::Int,
1244            },
1245        ];
1246
1247        let signature = FFIUtils::create_c_signature("compute_mean", &params, CType::Double);
1248
1249        assert!(signature.contains("extern \"C\" fn compute_mean"));
1250        assert!(signature.contains("const double* data"));
1251        assert!(signature.contains("int length"));
1252        assert!(signature.contains("-> double"));
1253    }
1254
1255    #[test]
1256    fn test_c_header_generation() {
1257        let functions = vec![
1258            CFunctionSignature {
1259                name: "add".to_string(),
1260                signature: "double add(double a, double b)".to_string(),
1261                description: "Add two numbers".to_string(),
1262            },
1263            CFunctionSignature {
1264                name: "multiply".to_string(),
1265                signature: "double multiply(double a, double b)".to_string(),
1266                description: "Multiply two numbers".to_string(),
1267            },
1268        ];
1269
1270        let header = FFIUtils::generate_c_header("mylib", &functions);
1271
1272        assert!(header.contains("#ifndef MYLIB_H"));
1273        assert!(header.contains("#define MYLIB_H"));
1274        assert!(header.contains("extern \"C\" {"));
1275        assert!(header.contains("double add(double a, double b);"));
1276        assert!(header.contains("double multiply(double a, double b);"));
1277        assert!(header.contains("#endif // MYLIB_H"));
1278    }
1279
1280    #[test]
1281    fn test_rust_to_c_string() {
1282        let rust_str = "Hello, World!";
1283        let c_ptr = FFIUtils::rust_string_to_c(rust_str).unwrap();
1284
1285        // Convert back to verify
1286        let converted = unsafe { FFIUtils::c_string_to_rust(c_ptr).unwrap() };
1287        assert_eq!(converted, rust_str);
1288
1289        // Clean up
1290        unsafe {
1291            let _ = CString::from_raw(c_ptr);
1292        }
1293    }
1294
1295    #[test]
1296    fn test_array_transfer_creation() {
1297        let data = vec![1.0, 2.0, 3.0, 4.0];
1298        let transfer = FFIUtils::create_array_transfer(&data);
1299
1300        assert_eq!(transfer.length, 4);
1301        assert_eq!(transfer.capacity, 4);
1302        assert!(!transfer.data.is_null());
1303    }
1304
1305    #[test]
1306    fn test_ffi_examples_generation() {
1307        let examples = FFIUtils::generate_ffi_examples();
1308
1309        assert!(examples.contains("linear_regression_fit"));
1310        assert!(examples.contains("array_mean"));
1311        assert!(examples.contains("#[no_mangle]"));
1312        assert!(examples.contains("extern \"C\""));
1313        assert!(examples.contains("ArrayTransfer"));
1314    }
1315
1316    #[test]
1317    fn test_python_value_variants() {
1318        let string_val = PythonValue::String("test".to_string());
1319        let number_val = PythonValue::Number(42.0);
1320        let bool_val = PythonValue::Boolean(true);
1321        let array_val = PythonValue::Array("my_array".to_string());
1322
1323        // Test that all variants can be created
1324        match string_val {
1325            PythonValue::String(_) => {}
1326            _ => panic!(),
1327        }
1328        match number_val {
1329            PythonValue::Number(_) => {}
1330            _ => panic!(),
1331        }
1332        match bool_val {
1333            PythonValue::Boolean(_) => {}
1334            _ => panic!(),
1335        }
1336        match array_val {
1337            PythonValue::Array(_) => {}
1338            _ => panic!(),
1339        }
1340    }
1341
1342    #[test]
1343    fn test_wasm_type_display() {
1344        assert_eq!(WasmType::F64.to_string(), "f64");
1345        assert_eq!(WasmType::F32.to_string(), "f32");
1346        assert_eq!(WasmType::I32.to_string(), "i32");
1347        assert_eq!(WasmType::U32.to_string(), "u32");
1348        assert_eq!(WasmType::Bool.to_string(), "bool");
1349        assert_eq!(WasmType::String.to_string(), "String");
1350    }
1351
1352    #[test]
1353    fn test_c_type_display() {
1354        assert_eq!(CType::Int.to_string(), "int");
1355        assert_eq!(CType::Double.to_string(), "double");
1356        assert_eq!(CType::Float.to_string(), "float");
1357        assert_eq!(CType::CharPtr.to_string(), "char*");
1358        assert_eq!(CType::VoidPtr.to_string(), "void*");
1359        assert_eq!(CType::ConstCharPtr.to_string(), "const char*");
1360        assert_eq!(CType::ConstDoublePtr.to_string(), "const double*");
1361    }
1362
1363    // R integration tests
1364
1365    #[test]
1366    fn test_r_array_conversion() {
1367        let arr = array![1.0, 2.0, 3.0, 4.0];
1368        let r_vector = RInterop::array_to_r_vector(&arr);
1369
1370        assert_eq!(r_vector.data, vec![1.0, 2.0, 3.0, 4.0]);
1371        assert_eq!(r_vector.length, 4);
1372        assert_eq!(r_vector.r_type, RType::Numeric);
1373
1374        // Test conversion back
1375        let converted = RInterop::r_vector_to_array(&r_vector).unwrap();
1376        assert_eq!(converted, arr);
1377    }
1378
1379    #[test]
1380    fn test_r_matrix_conversion() {
1381        let arr = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1382        let r_matrix = RInterop::array2_to_r_matrix(&arr);
1383
1384        assert_eq!(r_matrix.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]); // Column-major
1385        assert_eq!(r_matrix.nrow, 3);
1386        assert_eq!(r_matrix.ncol, 2);
1387        assert!(!r_matrix.byrow);
1388        assert_eq!(r_matrix.r_type, RType::Numeric);
1389
1390        // Test conversion back
1391        let converted = RInterop::r_matrix_to_array2(&r_matrix).unwrap();
1392        assert_eq!(converted, arr);
1393    }
1394
1395    #[test]
1396    fn test_r_vector_code_generation() {
1397        let r_vector = RVector {
1398            data: vec![1.0, 2.0, 3.0],
1399            length: 3,
1400            r_type: RType::Numeric,
1401        };
1402
1403        let code = RInterop::generate_r_vector_code("my_vector", &r_vector);
1404
1405        assert_eq!(code, "my_vector <- c(1, 2, 3)");
1406    }
1407
1408    #[test]
1409    fn test_r_matrix_code_generation() {
1410        let r_matrix = RMatrix {
1411            data: vec![1.0, 2.0, 3.0, 4.0],
1412            nrow: 2,
1413            ncol: 2,
1414            byrow: false,
1415            r_type: RType::Numeric,
1416        };
1417
1418        let code = RInterop::generate_r_matrix_code("my_matrix", &r_matrix);
1419
1420        assert_eq!(
1421            code,
1422            "my_matrix <- matrix(c(1, 2, 3, 4), nrow = 2, ncol = 2, byrow = FALSE)"
1423        );
1424    }
1425
1426    #[test]
1427    fn test_r_dataframe_code_generation() {
1428        let mut columns = HashMap::new();
1429
1430        columns.insert(
1431            "x".to_string(),
1432            RVector {
1433                data: vec![1.0, 2.0, 3.0],
1434                length: 3,
1435                r_type: RType::Numeric,
1436            },
1437        );
1438
1439        columns.insert(
1440            "y".to_string(),
1441            RVector {
1442                data: vec![4.0, 5.0, 6.0],
1443                length: 3,
1444                r_type: RType::Numeric,
1445            },
1446        );
1447
1448        let code = RInterop::generate_r_dataframe_code("my_df", &columns).unwrap();
1449
1450        // The order might vary due to HashMap, so check for both possibilities
1451        let expected1 = "my_df <- data.frame(x = c(1, 2, 3), y = c(4, 5, 6))";
1452        let expected2 = "my_df <- data.frame(y = c(4, 5, 6), x = c(1, 2, 3))";
1453        assert!(code == expected1 || code == expected2);
1454    }
1455
1456    #[test]
1457    fn test_r_package_imports() {
1458        let packages = &["randomForest", "e1071", "ggplot2"];
1459        let imports = RInterop::generate_r_package_imports(packages);
1460
1461        assert!(imports.contains("library(randomForest)"));
1462        assert!(imports.contains("library(e1071)"));
1463        assert!(imports.contains("library(ggplot2)"));
1464    }
1465
1466    #[test]
1467    fn test_r_function_call_generation() {
1468        let params = vec![
1469            RParameter {
1470                name: "ntree".to_string(),
1471                value: RValue::Number(500.0),
1472            },
1473            RParameter {
1474                name: "mtry".to_string(),
1475                value: RValue::Number(3.0),
1476            },
1477            RParameter {
1478                name: "importance".to_string(),
1479                value: RValue::Boolean(true),
1480            },
1481            RParameter {
1482                name: "x".to_string(),
1483                value: RValue::Matrix("X_train".to_string()),
1484            },
1485        ];
1486
1487        let call = RInterop::generate_r_function_call("randomForest", &params);
1488
1489        assert!(call.contains("randomForest("));
1490        assert!(call.contains("ntree = 500"));
1491        assert!(call.contains("mtry = 3"));
1492        assert!(call.contains("importance = TRUE"));
1493        assert!(call.contains("x = X_train"));
1494    }
1495
1496    #[test]
1497    fn test_r_ml_script_generation() {
1498        let training_data = RMatrix {
1499            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1500            nrow: 3,
1501            ncol: 2,
1502            byrow: false,
1503            r_type: RType::Numeric,
1504        };
1505
1506        let response_var = RVector {
1507            data: vec![1.0, 0.0, 1.0],
1508            length: 3,
1509            r_type: RType::Numeric,
1510        };
1511
1512        let mut hyperparams = HashMap::new();
1513        hyperparams.insert("ntree".to_string(), 100.0);
1514        hyperparams.insert("mtry".to_string(), 1.0);
1515
1516        let script = RInterop::create_r_ml_script(
1517            "random_forest",
1518            &training_data,
1519            &response_var,
1520            &hyperparams,
1521        )
1522        .unwrap();
1523
1524        assert!(script.contains("library(randomForest)"));
1525        assert!(script.contains("X <- matrix"));
1526        assert!(script.contains("y <- c"));
1527        assert!(script.contains("randomForest(x = X_train, y = y_train, ntree = 100, mtry = 1)"));
1528        assert!(script.contains("train_indices"));
1529        assert!(script.contains("predictions <- predict"));
1530        assert!(script.contains("summary(model)"));
1531    }
1532
1533    #[test]
1534    fn test_r_linear_regression_script() {
1535        let training_data = RMatrix {
1536            data: vec![1.0, 2.0, 3.0, 4.0],
1537            nrow: 2,
1538            ncol: 2,
1539            byrow: false,
1540            r_type: RType::Numeric,
1541        };
1542
1543        let response_var = RVector {
1544            data: vec![1.0, 2.0],
1545            length: 2,
1546            r_type: RType::Numeric,
1547        };
1548
1549        let hyperparams = HashMap::new();
1550
1551        let script = RInterop::create_r_ml_script(
1552            "linear_regression",
1553            &training_data,
1554            &response_var,
1555            &hyperparams,
1556        )
1557        .unwrap();
1558
1559        assert!(script.contains("# Linear regression using base R"));
1560        assert!(script.contains("model <- lm(y_train ~ X_train)"));
1561        assert!(!script.contains("library(")); // Base R, no packages needed
1562    }
1563
1564    #[test]
1565    fn test_r_statistical_analysis() {
1566        let data = RMatrix {
1567            data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1568            nrow: 3,
1569            ncol: 2,
1570            byrow: false,
1571            r_type: RType::Numeric,
1572        };
1573
1574        // Test descriptive statistics
1575        let script = RInterop::create_r_statistical_analysis(&data, "descriptive").unwrap();
1576        assert!(script.contains("summary(data)"));
1577        assert!(script.contains("apply(data, 2, sd)"));
1578        assert!(script.contains("cor(data)"));
1579
1580        // Test PCA
1581        let script = RInterop::create_r_statistical_analysis(&data, "pca").unwrap();
1582        assert!(script.contains("prcomp(data, center = TRUE, scale. = TRUE)"));
1583        assert!(script.contains("plot(pca_result$x[,1:2])"));
1584
1585        // Test clustering
1586        let script = RInterop::create_r_statistical_analysis(&data, "clustering").unwrap();
1587        assert!(script.contains("kmeans(data, centers = 3)"));
1588        assert!(script.contains("plot(data, col = kmeans_result$cluster)"));
1589    }
1590
1591    #[test]
1592    fn test_r_output_conversion() {
1593        // Test vector parsing
1594        let vector_output = "[1] 1.0 2.5 3.8";
1595        let result = RInterop::convert_r_output(vector_output, ROutputType::Vector).unwrap();
1596        match result {
1597            ROutputValue::Vector(v) => assert_eq!(v, vec![1.0, 2.5, 3.8]),
1598            _ => panic!("Expected vector output"),
1599        }
1600
1601        // Test scalar parsing
1602        let scalar_output = "[1] 42.5";
1603        let result = RInterop::convert_r_output(scalar_output, ROutputType::Scalar).unwrap();
1604        match result {
1605            ROutputValue::Scalar(s) => assert_eq!(s, 42.5),
1606            _ => panic!("Expected scalar output"),
1607        }
1608
1609        // Test string parsing
1610        let string_output = "This is a test string";
1611        let result = RInterop::convert_r_output(string_output, ROutputType::String).unwrap();
1612        match result {
1613            ROutputValue::String(s) => assert_eq!(s, "This is a test string"),
1614            _ => panic!("Expected string output"),
1615        }
1616    }
1617
1618    #[test]
1619    fn test_r_matrix_row_major_conversion() {
1620        let r_matrix = RMatrix {
1621            data: vec![1.0, 2.0, 3.0, 4.0],
1622            nrow: 2,
1623            ncol: 2,
1624            byrow: true, // Row-major
1625            r_type: RType::Numeric,
1626        };
1627
1628        let converted = RInterop::r_matrix_to_array2(&r_matrix).unwrap();
1629
1630        // When byrow=true, data is already in row-major format [1,2,3,4]
1631        // This should give us [[1.0, 2.0], [3.0, 4.0]]
1632        let expected = array![[1.0, 2.0], [3.0, 4.0]];
1633        assert_eq!(converted, expected);
1634    }
1635
1636    #[test]
1637    fn test_r_type_variants() {
1638        let numeric = RType::Numeric;
1639        let integer = RType::Integer;
1640        let character = RType::Character;
1641        let logical = RType::Logical;
1642        let factor = RType::Factor;
1643
1644        // Test that all variants can be created and compared
1645        assert_eq!(numeric, RType::Numeric);
1646        assert_eq!(integer, RType::Integer);
1647        assert_eq!(character, RType::Character);
1648        assert_eq!(logical, RType::Logical);
1649        assert_eq!(factor, RType::Factor);
1650    }
1651
1652    #[test]
1653    fn test_r_value_variants() {
1654        let string_val = RValue::String("test".to_string());
1655        let number_val = RValue::Number(42.0);
1656        let bool_val = RValue::Boolean(true);
1657        let vector_val = RValue::Vector("my_vector".to_string());
1658        let matrix_val = RValue::Matrix("my_matrix".to_string());
1659        let df_val = RValue::DataFrame("my_df".to_string());
1660
1661        // Test that all variants can be created
1662        match string_val {
1663            RValue::String(_) => {}
1664            _ => panic!(),
1665        }
1666        match number_val {
1667            RValue::Number(_) => {}
1668            _ => panic!(),
1669        }
1670        match bool_val {
1671            RValue::Boolean(_) => {}
1672            _ => panic!(),
1673        }
1674        match vector_val {
1675            RValue::Vector(_) => {}
1676            _ => panic!(),
1677        }
1678        match matrix_val {
1679            RValue::Matrix(_) => {}
1680            _ => panic!(),
1681        }
1682        match df_val {
1683            RValue::DataFrame(_) => {}
1684            _ => panic!(),
1685        }
1686    }
1687
1688    #[test]
1689    fn test_r_unsupported_model_type() {
1690        let training_data = RMatrix {
1691            data: vec![1.0, 2.0, 3.0, 4.0],
1692            nrow: 2,
1693            ncol: 2,
1694            byrow: false,
1695            r_type: RType::Numeric,
1696        };
1697
1698        let response_var = RVector {
1699            data: vec![1.0, 2.0],
1700            length: 2,
1701            r_type: RType::Numeric,
1702        };
1703
1704        let hyperparams = HashMap::new();
1705
1706        let result = RInterop::create_r_ml_script(
1707            "unsupported_model",
1708            &training_data,
1709            &response_var,
1710            &hyperparams,
1711        );
1712
1713        assert!(result.is_err());
1714        assert!(result
1715            .unwrap_err()
1716            .to_string()
1717            .contains("Unsupported R model type"));
1718    }
1719
1720    #[test]
1721    fn test_r_unsupported_analysis_type() {
1722        let data = RMatrix {
1723            data: vec![1.0, 2.0, 3.0, 4.0],
1724            nrow: 2,
1725            ncol: 2,
1726            byrow: false,
1727            r_type: RType::Numeric,
1728        };
1729
1730        let result = RInterop::create_r_statistical_analysis(&data, "unsupported_analysis");
1731        assert!(result.is_err());
1732        assert!(result
1733            .unwrap_err()
1734            .to_string()
1735            .contains("Unsupported analysis type"));
1736    }
1737}