sklears_compose/workflow_language/
visual_builder.rs

1//! Visual Pipeline Builder
2//!
3//! This module provides visual pipeline building capabilities for creating machine learning
4//! workflows through a graphical interface, including drag-and-drop component assembly,
5//! visual connection management, and interactive workflow construction.
6
7use serde::{Deserialize, Serialize};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use std::collections::{BTreeMap, HashMap};
10
11use super::workflow_definitions::{Connection, ParameterValue, StepDefinition, WorkflowDefinition};
12
13/// Visual pipeline builder for creating workflows graphically
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct VisualPipelineBuilder {
16    /// Current workflow being built
17    pub workflow: WorkflowDefinition,
18    /// Component positioning information
19    pub component_positions: HashMap<String, Position>,
20    /// Canvas configuration
21    pub canvas_config: CanvasConfig,
22    /// Validation state
23    pub validation_state: ValidationState,
24    /// Undo/redo history
25    pub history: Vec<WorkflowSnapshot>,
26    /// Current history index
27    pub history_index: usize,
28}
29
30/// Position of a component on the visual canvas
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct Position {
33    /// X coordinate
34    pub x: f64,
35    /// Y coordinate
36    pub y: f64,
37    /// Width of the component
38    pub width: f64,
39    /// Height of the component
40    pub height: f64,
41}
42
43/// Canvas configuration for the visual builder
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct CanvasConfig {
46    /// Canvas width
47    pub width: f64,
48    /// Canvas height
49    pub height: f64,
50    /// Grid size for snapping
51    pub grid_size: f64,
52    /// Zoom level
53    pub zoom: f64,
54    /// Pan offset X
55    pub pan_x: f64,
56    /// Pan offset Y
57    pub pan_y: f64,
58    /// Enable grid snapping
59    pub snap_to_grid: bool,
60    /// Show grid lines
61    pub show_grid: bool,
62}
63
64/// Validation state for the workflow
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ValidationState {
67    /// Whether the workflow is valid
68    pub is_valid: bool,
69    /// Validation errors
70    pub errors: Vec<ValidationError>,
71    /// Validation warnings
72    pub warnings: Vec<ValidationWarning>,
73    /// Last validation timestamp
74    pub last_validated: String,
75}
76
77/// Validation error
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ValidationError {
80    /// Error type
81    pub error_type: ValidationErrorType,
82    /// Error message
83    pub message: String,
84    /// Step ID where error occurred (if applicable)
85    pub step_id: Option<String>,
86    /// Connection ID where error occurred (if applicable)
87    pub connection_id: Option<String>,
88}
89
90/// Validation warning
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ValidationWarning {
93    /// Warning type
94    pub warning_type: ValidationWarningType,
95    /// Warning message
96    pub message: String,
97    /// Step ID where warning occurred (if applicable)
98    pub step_id: Option<String>,
99}
100
101/// Types of validation errors
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum ValidationErrorType {
104    /// Missing required input
105    MissingInput,
106    /// Type mismatch between connected steps
107    TypeMismatch,
108    /// Circular dependency detected
109    CircularDependency,
110    /// Disconnected component
111    DisconnectedComponent,
112    /// Invalid parameter value
113    InvalidParameter,
114    /// Missing required step
115    MissingStep,
116}
117
118/// Types of validation warnings
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub enum ValidationWarningType {
121    /// Unused output
122    UnusedOutput,
123    /// Performance concern
124    PerformanceConcern,
125    /// Deprecated component
126    DeprecatedComponent,
127    /// Suboptimal configuration
128    SuboptimalConfiguration,
129}
130
131/// Workflow snapshot for undo/redo functionality
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct WorkflowSnapshot {
134    /// Snapshot of the workflow state
135    pub workflow: WorkflowDefinition,
136    /// Snapshot of component positions
137    pub positions: HashMap<String, Position>,
138    /// Description of the action that created this snapshot
139    pub action_description: String,
140    /// Timestamp of the snapshot
141    pub timestamp: String,
142}
143
144impl VisualPipelineBuilder {
145    /// Create a new visual pipeline builder
146    #[must_use]
147    pub fn new() -> Self {
148        let workflow = WorkflowDefinition::default();
149        let initial_snapshot = WorkflowSnapshot {
150            workflow: workflow.clone(),
151            positions: HashMap::new(),
152            action_description: "Initial state".to_string(),
153            timestamp: chrono::Utc::now().to_rfc3339(),
154        };
155
156        Self {
157            workflow,
158            component_positions: HashMap::new(),
159            canvas_config: CanvasConfig::default(),
160            validation_state: ValidationState::new(),
161            history: vec![initial_snapshot],
162            history_index: 0,
163        }
164    }
165
166    /// Add a step to the workflow
167    pub fn add_step(&mut self, step: StepDefinition) -> SklResult<()> {
168        // Check if step ID already exists
169        if self.workflow.steps.iter().any(|s| s.id == step.id) {
170            return Err(SklearsError::InvalidInput(format!(
171                "Step with ID '{}' already exists",
172                step.id
173            )));
174        }
175
176        // Add step to workflow
177        self.workflow.steps.push(step.clone());
178
179        // Add default position
180        let position = self.find_optimal_position(&step.id);
181        self.component_positions.insert(step.id.clone(), position);
182
183        // Create snapshot for undo/redo
184        self.create_snapshot(&format!("Added step '{}'", step.id));
185
186        // Validate workflow and update state
187        let _ = self.validate();
188
189        Ok(())
190    }
191
192    /// Remove a step from the workflow
193    pub fn remove_step(&mut self, step_id: &str) -> SklResult<()> {
194        // Remove step
195        let initial_count = self.workflow.steps.len();
196        self.workflow.steps.retain(|s| s.id != step_id);
197
198        if self.workflow.steps.len() == initial_count {
199            return Err(SklearsError::InvalidInput(format!(
200                "Step '{step_id}' not found"
201            )));
202        }
203
204        // Remove position
205        self.component_positions.remove(step_id);
206
207        // Remove connections involving this step
208        self.workflow
209            .connections
210            .retain(|c| c.from_step != step_id && c.to_step != step_id);
211
212        // Create snapshot
213        self.create_snapshot(&format!("Removed step '{step_id}'"));
214
215        // Validate workflow and update state
216        let _ = self.validate();
217
218        Ok(())
219    }
220
221    /// Add a connection between steps
222    pub fn add_connection(&mut self, connection: Connection) -> SklResult<()> {
223        // Validate that both steps exist
224        let from_exists = self
225            .workflow
226            .steps
227            .iter()
228            .any(|s| s.id == connection.from_step);
229        let to_exists = self
230            .workflow
231            .steps
232            .iter()
233            .any(|s| s.id == connection.to_step);
234
235        if !from_exists {
236            return Err(SklearsError::InvalidInput(format!(
237                "Source step '{}' not found",
238                connection.from_step
239            )));
240        }
241
242        if !to_exists {
243            return Err(SklearsError::InvalidInput(format!(
244                "Target step '{}' not found",
245                connection.to_step
246            )));
247        }
248
249        // Check for duplicate connections
250        let duplicate = self.workflow.connections.iter().any(|c| {
251            c.from_step == connection.from_step
252                && c.from_output == connection.from_output
253                && c.to_step == connection.to_step
254                && c.to_input == connection.to_input
255        });
256
257        if duplicate {
258            return Err(SklearsError::InvalidInput(
259                "Connection already exists".to_string(),
260            ));
261        }
262
263        // Add connection
264        self.workflow.connections.push(connection.clone());
265
266        // Create snapshot
267        self.create_snapshot(&format!(
268            "Added connection from '{}' to '{}'",
269            connection.from_step, connection.to_step
270        ));
271
272        // Validate workflow and update state
273        let _ = self.validate();
274
275        Ok(())
276    }
277
278    /// Remove a connection
279    pub fn remove_connection(
280        &mut self,
281        from_step: &str,
282        from_output: &str,
283        to_step: &str,
284        to_input: &str,
285    ) -> SklResult<()> {
286        let initial_count = self.workflow.connections.len();
287
288        self.workflow.connections.retain(|c| {
289            !(c.from_step == from_step
290                && c.from_output == from_output
291                && c.to_step == to_step
292                && c.to_input == to_input)
293        });
294
295        if self.workflow.connections.len() == initial_count {
296            return Err(SklearsError::InvalidInput(
297                "Connection not found".to_string(),
298            ));
299        }
300
301        // Create snapshot
302        self.create_snapshot(&format!(
303            "Removed connection from '{from_step}' to '{to_step}'"
304        ));
305
306        // Validate workflow and update state
307        let _ = self.validate();
308
309        Ok(())
310    }
311
312    /// Move a component to a new position
313    pub fn move_component(&mut self, step_id: &str, position: Position) -> SklResult<()> {
314        if !self.workflow.steps.iter().any(|s| s.id == step_id) {
315            return Err(SklearsError::InvalidInput(format!(
316                "Step '{step_id}' not found"
317            )));
318        }
319
320        let final_position = if self.canvas_config.snap_to_grid {
321            self.snap_to_grid(position)
322        } else {
323            position
324        };
325
326        self.component_positions
327            .insert(step_id.to_string(), final_position);
328
329        // Create snapshot for significant moves (optional - might be too frequent)
330        // self.create_snapshot(&format!("Moved component '{}'", step_id));
331
332        Ok(())
333    }
334
335    /// Update step parameters
336    pub fn update_step_parameters(
337        &mut self,
338        step_id: &str,
339        parameters: BTreeMap<String, ParameterValue>,
340    ) -> SklResult<()> {
341        let step = self
342            .workflow
343            .steps
344            .iter_mut()
345            .find(|s| s.id == step_id)
346            .ok_or_else(|| SklearsError::InvalidInput(format!("Step '{step_id}' not found")))?;
347
348        step.parameters = parameters;
349
350        // Create snapshot
351        self.create_snapshot(&format!("Updated parameters for '{step_id}'"));
352
353        // Validate workflow and update state
354        let _ = self.validate();
355
356        Ok(())
357    }
358
359    /// Validate the current workflow and update the validation state.
360    ///
361    /// This method performs a comprehensive validation pass, updating
362    /// [`self.validation_state`] with the latest errors and warnings. It does
363    /// **not** short-circuit the caller when issues are detected; instead,
364    /// callers are expected to inspect the returned [`ValidationState`] and
365    /// decide whether to proceed.
366    #[must_use]
367    pub fn validate(&mut self) -> &ValidationState {
368        let mut errors = Vec::new();
369        let mut warnings = Vec::new();
370
371        // Check for disconnected components
372        for step in &self.workflow.steps {
373            let has_input = self
374                .workflow
375                .connections
376                .iter()
377                .any(|c| c.to_step == step.id);
378            let has_output = self
379                .workflow
380                .connections
381                .iter()
382                .any(|c| c.from_step == step.id);
383
384            if !has_input && !has_output && self.workflow.steps.len() > 1 {
385                errors.push(ValidationError {
386                    error_type: ValidationErrorType::DisconnectedComponent,
387                    message: format!(
388                        "Step '{}' is not connected to any other components",
389                        step.id
390                    ),
391                    step_id: Some(step.id.clone()),
392                    connection_id: None,
393                });
394            }
395        }
396
397        // Check for circular dependencies
398        if self.has_circular_dependency() {
399            errors.push(ValidationError {
400                error_type: ValidationErrorType::CircularDependency,
401                message: "Circular dependency detected in workflow".to_string(),
402                step_id: None,
403                connection_id: None,
404            });
405        }
406
407        // Check for unused outputs
408        for step in &self.workflow.steps {
409            for output in &step.outputs {
410                let is_used = self
411                    .workflow
412                    .connections
413                    .iter()
414                    .any(|c| c.from_step == step.id && c.from_output == *output);
415
416                if !is_used {
417                    warnings.push(ValidationWarning {
418                        warning_type: ValidationWarningType::UnusedOutput,
419                        message: format!("Output '{}' of step '{}' is not used", output, step.id),
420                        step_id: Some(step.id.clone()),
421                    });
422                }
423            }
424        }
425
426        self.validation_state = ValidationState {
427            is_valid: errors.is_empty(),
428            errors,
429            warnings,
430            last_validated: chrono::Utc::now().to_rfc3339(),
431        };
432
433        &self.validation_state
434    }
435
436    /// Check for circular dependencies in the workflow
437    fn has_circular_dependency(&self) -> bool {
438        let mut visited = HashMap::new();
439        let mut recursion_stack = HashMap::new();
440
441        for step in &self.workflow.steps {
442            if !visited.get(&step.id).unwrap_or(&false)
443                && self.has_cycle_util(&step.id, &mut visited, &mut recursion_stack)
444            {
445                return true;
446            }
447        }
448
449        false
450    }
451
452    /// Utility function for cycle detection
453    fn has_cycle_util(
454        &self,
455        step_id: &str,
456        visited: &mut HashMap<String, bool>,
457        recursion_stack: &mut HashMap<String, bool>,
458    ) -> bool {
459        visited.insert(step_id.to_string(), true);
460        recursion_stack.insert(step_id.to_string(), true);
461
462        // Get all steps that depend on this step
463        for connection in &self.workflow.connections {
464            if connection.from_step == step_id {
465                let next_step = &connection.to_step;
466
467                if !visited.get(next_step).unwrap_or(&false) {
468                    if self.has_cycle_util(next_step, visited, recursion_stack) {
469                        return true;
470                    }
471                } else if *recursion_stack.get(next_step).unwrap_or(&false) {
472                    return true;
473                }
474            }
475        }
476
477        recursion_stack.insert(step_id.to_string(), false);
478        false
479    }
480
481    /// Find optimal position for a new component
482    fn find_optimal_position(&self, _step_id: &str) -> Position {
483        // Simple positioning logic - place components in a grid
484        let num_components = self.component_positions.len();
485        let grid_cols = 4;
486        let component_width = 120.0;
487        let component_height = 80.0;
488        let spacing_x = 160.0;
489        let spacing_y = 120.0;
490
491        let col = num_components % grid_cols;
492        let row = num_components / grid_cols;
493
494        Position {
495            x: 50.0 + col as f64 * spacing_x,
496            y: 50.0 + row as f64 * spacing_y,
497            width: component_width,
498            height: component_height,
499        }
500    }
501
502    /// Snap position to grid
503    #[must_use]
504    pub fn snap_to_grid(&self, position: Position) -> Position {
505        let grid_size = self.canvas_config.grid_size;
506        Position {
507            x: (position.x / grid_size).round() * grid_size,
508            y: (position.y / grid_size).round() * grid_size,
509            width: position.width,
510            height: position.height,
511        }
512    }
513
514    /// Create a snapshot for undo/redo
515    fn create_snapshot(&mut self, action_description: &str) {
516        let snapshot = WorkflowSnapshot {
517            workflow: self.workflow.clone(),
518            positions: self.component_positions.clone(),
519            action_description: action_description.to_string(),
520            timestamp: chrono::Utc::now().to_rfc3339(),
521        };
522
523        // Remove any snapshots after current index (when creating new branch)
524        self.history.truncate(self.history_index + 1);
525
526        // Add new snapshot
527        self.history.push(snapshot);
528        self.history_index = self.history.len() - 1;
529
530        // Limit history size
531        const MAX_HISTORY_SIZE: usize = 50;
532        if self.history.len() > MAX_HISTORY_SIZE {
533            self.history.remove(0);
534            self.history_index = self.history.len() - 1;
535        }
536    }
537
538    /// Undo last action
539    pub fn undo(&mut self) -> SklResult<()> {
540        if self.history_index == 0 {
541            return Err(SklearsError::InvalidInput("Nothing to undo".to_string()));
542        }
543
544        self.history_index -= 1;
545        let snapshot = &self.history[self.history_index];
546
547        self.workflow = snapshot.workflow.clone();
548        self.component_positions = snapshot.positions.clone();
549
550        Ok(())
551    }
552
553    /// Redo last undone action
554    pub fn redo(&mut self) -> SklResult<()> {
555        if self.history_index >= self.history.len() - 1 {
556            return Err(SklearsError::InvalidInput("Nothing to redo".to_string()));
557        }
558
559        self.history_index += 1;
560        let snapshot = &self.history[self.history_index];
561
562        self.workflow = snapshot.workflow.clone();
563        self.component_positions = snapshot.positions.clone();
564
565        Ok(())
566    }
567
568    /// Get current workflow
569    #[must_use]
570    pub fn get_workflow(&self) -> &WorkflowDefinition {
571        &self.workflow
572    }
573
574    /// Get component positions
575    #[must_use]
576    pub fn get_component_positions(&self) -> &HashMap<String, Position> {
577        &self.component_positions
578    }
579
580    /// Get validation state
581    #[must_use]
582    pub fn get_validation_state(&self) -> &ValidationState {
583        &self.validation_state
584    }
585
586    /// Clear the workflow
587    pub fn clear(&mut self) {
588        self.workflow = WorkflowDefinition::default();
589        self.component_positions.clear();
590        self.validation_state = ValidationState::new();
591
592        // Create snapshot
593        self.create_snapshot("Cleared workflow");
594    }
595
596    /// Set canvas configuration
597    pub fn set_canvas_config(&mut self, config: CanvasConfig) {
598        self.canvas_config = config;
599    }
600
601    /// Get canvas configuration
602    #[must_use]
603    pub fn get_canvas_config(&self) -> &CanvasConfig {
604        &self.canvas_config
605    }
606}
607
608impl Default for VisualPipelineBuilder {
609    fn default() -> Self {
610        Self::new()
611    }
612}
613
614impl Default for CanvasConfig {
615    fn default() -> Self {
616        Self {
617            width: 1200.0,
618            height: 800.0,
619            grid_size: 20.0,
620            zoom: 1.0,
621            pan_x: 0.0,
622            pan_y: 0.0,
623            snap_to_grid: true,
624            show_grid: true,
625        }
626    }
627}
628
629impl ValidationState {
630    #[must_use]
631    pub fn new() -> Self {
632        Self {
633            is_valid: true,
634            errors: Vec::new(),
635            warnings: Vec::new(),
636            last_validated: chrono::Utc::now().to_rfc3339(),
637        }
638    }
639}
640
641impl Default for ValidationState {
642    fn default() -> Self {
643        Self::new()
644    }
645}
646
647/// Type alias for component position
648pub type ComponentPosition = Position;
649
650/// Canvas interaction state for drag-and-drop operations
651#[derive(Debug, Clone, Serialize, Deserialize)]
652pub struct CanvasInteraction {
653    /// Current interaction mode
654    pub mode: InteractionMode,
655    /// Selected components
656    pub selected_components: Vec<String>,
657    /// Current drag state
658    pub drag_state: Option<DragState>,
659    /// Selection state
660    pub selection_state: SelectionState,
661}
662
663/// Interaction modes for the canvas
664#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
665pub enum InteractionMode {
666    /// Normal selection mode
667    Select,
668    /// Panning the canvas
669    Pan,
670    /// Zooming the canvas
671    Zoom,
672    /// Drawing connections
673    Connect,
674    /// Adding new components
675    AddComponent,
676}
677
678/// Drag and drop state
679#[derive(Debug, Clone, Serialize, Deserialize)]
680pub struct DragState {
681    /// Component being dragged
682    pub component_id: String,
683    /// Start position of drag
684    pub start_position: Position,
685    /// Current position during drag
686    pub current_position: Position,
687    /// Offset from component center
688    pub offset: Position,
689}
690
691/// Selection state management
692#[derive(Debug, Clone, Serialize, Deserialize, Default)]
693pub struct SelectionState {
694    /// Currently selected components
695    pub selected: Vec<String>,
696    /// Selection box coordinates (if active)
697    pub selection_box: Option<SelectionBox>,
698    /// Multi-select mode enabled
699    pub multi_select: bool,
700}
701
702/// Selection box for area selection
703#[derive(Debug, Clone, Serialize, Deserialize)]
704pub struct SelectionBox {
705    /// Start corner of selection box
706    pub start: Position,
707    /// End corner of selection box
708    pub end: Position,
709}
710
711/// Type alias for grid configuration
712pub type GridConfig = CanvasConfig;
713
714/// Viewport configuration for canvas display
715#[derive(Debug, Clone, Serialize, Deserialize)]
716pub struct ViewportConfig {
717    /// Viewport width
718    pub width: f64,
719    /// Viewport height
720    pub height: f64,
721    /// Pan offset X
722    pub pan_x: f64,
723    /// Pan offset Y
724    pub pan_y: f64,
725    /// Zoom level
726    pub zoom: f64,
727}
728
729/// Zoom configuration settings
730#[derive(Debug, Clone, Serialize, Deserialize)]
731pub struct ZoomConfig {
732    /// Current zoom level
733    pub level: f64,
734    /// Minimum zoom level
735    pub min_zoom: f64,
736    /// Maximum zoom level
737    pub max_zoom: f64,
738    /// Zoom step size
739    pub zoom_step: f64,
740    /// Zoom center point
741    pub center: Position,
742}
743
744/// Type alias for workflow history
745pub type WorkflowHistory = Vec<WorkflowSnapshot>;
746
747/// Undo/Redo manager for workflow operations
748#[derive(Debug, Clone, Serialize, Deserialize)]
749pub struct UndoRedoManager {
750    /// History of workflow snapshots
751    pub history: WorkflowHistory,
752    /// Current position in history
753    pub current_index: usize,
754    /// Maximum history size
755    pub max_history_size: usize,
756}
757
758impl Default for CanvasInteraction {
759    fn default() -> Self {
760        Self {
761            mode: InteractionMode::Select,
762            selected_components: Vec::new(),
763            drag_state: None,
764            selection_state: SelectionState::default(),
765        }
766    }
767}
768
769impl Default for ViewportConfig {
770    fn default() -> Self {
771        Self {
772            width: 1200.0,
773            height: 800.0,
774            pan_x: 0.0,
775            pan_y: 0.0,
776            zoom: 1.0,
777        }
778    }
779}
780
781impl Default for ZoomConfig {
782    fn default() -> Self {
783        Self {
784            level: 1.0,
785            min_zoom: 0.1,
786            max_zoom: 5.0,
787            zoom_step: 0.1,
788            center: Position {
789                x: 0.0,
790                y: 0.0,
791                width: 0.0,
792                height: 0.0,
793            },
794        }
795    }
796}
797
798impl UndoRedoManager {
799    /// Create a new undo/redo manager
800    #[must_use]
801    pub fn new() -> Self {
802        Self {
803            history: Vec::new(),
804            current_index: 0,
805            max_history_size: 50,
806        }
807    }
808
809    /// Add a new snapshot to history
810    pub fn add_snapshot(&mut self, snapshot: WorkflowSnapshot) {
811        // Remove any snapshots after current index
812        self.history.truncate(self.current_index + 1);
813
814        // Add new snapshot
815        self.history.push(snapshot);
816        self.current_index = self.history.len() - 1;
817
818        // Limit history size
819        if self.history.len() > self.max_history_size {
820            self.history.remove(0);
821            self.current_index = self.history.len() - 1;
822        }
823    }
824
825    /// Check if undo is available
826    #[must_use]
827    pub fn can_undo(&self) -> bool {
828        self.current_index > 0
829    }
830
831    /// Check if redo is available
832    #[must_use]
833    pub fn can_redo(&self) -> bool {
834        self.current_index < self.history.len() - 1
835    }
836}
837
838impl Default for UndoRedoManager {
839    fn default() -> Self {
840        Self::new()
841    }
842}
843
844#[allow(non_snake_case)]
845#[cfg(test)]
846mod tests {
847    use super::*;
848    use crate::workflow_language::workflow_definitions::{DataType, StepType};
849
850    #[test]
851    fn test_visual_pipeline_builder_creation() {
852        let builder = VisualPipelineBuilder::new();
853        assert_eq!(builder.workflow.steps.len(), 0);
854        assert_eq!(builder.component_positions.len(), 0);
855        assert!(builder.validation_state.is_valid);
856        assert_eq!(builder.history.len(), 1); // Initial snapshot
857    }
858
859    #[test]
860    fn test_add_step() {
861        let mut builder = VisualPipelineBuilder::new();
862
863        let step = StepDefinition::new("step1", StepType::Transformer, "StandardScaler");
864        let result = builder.add_step(step);
865
866        assert!(result.is_ok());
867        assert_eq!(builder.workflow.steps.len(), 1);
868        assert_eq!(builder.component_positions.len(), 1);
869        assert!(builder.component_positions.contains_key("step1"));
870    }
871
872    #[test]
873    fn test_add_duplicate_step() {
874        let mut builder = VisualPipelineBuilder::new();
875
876        let step1 = StepDefinition::new("step1", StepType::Transformer, "StandardScaler");
877        let step2 = StepDefinition::new("step1", StepType::Predictor, "LinearRegression");
878
879        assert!(builder.add_step(step1).is_ok());
880        assert!(builder.add_step(step2).is_err());
881    }
882
883    #[test]
884    fn test_remove_step() {
885        let mut builder = VisualPipelineBuilder::new();
886
887        let step = StepDefinition::new("step1", StepType::Transformer, "StandardScaler");
888        builder.add_step(step).unwrap();
889
890        assert_eq!(builder.workflow.steps.len(), 1);
891
892        let result = builder.remove_step("step1");
893        assert!(result.is_ok());
894        assert_eq!(builder.workflow.steps.len(), 0);
895        assert!(!builder.component_positions.contains_key("step1"));
896    }
897
898    #[test]
899    fn test_add_connection() {
900        let mut builder = VisualPipelineBuilder::new();
901
902        let step1 = StepDefinition::new("step1", StepType::Transformer, "StandardScaler")
903            .with_output("X_scaled");
904        let step2 =
905            StepDefinition::new("step2", StepType::Predictor, "LinearRegression").with_input("X");
906
907        builder.add_step(step1).unwrap();
908        builder.add_step(step2).unwrap();
909
910        let connection = Connection::direct("step1", "X_scaled", "step2", "X");
911        let result = builder.add_connection(connection);
912
913        assert!(result.is_ok());
914        assert_eq!(builder.workflow.connections.len(), 1);
915    }
916
917    #[test]
918    fn test_undo_redo() {
919        let mut builder = VisualPipelineBuilder::new();
920
921        let step = StepDefinition::new("step1", StepType::Transformer, "StandardScaler");
922        builder.add_step(step).unwrap();
923
924        assert_eq!(builder.workflow.steps.len(), 1);
925        assert_eq!(builder.history_index, 1);
926
927        // Undo
928        builder.undo().unwrap();
929        assert_eq!(builder.workflow.steps.len(), 0);
930        assert_eq!(builder.history_index, 0);
931
932        // Redo
933        builder.redo().unwrap();
934        assert_eq!(builder.workflow.steps.len(), 1);
935        assert_eq!(builder.history_index, 1);
936    }
937
938    #[test]
939    fn test_snap_to_grid() {
940        let builder = VisualPipelineBuilder::new();
941        let position = Position {
942            x: 23.7,
943            y: 47.3,
944            width: 100.0,
945            height: 80.0,
946        };
947
948        let snapped = builder.snap_to_grid(position);
949        assert_eq!(snapped.x, 20.0);
950        assert_eq!(snapped.y, 40.0);
951    }
952
953    #[test]
954    fn test_validation_disconnected_component() {
955        let mut builder = VisualPipelineBuilder::new();
956
957        let step1 = StepDefinition::new("step1", StepType::Transformer, "StandardScaler");
958        let step2 = StepDefinition::new("step2", StepType::Predictor, "LinearRegression");
959
960        builder.add_step(step1).unwrap();
961        builder.add_step(step2).unwrap();
962
963        // Both steps are disconnected, validation should fail
964        assert!(!builder.validation_state.is_valid);
965        assert!(!builder.validation_state.errors.is_empty());
966    }
967}