Skip to main content

somatize_compiler/
plan.rs

1//! Execution plan — the compiled representation of a pipeline.
2//!
3//! Variants: Sequence, Parallel, Execute, Cached, Loop, Branch, Remote, Empty.
4//! Plans are data-free (no filter implementations) and serializable.
5
6use serde::{Deserialize, Serialize};
7use somatize_core::cache::CacheKey;
8use somatize_core::filter::RemoteTarget;
9use somatize_core::graph::NodeId;
10use std::fmt;
11
12/// A compiled execution plan produced by the compiler.
13///
14/// This is a recursive tree that the runtime walks to execute a pipeline.
15/// The compiler resolves caching, parallelism, and distribution before
16/// the runtime sees the plan.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[non_exhaustive]
19pub enum ExecutionPlan {
20    /// Execute steps sequentially, one after another.
21    Sequence(Vec<ExecutionPlan>),
22
23    /// Execute branches concurrently (fork-join).
24    Parallel(Vec<ExecutionPlan>),
25
26    /// Execute a single filter node.
27    Execute { node_id: NodeId },
28
29    /// Load result from cache (resolved at compile time).
30    Cached { node_id: NodeId, key: CacheKey },
31
32    /// Iterate: execute body for each item in a collection.
33    Loop {
34        node_id: NodeId,
35        body: Box<ExecutionPlan>,
36        max_iterations: Option<usize>,
37    },
38
39    /// Conditional branching: evaluate condition, pick an arm.
40    Branch {
41        node_id: NodeId,
42        arms: Vec<(String, ExecutionPlan)>,
43    },
44
45    /// Execute a sub-plan on a remote worker.
46    Remote {
47        node_id: NodeId,
48        target: RemoteTarget,
49        plan: Box<ExecutionPlan>,
50    },
51
52    /// No-op: nothing to execute (e.g. empty graph).
53    Empty,
54}
55
56impl ExecutionPlan {
57    /// Count total nodes in the plan (Execute + Cached).
58    pub fn node_count(&self) -> usize {
59        match self {
60            Self::Execute { .. } | Self::Cached { .. } => 1,
61            Self::Sequence(steps) | Self::Parallel(steps) => {
62                steps.iter().map(|s| s.node_count()).sum()
63            }
64            Self::Loop { body, .. } => 1 + body.node_count(),
65            Self::Branch { arms, .. } => {
66                1 + arms.iter().map(|(_, p)| p.node_count()).sum::<usize>()
67            }
68            Self::Remote { plan, .. } => plan.node_count(),
69            Self::Empty => 0,
70        }
71    }
72
73    /// Count cached nodes in the plan.
74    pub fn cached_count(&self) -> usize {
75        match self {
76            Self::Cached { .. } => 1,
77            Self::Execute { .. } => 0,
78            Self::Sequence(steps) | Self::Parallel(steps) => {
79                steps.iter().map(|s| s.cached_count()).sum()
80            }
81            Self::Loop { body, .. } => body.cached_count(),
82            Self::Branch { arms, .. } => arms.iter().map(|(_, p)| p.cached_count()).sum(),
83            Self::Remote { plan, .. } => plan.cached_count(),
84            Self::Empty => 0,
85        }
86    }
87
88    /// Count parallel branches at the top level of the plan.
89    pub fn parallel_branch_count(&self) -> usize {
90        match self {
91            Self::Parallel(branches) => branches.len(),
92            Self::Sequence(steps) => steps.iter().map(|s| s.parallel_branch_count()).sum(),
93            _ => 0,
94        }
95    }
96
97    /// Collect all node IDs referenced in the plan.
98    pub fn node_ids(&self) -> Vec<&str> {
99        match self {
100            Self::Execute { node_id } | Self::Cached { node_id, .. } => vec![node_id.as_str()],
101            Self::Sequence(steps) | Self::Parallel(steps) => {
102                steps.iter().flat_map(|s| s.node_ids()).collect()
103            }
104            Self::Loop { node_id, body, .. } => {
105                let mut ids = vec![node_id.as_str()];
106                ids.extend(body.node_ids());
107                ids
108            }
109            Self::Branch { node_id, arms, .. } => {
110                let mut ids = vec![node_id.as_str()];
111                for (_, p) in arms {
112                    ids.extend(p.node_ids());
113                }
114                ids
115            }
116            Self::Remote { node_id, plan, .. } => {
117                let mut ids = vec![node_id.as_str()];
118                ids.extend(plan.node_ids());
119                ids
120            }
121            Self::Empty => vec![],
122        }
123    }
124
125    /// Create a PlanSummary for event payloads.
126    pub fn summary(&self) -> somatize_core::event::PlanSummary {
127        somatize_core::event::PlanSummary {
128            total_nodes: self.node_count(),
129            cached_nodes: self.cached_count(),
130            parallel_branches: self.parallel_branch_count(),
131        }
132    }
133
134    /// Flatten unnecessary nesting (e.g. Sequence of one element).
135    pub fn simplify(self) -> Self {
136        match self {
137            Self::Sequence(mut steps) => {
138                steps = steps.into_iter().map(|s| s.simplify()).collect();
139                steps.retain(|s| !matches!(s, Self::Empty));
140                match steps.len() {
141                    0 => Self::Empty,
142                    1 => steps.into_iter().next().unwrap(),
143                    _ => Self::Sequence(steps),
144                }
145            }
146            Self::Parallel(mut branches) => {
147                branches = branches.into_iter().map(|b| b.simplify()).collect();
148                branches.retain(|b| !matches!(b, Self::Empty));
149                match branches.len() {
150                    0 => Self::Empty,
151                    1 => branches.into_iter().next().unwrap(),
152                    _ => Self::Parallel(branches),
153                }
154            }
155            other => other,
156        }
157    }
158}
159
160impl ExecutionPlan {
161    /// Render the execution plan as a Mermaid flowchart.
162    pub fn to_mermaid(&self) -> String {
163        let mut out = String::from("graph TD\n");
164        let mut counter = 0;
165        self.mermaid_nodes(&mut out, &mut counter, None);
166        out
167    }
168
169    fn mermaid_nodes(&self, out: &mut String, counter: &mut usize, parent: Option<&str>) {
170        use std::fmt::Write;
171        match self {
172            Self::Execute { node_id } => {
173                let _ = writeln!(out, "    {node_id}[{node_id}]");
174                if let Some(p) = parent {
175                    let _ = writeln!(out, "    {p} --> {node_id}");
176                }
177            }
178            Self::Cached { node_id, .. } => {
179                let _ = writeln!(out, "    {node_id}[/{node_id} cached/]");
180                if let Some(p) = parent {
181                    let _ = writeln!(out, "    {p} --> {node_id}");
182                }
183            }
184            Self::Sequence(steps) => {
185                let mut prev = parent.map(String::from);
186                for step in steps {
187                    step.mermaid_nodes(out, counter, prev.as_deref());
188                    prev = step.first_node_id().map(String::from);
189                }
190            }
191            Self::Parallel(branches) => {
192                let fork_id = format!("fork_{counter}");
193                *counter += 1;
194                let _ = writeln!(out, "    {fork_id}{{{{fork}}}}");
195                if let Some(p) = parent {
196                    let _ = writeln!(out, "    {p} --> {fork_id}");
197                }
198                for branch in branches {
199                    branch.mermaid_nodes(out, counter, Some(&fork_id));
200                }
201            }
202            Self::Loop {
203                node_id,
204                body,
205                max_iterations,
206            } => {
207                let label = match max_iterations {
208                    Some(n) => format!("{node_id} loop max={n}"),
209                    None => format!("{node_id} loop"),
210                };
211                let _ = writeln!(out, "    {node_id}(({label}))");
212                if let Some(p) = parent {
213                    let _ = writeln!(out, "    {p} --> {node_id}");
214                }
215                body.mermaid_nodes(out, counter, Some(node_id));
216            }
217            Self::Branch { node_id, arms } => {
218                let _ = writeln!(out, "    {node_id}{{{{{node_id}}}}}");
219                if let Some(p) = parent {
220                    let _ = writeln!(out, "    {p} --> {node_id}");
221                }
222                for (label, plan) in arms {
223                    let arm_id = format!("arm_{counter}");
224                    *counter += 1;
225                    let _ = writeln!(out, "    {node_id} -->|{label}| {arm_id}[{label}]");
226                    plan.mermaid_nodes(out, counter, Some(&arm_id));
227                }
228            }
229            Self::Remote {
230                node_id,
231                target,
232                plan,
233            } => {
234                let _ = writeln!(out, "    {node_id}>{{{node_id} remote: {target:?}}}]");
235                if let Some(p) = parent {
236                    let _ = writeln!(out, "    {p} --> {node_id}");
237                }
238                plan.mermaid_nodes(out, counter, Some(node_id));
239            }
240            Self::Empty => {}
241        }
242    }
243
244    fn first_node_id(&self) -> Option<&str> {
245        match self {
246            Self::Execute { node_id } | Self::Cached { node_id, .. } => Some(node_id),
247            Self::Sequence(steps) => steps.first().and_then(|s| s.first_node_id()),
248            Self::Parallel(_) => None,
249            Self::Loop { node_id, .. }
250            | Self::Branch { node_id, .. }
251            | Self::Remote { node_id, .. } => Some(node_id),
252            Self::Empty => None,
253        }
254    }
255}
256
257impl fmt::Display for ExecutionPlan {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        self.fmt_indent(f, 0)
260    }
261}
262
263impl ExecutionPlan {
264    fn fmt_indent(&self, f: &mut fmt::Formatter<'_>, indent: usize) -> fmt::Result {
265        let pad = "  ".repeat(indent);
266        match self {
267            Self::Sequence(steps) => {
268                writeln!(f, "{pad}Sequence:")?;
269                for step in steps {
270                    step.fmt_indent(f, indent + 1)?;
271                }
272                Ok(())
273            }
274            Self::Parallel(branches) => {
275                writeln!(f, "{pad}Parallel:")?;
276                for branch in branches {
277                    branch.fmt_indent(f, indent + 1)?;
278                }
279                Ok(())
280            }
281            Self::Execute { node_id } => writeln!(f, "{pad}Execute({node_id})"),
282            Self::Cached { node_id, key } => writeln!(f, "{pad}Cached({node_id}, {key})"),
283            Self::Loop {
284                node_id,
285                body,
286                max_iterations,
287            } => {
288                writeln!(f, "{pad}Loop({node_id}, max={max_iterations:?}):")?;
289                body.fmt_indent(f, indent + 1)
290            }
291            Self::Branch { node_id, arms } => {
292                writeln!(f, "{pad}Branch({node_id}):")?;
293                for (label, plan) in arms {
294                    writeln!(f, "{pad}  [{label}]:")?;
295                    plan.fmt_indent(f, indent + 2)?;
296                }
297                Ok(())
298            }
299            Self::Remote {
300                node_id,
301                target,
302                plan,
303            } => {
304                writeln!(f, "{pad}Remote({node_id}, target={target:?}):")?;
305                plan.fmt_indent(f, indent + 1)
306            }
307            Self::Empty => writeln!(f, "{pad}Empty"),
308        }
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn node_count_linear() {
318        let plan = ExecutionPlan::Sequence(vec![
319            ExecutionPlan::Execute {
320                node_id: "a".into(),
321            },
322            ExecutionPlan::Execute {
323                node_id: "b".into(),
324            },
325            ExecutionPlan::Execute {
326                node_id: "c".into(),
327            },
328        ]);
329        assert_eq!(plan.node_count(), 3);
330        assert_eq!(plan.cached_count(), 0);
331    }
332
333    #[test]
334    fn cached_count() {
335        let plan = ExecutionPlan::Sequence(vec![
336            ExecutionPlan::Cached {
337                node_id: "a".into(),
338                key: CacheKey::hash_data(b"a"),
339            },
340            ExecutionPlan::Execute {
341                node_id: "b".into(),
342            },
343            ExecutionPlan::Cached {
344                node_id: "c".into(),
345                key: CacheKey::hash_data(b"c"),
346            },
347        ]);
348        assert_eq!(plan.node_count(), 3);
349        assert_eq!(plan.cached_count(), 2);
350    }
351
352    #[test]
353    fn parallel_branch_count() {
354        let plan = ExecutionPlan::Sequence(vec![
355            ExecutionPlan::Execute {
356                node_id: "a".into(),
357            },
358            ExecutionPlan::Parallel(vec![
359                ExecutionPlan::Execute {
360                    node_id: "b".into(),
361                },
362                ExecutionPlan::Execute {
363                    node_id: "c".into(),
364                },
365                ExecutionPlan::Execute {
366                    node_id: "d".into(),
367                },
368            ]),
369            ExecutionPlan::Execute {
370                node_id: "e".into(),
371            },
372        ]);
373        assert_eq!(plan.parallel_branch_count(), 3);
374        assert_eq!(plan.node_count(), 5);
375    }
376
377    #[test]
378    fn node_ids_collected() {
379        let plan = ExecutionPlan::Sequence(vec![
380            ExecutionPlan::Cached {
381                node_id: "a".into(),
382                key: CacheKey::hash_data(b"a"),
383            },
384            ExecutionPlan::Execute {
385                node_id: "b".into(),
386            },
387        ]);
388        let ids = plan.node_ids();
389        assert_eq!(ids, vec!["a", "b"]);
390    }
391
392    #[test]
393    fn simplify_removes_empty() {
394        let plan = ExecutionPlan::Sequence(vec![
395            ExecutionPlan::Empty,
396            ExecutionPlan::Execute {
397                node_id: "a".into(),
398            },
399            ExecutionPlan::Empty,
400        ]);
401        let simplified = plan.simplify();
402        assert!(matches!(simplified, ExecutionPlan::Execute { .. }));
403    }
404
405    #[test]
406    fn simplify_unwraps_single_element() {
407        let plan = ExecutionPlan::Sequence(vec![ExecutionPlan::Execute {
408            node_id: "a".into(),
409        }]);
410        let simplified = plan.simplify();
411        assert!(matches!(simplified, ExecutionPlan::Execute { .. }));
412    }
413
414    #[test]
415    fn simplify_preserves_multi() {
416        let plan = ExecutionPlan::Sequence(vec![
417            ExecutionPlan::Execute {
418                node_id: "a".into(),
419            },
420            ExecutionPlan::Execute {
421                node_id: "b".into(),
422            },
423        ]);
424        let simplified = plan.simplify();
425        assert!(matches!(simplified, ExecutionPlan::Sequence(_)));
426    }
427
428    #[test]
429    fn display_format() {
430        let plan = ExecutionPlan::Sequence(vec![
431            ExecutionPlan::Execute {
432                node_id: "scaler".into(),
433            },
434            ExecutionPlan::Parallel(vec![
435                ExecutionPlan::Execute {
436                    node_id: "pca".into(),
437                },
438                ExecutionPlan::Execute {
439                    node_id: "umap".into(),
440                },
441            ]),
442            ExecutionPlan::Execute {
443                node_id: "svm".into(),
444            },
445        ]);
446        let output = format!("{plan}");
447        assert!(output.contains("Sequence:"));
448        assert!(output.contains("Parallel:"));
449        assert!(output.contains("Execute(scaler)"));
450        assert!(output.contains("Execute(pca)"));
451    }
452
453    #[test]
454    fn summary_values() {
455        let plan = ExecutionPlan::Sequence(vec![
456            ExecutionPlan::Cached {
457                node_id: "a".into(),
458                key: CacheKey::hash_data(b"a"),
459            },
460            ExecutionPlan::Parallel(vec![
461                ExecutionPlan::Execute {
462                    node_id: "b".into(),
463                },
464                ExecutionPlan::Execute {
465                    node_id: "c".into(),
466                },
467            ]),
468            ExecutionPlan::Execute {
469                node_id: "d".into(),
470            },
471        ]);
472        let summary = plan.summary();
473        assert_eq!(summary.total_nodes, 4);
474        assert_eq!(summary.cached_nodes, 1);
475        assert_eq!(summary.parallel_branches, 2);
476    }
477
478    #[test]
479    fn serde_roundtrip() {
480        let plan = ExecutionPlan::Sequence(vec![
481            ExecutionPlan::Cached {
482                node_id: "a".into(),
483                key: CacheKey::hash_data(b"test"),
484            },
485            ExecutionPlan::Execute {
486                node_id: "b".into(),
487            },
488        ]);
489        let json = serde_json::to_string(&plan).unwrap();
490        let deserialized: ExecutionPlan = serde_json::from_str(&json).unwrap();
491        assert_eq!(deserialized.node_count(), 2);
492    }
493
494    #[test]
495    fn empty_plan() {
496        let plan = ExecutionPlan::Empty;
497        assert_eq!(plan.node_count(), 0);
498        assert_eq!(plan.cached_count(), 0);
499        assert!(plan.node_ids().is_empty());
500    }
501
502    #[test]
503    fn to_mermaid_sequence() {
504        let plan = ExecutionPlan::Sequence(vec![
505            ExecutionPlan::Execute {
506                node_id: "scaler".into(),
507            },
508            ExecutionPlan::Execute {
509                node_id: "model".into(),
510            },
511        ]);
512        let m = plan.to_mermaid();
513        assert!(m.starts_with("graph TD"));
514        assert!(m.contains("scaler[scaler]"));
515        assert!(m.contains("model[model]"));
516        assert!(m.contains("scaler --> model"));
517    }
518
519    #[test]
520    fn to_mermaid_parallel() {
521        let plan = ExecutionPlan::Parallel(vec![
522            ExecutionPlan::Execute {
523                node_id: "a".into(),
524            },
525            ExecutionPlan::Execute {
526                node_id: "b".into(),
527            },
528        ]);
529        let m = plan.to_mermaid();
530        assert!(m.contains("fork_0{"));
531        assert!(m.contains("fork_0 --> a"));
532        assert!(m.contains("fork_0 --> b"));
533    }
534
535    #[test]
536    fn to_mermaid_cached() {
537        let plan = ExecutionPlan::Cached {
538            node_id: "x".into(),
539            key: CacheKey::hash_data(b"x"),
540        };
541        let m = plan.to_mermaid();
542        assert!(m.contains("x[/x cached/]"));
543    }
544}