1use serde::{Deserialize, Serialize};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use std::collections::{BTreeMap, HashMap};
10
11use super::workflow_definitions::{Connection, ParameterValue, StepDefinition, WorkflowDefinition};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct VisualPipelineBuilder {
16 pub workflow: WorkflowDefinition,
18 pub component_positions: HashMap<String, Position>,
20 pub canvas_config: CanvasConfig,
22 pub validation_state: ValidationState,
24 pub history: Vec<WorkflowSnapshot>,
26 pub history_index: usize,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct Position {
33 pub x: f64,
35 pub y: f64,
37 pub width: f64,
39 pub height: f64,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct CanvasConfig {
46 pub width: f64,
48 pub height: f64,
50 pub grid_size: f64,
52 pub zoom: f64,
54 pub pan_x: f64,
56 pub pan_y: f64,
58 pub snap_to_grid: bool,
60 pub show_grid: bool,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ValidationState {
67 pub is_valid: bool,
69 pub errors: Vec<ValidationError>,
71 pub warnings: Vec<ValidationWarning>,
73 pub last_validated: String,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ValidationError {
80 pub error_type: ValidationErrorType,
82 pub message: String,
84 pub step_id: Option<String>,
86 pub connection_id: Option<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ValidationWarning {
93 pub warning_type: ValidationWarningType,
95 pub message: String,
97 pub step_id: Option<String>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum ValidationErrorType {
104 MissingInput,
106 TypeMismatch,
108 CircularDependency,
110 DisconnectedComponent,
112 InvalidParameter,
114 MissingStep,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub enum ValidationWarningType {
121 UnusedOutput,
123 PerformanceConcern,
125 DeprecatedComponent,
127 SuboptimalConfiguration,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct WorkflowSnapshot {
134 pub workflow: WorkflowDefinition,
136 pub positions: HashMap<String, Position>,
138 pub action_description: String,
140 pub timestamp: String,
142}
143
144impl VisualPipelineBuilder {
145 #[must_use]
147 pub fn new() -> Self {
148 let workflow = WorkflowDefinition::default();
149 let initial_snapshot = WorkflowSnapshot {
150 workflow: workflow.clone(),
151 positions: HashMap::new(),
152 action_description: "Initial state".to_string(),
153 timestamp: chrono::Utc::now().to_rfc3339(),
154 };
155
156 Self {
157 workflow,
158 component_positions: HashMap::new(),
159 canvas_config: CanvasConfig::default(),
160 validation_state: ValidationState::new(),
161 history: vec![initial_snapshot],
162 history_index: 0,
163 }
164 }
165
166 pub fn add_step(&mut self, step: StepDefinition) -> SklResult<()> {
168 if self.workflow.steps.iter().any(|s| s.id == step.id) {
170 return Err(SklearsError::InvalidInput(format!(
171 "Step with ID '{}' already exists",
172 step.id
173 )));
174 }
175
176 self.workflow.steps.push(step.clone());
178
179 let position = self.find_optimal_position(&step.id);
181 self.component_positions.insert(step.id.clone(), position);
182
183 self.create_snapshot(&format!("Added step '{}'", step.id));
185
186 let _ = self.validate();
188
189 Ok(())
190 }
191
192 pub fn remove_step(&mut self, step_id: &str) -> SklResult<()> {
194 let initial_count = self.workflow.steps.len();
196 self.workflow.steps.retain(|s| s.id != step_id);
197
198 if self.workflow.steps.len() == initial_count {
199 return Err(SklearsError::InvalidInput(format!(
200 "Step '{step_id}' not found"
201 )));
202 }
203
204 self.component_positions.remove(step_id);
206
207 self.workflow
209 .connections
210 .retain(|c| c.from_step != step_id && c.to_step != step_id);
211
212 self.create_snapshot(&format!("Removed step '{step_id}'"));
214
215 let _ = self.validate();
217
218 Ok(())
219 }
220
221 pub fn add_connection(&mut self, connection: Connection) -> SklResult<()> {
223 let from_exists = self
225 .workflow
226 .steps
227 .iter()
228 .any(|s| s.id == connection.from_step);
229 let to_exists = self
230 .workflow
231 .steps
232 .iter()
233 .any(|s| s.id == connection.to_step);
234
235 if !from_exists {
236 return Err(SklearsError::InvalidInput(format!(
237 "Source step '{}' not found",
238 connection.from_step
239 )));
240 }
241
242 if !to_exists {
243 return Err(SklearsError::InvalidInput(format!(
244 "Target step '{}' not found",
245 connection.to_step
246 )));
247 }
248
249 let duplicate = self.workflow.connections.iter().any(|c| {
251 c.from_step == connection.from_step
252 && c.from_output == connection.from_output
253 && c.to_step == connection.to_step
254 && c.to_input == connection.to_input
255 });
256
257 if duplicate {
258 return Err(SklearsError::InvalidInput(
259 "Connection already exists".to_string(),
260 ));
261 }
262
263 self.workflow.connections.push(connection.clone());
265
266 self.create_snapshot(&format!(
268 "Added connection from '{}' to '{}'",
269 connection.from_step, connection.to_step
270 ));
271
272 let _ = self.validate();
274
275 Ok(())
276 }
277
278 pub fn remove_connection(
280 &mut self,
281 from_step: &str,
282 from_output: &str,
283 to_step: &str,
284 to_input: &str,
285 ) -> SklResult<()> {
286 let initial_count = self.workflow.connections.len();
287
288 self.workflow.connections.retain(|c| {
289 !(c.from_step == from_step
290 && c.from_output == from_output
291 && c.to_step == to_step
292 && c.to_input == to_input)
293 });
294
295 if self.workflow.connections.len() == initial_count {
296 return Err(SklearsError::InvalidInput(
297 "Connection not found".to_string(),
298 ));
299 }
300
301 self.create_snapshot(&format!(
303 "Removed connection from '{from_step}' to '{to_step}'"
304 ));
305
306 let _ = self.validate();
308
309 Ok(())
310 }
311
312 pub fn move_component(&mut self, step_id: &str, position: Position) -> SklResult<()> {
314 if !self.workflow.steps.iter().any(|s| s.id == step_id) {
315 return Err(SklearsError::InvalidInput(format!(
316 "Step '{step_id}' not found"
317 )));
318 }
319
320 let final_position = if self.canvas_config.snap_to_grid {
321 self.snap_to_grid(position)
322 } else {
323 position
324 };
325
326 self.component_positions
327 .insert(step_id.to_string(), final_position);
328
329 Ok(())
333 }
334
335 pub fn update_step_parameters(
337 &mut self,
338 step_id: &str,
339 parameters: BTreeMap<String, ParameterValue>,
340 ) -> SklResult<()> {
341 let step = self
342 .workflow
343 .steps
344 .iter_mut()
345 .find(|s| s.id == step_id)
346 .ok_or_else(|| SklearsError::InvalidInput(format!("Step '{step_id}' not found")))?;
347
348 step.parameters = parameters;
349
350 self.create_snapshot(&format!("Updated parameters for '{step_id}'"));
352
353 let _ = self.validate();
355
356 Ok(())
357 }
358
359 #[must_use]
367 pub fn validate(&mut self) -> &ValidationState {
368 let mut errors = Vec::new();
369 let mut warnings = Vec::new();
370
371 for step in &self.workflow.steps {
373 let has_input = self
374 .workflow
375 .connections
376 .iter()
377 .any(|c| c.to_step == step.id);
378 let has_output = self
379 .workflow
380 .connections
381 .iter()
382 .any(|c| c.from_step == step.id);
383
384 if !has_input && !has_output && self.workflow.steps.len() > 1 {
385 errors.push(ValidationError {
386 error_type: ValidationErrorType::DisconnectedComponent,
387 message: format!(
388 "Step '{}' is not connected to any other components",
389 step.id
390 ),
391 step_id: Some(step.id.clone()),
392 connection_id: None,
393 });
394 }
395 }
396
397 if self.has_circular_dependency() {
399 errors.push(ValidationError {
400 error_type: ValidationErrorType::CircularDependency,
401 message: "Circular dependency detected in workflow".to_string(),
402 step_id: None,
403 connection_id: None,
404 });
405 }
406
407 for step in &self.workflow.steps {
409 for output in &step.outputs {
410 let is_used = self
411 .workflow
412 .connections
413 .iter()
414 .any(|c| c.from_step == step.id && c.from_output == *output);
415
416 if !is_used {
417 warnings.push(ValidationWarning {
418 warning_type: ValidationWarningType::UnusedOutput,
419 message: format!("Output '{}' of step '{}' is not used", output, step.id),
420 step_id: Some(step.id.clone()),
421 });
422 }
423 }
424 }
425
426 self.validation_state = ValidationState {
427 is_valid: errors.is_empty(),
428 errors,
429 warnings,
430 last_validated: chrono::Utc::now().to_rfc3339(),
431 };
432
433 &self.validation_state
434 }
435
436 fn has_circular_dependency(&self) -> bool {
438 let mut visited = HashMap::new();
439 let mut recursion_stack = HashMap::new();
440
441 for step in &self.workflow.steps {
442 if !visited.get(&step.id).unwrap_or(&false)
443 && self.has_cycle_util(&step.id, &mut visited, &mut recursion_stack)
444 {
445 return true;
446 }
447 }
448
449 false
450 }
451
452 fn has_cycle_util(
454 &self,
455 step_id: &str,
456 visited: &mut HashMap<String, bool>,
457 recursion_stack: &mut HashMap<String, bool>,
458 ) -> bool {
459 visited.insert(step_id.to_string(), true);
460 recursion_stack.insert(step_id.to_string(), true);
461
462 for connection in &self.workflow.connections {
464 if connection.from_step == step_id {
465 let next_step = &connection.to_step;
466
467 if !visited.get(next_step).unwrap_or(&false) {
468 if self.has_cycle_util(next_step, visited, recursion_stack) {
469 return true;
470 }
471 } else if *recursion_stack.get(next_step).unwrap_or(&false) {
472 return true;
473 }
474 }
475 }
476
477 recursion_stack.insert(step_id.to_string(), false);
478 false
479 }
480
481 fn find_optimal_position(&self, _step_id: &str) -> Position {
483 let num_components = self.component_positions.len();
485 let grid_cols = 4;
486 let component_width = 120.0;
487 let component_height = 80.0;
488 let spacing_x = 160.0;
489 let spacing_y = 120.0;
490
491 let col = num_components % grid_cols;
492 let row = num_components / grid_cols;
493
494 Position {
495 x: 50.0 + col as f64 * spacing_x,
496 y: 50.0 + row as f64 * spacing_y,
497 width: component_width,
498 height: component_height,
499 }
500 }
501
502 #[must_use]
504 pub fn snap_to_grid(&self, position: Position) -> Position {
505 let grid_size = self.canvas_config.grid_size;
506 Position {
507 x: (position.x / grid_size).round() * grid_size,
508 y: (position.y / grid_size).round() * grid_size,
509 width: position.width,
510 height: position.height,
511 }
512 }
513
514 fn create_snapshot(&mut self, action_description: &str) {
516 let snapshot = WorkflowSnapshot {
517 workflow: self.workflow.clone(),
518 positions: self.component_positions.clone(),
519 action_description: action_description.to_string(),
520 timestamp: chrono::Utc::now().to_rfc3339(),
521 };
522
523 self.history.truncate(self.history_index + 1);
525
526 self.history.push(snapshot);
528 self.history_index = self.history.len() - 1;
529
530 const MAX_HISTORY_SIZE: usize = 50;
532 if self.history.len() > MAX_HISTORY_SIZE {
533 self.history.remove(0);
534 self.history_index = self.history.len() - 1;
535 }
536 }
537
538 pub fn undo(&mut self) -> SklResult<()> {
540 if self.history_index == 0 {
541 return Err(SklearsError::InvalidInput("Nothing to undo".to_string()));
542 }
543
544 self.history_index -= 1;
545 let snapshot = &self.history[self.history_index];
546
547 self.workflow = snapshot.workflow.clone();
548 self.component_positions = snapshot.positions.clone();
549
550 Ok(())
551 }
552
553 pub fn redo(&mut self) -> SklResult<()> {
555 if self.history_index >= self.history.len() - 1 {
556 return Err(SklearsError::InvalidInput("Nothing to redo".to_string()));
557 }
558
559 self.history_index += 1;
560 let snapshot = &self.history[self.history_index];
561
562 self.workflow = snapshot.workflow.clone();
563 self.component_positions = snapshot.positions.clone();
564
565 Ok(())
566 }
567
568 #[must_use]
570 pub fn get_workflow(&self) -> &WorkflowDefinition {
571 &self.workflow
572 }
573
574 #[must_use]
576 pub fn get_component_positions(&self) -> &HashMap<String, Position> {
577 &self.component_positions
578 }
579
580 #[must_use]
582 pub fn get_validation_state(&self) -> &ValidationState {
583 &self.validation_state
584 }
585
586 pub fn clear(&mut self) {
588 self.workflow = WorkflowDefinition::default();
589 self.component_positions.clear();
590 self.validation_state = ValidationState::new();
591
592 self.create_snapshot("Cleared workflow");
594 }
595
596 pub fn set_canvas_config(&mut self, config: CanvasConfig) {
598 self.canvas_config = config;
599 }
600
601 #[must_use]
603 pub fn get_canvas_config(&self) -> &CanvasConfig {
604 &self.canvas_config
605 }
606}
607
608impl Default for VisualPipelineBuilder {
609 fn default() -> Self {
610 Self::new()
611 }
612}
613
614impl Default for CanvasConfig {
615 fn default() -> Self {
616 Self {
617 width: 1200.0,
618 height: 800.0,
619 grid_size: 20.0,
620 zoom: 1.0,
621 pan_x: 0.0,
622 pan_y: 0.0,
623 snap_to_grid: true,
624 show_grid: true,
625 }
626 }
627}
628
629impl ValidationState {
630 #[must_use]
631 pub fn new() -> Self {
632 Self {
633 is_valid: true,
634 errors: Vec::new(),
635 warnings: Vec::new(),
636 last_validated: chrono::Utc::now().to_rfc3339(),
637 }
638 }
639}
640
641impl Default for ValidationState {
642 fn default() -> Self {
643 Self::new()
644 }
645}
646
647pub type ComponentPosition = Position;
649
650#[derive(Debug, Clone, Serialize, Deserialize)]
652pub struct CanvasInteraction {
653 pub mode: InteractionMode,
655 pub selected_components: Vec<String>,
657 pub drag_state: Option<DragState>,
659 pub selection_state: SelectionState,
661}
662
663#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
665pub enum InteractionMode {
666 Select,
668 Pan,
670 Zoom,
672 Connect,
674 AddComponent,
676}
677
678#[derive(Debug, Clone, Serialize, Deserialize)]
680pub struct DragState {
681 pub component_id: String,
683 pub start_position: Position,
685 pub current_position: Position,
687 pub offset: Position,
689}
690
691#[derive(Debug, Clone, Serialize, Deserialize, Default)]
693pub struct SelectionState {
694 pub selected: Vec<String>,
696 pub selection_box: Option<SelectionBox>,
698 pub multi_select: bool,
700}
701
702#[derive(Debug, Clone, Serialize, Deserialize)]
704pub struct SelectionBox {
705 pub start: Position,
707 pub end: Position,
709}
710
711pub type GridConfig = CanvasConfig;
713
714#[derive(Debug, Clone, Serialize, Deserialize)]
716pub struct ViewportConfig {
717 pub width: f64,
719 pub height: f64,
721 pub pan_x: f64,
723 pub pan_y: f64,
725 pub zoom: f64,
727}
728
729#[derive(Debug, Clone, Serialize, Deserialize)]
731pub struct ZoomConfig {
732 pub level: f64,
734 pub min_zoom: f64,
736 pub max_zoom: f64,
738 pub zoom_step: f64,
740 pub center: Position,
742}
743
744pub type WorkflowHistory = Vec<WorkflowSnapshot>;
746
747#[derive(Debug, Clone, Serialize, Deserialize)]
749pub struct UndoRedoManager {
750 pub history: WorkflowHistory,
752 pub current_index: usize,
754 pub max_history_size: usize,
756}
757
758impl Default for CanvasInteraction {
759 fn default() -> Self {
760 Self {
761 mode: InteractionMode::Select,
762 selected_components: Vec::new(),
763 drag_state: None,
764 selection_state: SelectionState::default(),
765 }
766 }
767}
768
769impl Default for ViewportConfig {
770 fn default() -> Self {
771 Self {
772 width: 1200.0,
773 height: 800.0,
774 pan_x: 0.0,
775 pan_y: 0.0,
776 zoom: 1.0,
777 }
778 }
779}
780
781impl Default for ZoomConfig {
782 fn default() -> Self {
783 Self {
784 level: 1.0,
785 min_zoom: 0.1,
786 max_zoom: 5.0,
787 zoom_step: 0.1,
788 center: Position {
789 x: 0.0,
790 y: 0.0,
791 width: 0.0,
792 height: 0.0,
793 },
794 }
795 }
796}
797
798impl UndoRedoManager {
799 #[must_use]
801 pub fn new() -> Self {
802 Self {
803 history: Vec::new(),
804 current_index: 0,
805 max_history_size: 50,
806 }
807 }
808
809 pub fn add_snapshot(&mut self, snapshot: WorkflowSnapshot) {
811 self.history.truncate(self.current_index + 1);
813
814 self.history.push(snapshot);
816 self.current_index = self.history.len() - 1;
817
818 if self.history.len() > self.max_history_size {
820 self.history.remove(0);
821 self.current_index = self.history.len() - 1;
822 }
823 }
824
825 #[must_use]
827 pub fn can_undo(&self) -> bool {
828 self.current_index > 0
829 }
830
831 #[must_use]
833 pub fn can_redo(&self) -> bool {
834 self.current_index < self.history.len() - 1
835 }
836}
837
838impl Default for UndoRedoManager {
839 fn default() -> Self {
840 Self::new()
841 }
842}
843
844#[allow(non_snake_case)]
845#[cfg(test)]
846mod tests {
847 use super::*;
848 use crate::workflow_language::workflow_definitions::{DataType, StepType};
849
850 #[test]
851 fn test_visual_pipeline_builder_creation() {
852 let builder = VisualPipelineBuilder::new();
853 assert_eq!(builder.workflow.steps.len(), 0);
854 assert_eq!(builder.component_positions.len(), 0);
855 assert!(builder.validation_state.is_valid);
856 assert_eq!(builder.history.len(), 1); }
858
859 #[test]
860 fn test_add_step() {
861 let mut builder = VisualPipelineBuilder::new();
862
863 let step = StepDefinition::new("step1", StepType::Transformer, "StandardScaler");
864 let result = builder.add_step(step);
865
866 assert!(result.is_ok());
867 assert_eq!(builder.workflow.steps.len(), 1);
868 assert_eq!(builder.component_positions.len(), 1);
869 assert!(builder.component_positions.contains_key("step1"));
870 }
871
872 #[test]
873 fn test_add_duplicate_step() {
874 let mut builder = VisualPipelineBuilder::new();
875
876 let step1 = StepDefinition::new("step1", StepType::Transformer, "StandardScaler");
877 let step2 = StepDefinition::new("step1", StepType::Predictor, "LinearRegression");
878
879 assert!(builder.add_step(step1).is_ok());
880 assert!(builder.add_step(step2).is_err());
881 }
882
883 #[test]
884 fn test_remove_step() {
885 let mut builder = VisualPipelineBuilder::new();
886
887 let step = StepDefinition::new("step1", StepType::Transformer, "StandardScaler");
888 builder.add_step(step).unwrap();
889
890 assert_eq!(builder.workflow.steps.len(), 1);
891
892 let result = builder.remove_step("step1");
893 assert!(result.is_ok());
894 assert_eq!(builder.workflow.steps.len(), 0);
895 assert!(!builder.component_positions.contains_key("step1"));
896 }
897
898 #[test]
899 fn test_add_connection() {
900 let mut builder = VisualPipelineBuilder::new();
901
902 let step1 = StepDefinition::new("step1", StepType::Transformer, "StandardScaler")
903 .with_output("X_scaled");
904 let step2 =
905 StepDefinition::new("step2", StepType::Predictor, "LinearRegression").with_input("X");
906
907 builder.add_step(step1).unwrap();
908 builder.add_step(step2).unwrap();
909
910 let connection = Connection::direct("step1", "X_scaled", "step2", "X");
911 let result = builder.add_connection(connection);
912
913 assert!(result.is_ok());
914 assert_eq!(builder.workflow.connections.len(), 1);
915 }
916
917 #[test]
918 fn test_undo_redo() {
919 let mut builder = VisualPipelineBuilder::new();
920
921 let step = StepDefinition::new("step1", StepType::Transformer, "StandardScaler");
922 builder.add_step(step).unwrap();
923
924 assert_eq!(builder.workflow.steps.len(), 1);
925 assert_eq!(builder.history_index, 1);
926
927 builder.undo().unwrap();
929 assert_eq!(builder.workflow.steps.len(), 0);
930 assert_eq!(builder.history_index, 0);
931
932 builder.redo().unwrap();
934 assert_eq!(builder.workflow.steps.len(), 1);
935 assert_eq!(builder.history_index, 1);
936 }
937
938 #[test]
939 fn test_snap_to_grid() {
940 let builder = VisualPipelineBuilder::new();
941 let position = Position {
942 x: 23.7,
943 y: 47.3,
944 width: 100.0,
945 height: 80.0,
946 };
947
948 let snapped = builder.snap_to_grid(position);
949 assert_eq!(snapped.x, 20.0);
950 assert_eq!(snapped.y, 40.0);
951 }
952
953 #[test]
954 fn test_validation_disconnected_component() {
955 let mut builder = VisualPipelineBuilder::new();
956
957 let step1 = StepDefinition::new("step1", StepType::Transformer, "StandardScaler");
958 let step2 = StepDefinition::new("step2", StepType::Predictor, "LinearRegression");
959
960 builder.add_step(step1).unwrap();
961 builder.add_step(step2).unwrap();
962
963 assert!(!builder.validation_state.is_valid);
965 assert!(!builder.validation_state.errors.is_empty());
966 }
967}