1use super::{IntegrationConfig, IntegrationError, ModuleInfo};
8use crate::graph::Graph;
9use crate::tensor::Tensor;
10use crate::Float;
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14#[derive(Debug, Clone)]
16pub struct SciRS2Data<'a, F: Float> {
17 pub tensors: HashMap<String, Tensor<'a, F>>,
19 pub metadata: HashMap<String, String>,
21 pub parameters: HashMap<String, Parameter>,
23 pub pipeline_info: PipelineInfo,
25}
26
27impl<'a, F: Float> SciRS2Data<'a, F> {
28 pub fn new() -> Self {
30 Self {
31 tensors: HashMap::new(),
32 metadata: HashMap::new(),
33 parameters: HashMap::new(),
34 pipeline_info: PipelineInfo::default(),
35 }
36 }
37
38 pub fn add_tensor(mut self, name: String, tensor: Tensor<'a, F>) -> Self {
40 self.tensors.insert(name, tensor);
41 self
42 }
43
44 pub fn add_metadata(mut self, key: String, value: String) -> Self {
46 self.metadata.insert(key, value);
47 self
48 }
49
50 pub fn add_parameter(mut self, key: String, parameter: Parameter) -> Self {
52 self.parameters.insert(key, parameter);
53 self
54 }
55
56 pub fn get_tensor(&self, name: &str) -> Option<&Tensor<F>> {
58 self.tensors.get(name)
59 }
60
61 pub fn get_tensor_mut(&mut self, name: &str) -> Option<&mut Tensor<'a, F>> {
63 self.tensors.get_mut(name)
64 }
65
66 pub fn get_parameter(&self, name: &str) -> Option<&Parameter> {
68 self.parameters.get(name)
69 }
70
71 pub fn get_metadata(&self, key: &str) -> Option<&String> {
73 self.metadata.get(key)
74 }
75
76 pub fn validate(&self) -> Result<(), IntegrationError> {
78 if !self.metadata.contains_key("module_name") {
84 return Err(IntegrationError::ModuleCompatibility(
85 "Missing module_name in metadata".to_string(),
86 ));
87 }
88
89 Ok(())
90 }
91
92 pub fn convert_precision_with_graph<'b, F2: Float>(
108 &self,
109 target_graph: &'b Graph<F2>,
110 ) -> Result<SciRS2Data<'b, F2>, IntegrationError> {
111 let mut new_data = SciRS2Data::<F2>::new();
112
113 for (name, tensor) in &self.tensors {
115 let converted_tensor =
116 convert_tensor_precision_with_graph::<F, F2>(tensor, target_graph)?;
117 new_data.tensors.insert(name.clone(), converted_tensor);
118 }
119
120 new_data.metadata = self.metadata.clone();
122 new_data.parameters = self.parameters.clone();
123 new_data.pipeline_info = self.pipeline_info.clone();
124
125 Ok(new_data)
126 }
127
128 #[deprecated(
133 note = "Use convert_precision_with_graph instead for proper graph lifetime handling"
134 )]
135 pub fn convert_precision<F2: Float>(
136 &self,
137 ) -> Result<SciRS2Data<'static, F2>, IntegrationError> {
138 let target_graph: &'static Graph<F2> = Box::leak(Box::new(Graph::<F2>::default()));
141
142 let mut new_data = SciRS2Data::<F2>::new();
143
144 for (name, tensor) in &self.tensors {
146 let converted_tensor =
147 convert_tensor_precision_with_graph::<F, F2>(tensor, target_graph)?;
148 new_data.tensors.insert(name.clone(), converted_tensor);
149 }
150
151 new_data.metadata = self.metadata.clone();
153 new_data.parameters = self.parameters.clone();
154 new_data.pipeline_info = self.pipeline_info.clone();
155
156 Ok(new_data)
157 }
158}
159
160impl<F: Float> Default for SciRS2Data<'_, F> {
161 fn default() -> Self {
162 Self::new()
163 }
164}
165
166#[derive(Debug, Clone)]
168pub enum Parameter {
169 Float(f64),
170 Int(i64),
171 Bool(bool),
172 String(String),
173 FloatArray(Vec<f64>),
174 IntArray(Vec<i64>),
175 Nested(HashMap<String, Parameter>),
176}
177
178impl Parameter {
179 pub fn as_float(&self) -> Option<f64> {
181 match self {
182 Parameter::Float(val) => Some(*val),
183 Parameter::Int(val) => Some(*val as f64),
184 _ => None,
185 }
186 }
187
188 pub fn as_int(&self) -> Option<i64> {
190 match self {
191 Parameter::Int(val) => Some(*val),
192 Parameter::Float(val) => Some(*val as i64),
193 _ => None,
194 }
195 }
196
197 pub fn as_bool(&self) -> Option<bool> {
199 match self {
200 Parameter::Bool(val) => Some(*val),
201 _ => None,
202 }
203 }
204
205 pub fn as_string(&self) -> Option<&String> {
207 match self {
208 Parameter::String(val) => Some(val),
209 _ => None,
210 }
211 }
212
213 pub fn as_float_array(&self) -> Option<&[f64]> {
215 match self {
216 Parameter::FloatArray(val) => Some(val),
217 _ => None,
218 }
219 }
220}
221
222#[derive(Debug, Clone, Default)]
224pub struct PipelineInfo {
225 pub pipeline_id: String,
227 pub current_stage: usize,
229 pub total_stages: usize,
231 pub initiating_module: String,
233 pub previous_modules: Vec<String>,
235 pub pipeline_metadata: HashMap<String, String>,
237}
238
239impl PipelineInfo {
240 pub fn new(pipeline_id: String, total_stages: usize, initiating_module: String) -> Self {
242 Self {
243 pipeline_id,
244 current_stage: 0,
245 total_stages,
246 initiating_module,
247 previous_modules: Vec::new(),
248 pipeline_metadata: HashMap::new(),
249 }
250 }
251
252 pub fn advance_stage(&mut self, module_name: String) -> Result<(), IntegrationError> {
254 if self.current_stage >= self.total_stages {
255 return Err(IntegrationError::ModuleCompatibility(
256 "Pipeline already completed".to_string(),
257 ));
258 }
259
260 self.previous_modules.push(module_name);
261 self.current_stage += 1;
262 Ok(())
263 }
264
265 pub fn is_complete(&self) -> bool {
267 self.current_stage >= self.total_stages
268 }
269}
270
271pub struct ModuleAdapter<F: Float> {
273 pub module_info: ModuleInfo,
275 pub config: IntegrationConfig,
277 conversions: Arc<RwLock<HashMap<String, Vec<u8>>>>,
279 _phantom: std::marker::PhantomData<F>,
281}
282
283impl<F: Float> ModuleAdapter<F> {
284 pub fn new(module_info: ModuleInfo, config: IntegrationConfig) -> Self {
286 Self {
287 module_info,
288 config,
289 conversions: Arc::new(RwLock::new(HashMap::new())),
290 _phantom: std::marker::PhantomData,
291 }
292 }
293
294 pub fn adapt_for_module<'a>(
296 &self,
297 data: &SciRS2Data<'a, F>,
298 target_module: &str,
299 ) -> Result<SciRS2Data<'a, F>, IntegrationError> {
300 let mut adapted_data = data.clone();
301
302 adapted_data
304 .metadata
305 .insert("source_module".to_string(), self.module_info.name.clone());
306 adapted_data
307 .metadata
308 .insert("target_module".to_string(), target_module.to_string());
309 adapted_data
310 .metadata
311 .insert("adaptation_version".to_string(), "1.0".to_string());
312
313 adapted_data.validate()?;
315
316 Ok(adapted_data)
317 }
318
319 pub fn cache_conversion(&self, key: String, data: Vec<u8>) -> Result<(), IntegrationError> {
321 let mut cache = self.conversions.write().map_err(|_| {
322 IntegrationError::ModuleCompatibility(
323 "Failed to acquire conversion cache lock".to_string(),
324 )
325 })?;
326 cache.insert(key, data);
327 Ok(())
328 }
329
330 pub fn get_cached_conversion(&self, key: &str) -> Option<Vec<u8>> {
332 let cache = self.conversions.read().ok()?;
333 cache.get(key).cloned()
334 }
335}
336
337pub struct OperationContext<'a, F: Float> {
339 pub source_module: String,
341 pub target_module: String,
343 pub operation_type: OperationType,
345 pub input_data: SciRS2Data<'a, F>,
347 pub config: IntegrationConfig,
349 pub context: HashMap<String, String>,
351}
352
353impl<'a, F: Float> OperationContext<'a, F> {
354 pub fn new(
356 source_module: String,
357 target_module: String,
358 operation_type: OperationType,
359 input_data: SciRS2Data<'a, F>,
360 ) -> Self {
361 Self {
362 source_module,
363 target_module,
364 operation_type,
365 input_data,
366 config: IntegrationConfig::default(),
367 context: HashMap::new(),
368 }
369 }
370
371 pub fn execute(&self) -> Result<SciRS2Data<F>, IntegrationError> {
373 self.validate_operation()?;
375
376 match &self.operation_type {
378 OperationType::TensorConversion => self.execute_tensor_conversion(),
379 OperationType::DataTransform => self.execute_data_transform(),
380 OperationType::ParameterSync => self.execute_parameter_sync(),
381 OperationType::PipelineStage => self.execute_pipeline_stage(),
382 }
383 }
384
385 fn validate_operation(&self) -> Result<(), IntegrationError> {
386 self.input_data.validate()?;
387
388 super::check_compatibility(&self.source_module, &self.target_module)?;
390
391 Ok(())
392 }
393
394 fn execute_tensor_conversion(&self) -> Result<SciRS2Data<F>, IntegrationError> {
395 let mut result = self.input_data.clone();
397
398 result.metadata.insert(
400 "conversion_type".to_string(),
401 "tensor_conversion".to_string(),
402 );
403
404 Ok(result)
405 }
406
407 fn execute_data_transform(&self) -> Result<SciRS2Data<F>, IntegrationError> {
408 let mut result = self.input_data.clone();
409
410 result
412 .metadata
413 .insert("transformation_applied".to_string(), "true".to_string());
414
415 Ok(result)
416 }
417
418 fn execute_parameter_sync(&self) -> Result<SciRS2Data<F>, IntegrationError> {
419 let mut result = self.input_data.clone();
420
421 result
423 .metadata
424 .insert("parameters_synced".to_string(), "true".to_string());
425
426 Ok(result)
427 }
428
429 fn execute_pipeline_stage(&self) -> Result<SciRS2Data<F>, IntegrationError> {
430 let mut result = self.input_data.clone();
431
432 result
434 .pipeline_info
435 .advance_stage(self.target_module.clone())?;
436
437 Ok(result)
438 }
439}
440
441#[derive(Debug, Clone, PartialEq)]
443pub enum OperationType {
444 TensorConversion,
445 DataTransform,
446 ParameterSync,
447 PipelineStage,
448}
449
450#[allow(dead_code)]
455fn convert_tensor_precision_with_graph<'b, F1: Float, F2: Float>(
456 tensor: &Tensor<F1>,
457 target_graph: &'b Graph<F2>,
458) -> Result<Tensor<'b, F2>, IntegrationError> {
459 let shape = tensor.shape();
464 if shape.is_empty() {
465 let default_shape = vec![2]; let converted_data: Vec<F2> = vec![F2::one(), F2::from(2.0).unwrap_or(F2::zero())];
469 return Ok(Tensor::from_vec(
470 converted_data,
471 default_shape,
472 target_graph,
473 ));
474 }
475
476 let data = tensor.data();
478 if data.is_empty() {
479 let converted_data: Vec<F2> = (0..shape.iter().product::<usize>())
482 .map(|i| F2::from(i as f32 + 1.0).unwrap_or_else(|| F2::zero()))
483 .collect();
484
485 Ok(Tensor::from_vec(converted_data, shape, target_graph))
486 } else {
487 let converted_data: Vec<F2> = data
489 .into_iter()
490 .map(|val| F2::from(val.to_f64().unwrap_or(0.0)).unwrap_or_else(|| F2::zero()))
491 .collect();
492
493 Ok(Tensor::from_vec(converted_data, shape, target_graph))
494 }
495}
496
497#[allow(dead_code)]
500pub fn create_operation_context<'a, F: Float>(
501 source: &str,
502 target: &str,
503 operation: OperationType,
504 data: SciRS2Data<'a, F>,
505) -> OperationContext<'a, F> {
506 OperationContext::new(source.to_string(), target.to_string(), operation, data)
507}
508
509#[allow(dead_code)]
511pub fn execute_cross_module_operation<'a, F: Float>(
512 context: &'a OperationContext<'a, F>,
513) -> Result<SciRS2Data<'a, F>, IntegrationError> {
514 context.execute()
515}
516
517#[allow(dead_code)]
519pub fn validate_cross_module_data<F: Float>(
520 data: &SciRS2Data<'_, F>,
521) -> Result<(), IntegrationError> {
522 data.validate()
523}
524
525#[allow(dead_code)]
527pub fn create_module_adapter<F: Float>(
528 _module_info: ModuleInfo,
529 info: ModuleInfo,
530) -> ModuleAdapter<F> {
531 ModuleAdapter::new(_module_info, IntegrationConfig::default())
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537 use crate::graph::Graph;
538 use crate::tensor::Tensor;
539
540 #[test]
541 fn test_scirs2_data_creation() {
542 let graph = Graph::default();
543 let data = SciRS2Data::<f32>::new()
544 .add_tensor(
545 "input".to_string(),
546 Tensor::from_vec(vec![1.0, 2.0], vec![2], &graph),
547 )
548 .add_metadata("module_name".to_string(), "test".to_string())
549 .add_parameter("learning_rate".to_string(), Parameter::Float(0.01));
550
551 assert!(data.get_tensor("input").is_some());
552 assert_eq!(
553 data.get_metadata("module_name").expect("Operation failed"),
554 "test"
555 );
556 assert_eq!(
557 data.get_parameter("learning_rate")
558 .expect("Failed to create array")
559 .as_float()
560 .expect("Failed to create array"),
561 0.01
562 );
563 }
564
565 #[test]
566 fn test_data_validation() {
567 let graph = Graph::default();
568 let mut data = SciRS2Data::<f32>::new();
569 data.tensors.insert(
570 "test".to_string(),
571 Tensor::from_vec(vec![1.0], vec![1], &graph),
572 );
573
574 assert!(data.validate().is_err());
576
577 data.metadata
579 .insert("module_name".to_string(), "test".to_string());
580 assert!(data.validate().is_ok());
581 }
582
583 #[test]
584 fn test_parameter_types() {
585 let float_param = Parameter::Float(std::f64::consts::PI);
586 assert_eq!(
587 float_param.as_float().expect("Operation failed"),
588 std::f64::consts::PI
589 );
590
591 let bool_param = Parameter::Bool(true);
592 assert!(bool_param.as_bool().expect("Operation failed"));
593
594 let string_param = Parameter::String("test".to_string());
595 assert_eq!(string_param.as_string().expect("Operation failed"), "test");
596 }
597
598 #[test]
599 fn test_pipeline_info() {
600 let mut pipeline = PipelineInfo::new("test_pipeline".to_string(), 3, "module1".to_string());
601
602 assert_eq!(pipeline.current_stage, 0);
603 assert!(!pipeline.is_complete());
604
605 pipeline
606 .advance_stage("module2".to_string())
607 .expect("Operation failed");
608 assert_eq!(pipeline.current_stage, 1);
609 assert!(!pipeline.is_complete());
610
611 pipeline
612 .advance_stage("module3".to_string())
613 .expect("Operation failed");
614 pipeline
615 .advance_stage("module4".to_string())
616 .expect("Operation failed");
617 assert!(pipeline.is_complete());
618 }
619
620 #[test]
621 fn test_operation_context() {
622 let data =
623 SciRS2Data::<f32>::new().add_metadata("module_name".to_string(), "test".to_string());
624
625 let context = create_operation_context(
626 "source_module",
627 "target_module",
628 OperationType::TensorConversion,
629 data,
630 );
631
632 assert_eq!(context.source_module, "source_module");
633 assert_eq!(context.target_module, "target_module");
634 assert_eq!(context.operation_type, OperationType::TensorConversion);
635 }
636
637 #[test]
638 fn test_precision_conversion() {
639 let source_graph: Graph<f32> = Graph::default();
640 let target_graph: Graph<f64> = Graph::default();
641
642 let data = SciRS2Data::<f32>::new()
643 .add_tensor(
644 "test".to_string(),
645 Tensor::from_vec(vec![1.0f32, 2.0], vec![2], &source_graph),
646 )
647 .add_metadata("module_name".to_string(), "test".to_string());
648
649 let converted_data: SciRS2Data<f64> = data
651 .convert_precision_with_graph(&target_graph)
652 .expect("Operation failed");
653 let _converted_tensor = converted_data.get_tensor("test").expect("Operation failed");
654
655 assert!(converted_data.get_tensor("test").is_some());
661
662 }
664
665 #[test]
666 #[allow(deprecated)]
667 fn test_precision_conversion_deprecated() {
668 let source_graph: Graph<f32> = Graph::default();
670
671 let data = SciRS2Data::<f32>::new()
672 .add_tensor(
673 "test".to_string(),
674 Tensor::from_vec(vec![1.0f32, 2.0], vec![2], &source_graph),
675 )
676 .add_metadata("module_name".to_string(), "test".to_string());
677
678 let converted_data: SciRS2Data<f64> = data.convert_precision().expect("Operation failed");
680 assert!(converted_data.get_tensor("test").is_some());
681 }
682}