Skip to main content

tensorlogic_infer/
auto_parallel.rs

1//! Automatic parallelization for computation graphs.
2//!
3//! This module provides automatic detection and exploitation of parallelism opportunities:
4//! - **Dependency analysis**: Build dependency graphs and detect parallelizable operations
5//! - **Cost modeling**: Estimate execution costs and communication overhead
6//! - **Work partitioning**: Dynamically partition work across threads/devices
7//! - **Load balancing**: Balance work to minimize idle time
8//! - **Pipeline detection**: Identify pipeline parallelism opportunities
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use tensorlogic_infer::{AutoParallelizer, ParallelizationStrategy, CostModel};
14//!
15//! // Create auto-parallelizer with cost model
16//! let parallelizer = AutoParallelizer::new()
17//!     .with_strategy(ParallelizationStrategy::Aggressive)
18//!     .with_cost_model(CostModel::ProfileBased);
19//!
20//! // Analyze graph for parallelism
21//! let analysis = parallelizer.analyze(&graph)?;
22//! println!("Found {} parallelizable stages", analysis.num_stages);
23//!
24//! // Generate parallel execution plan
25//! let plan = parallelizer.generate_plan(&graph)?;
26//! ```
27
28use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet, VecDeque};
30use thiserror::Error;
31
32/// Auto-parallelization errors.
33#[derive(Error, Debug, Clone, PartialEq)]
34pub enum AutoParallelError {
35    #[error("Dependency cycle detected: {0}")]
36    DependencyCycle(String),
37
38    #[error("Invalid graph: {0}")]
39    InvalidGraph(String),
40
41    #[error("Cost model error: {0}")]
42    CostModelError(String),
43
44    #[error("Partitioning failed: {0}")]
45    PartitioningFailed(String),
46}
47
48/// Node ID in the computation graph.
49pub type NodeId = String;
50
51/// Parallelization strategy.
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
53pub enum ParallelizationStrategy {
54    /// Conservative: Only parallelize when clearly beneficial
55    Conservative,
56    /// Balanced: Balance parallelism and overhead
57    Balanced,
58    /// Aggressive: Maximize parallelism even with potential overhead
59    Aggressive,
60    /// Cost-based: Use cost model to decide
61    CostBased,
62}
63
64/// Cost model type.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66pub enum CostModel {
67    /// Simple heuristic-based cost model
68    Heuristic,
69    /// Profile-based cost model using historical data
70    ProfileBased,
71    /// Analytical cost model based on operation complexity
72    Analytical,
73    /// Hybrid approach combining multiple models
74    Hybrid,
75}
76
77/// Dependency type between nodes.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum DependencyType {
80    /// Data dependency: consumer needs producer's output
81    Data,
82    /// Control dependency: execution order matters
83    Control,
84    /// Memory dependency: shared memory access
85    Memory,
86}
87
88/// Node information for parallelization analysis.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct NodeInfo {
91    pub id: NodeId,
92    pub op_type: String,
93    pub estimated_cost: f64, // in microseconds
94    pub memory_size: usize,  // in bytes
95    pub dependencies: Vec<(NodeId, DependencyType)>,
96    pub can_parallelize: bool,
97}
98
99/// Parallel stage containing nodes that can execute concurrently.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ParallelStage {
102    pub stage_id: usize,
103    pub nodes: Vec<NodeId>,
104    pub estimated_time: f64,
105    pub memory_requirement: usize,
106    pub predecessors: Vec<usize>, // Stages that must complete before this
107}
108
109/// Work partition for a single worker.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct WorkPartition {
112    pub worker_id: usize,
113    pub nodes: Vec<NodeId>,
114    pub estimated_load: f64,
115}
116
117/// Parallelization analysis results.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ParallelizationAnalysis {
120    pub num_stages: usize,
121    pub stages: Vec<ParallelStage>,
122    pub critical_path_length: f64,
123    pub total_work: f64,
124    pub parallelism_factor: f64, // total_work / critical_path_length
125    pub communication_overhead: f64,
126    pub recommended_workers: usize,
127}
128
129/// Parallel execution plan.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ParallelExecutionPlan {
132    pub stages: Vec<ParallelStage>,
133    pub partitions: Vec<WorkPartition>,
134    pub estimated_speedup: f64,
135    pub load_balance_ratio: f64,
136}
137
138/// Automatic parallelizer.
139pub struct AutoParallelizer {
140    strategy: ParallelizationStrategy,
141    cost_model: CostModel,
142    max_workers: usize,
143    overhead_per_task: f64,             // microseconds
144    communication_bandwidth: f64,       // GB/s
145    profile_data: HashMap<String, f64>, // op_type -> avg_time_us
146}
147
148impl AutoParallelizer {
149    /// Create a new auto-parallelizer with default settings.
150    pub fn new() -> Self {
151        Self {
152            strategy: ParallelizationStrategy::Balanced,
153            cost_model: CostModel::Heuristic,
154            max_workers: num_cpus::get(),
155            overhead_per_task: 10.0,        // 10 microseconds per task
156            communication_bandwidth: 100.0, // 100 GB/s
157            profile_data: HashMap::new(),
158        }
159    }
160
161    /// Set parallelization strategy.
162    pub fn with_strategy(mut self, strategy: ParallelizationStrategy) -> Self {
163        self.strategy = strategy;
164        self
165    }
166
167    /// Set cost model.
168    pub fn with_cost_model(mut self, model: CostModel) -> Self {
169        self.cost_model = model;
170        self
171    }
172
173    /// Set maximum number of workers.
174    pub fn with_max_workers(mut self, workers: usize) -> Self {
175        self.max_workers = workers;
176        self
177    }
178
179    /// Update profile data with observed execution times.
180    pub fn update_profile(&mut self, op_type: String, time_us: f64) {
181        let entry = self.profile_data.entry(op_type).or_insert(0.0);
182        *entry = 0.9 * *entry + 0.1 * time_us; // Exponential moving average
183    }
184
185    /// Analyze graph for parallelization opportunities.
186    pub fn analyze(
187        &self,
188        nodes: &[NodeInfo],
189    ) -> Result<ParallelizationAnalysis, AutoParallelError> {
190        // Build dependency graph
191        let dep_graph = self.build_dependency_graph(nodes)?;
192
193        // Topological sort to find stages
194        let stages = self.compute_stages(nodes, &dep_graph)?;
195
196        // Calculate critical path
197        let critical_path_length = self.calculate_critical_path(&stages);
198
199        // Calculate total work
200        let total_work: f64 = nodes.iter().map(|n| n.estimated_cost).sum();
201
202        // Estimate communication overhead
203        let communication_overhead = self.estimate_communication_overhead(&stages, nodes);
204
205        // Calculate parallelism factor
206        let parallelism_factor = if critical_path_length > 0.0 {
207            total_work / critical_path_length
208        } else {
209            1.0
210        };
211
212        // Recommend number of workers
213        let recommended_workers = self.recommend_worker_count(parallelism_factor);
214
215        Ok(ParallelizationAnalysis {
216            num_stages: stages.len(),
217            stages,
218            critical_path_length,
219            total_work,
220            parallelism_factor,
221            communication_overhead,
222            recommended_workers,
223        })
224    }
225
226    /// Generate parallel execution plan.
227    pub fn generate_plan(
228        &self,
229        nodes: &[NodeInfo],
230    ) -> Result<ParallelExecutionPlan, AutoParallelError> {
231        let analysis = self.analyze(nodes)?;
232
233        // Partition work across workers
234        let partitions = self.partition_work(&analysis)?;
235
236        // Calculate estimated speedup
237        let sequential_time = analysis.total_work;
238        let parallel_time = analysis.critical_path_length + analysis.communication_overhead;
239        let estimated_speedup = if parallel_time > 0.0 {
240            sequential_time / parallel_time
241        } else {
242            1.0
243        };
244
245        // Calculate load balance ratio
246        let load_balance_ratio = self.calculate_load_balance(&partitions);
247
248        Ok(ParallelExecutionPlan {
249            stages: analysis.stages,
250            partitions,
251            estimated_speedup,
252            load_balance_ratio,
253        })
254    }
255
256    /// Build dependency graph from nodes.
257    fn build_dependency_graph(
258        &self,
259        nodes: &[NodeInfo],
260    ) -> Result<HashMap<NodeId, HashSet<NodeId>>, AutoParallelError> {
261        let mut graph: HashMap<NodeId, HashSet<NodeId>> = HashMap::new();
262
263        // Initialize graph with all nodes
264        for node in nodes {
265            graph.entry(node.id.clone()).or_insert_with(HashSet::new);
266        }
267
268        // Add edges
269        for node in nodes {
270            for (dep_id, _dep_type) in &node.dependencies {
271                if !graph.contains_key(dep_id) {
272                    return Err(AutoParallelError::InvalidGraph(format!(
273                        "Unknown dependency: {}",
274                        dep_id
275                    )));
276                }
277                graph
278                    .entry(node.id.clone())
279                    .or_insert_with(HashSet::new)
280                    .insert(dep_id.clone());
281            }
282        }
283
284        // Check for cycles
285        self.check_cycles(&graph)?;
286
287        Ok(graph)
288    }
289
290    /// Check for dependency cycles using DFS.
291    fn check_cycles(
292        &self,
293        graph: &HashMap<NodeId, HashSet<NodeId>>,
294    ) -> Result<(), AutoParallelError> {
295        let mut visited = HashSet::new();
296        let mut rec_stack = HashSet::new();
297
298        for node in graph.keys() {
299            if !visited.contains(node) {
300                if self.has_cycle_util(node, graph, &mut visited, &mut rec_stack)? {
301                    return Err(AutoParallelError::DependencyCycle(format!(
302                        "Cycle detected involving node: {}",
303                        node
304                    )));
305                }
306            }
307        }
308
309        Ok(())
310    }
311
312    fn has_cycle_util(
313        &self,
314        node: &NodeId,
315        graph: &HashMap<NodeId, HashSet<NodeId>>,
316        visited: &mut HashSet<NodeId>,
317        rec_stack: &mut HashSet<NodeId>,
318    ) -> Result<bool, AutoParallelError> {
319        visited.insert(node.clone());
320        rec_stack.insert(node.clone());
321
322        if let Some(neighbors) = graph.get(node) {
323            for neighbor in neighbors {
324                if !visited.contains(neighbor) {
325                    if self.has_cycle_util(neighbor, graph, visited, rec_stack)? {
326                        return Ok(true);
327                    }
328                } else if rec_stack.contains(neighbor) {
329                    return Ok(true);
330                }
331            }
332        }
333
334        rec_stack.remove(node);
335        Ok(false)
336    }
337
338    /// Compute parallel stages using level-based topological sort.
339    fn compute_stages(
340        &self,
341        nodes: &[NodeInfo],
342        dep_graph: &HashMap<NodeId, HashSet<NodeId>>,
343    ) -> Result<Vec<ParallelStage>, AutoParallelError> {
344        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
345        let mut node_map: HashMap<NodeId, &NodeInfo> = HashMap::new();
346
347        // Initialize in-degree and node map
348        for node in nodes {
349            node_map.insert(node.id.clone(), node);
350            let deps = dep_graph.get(&node.id).unwrap();
351            in_degree.insert(node.id.clone(), deps.len());
352        }
353
354        let mut stages = Vec::new();
355        let mut current_level: VecDeque<NodeId> = VecDeque::new();
356
357        // Find nodes with no dependencies
358        for (node_id, &degree) in &in_degree {
359            if degree == 0 {
360                current_level.push_back(node_id.clone());
361            }
362        }
363
364        let mut stage_id = 0;
365        while !current_level.is_empty() {
366            let mut stage_nodes = Vec::new();
367            let mut estimated_time: f64 = 0.0;
368            let mut memory_requirement = 0;
369
370            // Process all nodes at current level
371            for _ in 0..current_level.len() {
372                if let Some(node_id) = current_level.pop_front() {
373                    let node = node_map[&node_id];
374                    stage_nodes.push(node_id.clone());
375                    estimated_time = estimated_time.max(node.estimated_cost);
376                    memory_requirement += node.memory_size;
377
378                    // Decrease in-degree of dependent nodes
379                    for other_id in node_map.keys() {
380                        if dep_graph[other_id].contains(&node_id) {
381                            if let Some(degree) = in_degree.get_mut(other_id) {
382                                *degree -= 1;
383                                if *degree == 0 {
384                                    current_level.push_back(other_id.clone());
385                                }
386                            }
387                        }
388                    }
389                }
390            }
391
392            if !stage_nodes.is_empty() {
393                stages.push(ParallelStage {
394                    stage_id,
395                    nodes: stage_nodes,
396                    estimated_time,
397                    memory_requirement,
398                    predecessors: if stage_id > 0 {
399                        vec![stage_id - 1]
400                    } else {
401                        vec![]
402                    },
403                });
404                stage_id += 1;
405            }
406        }
407
408        // Check if all nodes were processed
409        if stages.iter().map(|s| s.nodes.len()).sum::<usize>() != nodes.len() {
410            return Err(AutoParallelError::DependencyCycle(
411                "Not all nodes were processed - cycle detected".to_string(),
412            ));
413        }
414
415        Ok(stages)
416    }
417
418    /// Calculate critical path length.
419    fn calculate_critical_path(&self, stages: &[ParallelStage]) -> f64 {
420        stages.iter().map(|s| s.estimated_time).sum()
421    }
422
423    /// Estimate communication overhead.
424    fn estimate_communication_overhead(
425        &self,
426        stages: &[ParallelStage],
427        _nodes: &[NodeInfo],
428    ) -> f64 {
429        let mut overhead = 0.0;
430
431        // Add overhead for each stage boundary
432        for stage in stages {
433            if stage.nodes.len() > 1 {
434                // Multiple nodes in stage need synchronization
435                overhead += self.overhead_per_task * stage.nodes.len() as f64;
436
437                // Add communication overhead based on memory transfer
438                let transfer_time =
439                    stage.memory_requirement as f64 / (self.communication_bandwidth * 1e9) * 1e6;
440                overhead += transfer_time;
441            }
442        }
443
444        overhead
445    }
446
447    /// Recommend number of workers based on parallelism factor.
448    fn recommend_worker_count(&self, parallelism_factor: f64) -> usize {
449        let ideal = parallelism_factor.ceil() as usize;
450
451        match self.strategy {
452            ParallelizationStrategy::Conservative => ideal.min(self.max_workers / 2).max(1),
453            ParallelizationStrategy::Balanced => ideal.min(self.max_workers),
454            ParallelizationStrategy::Aggressive => self.max_workers,
455            ParallelizationStrategy::CostBased => {
456                // Use cost model to decide
457                if parallelism_factor > 2.0 {
458                    ideal.min(self.max_workers)
459                } else {
460                    (ideal / 2).max(1)
461                }
462            }
463        }
464    }
465
466    /// Partition work across workers.
467    fn partition_work(
468        &self,
469        analysis: &ParallelizationAnalysis,
470    ) -> Result<Vec<WorkPartition>, AutoParallelError> {
471        let num_workers = analysis.recommended_workers;
472        let mut partitions: Vec<WorkPartition> = (0..num_workers)
473            .map(|i| WorkPartition {
474                worker_id: i,
475                nodes: Vec::new(),
476                estimated_load: 0.0,
477            })
478            .collect();
479
480        // For each stage, distribute nodes across workers
481        for stage in &analysis.stages {
482            // Sort nodes by estimated cost (descending)
483            let mut stage_nodes: Vec<(NodeId, f64)> = stage
484                .nodes
485                .iter()
486                .map(|id| (id.clone(), 1.0)) // Simplified: assume uniform cost
487                .collect();
488            stage_nodes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
489
490            // Greedy assignment to least loaded worker
491            for (node_id, cost) in stage_nodes {
492                let min_partition = partitions
493                    .iter_mut()
494                    .min_by(|a, b| a.estimated_load.partial_cmp(&b.estimated_load).unwrap())
495                    .ok_or_else(|| {
496                        AutoParallelError::PartitioningFailed("No partitions available".to_string())
497                    })?;
498
499                min_partition.nodes.push(node_id);
500                min_partition.estimated_load += cost;
501            }
502        }
503
504        Ok(partitions)
505    }
506
507    /// Calculate load balance ratio (1.0 = perfect balance).
508    fn calculate_load_balance(&self, partitions: &[WorkPartition]) -> f64 {
509        if partitions.is_empty() {
510            return 1.0;
511        }
512
513        let loads: Vec<f64> = partitions.iter().map(|p| p.estimated_load).collect();
514        let max_load = loads.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
515        let avg_load = loads.iter().sum::<f64>() / loads.len() as f64;
516
517        if max_load > 0.0 {
518            avg_load / max_load
519        } else {
520            1.0
521        }
522    }
523}
524
525impl Default for AutoParallelizer {
526    fn default() -> Self {
527        Self::new()
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    fn create_test_nodes() -> Vec<NodeInfo> {
536        vec![
537            NodeInfo {
538                id: "a".to_string(),
539                op_type: "input".to_string(),
540                estimated_cost: 10.0,
541                memory_size: 1000,
542                dependencies: vec![],
543                can_parallelize: true,
544            },
545            NodeInfo {
546                id: "b".to_string(),
547                op_type: "compute".to_string(),
548                estimated_cost: 20.0,
549                memory_size: 2000,
550                dependencies: vec![("a".to_string(), DependencyType::Data)],
551                can_parallelize: true,
552            },
553            NodeInfo {
554                id: "c".to_string(),
555                op_type: "compute".to_string(),
556                estimated_cost: 15.0,
557                memory_size: 1500,
558                dependencies: vec![("a".to_string(), DependencyType::Data)],
559                can_parallelize: true,
560            },
561            NodeInfo {
562                id: "d".to_string(),
563                op_type: "output".to_string(),
564                estimated_cost: 10.0,
565                memory_size: 1000,
566                dependencies: vec![
567                    ("b".to_string(), DependencyType::Data),
568                    ("c".to_string(), DependencyType::Data),
569                ],
570                can_parallelize: false,
571            },
572        ]
573    }
574
575    #[test]
576    fn test_auto_parallelizer_creation() {
577        let parallelizer = AutoParallelizer::new();
578        assert_eq!(parallelizer.strategy, ParallelizationStrategy::Balanced);
579        assert_eq!(parallelizer.cost_model, CostModel::Heuristic);
580    }
581
582    #[test]
583    fn test_builder_pattern() {
584        let parallelizer = AutoParallelizer::new()
585            .with_strategy(ParallelizationStrategy::Aggressive)
586            .with_cost_model(CostModel::ProfileBased)
587            .with_max_workers(8);
588
589        assert_eq!(parallelizer.strategy, ParallelizationStrategy::Aggressive);
590        assert_eq!(parallelizer.cost_model, CostModel::ProfileBased);
591        assert_eq!(parallelizer.max_workers, 8);
592    }
593
594    #[test]
595    fn test_dependency_graph_building() {
596        let parallelizer = AutoParallelizer::new();
597        let nodes = create_test_nodes();
598
599        let graph = parallelizer.build_dependency_graph(&nodes).unwrap();
600
601        assert_eq!(graph.len(), 4);
602        assert!(graph["b"].contains("a"));
603        assert!(graph["c"].contains("a"));
604        assert!(graph["d"].contains("b"));
605        assert!(graph["d"].contains("c"));
606    }
607
608    #[test]
609    fn test_cycle_detection() {
610        let parallelizer = AutoParallelizer::new();
611
612        // Create nodes with a cycle
613        let nodes = vec![
614            NodeInfo {
615                id: "a".to_string(),
616                op_type: "compute".to_string(),
617                estimated_cost: 10.0,
618                memory_size: 1000,
619                dependencies: vec![("b".to_string(), DependencyType::Data)],
620                can_parallelize: true,
621            },
622            NodeInfo {
623                id: "b".to_string(),
624                op_type: "compute".to_string(),
625                estimated_cost: 10.0,
626                memory_size: 1000,
627                dependencies: vec![("a".to_string(), DependencyType::Data)],
628                can_parallelize: true,
629            },
630        ];
631
632        let result = parallelizer.build_dependency_graph(&nodes);
633        assert!(result.is_err());
634    }
635
636    #[test]
637    fn test_stage_computation() {
638        let parallelizer = AutoParallelizer::new();
639        let nodes = create_test_nodes();
640
641        let analysis = parallelizer.analyze(&nodes).unwrap();
642
643        assert_eq!(analysis.num_stages, 3);
644        assert_eq!(analysis.stages[0].nodes, vec!["a"]);
645        assert_eq!(analysis.stages[1].nodes.len(), 2); // b and c can run in parallel
646        assert!(analysis.stages[1].nodes.contains(&"b".to_string()));
647        assert!(analysis.stages[1].nodes.contains(&"c".to_string()));
648        assert_eq!(analysis.stages[2].nodes, vec!["d"]);
649    }
650
651    #[test]
652    fn test_critical_path_calculation() {
653        let parallelizer = AutoParallelizer::new();
654        let nodes = create_test_nodes();
655
656        let analysis = parallelizer.analyze(&nodes).unwrap();
657
658        // Critical path: a (10) -> max(b (20), c (15)) -> d (10) = 40
659        assert_eq!(analysis.critical_path_length, 40.0);
660    }
661
662    #[test]
663    fn test_parallelism_factor() {
664        let parallelizer = AutoParallelizer::new();
665        let nodes = create_test_nodes();
666
667        let analysis = parallelizer.analyze(&nodes).unwrap();
668
669        // Total work: 10 + 20 + 15 + 10 = 55
670        // Critical path: 40
671        // Parallelism factor: 55 / 40 = 1.375
672        assert!((analysis.parallelism_factor - 1.375).abs() < 0.01);
673    }
674
675    #[test]
676    fn test_execution_plan_generation() {
677        let parallelizer = AutoParallelizer::new();
678        let nodes = create_test_nodes();
679
680        let plan = parallelizer.generate_plan(&nodes).unwrap();
681
682        assert_eq!(plan.stages.len(), 3);
683        assert!(!plan.partitions.is_empty());
684        // May not always have speedup due to overhead, just check it's positive
685        assert!(plan.estimated_speedup > 0.0);
686        assert!(plan.load_balance_ratio > 0.0 && plan.load_balance_ratio <= 1.0);
687    }
688
689    #[test]
690    fn test_profile_update() {
691        let mut parallelizer = AutoParallelizer::new();
692
693        parallelizer.update_profile("compute".to_string(), 100.0);
694        parallelizer.update_profile("compute".to_string(), 200.0);
695
696        assert!(parallelizer.profile_data.contains_key("compute"));
697        let avg = parallelizer.profile_data["compute"];
698        // First update: 0.9 * 0.0 + 0.1 * 100.0 = 10.0
699        // Second update: 0.9 * 10.0 + 0.1 * 200.0 = 29.0
700        assert!(avg >= 0.0);
701    }
702
703    #[test]
704    fn test_strategy_variations() {
705        let nodes = create_test_nodes();
706
707        let conservative = AutoParallelizer::new()
708            .with_strategy(ParallelizationStrategy::Conservative)
709            .analyze(&nodes)
710            .unwrap();
711
712        let aggressive = AutoParallelizer::new()
713            .with_strategy(ParallelizationStrategy::Aggressive)
714            .analyze(&nodes)
715            .unwrap();
716
717        // Aggressive should recommend more workers
718        assert!(aggressive.recommended_workers >= conservative.recommended_workers);
719    }
720
721    #[test]
722    fn test_sequential_graph() {
723        let parallelizer = AutoParallelizer::new();
724
725        // Create a sequential graph (no parallelism)
726        let nodes = vec![
727            NodeInfo {
728                id: "a".to_string(),
729                op_type: "compute".to_string(),
730                estimated_cost: 10.0,
731                memory_size: 1000,
732                dependencies: vec![],
733                can_parallelize: true,
734            },
735            NodeInfo {
736                id: "b".to_string(),
737                op_type: "compute".to_string(),
738                estimated_cost: 10.0,
739                memory_size: 1000,
740                dependencies: vec![("a".to_string(), DependencyType::Data)],
741                can_parallelize: true,
742            },
743        ];
744
745        let analysis = parallelizer.analyze(&nodes).unwrap();
746
747        assert_eq!(analysis.num_stages, 2);
748        assert_eq!(analysis.parallelism_factor, 1.0); // No parallelism
749    }
750
751    #[test]
752    fn test_fully_parallel_graph() {
753        let parallelizer = AutoParallelizer::new();
754
755        // Create a fully parallel graph
756        let nodes = vec![
757            NodeInfo {
758                id: "a".to_string(),
759                op_type: "compute".to_string(),
760                estimated_cost: 10.0,
761                memory_size: 1000,
762                dependencies: vec![],
763                can_parallelize: true,
764            },
765            NodeInfo {
766                id: "b".to_string(),
767                op_type: "compute".to_string(),
768                estimated_cost: 10.0,
769                memory_size: 1000,
770                dependencies: vec![],
771                can_parallelize: true,
772            },
773            NodeInfo {
774                id: "c".to_string(),
775                op_type: "compute".to_string(),
776                estimated_cost: 10.0,
777                memory_size: 1000,
778                dependencies: vec![],
779                can_parallelize: true,
780            },
781        ];
782
783        let analysis = parallelizer.analyze(&nodes).unwrap();
784
785        assert_eq!(analysis.num_stages, 1);
786        assert_eq!(analysis.parallelism_factor, 3.0); // Perfect parallelism
787    }
788
789    #[test]
790    fn test_load_balancing() {
791        let parallelizer = AutoParallelizer::new().with_max_workers(2);
792        let nodes = create_test_nodes();
793
794        let plan = parallelizer.generate_plan(&nodes).unwrap();
795
796        // Check that partitions exist and have reasonable balance
797        assert!(plan.partitions.len() > 0);
798        assert!(plan.load_balance_ratio > 0.0 && plan.load_balance_ratio <= 1.0);
799    }
800
801    #[test]
802    fn test_invalid_graph() {
803        let parallelizer = AutoParallelizer::new();
804
805        // Node with unknown dependency
806        let nodes = vec![NodeInfo {
807            id: "a".to_string(),
808            op_type: "compute".to_string(),
809            estimated_cost: 10.0,
810            memory_size: 1000,
811            dependencies: vec![("unknown".to_string(), DependencyType::Data)],
812            can_parallelize: true,
813        }];
814
815        let result = parallelizer.build_dependency_graph(&nodes);
816        assert!(result.is_err());
817    }
818}