Skip to main content

tensorlogic_infer/
auto_parallel.rs

1//! Automatic parallelization for computation graphs.
2//!
3//! This module provides automatic detection and exploitation of parallelism opportunities:
4//! - **Dependency analysis**: Build dependency graphs and detect parallelizable operations
5//! - **Cost modeling**: Estimate execution costs and communication overhead
6//! - **Work partitioning**: Dynamically partition work across threads/devices
7//! - **Load balancing**: Balance work to minimize idle time
8//! - **Pipeline detection**: Identify pipeline parallelism opportunities
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use tensorlogic_infer::{AutoParallelizer, ParallelizationStrategy, CostModel};
14//!
15//! // Create auto-parallelizer with cost model
16//! let parallelizer = AutoParallelizer::new()
17//!     .with_strategy(ParallelizationStrategy::Aggressive)
18//!     .with_cost_model(CostModel::ProfileBased);
19//!
20//! // Analyze graph for parallelism
21//! let analysis = parallelizer.analyze(&graph)?;
22//! println!("Found {} parallelizable stages", analysis.num_stages);
23//!
24//! // Generate parallel execution plan
25//! let plan = parallelizer.generate_plan(&graph)?;
26//! ```
27
28use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet, VecDeque};
30use thiserror::Error;
31
32/// Auto-parallelization errors.
33#[derive(Error, Debug, Clone, PartialEq)]
34pub enum AutoParallelError {
35    #[error("Dependency cycle detected: {0}")]
36    DependencyCycle(String),
37
38    #[error("Invalid graph: {0}")]
39    InvalidGraph(String),
40
41    #[error("Cost model error: {0}")]
42    CostModelError(String),
43
44    #[error("Partitioning failed: {0}")]
45    PartitioningFailed(String),
46}
47
48/// Node ID in the computation graph.
49pub type NodeId = String;
50
51/// Parallelization strategy.
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
53pub enum ParallelizationStrategy {
54    /// Conservative: Only parallelize when clearly beneficial
55    Conservative,
56    /// Balanced: Balance parallelism and overhead
57    Balanced,
58    /// Aggressive: Maximize parallelism even with potential overhead
59    Aggressive,
60    /// Cost-based: Use cost model to decide
61    CostBased,
62}
63
64/// Cost model type.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66pub enum CostModel {
67    /// Simple heuristic-based cost model
68    Heuristic,
69    /// Profile-based cost model using historical data
70    ProfileBased,
71    /// Analytical cost model based on operation complexity
72    Analytical,
73    /// Hybrid approach combining multiple models
74    Hybrid,
75}
76
77/// Dependency type between nodes.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum DependencyType {
80    /// Data dependency: consumer needs producer's output
81    Data,
82    /// Control dependency: execution order matters
83    Control,
84    /// Memory dependency: shared memory access
85    Memory,
86}
87
88/// Node information for parallelization analysis.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct NodeInfo {
91    pub id: NodeId,
92    pub op_type: String,
93    pub estimated_cost: f64, // in microseconds
94    pub memory_size: usize,  // in bytes
95    pub dependencies: Vec<(NodeId, DependencyType)>,
96    pub can_parallelize: bool,
97}
98
99/// Parallel stage containing nodes that can execute concurrently.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ParallelStage {
102    pub stage_id: usize,
103    pub nodes: Vec<NodeId>,
104    pub estimated_time: f64,
105    pub memory_requirement: usize,
106    pub predecessors: Vec<usize>, // Stages that must complete before this
107}
108
109/// Work partition for a single worker.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct WorkPartition {
112    pub worker_id: usize,
113    pub nodes: Vec<NodeId>,
114    pub estimated_load: f64,
115}
116
117/// Parallelization analysis results.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ParallelizationAnalysis {
120    pub num_stages: usize,
121    pub stages: Vec<ParallelStage>,
122    pub critical_path_length: f64,
123    pub total_work: f64,
124    pub parallelism_factor: f64, // total_work / critical_path_length
125    pub communication_overhead: f64,
126    pub recommended_workers: usize,
127}
128
129/// Parallel execution plan.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ParallelExecutionPlan {
132    pub stages: Vec<ParallelStage>,
133    pub partitions: Vec<WorkPartition>,
134    pub estimated_speedup: f64,
135    pub load_balance_ratio: f64,
136}
137
138/// Automatic parallelizer.
139pub struct AutoParallelizer {
140    strategy: ParallelizationStrategy,
141    cost_model: CostModel,
142    max_workers: usize,
143    overhead_per_task: f64,             // microseconds
144    communication_bandwidth: f64,       // GB/s
145    profile_data: HashMap<String, f64>, // op_type -> avg_time_us
146}
147
148impl AutoParallelizer {
149    /// Create a new auto-parallelizer with default settings.
150    pub fn new() -> Self {
151        Self {
152            strategy: ParallelizationStrategy::Balanced,
153            cost_model: CostModel::Heuristic,
154            max_workers: num_cpus::get(),
155            overhead_per_task: 10.0,        // 10 microseconds per task
156            communication_bandwidth: 100.0, // 100 GB/s
157            profile_data: HashMap::new(),
158        }
159    }
160
161    /// Set parallelization strategy.
162    pub fn with_strategy(mut self, strategy: ParallelizationStrategy) -> Self {
163        self.strategy = strategy;
164        self
165    }
166
167    /// Set cost model.
168    pub fn with_cost_model(mut self, model: CostModel) -> Self {
169        self.cost_model = model;
170        self
171    }
172
173    /// Set maximum number of workers.
174    pub fn with_max_workers(mut self, workers: usize) -> Self {
175        self.max_workers = workers;
176        self
177    }
178
179    /// Update profile data with observed execution times.
180    pub fn update_profile(&mut self, op_type: String, time_us: f64) {
181        let entry = self.profile_data.entry(op_type).or_insert(0.0);
182        *entry = 0.9 * *entry + 0.1 * time_us; // Exponential moving average
183    }
184
185    /// Analyze graph for parallelization opportunities.
186    pub fn analyze(
187        &self,
188        nodes: &[NodeInfo],
189    ) -> Result<ParallelizationAnalysis, AutoParallelError> {
190        // Build dependency graph
191        let dep_graph = self.build_dependency_graph(nodes)?;
192
193        // Topological sort to find stages
194        let stages = self.compute_stages(nodes, &dep_graph)?;
195
196        // Calculate critical path
197        let critical_path_length = self.calculate_critical_path(&stages);
198
199        // Calculate total work
200        let total_work: f64 = nodes.iter().map(|n| n.estimated_cost).sum();
201
202        // Estimate communication overhead
203        let communication_overhead = self.estimate_communication_overhead(&stages, nodes);
204
205        // Calculate parallelism factor
206        let parallelism_factor = if critical_path_length > 0.0 {
207            total_work / critical_path_length
208        } else {
209            1.0
210        };
211
212        // Recommend number of workers
213        let recommended_workers = self.recommend_worker_count(parallelism_factor);
214
215        Ok(ParallelizationAnalysis {
216            num_stages: stages.len(),
217            stages,
218            critical_path_length,
219            total_work,
220            parallelism_factor,
221            communication_overhead,
222            recommended_workers,
223        })
224    }
225
226    /// Generate parallel execution plan.
227    pub fn generate_plan(
228        &self,
229        nodes: &[NodeInfo],
230    ) -> Result<ParallelExecutionPlan, AutoParallelError> {
231        let analysis = self.analyze(nodes)?;
232
233        // Partition work across workers
234        let partitions = self.partition_work(&analysis)?;
235
236        // Calculate estimated speedup
237        let sequential_time = analysis.total_work;
238        let parallel_time = analysis.critical_path_length + analysis.communication_overhead;
239        let estimated_speedup = if parallel_time > 0.0 {
240            sequential_time / parallel_time
241        } else {
242            1.0
243        };
244
245        // Calculate load balance ratio
246        let load_balance_ratio = self.calculate_load_balance(&partitions);
247
248        Ok(ParallelExecutionPlan {
249            stages: analysis.stages,
250            partitions,
251            estimated_speedup,
252            load_balance_ratio,
253        })
254    }
255
256    /// Build dependency graph from nodes.
257    fn build_dependency_graph(
258        &self,
259        nodes: &[NodeInfo],
260    ) -> Result<HashMap<NodeId, HashSet<NodeId>>, AutoParallelError> {
261        let mut graph: HashMap<NodeId, HashSet<NodeId>> = HashMap::new();
262
263        // Initialize graph with all nodes
264        for node in nodes {
265            graph.entry(node.id.clone()).or_insert_with(HashSet::new);
266        }
267
268        // Add edges
269        for node in nodes {
270            for (dep_id, _dep_type) in &node.dependencies {
271                if !graph.contains_key(dep_id) {
272                    return Err(AutoParallelError::InvalidGraph(format!(
273                        "Unknown dependency: {}",
274                        dep_id
275                    )));
276                }
277                graph
278                    .entry(node.id.clone())
279                    .or_insert_with(HashSet::new)
280                    .insert(dep_id.clone());
281            }
282        }
283
284        // Check for cycles
285        self.check_cycles(&graph)?;
286
287        Ok(graph)
288    }
289
290    /// Check for dependency cycles using DFS.
291    fn check_cycles(
292        &self,
293        graph: &HashMap<NodeId, HashSet<NodeId>>,
294    ) -> Result<(), AutoParallelError> {
295        let mut visited = HashSet::new();
296        let mut rec_stack = HashSet::new();
297
298        for node in graph.keys() {
299            if !visited.contains(node) {
300                if self.has_cycle_util(node, graph, &mut visited, &mut rec_stack)? {
301                    return Err(AutoParallelError::DependencyCycle(format!(
302                        "Cycle detected involving node: {}",
303                        node
304                    )));
305                }
306            }
307        }
308
309        Ok(())
310    }
311
312    fn has_cycle_util(
313        &self,
314        node: &NodeId,
315        graph: &HashMap<NodeId, HashSet<NodeId>>,
316        visited: &mut HashSet<NodeId>,
317        rec_stack: &mut HashSet<NodeId>,
318    ) -> Result<bool, AutoParallelError> {
319        visited.insert(node.clone());
320        rec_stack.insert(node.clone());
321
322        if let Some(neighbors) = graph.get(node) {
323            for neighbor in neighbors {
324                if !visited.contains(neighbor) {
325                    if self.has_cycle_util(neighbor, graph, visited, rec_stack)? {
326                        return Ok(true);
327                    }
328                } else if rec_stack.contains(neighbor) {
329                    return Ok(true);
330                }
331            }
332        }
333
334        rec_stack.remove(node);
335        Ok(false)
336    }
337
338    /// Compute parallel stages using level-based topological sort.
339    fn compute_stages(
340        &self,
341        nodes: &[NodeInfo],
342        dep_graph: &HashMap<NodeId, HashSet<NodeId>>,
343    ) -> Result<Vec<ParallelStage>, AutoParallelError> {
344        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
345        let mut node_map: HashMap<NodeId, &NodeInfo> = HashMap::new();
346
347        // Initialize in-degree and node map
348        for node in nodes {
349            node_map.insert(node.id.clone(), node);
350            let deps = dep_graph
351                .get(&node.id)
352                .expect("dep_graph built from same nodes");
353            in_degree.insert(node.id.clone(), deps.len());
354        }
355
356        let mut stages = Vec::new();
357        let mut current_level: VecDeque<NodeId> = VecDeque::new();
358
359        // Find nodes with no dependencies
360        for (node_id, &degree) in &in_degree {
361            if degree == 0 {
362                current_level.push_back(node_id.clone());
363            }
364        }
365
366        let mut stage_id = 0;
367        while !current_level.is_empty() {
368            let mut stage_nodes = Vec::new();
369            let mut estimated_time: f64 = 0.0;
370            let mut memory_requirement = 0;
371
372            // Process all nodes at current level
373            for _ in 0..current_level.len() {
374                if let Some(node_id) = current_level.pop_front() {
375                    let node = node_map[&node_id];
376                    stage_nodes.push(node_id.clone());
377                    estimated_time = estimated_time.max(node.estimated_cost);
378                    memory_requirement += node.memory_size;
379
380                    // Decrease in-degree of dependent nodes
381                    for other_id in node_map.keys() {
382                        if dep_graph[other_id].contains(&node_id) {
383                            if let Some(degree) = in_degree.get_mut(other_id) {
384                                *degree -= 1;
385                                if *degree == 0 {
386                                    current_level.push_back(other_id.clone());
387                                }
388                            }
389                        }
390                    }
391                }
392            }
393
394            if !stage_nodes.is_empty() {
395                stages.push(ParallelStage {
396                    stage_id,
397                    nodes: stage_nodes,
398                    estimated_time,
399                    memory_requirement,
400                    predecessors: if stage_id > 0 {
401                        vec![stage_id - 1]
402                    } else {
403                        vec![]
404                    },
405                });
406                stage_id += 1;
407            }
408        }
409
410        // Check if all nodes were processed
411        if stages.iter().map(|s| s.nodes.len()).sum::<usize>() != nodes.len() {
412            return Err(AutoParallelError::DependencyCycle(
413                "Not all nodes were processed - cycle detected".to_string(),
414            ));
415        }
416
417        Ok(stages)
418    }
419
420    /// Calculate critical path length.
421    fn calculate_critical_path(&self, stages: &[ParallelStage]) -> f64 {
422        stages.iter().map(|s| s.estimated_time).sum()
423    }
424
425    /// Estimate communication overhead.
426    fn estimate_communication_overhead(
427        &self,
428        stages: &[ParallelStage],
429        _nodes: &[NodeInfo],
430    ) -> f64 {
431        let mut overhead = 0.0;
432
433        // Add overhead for each stage boundary
434        for stage in stages {
435            if stage.nodes.len() > 1 {
436                // Multiple nodes in stage need synchronization
437                overhead += self.overhead_per_task * stage.nodes.len() as f64;
438
439                // Add communication overhead based on memory transfer
440                let transfer_time =
441                    stage.memory_requirement as f64 / (self.communication_bandwidth * 1e9) * 1e6;
442                overhead += transfer_time;
443            }
444        }
445
446        overhead
447    }
448
449    /// Recommend number of workers based on parallelism factor.
450    fn recommend_worker_count(&self, parallelism_factor: f64) -> usize {
451        let ideal = parallelism_factor.ceil() as usize;
452
453        match self.strategy {
454            ParallelizationStrategy::Conservative => ideal.min(self.max_workers / 2).max(1),
455            ParallelizationStrategy::Balanced => ideal.min(self.max_workers),
456            ParallelizationStrategy::Aggressive => self.max_workers,
457            ParallelizationStrategy::CostBased => {
458                // Use cost model to decide
459                if parallelism_factor > 2.0 {
460                    ideal.min(self.max_workers)
461                } else {
462                    (ideal / 2).max(1)
463                }
464            }
465        }
466    }
467
468    /// Partition work across workers.
469    fn partition_work(
470        &self,
471        analysis: &ParallelizationAnalysis,
472    ) -> Result<Vec<WorkPartition>, AutoParallelError> {
473        let num_workers = analysis.recommended_workers;
474        let mut partitions: Vec<WorkPartition> = (0..num_workers)
475            .map(|i| WorkPartition {
476                worker_id: i,
477                nodes: Vec::new(),
478                estimated_load: 0.0,
479            })
480            .collect();
481
482        // For each stage, distribute nodes across workers
483        for stage in &analysis.stages {
484            // Sort nodes by estimated cost (descending)
485            let mut stage_nodes: Vec<(NodeId, f64)> = stage
486                .nodes
487                .iter()
488                .map(|id| (id.clone(), 1.0)) // Simplified: assume uniform cost
489                .collect();
490            stage_nodes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
491
492            // Greedy assignment to least loaded worker
493            for (node_id, cost) in stage_nodes {
494                let min_partition = partitions
495                    .iter_mut()
496                    .min_by(|a, b| {
497                        a.estimated_load
498                            .partial_cmp(&b.estimated_load)
499                            .unwrap_or(std::cmp::Ordering::Equal)
500                    })
501                    .ok_or_else(|| {
502                        AutoParallelError::PartitioningFailed("No partitions available".to_string())
503                    })?;
504
505                min_partition.nodes.push(node_id);
506                min_partition.estimated_load += cost;
507            }
508        }
509
510        Ok(partitions)
511    }
512
513    /// Calculate load balance ratio (1.0 = perfect balance).
514    fn calculate_load_balance(&self, partitions: &[WorkPartition]) -> f64 {
515        if partitions.is_empty() {
516            return 1.0;
517        }
518
519        let loads: Vec<f64> = partitions.iter().map(|p| p.estimated_load).collect();
520        let max_load = loads.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
521        let avg_load = loads.iter().sum::<f64>() / loads.len() as f64;
522
523        if max_load > 0.0 {
524            avg_load / max_load
525        } else {
526            1.0
527        }
528    }
529}
530
531impl Default for AutoParallelizer {
532    fn default() -> Self {
533        Self::new()
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540
541    fn create_test_nodes() -> Vec<NodeInfo> {
542        vec![
543            NodeInfo {
544                id: "a".to_string(),
545                op_type: "input".to_string(),
546                estimated_cost: 10.0,
547                memory_size: 1000,
548                dependencies: vec![],
549                can_parallelize: true,
550            },
551            NodeInfo {
552                id: "b".to_string(),
553                op_type: "compute".to_string(),
554                estimated_cost: 20.0,
555                memory_size: 2000,
556                dependencies: vec![("a".to_string(), DependencyType::Data)],
557                can_parallelize: true,
558            },
559            NodeInfo {
560                id: "c".to_string(),
561                op_type: "compute".to_string(),
562                estimated_cost: 15.0,
563                memory_size: 1500,
564                dependencies: vec![("a".to_string(), DependencyType::Data)],
565                can_parallelize: true,
566            },
567            NodeInfo {
568                id: "d".to_string(),
569                op_type: "output".to_string(),
570                estimated_cost: 10.0,
571                memory_size: 1000,
572                dependencies: vec![
573                    ("b".to_string(), DependencyType::Data),
574                    ("c".to_string(), DependencyType::Data),
575                ],
576                can_parallelize: false,
577            },
578        ]
579    }
580
581    #[test]
582    fn test_auto_parallelizer_creation() {
583        let parallelizer = AutoParallelizer::new();
584        assert_eq!(parallelizer.strategy, ParallelizationStrategy::Balanced);
585        assert_eq!(parallelizer.cost_model, CostModel::Heuristic);
586    }
587
588    #[test]
589    fn test_builder_pattern() {
590        let parallelizer = AutoParallelizer::new()
591            .with_strategy(ParallelizationStrategy::Aggressive)
592            .with_cost_model(CostModel::ProfileBased)
593            .with_max_workers(8);
594
595        assert_eq!(parallelizer.strategy, ParallelizationStrategy::Aggressive);
596        assert_eq!(parallelizer.cost_model, CostModel::ProfileBased);
597        assert_eq!(parallelizer.max_workers, 8);
598    }
599
600    #[test]
601    fn test_dependency_graph_building() {
602        let parallelizer = AutoParallelizer::new();
603        let nodes = create_test_nodes();
604
605        let graph = parallelizer.build_dependency_graph(&nodes).expect("unwrap");
606
607        assert_eq!(graph.len(), 4);
608        assert!(graph["b"].contains("a"));
609        assert!(graph["c"].contains("a"));
610        assert!(graph["d"].contains("b"));
611        assert!(graph["d"].contains("c"));
612    }
613
614    #[test]
615    fn test_cycle_detection() {
616        let parallelizer = AutoParallelizer::new();
617
618        // Create nodes with a cycle
619        let nodes = vec![
620            NodeInfo {
621                id: "a".to_string(),
622                op_type: "compute".to_string(),
623                estimated_cost: 10.0,
624                memory_size: 1000,
625                dependencies: vec![("b".to_string(), DependencyType::Data)],
626                can_parallelize: true,
627            },
628            NodeInfo {
629                id: "b".to_string(),
630                op_type: "compute".to_string(),
631                estimated_cost: 10.0,
632                memory_size: 1000,
633                dependencies: vec![("a".to_string(), DependencyType::Data)],
634                can_parallelize: true,
635            },
636        ];
637
638        let result = parallelizer.build_dependency_graph(&nodes);
639        assert!(result.is_err());
640    }
641
642    #[test]
643    fn test_stage_computation() {
644        let parallelizer = AutoParallelizer::new();
645        let nodes = create_test_nodes();
646
647        let analysis = parallelizer.analyze(&nodes).expect("unwrap");
648
649        assert_eq!(analysis.num_stages, 3);
650        assert_eq!(analysis.stages[0].nodes, vec!["a"]);
651        assert_eq!(analysis.stages[1].nodes.len(), 2); // b and c can run in parallel
652        assert!(analysis.stages[1].nodes.contains(&"b".to_string()));
653        assert!(analysis.stages[1].nodes.contains(&"c".to_string()));
654        assert_eq!(analysis.stages[2].nodes, vec!["d"]);
655    }
656
657    #[test]
658    fn test_critical_path_calculation() {
659        let parallelizer = AutoParallelizer::new();
660        let nodes = create_test_nodes();
661
662        let analysis = parallelizer.analyze(&nodes).expect("unwrap");
663
664        // Critical path: a (10) -> max(b (20), c (15)) -> d (10) = 40
665        assert_eq!(analysis.critical_path_length, 40.0);
666    }
667
668    #[test]
669    fn test_parallelism_factor() {
670        let parallelizer = AutoParallelizer::new();
671        let nodes = create_test_nodes();
672
673        let analysis = parallelizer.analyze(&nodes).expect("unwrap");
674
675        // Total work: 10 + 20 + 15 + 10 = 55
676        // Critical path: 40
677        // Parallelism factor: 55 / 40 = 1.375
678        assert!((analysis.parallelism_factor - 1.375).abs() < 0.01);
679    }
680
681    #[test]
682    fn test_execution_plan_generation() {
683        let parallelizer = AutoParallelizer::new();
684        let nodes = create_test_nodes();
685
686        let plan = parallelizer.generate_plan(&nodes).expect("unwrap");
687
688        assert_eq!(plan.stages.len(), 3);
689        assert!(!plan.partitions.is_empty());
690        // May not always have speedup due to overhead, just check it's positive
691        assert!(plan.estimated_speedup > 0.0);
692        assert!(plan.load_balance_ratio > 0.0 && plan.load_balance_ratio <= 1.0);
693    }
694
695    #[test]
696    fn test_profile_update() {
697        let mut parallelizer = AutoParallelizer::new();
698
699        parallelizer.update_profile("compute".to_string(), 100.0);
700        parallelizer.update_profile("compute".to_string(), 200.0);
701
702        assert!(parallelizer.profile_data.contains_key("compute"));
703        let avg = parallelizer.profile_data["compute"];
704        // First update: 0.9 * 0.0 + 0.1 * 100.0 = 10.0
705        // Second update: 0.9 * 10.0 + 0.1 * 200.0 = 29.0
706        assert!(avg >= 0.0);
707    }
708
709    #[test]
710    fn test_strategy_variations() {
711        let nodes = create_test_nodes();
712
713        let conservative = AutoParallelizer::new()
714            .with_strategy(ParallelizationStrategy::Conservative)
715            .analyze(&nodes)
716            .expect("unwrap");
717
718        let aggressive = AutoParallelizer::new()
719            .with_strategy(ParallelizationStrategy::Aggressive)
720            .analyze(&nodes)
721            .expect("unwrap");
722
723        // Aggressive should recommend more workers
724        assert!(aggressive.recommended_workers >= conservative.recommended_workers);
725    }
726
727    #[test]
728    fn test_sequential_graph() {
729        let parallelizer = AutoParallelizer::new();
730
731        // Create a sequential graph (no parallelism)
732        let nodes = vec![
733            NodeInfo {
734                id: "a".to_string(),
735                op_type: "compute".to_string(),
736                estimated_cost: 10.0,
737                memory_size: 1000,
738                dependencies: vec![],
739                can_parallelize: true,
740            },
741            NodeInfo {
742                id: "b".to_string(),
743                op_type: "compute".to_string(),
744                estimated_cost: 10.0,
745                memory_size: 1000,
746                dependencies: vec![("a".to_string(), DependencyType::Data)],
747                can_parallelize: true,
748            },
749        ];
750
751        let analysis = parallelizer.analyze(&nodes).expect("unwrap");
752
753        assert_eq!(analysis.num_stages, 2);
754        assert_eq!(analysis.parallelism_factor, 1.0); // No parallelism
755    }
756
757    #[test]
758    fn test_fully_parallel_graph() {
759        let parallelizer = AutoParallelizer::new();
760
761        // Create a fully parallel graph
762        let nodes = vec![
763            NodeInfo {
764                id: "a".to_string(),
765                op_type: "compute".to_string(),
766                estimated_cost: 10.0,
767                memory_size: 1000,
768                dependencies: vec![],
769                can_parallelize: true,
770            },
771            NodeInfo {
772                id: "b".to_string(),
773                op_type: "compute".to_string(),
774                estimated_cost: 10.0,
775                memory_size: 1000,
776                dependencies: vec![],
777                can_parallelize: true,
778            },
779            NodeInfo {
780                id: "c".to_string(),
781                op_type: "compute".to_string(),
782                estimated_cost: 10.0,
783                memory_size: 1000,
784                dependencies: vec![],
785                can_parallelize: true,
786            },
787        ];
788
789        let analysis = parallelizer.analyze(&nodes).expect("unwrap");
790
791        assert_eq!(analysis.num_stages, 1);
792        assert_eq!(analysis.parallelism_factor, 3.0); // Perfect parallelism
793    }
794
795    #[test]
796    fn test_load_balancing() {
797        let parallelizer = AutoParallelizer::new().with_max_workers(2);
798        let nodes = create_test_nodes();
799
800        let plan = parallelizer.generate_plan(&nodes).expect("unwrap");
801
802        // Check that partitions exist and have reasonable balance
803        assert!(plan.partitions.len() > 0);
804        assert!(plan.load_balance_ratio > 0.0 && plan.load_balance_ratio <= 1.0);
805    }
806
807    #[test]
808    fn test_invalid_graph() {
809        let parallelizer = AutoParallelizer::new();
810
811        // Node with unknown dependency
812        let nodes = vec![NodeInfo {
813            id: "a".to_string(),
814            op_type: "compute".to_string(),
815            estimated_cost: 10.0,
816            memory_size: 1000,
817            dependencies: vec![("unknown".to_string(), DependencyType::Data)],
818            can_parallelize: true,
819        }];
820
821        let result = parallelizer.build_dependency_graph(&nodes);
822        assert!(result.is_err());
823    }
824}