1use crate::{UtilsError, UtilsResult};
8use scirs2_core::ndarray::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::ffi::{CStr, CString};
12use std::fmt;
13use std::os::raw::c_char;
14
15pub struct PythonInterop;
17
18impl PythonInterop {
19 pub fn array_to_python_buffer(array: &Array1<f64>) -> PyArrayBuffer {
21 PyArrayBuffer {
22 data: array.as_slice().unwrap().to_vec(),
23 shape: vec![array.len()],
24 dtype: "float64".to_string(),
25 order: "C".to_string(),
26 }
27 }
28
29 pub fn array2_to_python_buffer(array: &Array2<f64>) -> PyArrayBuffer {
31 let (rows, cols) = array.dim();
32 PyArrayBuffer {
33 data: array.as_slice().unwrap().to_vec(),
34 shape: vec![rows, cols],
35 dtype: "float64".to_string(),
36 order: "C".to_string(),
37 }
38 }
39
40 pub fn python_buffer_to_array(buffer: &PyArrayBuffer) -> UtilsResult<Array1<f64>> {
42 if buffer.shape.len() != 1 {
43 return Err(UtilsError::InvalidParameter(
44 "Expected 1D array".to_string(),
45 ));
46 }
47
48 if buffer.dtype != "float64" {
49 return Err(UtilsError::InvalidParameter(format!(
50 "Unsupported dtype: {}",
51 buffer.dtype
52 )));
53 }
54
55 Array1::from_vec(buffer.data.clone())
56 .into_shape_with_order(buffer.shape[0])
57 .map_err(|e| UtilsError::InvalidParameter(format!("Shape error: {e}")))
58 }
59
60 pub fn python_buffer_to_array2(buffer: &PyArrayBuffer) -> UtilsResult<Array2<f64>> {
62 if buffer.shape.len() != 2 {
63 return Err(UtilsError::InvalidParameter(
64 "Expected 2D array".to_string(),
65 ));
66 }
67
68 if buffer.dtype != "float64" {
69 return Err(UtilsError::InvalidParameter(format!(
70 "Unsupported dtype: {}",
71 buffer.dtype
72 )));
73 }
74
75 Array2::from_shape_vec((buffer.shape[0], buffer.shape[1]), buffer.data.clone())
76 .map_err(|e| UtilsError::InvalidParameter(format!("Shape error: {e}")))
77 }
78
79 pub fn generate_numpy_import_code(array_name: &str, buffer: &PyArrayBuffer) -> String {
81 format!(
82 r#"
83import numpy as np
84
85# Data generated by sklears-utils
86{} = np.array({:?}, dtype='{}').reshape({:?})
87"#,
88 array_name, buffer.data, buffer.dtype, buffer.shape
89 )
90 }
91
92 pub fn generate_function_call_template(
94 function_name: &str,
95 parameters: &[PythonParameter],
96 ) -> String {
97 let param_strings: Vec<String> = parameters
98 .iter()
99 .map(|p| match &p.value {
100 PythonValue::String(s) => format!("{}='{}'", p.name, s),
101 PythonValue::Number(n) => format!("{}={}", p.name, n),
102 PythonValue::Boolean(b) => {
103 format!("{}={}", p.name, if *b { "True" } else { "False" })
104 }
105 PythonValue::Array(name) => format!("{}={}", p.name, name),
106 })
107 .collect();
108
109 format!("{}({})", function_name, param_strings.join(", "))
110 }
111
112 pub fn create_ml_script(
114 model_type: &str,
115 training_data: &PyArrayBuffer,
116 labels: &PyArrayBuffer,
117 hyperparameters: &HashMap<String, f64>,
118 ) -> UtilsResult<String> {
119 let mut script = String::new();
120
121 script.push_str("import numpy as np\n");
123 script.push_str("from sklearn.model_selection import train_test_split\n");
124
125 match model_type {
126 "linear_regression" => {
127 script.push_str("from sklearn.linear_model import LinearRegression\n")
128 }
129 "random_forest" => {
130 script.push_str("from sklearn.ensemble import RandomForestRegressor\n")
131 }
132 "svm" => script.push_str("from sklearn.svm import SVC\n"),
133 _ => {
134 return Err(UtilsError::InvalidParameter(format!(
135 "Unsupported model type: {model_type}"
136 )))
137 }
138 }
139
140 script.push_str("\n# Data preparation\n");
141 script.push_str(&Self::generate_numpy_import_code("X", training_data));
142 script.push_str(&Self::generate_numpy_import_code("y", labels));
143
144 script.push_str("\n# Train-test split\n");
145 script.push_str("X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n");
146
147 script.push_str("\n# Model creation and training\n");
148 let model_creation = match model_type {
149 "linear_regression" => "model = LinearRegression()".to_string(),
150 "random_forest" => {
151 let n_estimators = hyperparameters.get("n_estimators").unwrap_or(&100.0);
152 format!(
153 "model = RandomForestRegressor(n_estimators={})",
154 *n_estimators as i32
155 )
156 }
157 "svm" => {
158 let c = hyperparameters.get("C").unwrap_or(&1.0);
159 format!("model = SVC(C={c})")
160 }
161 _ => {
162 return Err(UtilsError::InvalidParameter(format!(
163 "Unsupported model type: {model_type}"
164 )))
165 }
166 };
167
168 script.push_str(&model_creation);
169 script.push_str("\nmodel.fit(X_train, y_train)\n");
170
171 script.push_str("\n# Evaluation\n");
172 script.push_str("score = model.score(X_test, y_test)\n");
173 script.push_str("print(f'Model score: {score:.4f}')\n");
174
175 Ok(script)
176 }
177}
178
179pub struct WasmUtils;
181
182impl WasmUtils {
183 pub fn generate_wasm_signature(
185 function_name: &str,
186 parameters: &[WasmParameter],
187 return_type: WasmType,
188 ) -> String {
189 let param_strings: Vec<String> = parameters
190 .iter()
191 .map(|p| format!("{}: {}", p.name, p.param_type))
192 .collect();
193
194 format!(
195 "#[wasm_bindgen]\npub fn {}({}) -> {} {{",
196 function_name,
197 param_strings.join(", "),
198 return_type
199 )
200 }
201
202 pub fn generate_memory_helpers() -> String {
204 r#"
205use wasm_bindgen::prelude::*;
206
207// Memory management helpers for WASM
208#[wasm_bindgen]
209pub fn alloc(size: usize) -> *mut u8 {
210 let mut buf = Vec::with_capacity(size);
211 let ptr = buf.as_mut_ptr();
212 std::mem::forget(buf);
213 ptr
214}
215
216#[wasm_bindgen]
217pub fn dealloc(ptr: *mut u8, size: usize) {
218 unsafe {
219 let _ = Vec::from_raw_parts(ptr, size, size);
220 }
221}
222
223// Array helpers
224#[wasm_bindgen]
225pub struct Float64Array {
226 data: Vec<f64>,
227}
228
229#[wasm_bindgen]
230impl Float64Array {
231 #[wasm_bindgen(constructor)]
232 pub fn new(size: usize) -> Float64Array {
233 Float64Array {
234 data: vec![0.0; size],
235 }
236 }
237
238 #[wasm_bindgen(getter)]
239 pub fn length(&self) -> usize {
240 self.data.len()
241 }
242
243 #[wasm_bindgen]
244 pub fn get(&self, index: usize) -> f64 {
245 self.data.get(index).copied().unwrap_or(0.0)
246 }
247
248 #[wasm_bindgen]
249 pub fn set(&mut self, index: usize, value: f64) {
250 if index < self.data.len() {
251 self.data[index] = value;
252 }
253 }
254
255 #[wasm_bindgen]
256 pub fn as_ptr(&self) -> *const f64 {
257 self.data.as_ptr()
258 }
259}
260"#
261 .to_string()
262 }
263
264 pub fn generate_ml_bindings() -> String {
266 r#"
267use wasm_bindgen::prelude::*;
268
269// Linear algebra operations
270#[wasm_bindgen]
271pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
272 if a.len() != b.len() {
273 return 0.0;
274 }
275 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
276}
277
278#[wasm_bindgen]
279pub fn matrix_multiply(
280 a: &[f64], a_rows: usize, a_cols: usize,
281 b: &[f64], b_rows: usize, b_cols: usize,
282 result: &mut [f64]
283) -> bool {
284 if a_cols != b_rows || result.len() != a_rows * b_cols {
285 return false;
286 }
287
288 for i in 0..a_rows {
289 for j in 0..b_cols {
290 let mut sum = 0.0;
291 for k in 0..a_cols {
292 sum += a[i * a_cols + k] * b[k * b_cols + j];
293 }
294 result[i * b_cols + j] = sum;
295 }
296 }
297 true
298}
299
300// Statistical functions
301#[wasm_bindgen]
302pub fn mean(data: &[f64]) -> f64 {
303 if data.is_empty() {
304 return 0.0;
305 }
306 data.iter().sum::<f64>() / data.len() as f64
307}
308
309#[wasm_bindgen]
310pub fn variance(data: &[f64]) -> f64 {
311 if data.len() < 2 {
312 return 0.0;
313 }
314 let mean_val = mean(data);
315 let variance = data.iter()
316 .map(|x| (x - mean_val).powi(2))
317 .sum::<f64>() / (data.len() - 1) as f64;
318 variance
319}
320
321#[wasm_bindgen]
322pub fn standard_deviation(data: &[f64]) -> f64 {
323 variance(data).sqrt()
324}
325"#
326 .to_string()
327 }
328
329 pub fn create_wasm_build_config() -> WasmBuildConfig {
331 WasmBuildConfig {
332 target: "wasm32-unknown-unknown".to_string(),
333 features: vec![
334 "wasm-bindgen".to_string(),
335 "console_error_panic_hook".to_string(),
336 ],
337 optimization: WasmOptimization::Size,
338 debug: false,
339 typescript_bindings: true,
340 }
341 }
342
343 pub fn generate_package_json(project_name: &str, version: &str) -> String {
345 format!(
346 r#"{{
347 "name": "{project_name}",
348 "version": "{version}",
349 "description": "WASM bindings for sklears ML utilities",
350 "main": "index.js",
351 "types": "index.d.ts",
352 "scripts": {{
353 "build": "wasm-pack build --target web --out-dir pkg",
354 "build:nodejs": "wasm-pack build --target nodejs --out-dir pkg-node",
355 "test": "wasm-pack test --headless --chrome"
356 }},
357 "devDependencies": {{
358 "wasm-pack": "^0.12.0"
359 }},
360 "files": [
361 "pkg/"
362 ],
363 "keywords": [
364 "wasm",
365 "machine-learning",
366 "sklears",
367 "linear-algebra"
368 ]
369}}"#
370 )
371 }
372}
373
374pub struct RInterop;
376
377impl RInterop {
378 pub fn array_to_r_vector(array: &Array1<f64>) -> RVector {
380 RVector {
381 data: array.as_slice().unwrap().to_vec(),
382 length: array.len(),
383 r_type: RType::Numeric,
384 }
385 }
386
387 pub fn array2_to_r_matrix(array: &Array2<f64>) -> RMatrix {
389 let (rows, cols) = array.dim();
390
391 let mut col_major_data = vec![0.0; rows * cols];
393 for i in 0..rows {
394 for j in 0..cols {
395 col_major_data[j * rows + i] = array[[i, j]];
396 }
397 }
398
399 RMatrix {
400 data: col_major_data,
401 nrow: rows,
402 ncol: cols,
403 byrow: false, r_type: RType::Numeric,
405 }
406 }
407
408 pub fn r_vector_to_array(vector: &RVector) -> UtilsResult<Array1<f64>> {
410 if vector.r_type != RType::Numeric {
411 return Err(UtilsError::InvalidParameter(format!(
412 "Expected numeric vector, got {:?}",
413 vector.r_type
414 )));
415 }
416
417 Array1::from_vec(vector.data.clone())
418 .into_shape_with_order(vector.length)
419 .map_err(|e| UtilsError::InvalidParameter(format!("Shape error: {e}")))
420 }
421
422 pub fn r_matrix_to_array2(matrix: &RMatrix) -> UtilsResult<Array2<f64>> {
424 if matrix.r_type != RType::Numeric {
425 return Err(UtilsError::InvalidParameter(format!(
426 "Expected numeric matrix, got {:?}",
427 matrix.r_type
428 )));
429 }
430
431 if matrix.byrow {
432 Array2::from_shape_vec((matrix.nrow, matrix.ncol), matrix.data.clone())
434 .map_err(|e| UtilsError::InvalidParameter(format!("Shape error: {e}")))
435 } else {
436 let mut row_major_data = vec![0.0; matrix.data.len()];
438 for i in 0..matrix.nrow {
439 for j in 0..matrix.ncol {
440 row_major_data[i * matrix.ncol + j] = matrix.data[j * matrix.nrow + i];
441 }
442 }
443 Array2::from_shape_vec((matrix.nrow, matrix.ncol), row_major_data)
444 .map_err(|e| UtilsError::InvalidParameter(format!("Shape error: {e}")))
445 }
446 }
447
448 pub fn generate_r_vector_code(vector_name: &str, vector: &RVector) -> String {
450 format!(
451 "{} <- c({})",
452 vector_name,
453 vector
454 .data
455 .iter()
456 .map(|x| x.to_string())
457 .collect::<Vec<_>>()
458 .join(", ")
459 )
460 }
461
462 pub fn generate_r_matrix_code(matrix_name: &str, matrix: &RMatrix) -> String {
464 let data_str = matrix
465 .data
466 .iter()
467 .map(|x| x.to_string())
468 .collect::<Vec<_>>()
469 .join(", ");
470
471 format!(
472 "{} <- matrix(c({}), nrow = {}, ncol = {}, byrow = {})",
473 matrix_name,
474 data_str,
475 matrix.nrow,
476 matrix.ncol,
477 if matrix.byrow { "TRUE" } else { "FALSE" }
478 )
479 }
480
481 pub fn generate_r_dataframe_code(
483 df_name: &str,
484 columns: &HashMap<String, RVector>,
485 ) -> UtilsResult<String> {
486 let mut column_defs = Vec::new();
487
488 let first_length = columns.values().next().map(|v| v.length).unwrap_or(0);
490
491 for (name, vector) in columns {
492 if vector.length != first_length {
493 return Err(UtilsError::InvalidParameter(
494 "All columns must have the same length".to_string(),
495 ));
496 }
497
498 let data_str = vector
499 .data
500 .iter()
501 .map(|x| x.to_string())
502 .collect::<Vec<_>>()
503 .join(", ");
504
505 column_defs.push(format!("{name} = c({data_str})"));
506 }
507
508 Ok(format!(
509 "{} <- data.frame({})",
510 df_name,
511 column_defs.join(", ")
512 ))
513 }
514
515 pub fn generate_r_package_imports(packages: &[&str]) -> String {
517 let mut imports = String::new();
518
519 for package in packages {
520 imports.push_str(&format!("library({package})\n"));
521 }
522
523 imports
524 }
525
526 pub fn generate_r_function_call(function_name: &str, parameters: &[RParameter]) -> String {
528 let param_strings: Vec<String> = parameters
529 .iter()
530 .map(|p| match &p.value {
531 RValue::String(s) => format!("{} = \"{}\"", p.name, s),
532 RValue::Number(n) => format!("{} = {}", p.name, n),
533 RValue::Boolean(b) => format!("{} = {}", p.name, if *b { "TRUE" } else { "FALSE" }),
534 RValue::Vector(name) => format!("{} = {}", p.name, name),
535 RValue::Matrix(name) => format!("{} = {}", p.name, name),
536 RValue::DataFrame(name) => format!("{} = {}", p.name, name),
537 })
538 .collect();
539
540 format!("{}({})", function_name, param_strings.join(", "))
541 }
542
543 pub fn create_r_ml_script(
545 model_type: &str,
546 training_data: &RMatrix,
547 response_var: &RVector,
548 hyperparameters: &HashMap<String, f64>,
549 ) -> UtilsResult<String> {
550 let mut script = String::new();
551
552 match model_type {
554 "linear_regression" => {
555 script.push_str("# Linear regression using base R\n");
556 }
557 "random_forest" => {
558 script.push_str(&Self::generate_r_package_imports(&["randomForest"]));
559 }
560 "svm" => {
561 script.push_str(&Self::generate_r_package_imports(&["e1071"]));
562 }
563 "glm" => {
564 script.push_str("# Generalized linear model using base R\n");
565 }
566 "tree" => {
567 script.push_str(&Self::generate_r_package_imports(&["tree"]));
568 }
569 _ => {
570 return Err(UtilsError::InvalidParameter(format!(
571 "Unsupported R model type: {model_type}"
572 )))
573 }
574 }
575
576 script.push_str("\n# Data preparation\n");
577 script.push_str(&Self::generate_r_matrix_code("X", training_data));
578 script.push('\n');
579 script.push_str(&Self::generate_r_vector_code("y", response_var));
580 script.push('\n');
581
582 if matches!(model_type, "glm" | "tree") {
584 script.push_str("\n# Create data frame\n");
585 script.push_str("df <- data.frame(y = y, X)\n");
586 script.push_str("colnames(df) <- c('response', paste0('X', 1:ncol(X)))\n");
587 }
588
589 script.push_str("\n# Train-test split\n");
590 script.push_str("set.seed(42)\n");
591 script.push_str("train_indices <- sample(1:nrow(X), size = 0.8 * nrow(X))\n");
592 script.push_str("X_train <- X[train_indices, ]\n");
593 script.push_str("X_test <- X[-train_indices, ]\n");
594 script.push_str("y_train <- y[train_indices]\n");
595 script.push_str("y_test <- y[-train_indices]\n");
596
597 script.push_str("\n# Model creation and training\n");
598 let model_creation = match model_type {
599 "linear_regression" => "model <- lm(y_train ~ X_train)".to_string(),
600 "random_forest" => {
601 let ntree = hyperparameters.get("ntree").unwrap_or(&500.0);
602 let mtry = hyperparameters.get("mtry").unwrap_or(&3.0);
603 format!(
604 "model <- randomForest(x = X_train, y = y_train, ntree = {}, mtry = {})",
605 *ntree as i32, *mtry as i32
606 )
607 }
608 "svm" => {
609 let cost = hyperparameters.get("cost").unwrap_or(&1.0);
610 let gamma = hyperparameters.get("gamma").unwrap_or(&0.1);
611 format!("model <- svm(x = X_train, y = y_train, cost = {cost}, gamma = {gamma})")
612 }
613 "glm" => {
614 let family = "gaussian"; format!("df_train <- df[train_indices, ]\nmodel <- glm(response ~ ., data = df_train, family = {family})")
616 }
617 "tree" => {
618 "df_train <- df[train_indices, ]\nmodel <- tree(response ~ ., data = df_train)"
619 .to_string()
620 }
621 _ => {
622 return Err(UtilsError::InvalidParameter(format!(
623 "Unsupported R model type: {model_type}"
624 )))
625 }
626 };
627
628 script.push_str(&model_creation);
629 script.push('\n');
630
631 script.push_str("\n# Prediction and evaluation\n");
632 let prediction_code = match model_type {
633 "linear_regression" => {
634 "predictions <- predict(model, data.frame(X_test))\nrmse <- sqrt(mean((predictions - y_test)^2))\ncat('RMSE:', rmse, '\\n')"
635 },
636 "random_forest" => {
637 "predictions <- predict(model, X_test)\nrmse <- sqrt(mean((predictions - y_test)^2))\ncat('RMSE:', rmse, '\\n')"
638 },
639 "svm" => {
640 "predictions <- predict(model, X_test)\nrmse <- sqrt(mean((predictions - y_test)^2))\ncat('RMSE:', rmse, '\\n')"
641 },
642 "glm" | "tree" => {
643 "df_test <- df[-train_indices, ]\npredictions <- predict(model, df_test)\nrmse <- sqrt(mean((predictions - y_test)^2))\ncat('RMSE:', rmse, '\\n')"
644 },
645 _ => {
646 return Err(UtilsError::InvalidParameter(format!(
647 "Unsupported R model type for prediction: {model_type}"
648 )))
649 }
650 };
651
652 script.push_str(prediction_code);
653 script.push('\n');
654
655 script.push_str("\n# Model summary\n");
656 script.push_str("print(summary(model))\n");
657
658 Ok(script)
659 }
660
661 pub fn create_r_statistical_analysis(
663 data: &RMatrix,
664 analysis_type: &str,
665 ) -> UtilsResult<String> {
666 let mut script = String::new();
667
668 script.push_str("# Statistical analysis generated by sklears-utils\n");
669 script.push_str(&Self::generate_r_matrix_code("data", data));
670 script.push('\n');
671
672 match analysis_type {
673 "descriptive" => {
674 script.push_str("\n# Descriptive statistics\n");
675 script.push_str("summary(data)\n");
676 script.push_str("apply(data, 2, sd) # Standard deviations\n");
677 script.push_str("cor(data) # Correlation matrix\n");
678 }
679 "pca" => {
680 script.push_str("\n# Principal Component Analysis\n");
681 script.push_str("pca_result <- prcomp(data, center = TRUE, scale. = TRUE)\n");
682 script.push_str("summary(pca_result)\n");
683 script
684 .push_str("plot(pca_result$x[,1:2]) # Scatter plot of first two components\n");
685 }
686 "clustering" => {
687 script.push_str("\n# K-means clustering\n");
688 script.push_str("set.seed(42)\n");
689 script.push_str("kmeans_result <- kmeans(data, centers = 3)\n");
690 script.push_str("print(kmeans_result)\n");
691 script.push_str("plot(data, col = kmeans_result$cluster)\n");
692 }
693 "normality_test" => {
694 script.push_str("\n# Normality tests\n");
695 script.push_str("for(i in 1:ncol(data)) {\n");
696 script.push_str(" cat('Column', i, '\\n')\n");
697 script.push_str(" print(shapiro.test(data[,i]))\n");
698 script.push_str("}\n");
699 }
700 _ => {
701 return Err(UtilsError::InvalidParameter(format!(
702 "Unsupported analysis type: {analysis_type}"
703 )))
704 }
705 }
706
707 Ok(script)
708 }
709
710 pub fn convert_r_output(output: &str, expected_type: ROutputType) -> UtilsResult<ROutputValue> {
712 match expected_type {
713 ROutputType::Vector => {
714 let cleaned = output.replace("[1]", "").trim().to_string();
716 let values: Result<Vec<f64>, _> = cleaned
717 .split_whitespace()
718 .map(|s| s.parse::<f64>())
719 .collect();
720
721 match values {
722 Ok(v) => Ok(ROutputValue::Vector(v)),
723 Err(_) => Err(UtilsError::InvalidParameter(
724 "Failed to parse R vector output".to_string(),
725 )),
726 }
727 }
728 ROutputType::Scalar => {
729 let cleaned = output.replace("[1]", "");
731 let cleaned = cleaned.trim();
732 match cleaned.parse::<f64>() {
733 Ok(v) => Ok(ROutputValue::Scalar(v)),
734 Err(_) => Err(UtilsError::InvalidParameter(
735 "Failed to parse R scalar output".to_string(),
736 )),
737 }
738 }
739 ROutputType::String => Ok(ROutputValue::String(output.to_string())),
740 }
741 }
742}
743
744pub struct FFIUtils;
746
747impl FFIUtils {
748 pub fn create_c_signature(
750 function_name: &str,
751 parameters: &[CParameter],
752 return_type: CType,
753 ) -> String {
754 let param_strings: Vec<String> = parameters
755 .iter()
756 .map(|p| format!("{} {}", p.param_type, p.name))
757 .collect();
758
759 format!(
760 "extern \"C\" fn {}({}) -> {}",
761 function_name,
762 param_strings.join(", "),
763 return_type
764 )
765 }
766
767 pub fn generate_c_header(library_name: &str, functions: &[CFunctionSignature]) -> String {
769 let mut header = String::new();
770
771 let library_upper = library_name.to_uppercase();
772 header.push_str(&format!("#ifndef {library_upper}_H\n"));
773 header.push_str(&format!("#define {library_upper}_H\n\n"));
774 header.push_str("#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n");
775
776 for func in functions {
777 header.push_str(&format!("{};\n", func.signature));
778 }
779
780 header.push_str("\n#ifdef __cplusplus\n}\n#endif\n\n");
781 header.push_str(&format!("#endif // {library_upper}_H\n"));
782
783 header
784 }
785
786 pub fn rust_string_to_c(s: &str) -> UtilsResult<*mut c_char> {
788 let c_string = CString::new(s)
789 .map_err(|e| UtilsError::InvalidParameter(format!("Invalid C string: {e}")))?;
790 Ok(c_string.into_raw())
791 }
792
793 pub unsafe fn c_string_to_rust(ptr: *const c_char) -> UtilsResult<String> {
802 if ptr.is_null() {
803 return Err(UtilsError::InvalidParameter("Null pointer".to_string()));
804 }
805
806 let c_str = CStr::from_ptr(ptr);
807 c_str
808 .to_str()
809 .map(|s| s.to_string())
810 .map_err(|e| UtilsError::InvalidParameter(format!("Invalid UTF-8: {e}")))
811 }
812
813 pub fn create_array_transfer(data: &[f64]) -> ArrayTransfer {
815 ArrayTransfer {
816 data: data.as_ptr(),
817 length: data.len(),
818 capacity: data.len(),
819 }
820 }
821
822 pub fn generate_ffi_examples() -> String {
824 r#"
825// Example FFI functions for machine learning operations
826
827use std::os::raw::{c_double, c_int};
828use std::slice;
829
830#[repr(C)]
831pub struct ArrayTransfer {
832 pub data: *const f64,
833 pub length: usize,
834 pub capacity: usize,
835}
836
837// Linear regression example
838#[no_mangle]
839pub extern "C" fn linear_regression_fit(
840 x_data: *const c_double,
841 y_data: *const c_double,
842 n_samples: c_int,
843 coefficients: *mut c_double,
844 intercept: *mut c_double,
845) -> c_int {
846 if x_data.is_null() || y_data.is_null() || coefficients.is_null() || intercept.is_null() {
847 return -1; // Error: null pointer
848 }
849
850 unsafe {
851 let x_slice = slice::from_raw_parts(x_data, n_samples as usize);
852 let y_slice = slice::from_raw_parts(y_data, n_samples as usize);
853
854 // Simple linear regression calculation
855 let n = n_samples as f64;
856 let sum_x: f64 = x_slice.iter().sum();
857 let sum_y: f64 = y_slice.iter().sum();
858 let sum_xy: f64 = x_slice.iter().zip(y_slice.iter()).map(|(x, y)| x * y).sum();
859 let sum_xx: f64 = x_slice.iter().map(|x| x * x).sum();
860
861 let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x * sum_x);
862 let intercept_val = (sum_y - slope * sum_x) / n;
863
864 *coefficients = slope;
865 *intercept = intercept_val;
866
867 0 // Success
868 }
869}
870
871// Array operations example
872#[no_mangle]
873pub extern "C" fn array_mean(
874 data: *const c_double,
875 length: c_int,
876 result: *mut c_double,
877) -> c_int {
878 if data.is_null() || result.is_null() || length <= 0 {
879 return -1;
880 }
881
882 unsafe {
883 let slice = slice::from_raw_parts(data, length as usize);
884 let mean = slice.iter().sum::<f64>() / length as f64;
885 *result = mean;
886 0
887 }
888}
889"#
890 .to_string()
891 }
892}
893
894#[derive(Debug, Clone, Serialize, Deserialize)]
897pub struct PyArrayBuffer {
898 pub data: Vec<f64>,
899 pub shape: Vec<usize>,
900 pub dtype: String,
901 pub order: String,
902}
903
904#[derive(Debug, Clone)]
905pub struct PythonParameter {
906 pub name: String,
907 pub value: PythonValue,
908}
909
910#[derive(Debug, Clone)]
911pub enum PythonValue {
912 String(String),
913 Number(f64),
914 Boolean(bool),
915 Array(String), }
917
918#[derive(Debug, Clone)]
919pub struct WasmParameter {
920 pub name: String,
921 pub param_type: WasmType,
922}
923
924#[derive(Debug, Clone)]
925pub enum WasmType {
926 F64,
927 F32,
928 I32,
929 U32,
930 Bool,
931 String,
932}
933
934impl fmt::Display for WasmType {
935 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
936 match self {
937 WasmType::F64 => write!(f, "f64"),
938 WasmType::F32 => write!(f, "f32"),
939 WasmType::I32 => write!(f, "i32"),
940 WasmType::U32 => write!(f, "u32"),
941 WasmType::Bool => write!(f, "bool"),
942 WasmType::String => write!(f, "String"),
943 }
944 }
945}
946
947#[derive(Debug, Clone)]
948pub struct WasmBuildConfig {
949 pub target: String,
950 pub features: Vec<String>,
951 pub optimization: WasmOptimization,
952 pub debug: bool,
953 pub typescript_bindings: bool,
954}
955
956#[derive(Debug, Clone)]
957pub enum WasmOptimization {
958 None,
959 Size,
960 Speed,
961}
962
963#[derive(Debug, Clone)]
964pub struct CParameter {
965 pub name: String,
966 pub param_type: CType,
967}
968
969#[derive(Debug, Clone)]
970pub enum CType {
971 Int,
972 Double,
973 Float,
974 CharPtr,
975 VoidPtr,
976 ConstCharPtr,
977 ConstDoublePtr,
978}
979
980impl fmt::Display for CType {
981 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
982 match self {
983 CType::Int => write!(f, "int"),
984 CType::Double => write!(f, "double"),
985 CType::Float => write!(f, "float"),
986 CType::CharPtr => write!(f, "char*"),
987 CType::VoidPtr => write!(f, "void*"),
988 CType::ConstCharPtr => write!(f, "const char*"),
989 CType::ConstDoublePtr => write!(f, "const double*"),
990 }
991 }
992}
993
994#[derive(Debug, Clone)]
995pub struct CFunctionSignature {
996 pub name: String,
997 pub signature: String,
998 pub description: String,
999}
1000
1001#[repr(C)]
1002#[derive(Debug, Clone)]
1003pub struct ArrayTransfer {
1004 pub data: *const f64,
1005 pub length: usize,
1006 pub capacity: usize,
1007}
1008
1009#[derive(Debug, Clone, Serialize, Deserialize)]
1012pub struct RVector {
1013 pub data: Vec<f64>,
1014 pub length: usize,
1015 pub r_type: RType,
1016}
1017
1018#[derive(Debug, Clone, Serialize, Deserialize)]
1019pub struct RMatrix {
1020 pub data: Vec<f64>,
1021 pub nrow: usize,
1022 pub ncol: usize,
1023 pub byrow: bool,
1024 pub r_type: RType,
1025}
1026
1027#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1028pub enum RType {
1029 Numeric,
1030 Integer,
1031 Character,
1032 Logical,
1033 Factor,
1034}
1035
1036#[derive(Debug, Clone)]
1037pub struct RParameter {
1038 pub name: String,
1039 pub value: RValue,
1040}
1041
1042#[derive(Debug, Clone)]
1043pub enum RValue {
1044 String(String),
1045 Number(f64),
1046 Boolean(bool),
1047 Vector(String), Matrix(String), DataFrame(String), }
1051
1052#[derive(Debug, Clone)]
1053pub enum ROutputType {
1054 Vector,
1055 Scalar,
1056 String,
1057}
1058
1059#[derive(Debug, Clone)]
1060pub enum ROutputValue {
1061 Vector(Vec<f64>),
1062 Scalar(f64),
1063 String(String),
1064}
1065
1066#[allow(non_snake_case)]
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070 use scirs2_core::ndarray::array;
1071
1072 #[test]
1073 fn test_python_array_conversion() {
1074 let arr = array![1.0, 2.0, 3.0, 4.0];
1075 let buffer = PythonInterop::array_to_python_buffer(&arr);
1076
1077 assert_eq!(buffer.data, vec![1.0, 2.0, 3.0, 4.0]);
1078 assert_eq!(buffer.shape, vec![4]);
1079 assert_eq!(buffer.dtype, "float64");
1080 assert_eq!(buffer.order, "C");
1081
1082 let converted = PythonInterop::python_buffer_to_array(&buffer).unwrap();
1084 assert_eq!(converted, arr);
1085 }
1086
1087 #[test]
1088 fn test_python_array2_conversion() {
1089 let arr = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1090 let buffer = PythonInterop::array2_to_python_buffer(&arr);
1091
1092 assert_eq!(buffer.data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1093 assert_eq!(buffer.shape, vec![3, 2]);
1094
1095 let converted = PythonInterop::python_buffer_to_array2(&buffer).unwrap();
1097 assert_eq!(converted, arr);
1098 }
1099
1100 #[test]
1101 fn test_numpy_code_generation() {
1102 let buffer = PyArrayBuffer {
1103 data: vec![1.0, 2.0, 3.0],
1104 shape: vec![3],
1105 dtype: "float64".to_string(),
1106 order: "C".to_string(),
1107 };
1108
1109 let code = PythonInterop::generate_numpy_import_code("my_array", &buffer);
1110
1111 assert!(code.contains("import numpy as np"));
1112 assert!(code.contains("my_array = np.array"));
1113 assert!(code.contains("[1.0, 2.0, 3.0]"));
1114 assert!(code.contains("dtype='float64'"));
1115 assert!(code.contains("reshape([3])"));
1116 }
1117
1118 #[test]
1119 fn test_python_function_call_template() {
1120 let params = vec![
1121 PythonParameter {
1122 name: "n_estimators".to_string(),
1123 value: PythonValue::Number(100.0),
1124 },
1125 PythonParameter {
1126 name: "random_state".to_string(),
1127 value: PythonValue::Number(42.0),
1128 },
1129 PythonParameter {
1130 name: "verbose".to_string(),
1131 value: PythonValue::Boolean(true),
1132 },
1133 ];
1134
1135 let call = PythonInterop::generate_function_call_template("RandomForestRegressor", ¶ms);
1136
1137 assert!(call.contains("RandomForestRegressor("));
1138 assert!(call.contains("n_estimators=100"));
1139 assert!(call.contains("random_state=42"));
1140 assert!(call.contains("verbose=True"));
1141 }
1142
1143 #[test]
1144 fn test_ml_script_generation() {
1145 let training_data = PyArrayBuffer {
1146 data: vec![1.0, 2.0, 3.0, 4.0],
1147 shape: vec![2, 2],
1148 dtype: "float64".to_string(),
1149 order: "C".to_string(),
1150 };
1151
1152 let labels = PyArrayBuffer {
1153 data: vec![0.0, 1.0],
1154 shape: vec![2],
1155 dtype: "float64".to_string(),
1156 order: "C".to_string(),
1157 };
1158
1159 let mut hyperparams = HashMap::new();
1160 hyperparams.insert("n_estimators".to_string(), 50.0);
1161
1162 let script =
1163 PythonInterop::create_ml_script("random_forest", &training_data, &labels, &hyperparams)
1164 .unwrap();
1165
1166 assert!(script.contains("import numpy as np"));
1167 assert!(script.contains("from sklearn.ensemble import RandomForestRegressor"));
1168 assert!(script.contains("train_test_split"));
1169 assert!(script.contains("RandomForestRegressor(n_estimators=50)"));
1170 assert!(script.contains("model.fit(X_train, y_train)"));
1171 assert!(script.contains("model.score(X_test, y_test)"));
1172 }
1173
1174 #[test]
1175 fn test_wasm_signature_generation() {
1176 let params = vec![
1177 WasmParameter {
1178 name: "a".to_string(),
1179 param_type: WasmType::F64,
1180 },
1181 WasmParameter {
1182 name: "b".to_string(),
1183 param_type: WasmType::F64,
1184 },
1185 ];
1186
1187 let signature = WasmUtils::generate_wasm_signature("add", ¶ms, WasmType::F64);
1188
1189 assert!(signature.contains("#[wasm_bindgen]"));
1190 assert!(signature.contains("pub fn add(a: f64, b: f64) -> f64"));
1191 }
1192
1193 #[test]
1194 fn test_wasm_memory_helpers() {
1195 let helpers = WasmUtils::generate_memory_helpers();
1196
1197 assert!(helpers.contains("pub fn alloc(size: usize)"));
1198 assert!(helpers.contains("pub fn dealloc(ptr: *mut u8, size: usize)"));
1199 assert!(helpers.contains("pub struct Float64Array"));
1200 assert!(helpers.contains("#[wasm_bindgen]"));
1201 }
1202
1203 #[test]
1204 fn test_wasm_ml_bindings() {
1205 let bindings = WasmUtils::generate_ml_bindings();
1206
1207 assert!(bindings.contains("pub fn dot_product"));
1208 assert!(bindings.contains("pub fn matrix_multiply"));
1209 assert!(bindings.contains("pub fn mean"));
1210 assert!(bindings.contains("pub fn variance"));
1211 assert!(bindings.contains("pub fn standard_deviation"));
1212 }
1213
1214 #[test]
1215 fn test_wasm_build_config() {
1216 let config = WasmUtils::create_wasm_build_config();
1217
1218 assert_eq!(config.target, "wasm32-unknown-unknown");
1219 assert!(config.features.contains(&"wasm-bindgen".to_string()));
1220 assert!(config.typescript_bindings);
1221 assert!(!config.debug);
1222 }
1223
1224 #[test]
1225 fn test_package_json_generation() {
1226 let json = WasmUtils::generate_package_json("my-ml-wasm", "0.1.0");
1227
1228 assert!(json.contains("\"name\": \"my-ml-wasm\""));
1229 assert!(json.contains("\"version\": \"0.1.0\""));
1230 assert!(json.contains("wasm-pack"));
1231 assert!(json.contains("\"machine-learning\""));
1232 }
1233
1234 #[test]
1235 fn test_c_signature_creation() {
1236 let params = vec![
1237 CParameter {
1238 name: "data".to_string(),
1239 param_type: CType::ConstDoublePtr,
1240 },
1241 CParameter {
1242 name: "length".to_string(),
1243 param_type: CType::Int,
1244 },
1245 ];
1246
1247 let signature = FFIUtils::create_c_signature("compute_mean", ¶ms, CType::Double);
1248
1249 assert!(signature.contains("extern \"C\" fn compute_mean"));
1250 assert!(signature.contains("const double* data"));
1251 assert!(signature.contains("int length"));
1252 assert!(signature.contains("-> double"));
1253 }
1254
1255 #[test]
1256 fn test_c_header_generation() {
1257 let functions = vec![
1258 CFunctionSignature {
1259 name: "add".to_string(),
1260 signature: "double add(double a, double b)".to_string(),
1261 description: "Add two numbers".to_string(),
1262 },
1263 CFunctionSignature {
1264 name: "multiply".to_string(),
1265 signature: "double multiply(double a, double b)".to_string(),
1266 description: "Multiply two numbers".to_string(),
1267 },
1268 ];
1269
1270 let header = FFIUtils::generate_c_header("mylib", &functions);
1271
1272 assert!(header.contains("#ifndef MYLIB_H"));
1273 assert!(header.contains("#define MYLIB_H"));
1274 assert!(header.contains("extern \"C\" {"));
1275 assert!(header.contains("double add(double a, double b);"));
1276 assert!(header.contains("double multiply(double a, double b);"));
1277 assert!(header.contains("#endif // MYLIB_H"));
1278 }
1279
1280 #[test]
1281 fn test_rust_to_c_string() {
1282 let rust_str = "Hello, World!";
1283 let c_ptr = FFIUtils::rust_string_to_c(rust_str).unwrap();
1284
1285 let converted = unsafe { FFIUtils::c_string_to_rust(c_ptr).unwrap() };
1287 assert_eq!(converted, rust_str);
1288
1289 unsafe {
1291 let _ = CString::from_raw(c_ptr);
1292 }
1293 }
1294
1295 #[test]
1296 fn test_array_transfer_creation() {
1297 let data = vec![1.0, 2.0, 3.0, 4.0];
1298 let transfer = FFIUtils::create_array_transfer(&data);
1299
1300 assert_eq!(transfer.length, 4);
1301 assert_eq!(transfer.capacity, 4);
1302 assert!(!transfer.data.is_null());
1303 }
1304
1305 #[test]
1306 fn test_ffi_examples_generation() {
1307 let examples = FFIUtils::generate_ffi_examples();
1308
1309 assert!(examples.contains("linear_regression_fit"));
1310 assert!(examples.contains("array_mean"));
1311 assert!(examples.contains("#[no_mangle]"));
1312 assert!(examples.contains("extern \"C\""));
1313 assert!(examples.contains("ArrayTransfer"));
1314 }
1315
1316 #[test]
1317 fn test_python_value_variants() {
1318 let string_val = PythonValue::String("test".to_string());
1319 let number_val = PythonValue::Number(42.0);
1320 let bool_val = PythonValue::Boolean(true);
1321 let array_val = PythonValue::Array("my_array".to_string());
1322
1323 match string_val {
1325 PythonValue::String(_) => {}
1326 _ => panic!(),
1327 }
1328 match number_val {
1329 PythonValue::Number(_) => {}
1330 _ => panic!(),
1331 }
1332 match bool_val {
1333 PythonValue::Boolean(_) => {}
1334 _ => panic!(),
1335 }
1336 match array_val {
1337 PythonValue::Array(_) => {}
1338 _ => panic!(),
1339 }
1340 }
1341
1342 #[test]
1343 fn test_wasm_type_display() {
1344 assert_eq!(WasmType::F64.to_string(), "f64");
1345 assert_eq!(WasmType::F32.to_string(), "f32");
1346 assert_eq!(WasmType::I32.to_string(), "i32");
1347 assert_eq!(WasmType::U32.to_string(), "u32");
1348 assert_eq!(WasmType::Bool.to_string(), "bool");
1349 assert_eq!(WasmType::String.to_string(), "String");
1350 }
1351
1352 #[test]
1353 fn test_c_type_display() {
1354 assert_eq!(CType::Int.to_string(), "int");
1355 assert_eq!(CType::Double.to_string(), "double");
1356 assert_eq!(CType::Float.to_string(), "float");
1357 assert_eq!(CType::CharPtr.to_string(), "char*");
1358 assert_eq!(CType::VoidPtr.to_string(), "void*");
1359 assert_eq!(CType::ConstCharPtr.to_string(), "const char*");
1360 assert_eq!(CType::ConstDoublePtr.to_string(), "const double*");
1361 }
1362
1363 #[test]
1366 fn test_r_array_conversion() {
1367 let arr = array![1.0, 2.0, 3.0, 4.0];
1368 let r_vector = RInterop::array_to_r_vector(&arr);
1369
1370 assert_eq!(r_vector.data, vec![1.0, 2.0, 3.0, 4.0]);
1371 assert_eq!(r_vector.length, 4);
1372 assert_eq!(r_vector.r_type, RType::Numeric);
1373
1374 let converted = RInterop::r_vector_to_array(&r_vector).unwrap();
1376 assert_eq!(converted, arr);
1377 }
1378
1379 #[test]
1380 fn test_r_matrix_conversion() {
1381 let arr = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1382 let r_matrix = RInterop::array2_to_r_matrix(&arr);
1383
1384 assert_eq!(r_matrix.data, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]); assert_eq!(r_matrix.nrow, 3);
1386 assert_eq!(r_matrix.ncol, 2);
1387 assert!(!r_matrix.byrow);
1388 assert_eq!(r_matrix.r_type, RType::Numeric);
1389
1390 let converted = RInterop::r_matrix_to_array2(&r_matrix).unwrap();
1392 assert_eq!(converted, arr);
1393 }
1394
1395 #[test]
1396 fn test_r_vector_code_generation() {
1397 let r_vector = RVector {
1398 data: vec![1.0, 2.0, 3.0],
1399 length: 3,
1400 r_type: RType::Numeric,
1401 };
1402
1403 let code = RInterop::generate_r_vector_code("my_vector", &r_vector);
1404
1405 assert_eq!(code, "my_vector <- c(1, 2, 3)");
1406 }
1407
1408 #[test]
1409 fn test_r_matrix_code_generation() {
1410 let r_matrix = RMatrix {
1411 data: vec![1.0, 2.0, 3.0, 4.0],
1412 nrow: 2,
1413 ncol: 2,
1414 byrow: false,
1415 r_type: RType::Numeric,
1416 };
1417
1418 let code = RInterop::generate_r_matrix_code("my_matrix", &r_matrix);
1419
1420 assert_eq!(
1421 code,
1422 "my_matrix <- matrix(c(1, 2, 3, 4), nrow = 2, ncol = 2, byrow = FALSE)"
1423 );
1424 }
1425
1426 #[test]
1427 fn test_r_dataframe_code_generation() {
1428 let mut columns = HashMap::new();
1429
1430 columns.insert(
1431 "x".to_string(),
1432 RVector {
1433 data: vec![1.0, 2.0, 3.0],
1434 length: 3,
1435 r_type: RType::Numeric,
1436 },
1437 );
1438
1439 columns.insert(
1440 "y".to_string(),
1441 RVector {
1442 data: vec![4.0, 5.0, 6.0],
1443 length: 3,
1444 r_type: RType::Numeric,
1445 },
1446 );
1447
1448 let code = RInterop::generate_r_dataframe_code("my_df", &columns).unwrap();
1449
1450 let expected1 = "my_df <- data.frame(x = c(1, 2, 3), y = c(4, 5, 6))";
1452 let expected2 = "my_df <- data.frame(y = c(4, 5, 6), x = c(1, 2, 3))";
1453 assert!(code == expected1 || code == expected2);
1454 }
1455
1456 #[test]
1457 fn test_r_package_imports() {
1458 let packages = &["randomForest", "e1071", "ggplot2"];
1459 let imports = RInterop::generate_r_package_imports(packages);
1460
1461 assert!(imports.contains("library(randomForest)"));
1462 assert!(imports.contains("library(e1071)"));
1463 assert!(imports.contains("library(ggplot2)"));
1464 }
1465
1466 #[test]
1467 fn test_r_function_call_generation() {
1468 let params = vec![
1469 RParameter {
1470 name: "ntree".to_string(),
1471 value: RValue::Number(500.0),
1472 },
1473 RParameter {
1474 name: "mtry".to_string(),
1475 value: RValue::Number(3.0),
1476 },
1477 RParameter {
1478 name: "importance".to_string(),
1479 value: RValue::Boolean(true),
1480 },
1481 RParameter {
1482 name: "x".to_string(),
1483 value: RValue::Matrix("X_train".to_string()),
1484 },
1485 ];
1486
1487 let call = RInterop::generate_r_function_call("randomForest", ¶ms);
1488
1489 assert!(call.contains("randomForest("));
1490 assert!(call.contains("ntree = 500"));
1491 assert!(call.contains("mtry = 3"));
1492 assert!(call.contains("importance = TRUE"));
1493 assert!(call.contains("x = X_train"));
1494 }
1495
1496 #[test]
1497 fn test_r_ml_script_generation() {
1498 let training_data = RMatrix {
1499 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1500 nrow: 3,
1501 ncol: 2,
1502 byrow: false,
1503 r_type: RType::Numeric,
1504 };
1505
1506 let response_var = RVector {
1507 data: vec![1.0, 0.0, 1.0],
1508 length: 3,
1509 r_type: RType::Numeric,
1510 };
1511
1512 let mut hyperparams = HashMap::new();
1513 hyperparams.insert("ntree".to_string(), 100.0);
1514 hyperparams.insert("mtry".to_string(), 1.0);
1515
1516 let script = RInterop::create_r_ml_script(
1517 "random_forest",
1518 &training_data,
1519 &response_var,
1520 &hyperparams,
1521 )
1522 .unwrap();
1523
1524 assert!(script.contains("library(randomForest)"));
1525 assert!(script.contains("X <- matrix"));
1526 assert!(script.contains("y <- c"));
1527 assert!(script.contains("randomForest(x = X_train, y = y_train, ntree = 100, mtry = 1)"));
1528 assert!(script.contains("train_indices"));
1529 assert!(script.contains("predictions <- predict"));
1530 assert!(script.contains("summary(model)"));
1531 }
1532
1533 #[test]
1534 fn test_r_linear_regression_script() {
1535 let training_data = RMatrix {
1536 data: vec![1.0, 2.0, 3.0, 4.0],
1537 nrow: 2,
1538 ncol: 2,
1539 byrow: false,
1540 r_type: RType::Numeric,
1541 };
1542
1543 let response_var = RVector {
1544 data: vec![1.0, 2.0],
1545 length: 2,
1546 r_type: RType::Numeric,
1547 };
1548
1549 let hyperparams = HashMap::new();
1550
1551 let script = RInterop::create_r_ml_script(
1552 "linear_regression",
1553 &training_data,
1554 &response_var,
1555 &hyperparams,
1556 )
1557 .unwrap();
1558
1559 assert!(script.contains("# Linear regression using base R"));
1560 assert!(script.contains("model <- lm(y_train ~ X_train)"));
1561 assert!(!script.contains("library(")); }
1563
1564 #[test]
1565 fn test_r_statistical_analysis() {
1566 let data = RMatrix {
1567 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1568 nrow: 3,
1569 ncol: 2,
1570 byrow: false,
1571 r_type: RType::Numeric,
1572 };
1573
1574 let script = RInterop::create_r_statistical_analysis(&data, "descriptive").unwrap();
1576 assert!(script.contains("summary(data)"));
1577 assert!(script.contains("apply(data, 2, sd)"));
1578 assert!(script.contains("cor(data)"));
1579
1580 let script = RInterop::create_r_statistical_analysis(&data, "pca").unwrap();
1582 assert!(script.contains("prcomp(data, center = TRUE, scale. = TRUE)"));
1583 assert!(script.contains("plot(pca_result$x[,1:2])"));
1584
1585 let script = RInterop::create_r_statistical_analysis(&data, "clustering").unwrap();
1587 assert!(script.contains("kmeans(data, centers = 3)"));
1588 assert!(script.contains("plot(data, col = kmeans_result$cluster)"));
1589 }
1590
1591 #[test]
1592 fn test_r_output_conversion() {
1593 let vector_output = "[1] 1.0 2.5 3.8";
1595 let result = RInterop::convert_r_output(vector_output, ROutputType::Vector).unwrap();
1596 match result {
1597 ROutputValue::Vector(v) => assert_eq!(v, vec![1.0, 2.5, 3.8]),
1598 _ => panic!("Expected vector output"),
1599 }
1600
1601 let scalar_output = "[1] 42.5";
1603 let result = RInterop::convert_r_output(scalar_output, ROutputType::Scalar).unwrap();
1604 match result {
1605 ROutputValue::Scalar(s) => assert_eq!(s, 42.5),
1606 _ => panic!("Expected scalar output"),
1607 }
1608
1609 let string_output = "This is a test string";
1611 let result = RInterop::convert_r_output(string_output, ROutputType::String).unwrap();
1612 match result {
1613 ROutputValue::String(s) => assert_eq!(s, "This is a test string"),
1614 _ => panic!("Expected string output"),
1615 }
1616 }
1617
1618 #[test]
1619 fn test_r_matrix_row_major_conversion() {
1620 let r_matrix = RMatrix {
1621 data: vec![1.0, 2.0, 3.0, 4.0],
1622 nrow: 2,
1623 ncol: 2,
1624 byrow: true, r_type: RType::Numeric,
1626 };
1627
1628 let converted = RInterop::r_matrix_to_array2(&r_matrix).unwrap();
1629
1630 let expected = array![[1.0, 2.0], [3.0, 4.0]];
1633 assert_eq!(converted, expected);
1634 }
1635
1636 #[test]
1637 fn test_r_type_variants() {
1638 let numeric = RType::Numeric;
1639 let integer = RType::Integer;
1640 let character = RType::Character;
1641 let logical = RType::Logical;
1642 let factor = RType::Factor;
1643
1644 assert_eq!(numeric, RType::Numeric);
1646 assert_eq!(integer, RType::Integer);
1647 assert_eq!(character, RType::Character);
1648 assert_eq!(logical, RType::Logical);
1649 assert_eq!(factor, RType::Factor);
1650 }
1651
1652 #[test]
1653 fn test_r_value_variants() {
1654 let string_val = RValue::String("test".to_string());
1655 let number_val = RValue::Number(42.0);
1656 let bool_val = RValue::Boolean(true);
1657 let vector_val = RValue::Vector("my_vector".to_string());
1658 let matrix_val = RValue::Matrix("my_matrix".to_string());
1659 let df_val = RValue::DataFrame("my_df".to_string());
1660
1661 match string_val {
1663 RValue::String(_) => {}
1664 _ => panic!(),
1665 }
1666 match number_val {
1667 RValue::Number(_) => {}
1668 _ => panic!(),
1669 }
1670 match bool_val {
1671 RValue::Boolean(_) => {}
1672 _ => panic!(),
1673 }
1674 match vector_val {
1675 RValue::Vector(_) => {}
1676 _ => panic!(),
1677 }
1678 match matrix_val {
1679 RValue::Matrix(_) => {}
1680 _ => panic!(),
1681 }
1682 match df_val {
1683 RValue::DataFrame(_) => {}
1684 _ => panic!(),
1685 }
1686 }
1687
1688 #[test]
1689 fn test_r_unsupported_model_type() {
1690 let training_data = RMatrix {
1691 data: vec![1.0, 2.0, 3.0, 4.0],
1692 nrow: 2,
1693 ncol: 2,
1694 byrow: false,
1695 r_type: RType::Numeric,
1696 };
1697
1698 let response_var = RVector {
1699 data: vec![1.0, 2.0],
1700 length: 2,
1701 r_type: RType::Numeric,
1702 };
1703
1704 let hyperparams = HashMap::new();
1705
1706 let result = RInterop::create_r_ml_script(
1707 "unsupported_model",
1708 &training_data,
1709 &response_var,
1710 &hyperparams,
1711 );
1712
1713 assert!(result.is_err());
1714 assert!(result
1715 .unwrap_err()
1716 .to_string()
1717 .contains("Unsupported R model type"));
1718 }
1719
1720 #[test]
1721 fn test_r_unsupported_analysis_type() {
1722 let data = RMatrix {
1723 data: vec![1.0, 2.0, 3.0, 4.0],
1724 nrow: 2,
1725 ncol: 2,
1726 byrow: false,
1727 r_type: RType::Numeric,
1728 };
1729
1730 let result = RInterop::create_r_statistical_analysis(&data, "unsupported_analysis");
1731 assert!(result.is_err());
1732 assert!(result
1733 .unwrap_err()
1734 .to_string()
1735 .contains("Unsupported analysis type"));
1736 }
1737}