1use crate::ExecutionPlan;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct WorkerInfo {
16 pub id: String,
17 pub name: String,
18 pub tags: Vec<String>,
19 pub gpu: bool,
20 pub cpu_cores: usize,
21 pub active_jobs: usize,
22 pub max_concurrent: usize,
23}
24
25impl WorkerInfo {
26 pub fn available_slots(&self) -> usize {
27 self.max_concurrent.saturating_sub(self.active_jobs)
28 }
29
30 pub fn has_capacity(&self) -> bool {
31 self.available_slots() > 0
32 }
33
34 pub fn matches_tag(&self, tag: &str) -> bool {
35 self.tags.iter().any(|t| t == tag)
36 }
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Assignment {
42 pub node_id: String,
43 pub worker_id: String,
44 pub worker_name: String,
45 pub phase: Phase,
46 pub reason: String,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
51#[serde(rename_all = "snake_case")]
52pub enum Phase {
53 Sequential,
54 Parallel,
55 Trial { trial_index: usize, total: usize },
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct DistributionPlan {
61 pub assignments: Vec<Assignment>,
62 pub phases: Vec<PlanPhase>,
63 pub data_transfers: Vec<DataTransfer>,
64 pub warnings: Vec<String>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct PlanPhase {
70 pub phase_index: usize,
71 pub phase_type: Phase,
72 pub node_ids: Vec<String>,
73 pub worker_ids: Vec<String>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct DataTransfer {
79 pub from_node: String,
80 pub to_node: String,
81 pub from_worker: String,
82 pub to_worker: String,
83 pub transfer_type: String, }
85
86struct ScheduleState<'a> {
88 workers: Vec<&'a WorkerInfo>,
89 diff_nodes: &'a [String],
90 assignments: Vec<Assignment>,
91 phases: Vec<PlanPhase>,
92 transfers: Vec<DataTransfer>,
93 warnings: Vec<String>,
94 phase_index: usize,
95}
96
97pub fn schedule(
99 plan: &ExecutionPlan,
100 workers: &[WorkerInfo],
101 differentiable_nodes: &[String],
102) -> DistributionPlan {
103 let mut state = ScheduleState {
104 workers: Vec::new(),
105 diff_nodes: differentiable_nodes,
106 assignments: Vec::new(),
107 phases: Vec::new(),
108 transfers: Vec::new(),
109 warnings: Vec::new(),
110 phase_index: 0,
111 };
112
113 if workers.is_empty() {
114 state
115 .warnings
116 .push("No workers available — will execute locally".into());
117 return DistributionPlan {
118 assignments: state.assignments,
119 phases: state.phases,
120 data_transfers: state.transfers,
121 warnings: state.warnings,
122 };
123 }
124
125 state.workers = workers.iter().filter(|w| w.has_capacity()).collect();
126 if state.workers.is_empty() {
127 state.warnings.push("All workers are at capacity".into());
128 return DistributionPlan {
129 assignments: state.assignments,
130 phases: state.phases,
131 data_transfers: state.transfers,
132 warnings: state.warnings,
133 };
134 }
135
136 schedule_plan(plan, &mut state, None);
137
138 DistributionPlan {
139 assignments: state.assignments,
140 phases: state.phases,
141 data_transfers: state.transfers,
142 warnings: state.warnings,
143 }
144}
145
146fn schedule_plan(plan: &ExecutionPlan, state: &mut ScheduleState<'_>, forced_worker: Option<&str>) {
147 match plan {
148 ExecutionPlan::Execute { node_id } => {
149 let worker = if let Some(fw) = forced_worker {
150 state
151 .workers
152 .iter()
153 .find(|w| w.id == fw)
154 .unwrap_or(&state.workers[0])
155 } else {
156 least_loaded(&state.workers)
157 };
158
159 state.assignments.push(Assignment {
160 node_id: node_id.clone(),
161 worker_id: worker.id.clone(),
162 worker_name: worker.name.clone(),
163 phase: Phase::Sequential,
164 reason: if forced_worker.is_some() {
165 "grouped with differentiable neighbors".into()
166 } else {
167 "least loaded worker".into()
168 },
169 });
170 }
171
172 ExecutionPlan::Sequence(steps) => {
173 let worker = forced_worker
174 .and_then(|fw| state.workers.iter().find(|w| w.id == fw).copied())
175 .unwrap_or_else(|| least_loaded(&state.workers));
176
177 let node_ids = collect_node_ids(plan);
178 let has_diff = node_ids.iter().any(|n| state.diff_nodes.contains(n));
179 let force = if has_diff {
180 Some(worker.id.as_str())
181 } else {
182 forced_worker
183 };
184
185 state.phases.push(PlanPhase {
186 phase_index: state.phase_index,
187 phase_type: Phase::Sequential,
188 node_ids: node_ids.clone(),
189 worker_ids: vec![worker.id.clone()],
190 });
191 state.phase_index += 1;
192
193 for step in steps {
194 schedule_plan(step, state, force);
195 }
196 }
197
198 ExecutionPlan::Parallel(branches) => {
199 let branch_ids: Vec<Vec<String>> = branches.iter().map(collect_node_ids).collect();
200 let mut assigned_workers = Vec::new();
201
202 for (i, branch) in branches.iter().enumerate() {
203 let worker_idx = i % state.workers.len();
204 let worker = state.workers[worker_idx];
205 assigned_workers.push(worker.id.clone());
206
207 let worker_id = worker.id.clone();
208 schedule_plan(branch, state, Some(&worker_id));
209
210 if let Some(prev) = state
212 .assignments
213 .iter()
214 .rev()
215 .find(|a| !branch_ids[i].contains(&a.node_id))
216 .filter(|prev| prev.worker_id != state.workers[worker_idx].id)
217 {
218 state.transfers.push(DataTransfer {
219 from_node: prev.node_id.clone(),
220 to_node: branch_ids[i].first().cloned().unwrap_or_default(),
221 from_worker: prev.worker_id.clone(),
222 to_worker: state.workers[worker_idx].id.clone(),
223 transfer_type: "s3".into(),
224 });
225 }
226 }
227
228 state.phases.push(PlanPhase {
229 phase_index: state.phase_index,
230 phase_type: Phase::Parallel,
231 node_ids: branch_ids.into_iter().flatten().collect(),
232 worker_ids: assigned_workers,
233 });
234 state.phase_index += 1;
235 }
236
237 ExecutionPlan::Cached { node_id, .. } => {
238 let worker = forced_worker
239 .and_then(|fw| state.workers.iter().find(|w| w.id == fw).copied())
240 .unwrap_or_else(|| least_loaded(&state.workers));
241 state.assignments.push(Assignment {
242 node_id: node_id.clone(),
243 worker_id: worker.id.clone(),
244 worker_name: worker.name.clone(),
245 phase: Phase::Sequential,
246 reason: "cached — will skip execution".into(),
247 });
248 }
249
250 ExecutionPlan::Remote { plan, .. } => {
251 schedule_plan(plan, state, None);
252 }
253
254 ExecutionPlan::Loop { body, node_id, .. } => {
255 let worker = forced_worker
256 .and_then(|fw| state.workers.iter().find(|w| w.id == fw).copied())
257 .unwrap_or_else(|| least_loaded(&state.workers));
258 state.assignments.push(Assignment {
259 node_id: node_id.clone(),
260 worker_id: worker.id.clone(),
261 worker_name: worker.name.clone(),
262 phase: Phase::Sequential,
263 reason: "loop controller".into(),
264 });
265 let worker_id = worker.id.clone();
266 schedule_plan(body, state, Some(&worker_id));
267 }
268
269 ExecutionPlan::Branch { node_id, arms, .. } => {
270 let worker = forced_worker
271 .and_then(|fw| state.workers.iter().find(|w| w.id == fw).copied())
272 .unwrap_or_else(|| least_loaded(&state.workers));
273 state.assignments.push(Assignment {
274 node_id: node_id.clone(),
275 worker_id: worker.id.clone(),
276 worker_name: worker.name.clone(),
277 phase: Phase::Sequential,
278 reason: "branch condition".into(),
279 });
280 let worker_id = worker.id.clone();
281 for (_, arm_plan) in arms {
282 schedule_plan(arm_plan, state, Some(&worker_id));
283 }
284 }
285
286 ExecutionPlan::Empty => {}
287 }
288}
289
290fn least_loaded<'a>(workers: &[&'a WorkerInfo]) -> &'a WorkerInfo {
291 workers.iter().max_by_key(|w| w.available_slots()).unwrap()
292}
293
294fn collect_node_ids(plan: &ExecutionPlan) -> Vec<String> {
295 plan.node_ids().into_iter().map(|s| s.to_string()).collect()
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 fn test_workers() -> Vec<WorkerInfo> {
303 vec![
304 WorkerInfo {
305 id: "w1".into(),
306 name: "GPU-A100".into(),
307 tags: vec!["gpu".into()],
308 gpu: true,
309 cpu_cores: 16,
310 active_jobs: 0,
311 max_concurrent: 4,
312 },
313 WorkerInfo {
314 id: "w2".into(),
315 name: "CPU-Server".into(),
316 tags: vec!["cpu".into()],
317 gpu: false,
318 cpu_cores: 64,
319 active_jobs: 1,
320 max_concurrent: 8,
321 },
322 ]
323 }
324
325 #[test]
326 fn sequential_same_worker() {
327 let plan = ExecutionPlan::Sequence(vec![
328 ExecutionPlan::Execute {
329 node_id: "normalize".into(),
330 },
331 ExecutionPlan::Execute {
332 node_id: "select".into(),
333 },
334 ExecutionPlan::Execute {
335 node_id: "classify".into(),
336 },
337 ]);
338
339 let result = schedule(&plan, &test_workers(), &[]);
340 let worker_ids: Vec<&str> = result
342 .assignments
343 .iter()
344 .map(|a| a.worker_id.as_str())
345 .collect();
346 assert!(worker_ids.windows(2).all(|w| w[0] == w[1]));
347 }
348
349 #[test]
350 fn parallel_distributes() {
351 let plan = ExecutionPlan::Parallel(vec![
352 ExecutionPlan::Execute {
353 node_id: "train_svm".into(),
354 },
355 ExecutionPlan::Execute {
356 node_id: "train_knn".into(),
357 },
358 ]);
359
360 let result = schedule(&plan, &test_workers(), &[]);
361 assert_eq!(result.assignments.len(), 2);
362 assert_ne!(
364 result.assignments[0].worker_id,
365 result.assignments[1].worker_id
366 );
367 }
368
369 #[test]
370 fn no_workers_warns() {
371 let plan = ExecutionPlan::Execute {
372 node_id: "test".into(),
373 };
374 let result = schedule(&plan, &[], &[]);
375 assert!(!result.warnings.is_empty());
376 }
377
378 #[test]
379 fn sequence_then_parallel() {
380 let plan = ExecutionPlan::Sequence(vec![
381 ExecutionPlan::Execute {
382 node_id: "load".into(),
383 },
384 ExecutionPlan::Execute {
385 node_id: "normalize".into(),
386 },
387 ExecutionPlan::Parallel(vec![
388 ExecutionPlan::Execute {
389 node_id: "train_a".into(),
390 },
391 ExecutionPlan::Execute {
392 node_id: "train_b".into(),
393 },
394 ]),
395 ]);
396
397 let result = schedule(&plan, &test_workers(), &[]);
398 assert!(result.assignments.len() >= 4);
400 assert_eq!(
401 result.assignments[0].worker_id,
402 result.assignments[1].worker_id
403 );
404 }
405
406 #[test]
407 fn data_transfer_on_split() {
408 let plan = ExecutionPlan::Sequence(vec![
409 ExecutionPlan::Execute {
410 node_id: "preprocess".into(),
411 },
412 ExecutionPlan::Parallel(vec![
413 ExecutionPlan::Execute {
414 node_id: "branch_a".into(),
415 },
416 ExecutionPlan::Execute {
417 node_id: "branch_b".into(),
418 },
419 ]),
420 ]);
421
422 let result = schedule(&plan, &test_workers(), &[]);
423 assert!(
425 !result.data_transfers.is_empty()
426 || result
427 .assignments
428 .iter()
429 .all(|a| a.worker_id == result.assignments[0].worker_id)
430 );
431 }
432}