1use serde::{Deserialize, Serialize};
7use somatize_core::cache::CacheKey;
8use somatize_core::filter::RemoteTarget;
9use somatize_core::graph::NodeId;
10use std::fmt;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18#[non_exhaustive]
19pub enum ExecutionPlan {
20 Sequence(Vec<ExecutionPlan>),
22
23 Parallel(Vec<ExecutionPlan>),
25
26 Execute { node_id: NodeId },
28
29 Cached { node_id: NodeId, key: CacheKey },
31
32 Loop {
34 node_id: NodeId,
35 body: Box<ExecutionPlan>,
36 max_iterations: Option<usize>,
37 },
38
39 Branch {
41 node_id: NodeId,
42 arms: Vec<(String, ExecutionPlan)>,
43 },
44
45 Remote {
47 node_id: NodeId,
48 target: RemoteTarget,
49 plan: Box<ExecutionPlan>,
50 },
51
52 Empty,
54}
55
56impl ExecutionPlan {
57 pub fn node_count(&self) -> usize {
59 match self {
60 Self::Execute { .. } | Self::Cached { .. } => 1,
61 Self::Sequence(steps) | Self::Parallel(steps) => {
62 steps.iter().map(|s| s.node_count()).sum()
63 }
64 Self::Loop { body, .. } => 1 + body.node_count(),
65 Self::Branch { arms, .. } => {
66 1 + arms.iter().map(|(_, p)| p.node_count()).sum::<usize>()
67 }
68 Self::Remote { plan, .. } => plan.node_count(),
69 Self::Empty => 0,
70 }
71 }
72
73 pub fn cached_count(&self) -> usize {
75 match self {
76 Self::Cached { .. } => 1,
77 Self::Execute { .. } => 0,
78 Self::Sequence(steps) | Self::Parallel(steps) => {
79 steps.iter().map(|s| s.cached_count()).sum()
80 }
81 Self::Loop { body, .. } => body.cached_count(),
82 Self::Branch { arms, .. } => arms.iter().map(|(_, p)| p.cached_count()).sum(),
83 Self::Remote { plan, .. } => plan.cached_count(),
84 Self::Empty => 0,
85 }
86 }
87
88 pub fn parallel_branch_count(&self) -> usize {
90 match self {
91 Self::Parallel(branches) => branches.len(),
92 Self::Sequence(steps) => steps.iter().map(|s| s.parallel_branch_count()).sum(),
93 _ => 0,
94 }
95 }
96
97 pub fn node_ids(&self) -> Vec<&str> {
99 match self {
100 Self::Execute { node_id } | Self::Cached { node_id, .. } => vec![node_id.as_str()],
101 Self::Sequence(steps) | Self::Parallel(steps) => {
102 steps.iter().flat_map(|s| s.node_ids()).collect()
103 }
104 Self::Loop { node_id, body, .. } => {
105 let mut ids = vec![node_id.as_str()];
106 ids.extend(body.node_ids());
107 ids
108 }
109 Self::Branch { node_id, arms, .. } => {
110 let mut ids = vec![node_id.as_str()];
111 for (_, p) in arms {
112 ids.extend(p.node_ids());
113 }
114 ids
115 }
116 Self::Remote { node_id, plan, .. } => {
117 let mut ids = vec![node_id.as_str()];
118 ids.extend(plan.node_ids());
119 ids
120 }
121 Self::Empty => vec![],
122 }
123 }
124
125 pub fn summary(&self) -> somatize_core::event::PlanSummary {
127 somatize_core::event::PlanSummary {
128 total_nodes: self.node_count(),
129 cached_nodes: self.cached_count(),
130 parallel_branches: self.parallel_branch_count(),
131 }
132 }
133
134 pub fn simplify(self) -> Self {
136 match self {
137 Self::Sequence(mut steps) => {
138 steps = steps.into_iter().map(|s| s.simplify()).collect();
139 steps.retain(|s| !matches!(s, Self::Empty));
140 match steps.len() {
141 0 => Self::Empty,
142 1 => steps.into_iter().next().unwrap(),
143 _ => Self::Sequence(steps),
144 }
145 }
146 Self::Parallel(mut branches) => {
147 branches = branches.into_iter().map(|b| b.simplify()).collect();
148 branches.retain(|b| !matches!(b, Self::Empty));
149 match branches.len() {
150 0 => Self::Empty,
151 1 => branches.into_iter().next().unwrap(),
152 _ => Self::Parallel(branches),
153 }
154 }
155 other => other,
156 }
157 }
158}
159
160impl ExecutionPlan {
161 pub fn to_mermaid(&self) -> String {
163 let mut out = String::from("graph TD\n");
164 let mut counter = 0;
165 self.mermaid_nodes(&mut out, &mut counter, None);
166 out
167 }
168
169 fn mermaid_nodes(&self, out: &mut String, counter: &mut usize, parent: Option<&str>) {
170 use std::fmt::Write;
171 match self {
172 Self::Execute { node_id } => {
173 let _ = writeln!(out, " {node_id}[{node_id}]");
174 if let Some(p) = parent {
175 let _ = writeln!(out, " {p} --> {node_id}");
176 }
177 }
178 Self::Cached { node_id, .. } => {
179 let _ = writeln!(out, " {node_id}[/{node_id} cached/]");
180 if let Some(p) = parent {
181 let _ = writeln!(out, " {p} --> {node_id}");
182 }
183 }
184 Self::Sequence(steps) => {
185 let mut prev = parent.map(String::from);
186 for step in steps {
187 step.mermaid_nodes(out, counter, prev.as_deref());
188 prev = step.first_node_id().map(String::from);
189 }
190 }
191 Self::Parallel(branches) => {
192 let fork_id = format!("fork_{counter}");
193 *counter += 1;
194 let _ = writeln!(out, " {fork_id}{{{{fork}}}}");
195 if let Some(p) = parent {
196 let _ = writeln!(out, " {p} --> {fork_id}");
197 }
198 for branch in branches {
199 branch.mermaid_nodes(out, counter, Some(&fork_id));
200 }
201 }
202 Self::Loop {
203 node_id,
204 body,
205 max_iterations,
206 } => {
207 let label = match max_iterations {
208 Some(n) => format!("{node_id} loop max={n}"),
209 None => format!("{node_id} loop"),
210 };
211 let _ = writeln!(out, " {node_id}(({label}))");
212 if let Some(p) = parent {
213 let _ = writeln!(out, " {p} --> {node_id}");
214 }
215 body.mermaid_nodes(out, counter, Some(node_id));
216 }
217 Self::Branch { node_id, arms } => {
218 let _ = writeln!(out, " {node_id}{{{{{node_id}}}}}");
219 if let Some(p) = parent {
220 let _ = writeln!(out, " {p} --> {node_id}");
221 }
222 for (label, plan) in arms {
223 let arm_id = format!("arm_{counter}");
224 *counter += 1;
225 let _ = writeln!(out, " {node_id} -->|{label}| {arm_id}[{label}]");
226 plan.mermaid_nodes(out, counter, Some(&arm_id));
227 }
228 }
229 Self::Remote {
230 node_id,
231 target,
232 plan,
233 } => {
234 let _ = writeln!(out, " {node_id}>{{{node_id} remote: {target:?}}}]");
235 if let Some(p) = parent {
236 let _ = writeln!(out, " {p} --> {node_id}");
237 }
238 plan.mermaid_nodes(out, counter, Some(node_id));
239 }
240 Self::Empty => {}
241 }
242 }
243
244 fn first_node_id(&self) -> Option<&str> {
245 match self {
246 Self::Execute { node_id } | Self::Cached { node_id, .. } => Some(node_id),
247 Self::Sequence(steps) => steps.first().and_then(|s| s.first_node_id()),
248 Self::Parallel(_) => None,
249 Self::Loop { node_id, .. }
250 | Self::Branch { node_id, .. }
251 | Self::Remote { node_id, .. } => Some(node_id),
252 Self::Empty => None,
253 }
254 }
255}
256
257impl fmt::Display for ExecutionPlan {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 self.fmt_indent(f, 0)
260 }
261}
262
263impl ExecutionPlan {
264 fn fmt_indent(&self, f: &mut fmt::Formatter<'_>, indent: usize) -> fmt::Result {
265 let pad = " ".repeat(indent);
266 match self {
267 Self::Sequence(steps) => {
268 writeln!(f, "{pad}Sequence:")?;
269 for step in steps {
270 step.fmt_indent(f, indent + 1)?;
271 }
272 Ok(())
273 }
274 Self::Parallel(branches) => {
275 writeln!(f, "{pad}Parallel:")?;
276 for branch in branches {
277 branch.fmt_indent(f, indent + 1)?;
278 }
279 Ok(())
280 }
281 Self::Execute { node_id } => writeln!(f, "{pad}Execute({node_id})"),
282 Self::Cached { node_id, key } => writeln!(f, "{pad}Cached({node_id}, {key})"),
283 Self::Loop {
284 node_id,
285 body,
286 max_iterations,
287 } => {
288 writeln!(f, "{pad}Loop({node_id}, max={max_iterations:?}):")?;
289 body.fmt_indent(f, indent + 1)
290 }
291 Self::Branch { node_id, arms } => {
292 writeln!(f, "{pad}Branch({node_id}):")?;
293 for (label, plan) in arms {
294 writeln!(f, "{pad} [{label}]:")?;
295 plan.fmt_indent(f, indent + 2)?;
296 }
297 Ok(())
298 }
299 Self::Remote {
300 node_id,
301 target,
302 plan,
303 } => {
304 writeln!(f, "{pad}Remote({node_id}, target={target:?}):")?;
305 plan.fmt_indent(f, indent + 1)
306 }
307 Self::Empty => writeln!(f, "{pad}Empty"),
308 }
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn node_count_linear() {
318 let plan = ExecutionPlan::Sequence(vec![
319 ExecutionPlan::Execute {
320 node_id: "a".into(),
321 },
322 ExecutionPlan::Execute {
323 node_id: "b".into(),
324 },
325 ExecutionPlan::Execute {
326 node_id: "c".into(),
327 },
328 ]);
329 assert_eq!(plan.node_count(), 3);
330 assert_eq!(plan.cached_count(), 0);
331 }
332
333 #[test]
334 fn cached_count() {
335 let plan = ExecutionPlan::Sequence(vec![
336 ExecutionPlan::Cached {
337 node_id: "a".into(),
338 key: CacheKey::hash_data(b"a"),
339 },
340 ExecutionPlan::Execute {
341 node_id: "b".into(),
342 },
343 ExecutionPlan::Cached {
344 node_id: "c".into(),
345 key: CacheKey::hash_data(b"c"),
346 },
347 ]);
348 assert_eq!(plan.node_count(), 3);
349 assert_eq!(plan.cached_count(), 2);
350 }
351
352 #[test]
353 fn parallel_branch_count() {
354 let plan = ExecutionPlan::Sequence(vec![
355 ExecutionPlan::Execute {
356 node_id: "a".into(),
357 },
358 ExecutionPlan::Parallel(vec![
359 ExecutionPlan::Execute {
360 node_id: "b".into(),
361 },
362 ExecutionPlan::Execute {
363 node_id: "c".into(),
364 },
365 ExecutionPlan::Execute {
366 node_id: "d".into(),
367 },
368 ]),
369 ExecutionPlan::Execute {
370 node_id: "e".into(),
371 },
372 ]);
373 assert_eq!(plan.parallel_branch_count(), 3);
374 assert_eq!(plan.node_count(), 5);
375 }
376
377 #[test]
378 fn node_ids_collected() {
379 let plan = ExecutionPlan::Sequence(vec![
380 ExecutionPlan::Cached {
381 node_id: "a".into(),
382 key: CacheKey::hash_data(b"a"),
383 },
384 ExecutionPlan::Execute {
385 node_id: "b".into(),
386 },
387 ]);
388 let ids = plan.node_ids();
389 assert_eq!(ids, vec!["a", "b"]);
390 }
391
392 #[test]
393 fn simplify_removes_empty() {
394 let plan = ExecutionPlan::Sequence(vec![
395 ExecutionPlan::Empty,
396 ExecutionPlan::Execute {
397 node_id: "a".into(),
398 },
399 ExecutionPlan::Empty,
400 ]);
401 let simplified = plan.simplify();
402 assert!(matches!(simplified, ExecutionPlan::Execute { .. }));
403 }
404
405 #[test]
406 fn simplify_unwraps_single_element() {
407 let plan = ExecutionPlan::Sequence(vec![ExecutionPlan::Execute {
408 node_id: "a".into(),
409 }]);
410 let simplified = plan.simplify();
411 assert!(matches!(simplified, ExecutionPlan::Execute { .. }));
412 }
413
414 #[test]
415 fn simplify_preserves_multi() {
416 let plan = ExecutionPlan::Sequence(vec![
417 ExecutionPlan::Execute {
418 node_id: "a".into(),
419 },
420 ExecutionPlan::Execute {
421 node_id: "b".into(),
422 },
423 ]);
424 let simplified = plan.simplify();
425 assert!(matches!(simplified, ExecutionPlan::Sequence(_)));
426 }
427
428 #[test]
429 fn display_format() {
430 let plan = ExecutionPlan::Sequence(vec![
431 ExecutionPlan::Execute {
432 node_id: "scaler".into(),
433 },
434 ExecutionPlan::Parallel(vec![
435 ExecutionPlan::Execute {
436 node_id: "pca".into(),
437 },
438 ExecutionPlan::Execute {
439 node_id: "umap".into(),
440 },
441 ]),
442 ExecutionPlan::Execute {
443 node_id: "svm".into(),
444 },
445 ]);
446 let output = format!("{plan}");
447 assert!(output.contains("Sequence:"));
448 assert!(output.contains("Parallel:"));
449 assert!(output.contains("Execute(scaler)"));
450 assert!(output.contains("Execute(pca)"));
451 }
452
453 #[test]
454 fn summary_values() {
455 let plan = ExecutionPlan::Sequence(vec![
456 ExecutionPlan::Cached {
457 node_id: "a".into(),
458 key: CacheKey::hash_data(b"a"),
459 },
460 ExecutionPlan::Parallel(vec![
461 ExecutionPlan::Execute {
462 node_id: "b".into(),
463 },
464 ExecutionPlan::Execute {
465 node_id: "c".into(),
466 },
467 ]),
468 ExecutionPlan::Execute {
469 node_id: "d".into(),
470 },
471 ]);
472 let summary = plan.summary();
473 assert_eq!(summary.total_nodes, 4);
474 assert_eq!(summary.cached_nodes, 1);
475 assert_eq!(summary.parallel_branches, 2);
476 }
477
478 #[test]
479 fn serde_roundtrip() {
480 let plan = ExecutionPlan::Sequence(vec![
481 ExecutionPlan::Cached {
482 node_id: "a".into(),
483 key: CacheKey::hash_data(b"test"),
484 },
485 ExecutionPlan::Execute {
486 node_id: "b".into(),
487 },
488 ]);
489 let json = serde_json::to_string(&plan).unwrap();
490 let deserialized: ExecutionPlan = serde_json::from_str(&json).unwrap();
491 assert_eq!(deserialized.node_count(), 2);
492 }
493
494 #[test]
495 fn empty_plan() {
496 let plan = ExecutionPlan::Empty;
497 assert_eq!(plan.node_count(), 0);
498 assert_eq!(plan.cached_count(), 0);
499 assert!(plan.node_ids().is_empty());
500 }
501
502 #[test]
503 fn to_mermaid_sequence() {
504 let plan = ExecutionPlan::Sequence(vec![
505 ExecutionPlan::Execute {
506 node_id: "scaler".into(),
507 },
508 ExecutionPlan::Execute {
509 node_id: "model".into(),
510 },
511 ]);
512 let m = plan.to_mermaid();
513 assert!(m.starts_with("graph TD"));
514 assert!(m.contains("scaler[scaler]"));
515 assert!(m.contains("model[model]"));
516 assert!(m.contains("scaler --> model"));
517 }
518
519 #[test]
520 fn to_mermaid_parallel() {
521 let plan = ExecutionPlan::Parallel(vec![
522 ExecutionPlan::Execute {
523 node_id: "a".into(),
524 },
525 ExecutionPlan::Execute {
526 node_id: "b".into(),
527 },
528 ]);
529 let m = plan.to_mermaid();
530 assert!(m.contains("fork_0{"));
531 assert!(m.contains("fork_0 --> a"));
532 assert!(m.contains("fork_0 --> b"));
533 }
534
535 #[test]
536 fn to_mermaid_cached() {
537 let plan = ExecutionPlan::Cached {
538 node_id: "x".into(),
539 key: CacheKey::hash_data(b"x"),
540 };
541 let m = plan.to_mermaid();
542 assert!(m.contains("x[/x cached/]"));
543 }
544}