Skip to main content

scud_task_core/
waves.rs

1//! Wave computation for parallel task execution.
2//!
3//! Computes execution waves using Kahn's algorithm (topological sort with level assignment),
4//! grouping tasks that can run in parallel based on their dependencies.
5
6use std::collections::{HashMap, HashSet};
7
8use crate::models::Task;
9
10/// A wave of tasks that can be executed in parallel
11#[derive(Debug, Clone)]
12pub struct Wave {
13    /// Wave number (1-indexed)
14    pub number: usize,
15    /// Task IDs in this wave
16    pub tasks: Vec<String>,
17}
18
19/// Result of wave computation
20#[derive(Debug)]
21pub struct WaveResult {
22    /// Computed waves
23    pub waves: Vec<Wave>,
24    /// Tasks with circular dependencies (if any)
25    pub circular_deps: Vec<String>,
26}
27
28/// Compute execution waves using Kahn's algorithm (topological sort with level assignment)
29/// When processing tasks from multiple phases, task IDs are expected to be namespaced
30pub fn compute_waves(tasks: &[&Task]) -> WaveResult {
31    // Build a map from task pointer to its namespaced ID
32    // This handles the case where multiple phases have tasks with the same local ID
33    let task_ids: HashSet<String> = tasks.iter().map(|t| t.id.clone()).collect();
34
35    // Build in-degree map (how many dependencies does each task have within our set?)
36    let mut in_degree: HashMap<String, usize> = HashMap::new();
37    let mut dependents: HashMap<String, Vec<String>> = HashMap::new();
38
39    for task in tasks {
40        in_degree.entry(task.id.clone()).or_insert(0);
41
42        for dep in &task.dependencies {
43            // Only count dependencies that are in our actionable task set
44            if task_ids.contains(dep) {
45                *in_degree.entry(task.id.clone()).or_insert(0) += 1;
46                dependents
47                    .entry(dep.clone())
48                    .or_default()
49                    .push(task.id.clone());
50            }
51        }
52    }
53
54    // Kahn's algorithm with wave tracking
55    let mut waves: Vec<Wave> = Vec::new();
56    let mut remaining = in_degree.clone();
57    let mut wave_number = 1;
58    let mut circular_deps = Vec::new();
59
60    while !remaining.is_empty() {
61        // Find all tasks with no remaining dependencies (in-degree = 0)
62        let ready: Vec<String> = remaining
63            .iter()
64            .filter(|(_, &deg)| deg == 0)
65            .map(|(id, _)| id.clone())
66            .collect();
67
68        if ready.is_empty() {
69            // Circular dependency detected - collect remaining tasks
70            circular_deps = remaining.keys().cloned().collect();
71            break;
72        }
73
74        // Remove ready tasks from remaining and update dependents
75        for task_id in &ready {
76            remaining.remove(task_id);
77
78            if let Some(deps) = dependents.get(task_id) {
79                for dep_id in deps {
80                    if let Some(deg) = remaining.get_mut(dep_id) {
81                        *deg = deg.saturating_sub(1);
82                    }
83                }
84            }
85        }
86
87        waves.push(Wave {
88            number: wave_number,
89            tasks: ready,
90        });
91        wave_number += 1;
92    }
93
94    WaveResult {
95        waves,
96        circular_deps,
97    }
98}
99
100/// Detect ID collisions when merging tasks from multiple phases
101/// Returns a list of (local_id, Vec<tag>) for IDs that appear in multiple tags
102pub fn detect_id_collisions(tasks: &[&Task]) -> Vec<(String, Vec<String>)> {
103    let mut id_to_tags: HashMap<String, Vec<String>> = HashMap::new();
104
105    for task in tasks {
106        let local_id = task.local_id().to_string();
107        let tag = task.epic_tag().unwrap_or("unknown").to_string();
108
109        id_to_tags.entry(local_id).or_default().push(tag);
110    }
111
112    // Filter to only those with collisions (same local ID in multiple tags)
113    let mut collisions: Vec<(String, Vec<String>)> = id_to_tags
114        .into_iter()
115        .filter(|(_, tags)| {
116            // Dedupe tags and check if more than one unique tag
117            let mut unique_tags: Vec<_> = tags.to_vec();
118            unique_tags.sort();
119            unique_tags.dedup();
120            unique_tags.len() > 1
121        })
122        .map(|(id, mut tags)| {
123            tags.sort();
124            tags.dedup();
125            (id, tags)
126        })
127        .collect();
128
129    collisions.sort_by(|a, b| a.0.cmp(&b.0));
130    collisions
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::models::Task;
137
138    #[test]
139    fn test_simple_linear_waves() {
140        let task1 = Task::new("1".to_string(), "Task 1".to_string(), String::new());
141        let mut task2 = Task::new("2".to_string(), "Task 2".to_string(), String::new());
142        task2.dependencies = vec!["1".to_string()];
143        let mut task3 = Task::new("3".to_string(), "Task 3".to_string(), String::new());
144        task3.dependencies = vec!["2".to_string()];
145
146        let tasks: Vec<&Task> = vec![&task1, &task2, &task3];
147        let result = compute_waves(&tasks);
148
149        assert_eq!(result.waves.len(), 3);
150        assert!(result.circular_deps.is_empty());
151        assert_eq!(result.waves[0].tasks, vec!["1"]);
152        assert_eq!(result.waves[1].tasks, vec!["2"]);
153        assert_eq!(result.waves[2].tasks, vec!["3"]);
154    }
155
156    #[test]
157    fn test_parallel_waves() {
158        let task1 = Task::new("1".to_string(), "Task 1".to_string(), String::new());
159        let task2 = Task::new("2".to_string(), "Task 2".to_string(), String::new());
160        let task3 = Task::new("3".to_string(), "Task 3".to_string(), String::new());
161
162        let tasks: Vec<&Task> = vec![&task1, &task2, &task3];
163        let result = compute_waves(&tasks);
164
165        // All tasks can run in parallel (no dependencies)
166        assert_eq!(result.waves.len(), 1);
167        assert!(result.circular_deps.is_empty());
168        assert_eq!(result.waves[0].tasks.len(), 3);
169    }
170
171    #[test]
172    fn test_diamond_dependency() {
173        //   1
174        //  / \
175        // 2   3
176        //  \ /
177        //   4
178        let task1 = Task::new("1".to_string(), "Task 1".to_string(), String::new());
179        let mut task2 = Task::new("2".to_string(), "Task 2".to_string(), String::new());
180        task2.dependencies = vec!["1".to_string()];
181        let mut task3 = Task::new("3".to_string(), "Task 3".to_string(), String::new());
182        task3.dependencies = vec!["1".to_string()];
183        let mut task4 = Task::new("4".to_string(), "Task 4".to_string(), String::new());
184        task4.dependencies = vec!["2".to_string(), "3".to_string()];
185
186        let tasks: Vec<&Task> = vec![&task1, &task2, &task3, &task4];
187        let result = compute_waves(&tasks);
188
189        assert_eq!(result.waves.len(), 3);
190        assert!(result.circular_deps.is_empty());
191        assert_eq!(result.waves[0].tasks, vec!["1"]);
192        assert!(result.waves[1].tasks.contains(&"2".to_string()));
193        assert!(result.waves[1].tasks.contains(&"3".to_string()));
194        assert_eq!(result.waves[2].tasks, vec!["4"]);
195    }
196
197    #[test]
198    fn test_circular_dependency_detected() {
199        let mut task1 = Task::new("1".to_string(), "Task 1".to_string(), String::new());
200        task1.dependencies = vec!["2".to_string()];
201        let mut task2 = Task::new("2".to_string(), "Task 2".to_string(), String::new());
202        task2.dependencies = vec!["1".to_string()];
203
204        let tasks: Vec<&Task> = vec![&task1, &task2];
205        let result = compute_waves(&tasks);
206
207        assert!(result.waves.is_empty());
208        assert_eq!(result.circular_deps.len(), 2);
209    }
210
211    #[test]
212    fn test_external_dependency_ignored() {
213        let task1 = Task::new("1".to_string(), "Task 1".to_string(), String::new());
214        let mut task2 = Task::new("2".to_string(), "Task 2".to_string(), String::new());
215        // Depends on task not in our set
216        task2.dependencies = vec!["external:99".to_string()];
217
218        let tasks: Vec<&Task> = vec![&task1, &task2];
219        let result = compute_waves(&tasks);
220
221        // Both can run in wave 1 since external dep is ignored
222        assert_eq!(result.waves.len(), 1);
223        assert_eq!(result.waves[0].tasks.len(), 2);
224    }
225
226    #[test]
227    fn test_id_collision_detection() {
228        // Two tasks with same local ID but different tags
229        let task1 = Task::new("auth:1".to_string(), "Auth Task".to_string(), String::new());
230        let task2 = Task::new("api:1".to_string(), "API Task".to_string(), String::new());
231        let task3 = Task::new(
232            "auth:2".to_string(),
233            "Auth Task 2".to_string(),
234            String::new(),
235        );
236
237        let tasks: Vec<&Task> = vec![&task1, &task2, &task3];
238        let collisions = detect_id_collisions(&tasks);
239
240        assert_eq!(collisions.len(), 1);
241        assert_eq!(collisions[0].0, "1");
242        assert!(collisions[0].1.contains(&"auth".to_string()));
243        assert!(collisions[0].1.contains(&"api".to_string()));
244    }
245
246    #[test]
247    fn test_no_id_collisions() {
248        let task1 = Task::new("auth:1".to_string(), "Auth Task".to_string(), String::new());
249        let task2 = Task::new("api:2".to_string(), "API Task".to_string(), String::new());
250
251        let tasks: Vec<&Task> = vec![&task1, &task2];
252        let collisions = detect_id_collisions(&tasks);
253
254        assert!(collisions.is_empty());
255    }
256}