1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4pub const WORKFLOW_IR_V0: &str = "v0";
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9pub struct WorkflowDefinition {
10 #[serde(default = "default_version")]
12 pub version: String,
13 pub name: String,
15 pub nodes: Vec<Node>,
17}
18
19fn default_version() -> String {
20 WORKFLOW_IR_V0.to_string()
21}
22
23impl WorkflowDefinition {
24 pub fn normalized(&self) -> Self {
26 let mut normalized = self.clone();
27 normalized.version = normalized.version.trim().to_string();
28 normalized.name = normalized.name.trim().to_string();
29
30 normalized.nodes = normalized
31 .nodes
32 .iter()
33 .cloned()
34 .map(|node| node.normalized())
35 .collect();
36 normalized.nodes.sort_by(|a, b| a.id.cmp(&b.id));
37 normalized
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
43pub struct Node {
44 pub id: String,
46 #[serde(flatten)]
48 pub kind: NodeKind,
49}
50
51impl Node {
52 pub fn normalized(mut self) -> Self {
54 self.id = self.id.trim().to_string();
55 self.kind = self.kind.normalized();
56 self
57 }
58
59 pub fn outgoing_edges(&self) -> Vec<&str> {
61 self.kind.outgoing_edges()
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67#[serde(tag = "type", rename_all = "snake_case")]
68pub enum NodeKind {
69 Start { next: String },
71 Llm {
73 model: String,
74 prompt: String,
75 next: Option<String>,
76 },
77 Tool {
79 tool: String,
80 #[serde(default)]
81 input: Value,
82 next: Option<String>,
83 },
84 #[serde(alias = "switch", alias = "if")]
86 Condition {
87 expression: String,
88 on_true: String,
89 on_false: String,
90 },
91 Debounce {
93 key_path: String,
94 window_steps: u32,
95 next: String,
96 on_suppressed: Option<String>,
97 },
98 Throttle {
100 key_path: String,
101 window_steps: u32,
102 next: String,
103 on_throttled: Option<String>,
104 },
105 RetryCompensate {
107 tool: String,
108 #[serde(default)]
109 input: Value,
110 max_retries: usize,
111 compensate_tool: String,
112 #[serde(default)]
113 compensate_input: Value,
114 next: String,
115 on_compensated: Option<String>,
116 },
117 HumanInTheLoop {
119 decision_path: String,
120 response_path: Option<String>,
121 on_approve: String,
122 on_reject: String,
123 },
124 CacheWrite {
126 key_path: String,
127 value_path: String,
128 next: String,
129 },
130 CacheRead {
132 key_path: String,
133 next: String,
134 on_miss: Option<String>,
135 },
136 EventTrigger {
138 event: String,
139 event_path: String,
140 next: String,
141 on_mismatch: Option<String>,
142 },
143 Router {
145 routes: Vec<RouterRoute>,
146 default: String,
147 },
148 Transform { expression: String, next: String },
150 Loop {
152 condition: String,
153 body: String,
154 next: String,
155 max_iterations: Option<u32>,
156 },
157 Subgraph { graph: String, next: Option<String> },
159 Batch { items_path: String, next: String },
161 Filter {
163 items_path: String,
164 expression: String,
165 next: String,
166 },
167 Parallel {
169 branches: Vec<String>,
170 next: String,
171 max_in_flight: Option<usize>,
172 },
173 Merge {
175 sources: Vec<String>,
176 policy: MergePolicy,
177 quorum: Option<usize>,
178 next: String,
179 },
180 Map {
182 tool: String,
183 items_path: String,
184 next: String,
185 max_in_flight: Option<usize>,
186 },
187 Reduce {
189 source: String,
190 operation: ReduceOperation,
191 next: String,
192 },
193 End,
195}
196
197#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
199#[serde(rename_all = "snake_case")]
200pub enum MergePolicy {
201 First,
202 All,
203 Quorum,
204}
205
206#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub enum ReduceOperation {
210 Count,
211 Sum,
212}
213
214#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
216pub struct RouterRoute {
217 pub when: String,
219 pub next: String,
221}
222
223impl NodeKind {
224 fn normalized(self) -> Self {
225 match self {
226 Self::Start { next } => Self::Start {
227 next: next.trim().to_string(),
228 },
229 Self::Llm {
230 model,
231 prompt,
232 next,
233 } => Self::Llm {
234 model: model.trim().to_string(),
235 prompt: prompt.trim().to_string(),
236 next: next.map(|edge| edge.trim().to_string()),
237 },
238 Self::Tool { tool, input, next } => Self::Tool {
239 tool: tool.trim().to_string(),
240 input,
241 next: next.map(|edge| edge.trim().to_string()),
242 },
243 Self::Condition {
244 expression,
245 on_true,
246 on_false,
247 } => Self::Condition {
248 expression: expression.trim().to_string(),
249 on_true: on_true.trim().to_string(),
250 on_false: on_false.trim().to_string(),
251 },
252 Self::Debounce {
253 key_path,
254 window_steps,
255 next,
256 on_suppressed,
257 } => Self::Debounce {
258 key_path: key_path.trim().to_string(),
259 window_steps,
260 next: next.trim().to_string(),
261 on_suppressed: on_suppressed.map(|edge| edge.trim().to_string()),
262 },
263 Self::Throttle {
264 key_path,
265 window_steps,
266 next,
267 on_throttled,
268 } => Self::Throttle {
269 key_path: key_path.trim().to_string(),
270 window_steps,
271 next: next.trim().to_string(),
272 on_throttled: on_throttled.map(|edge| edge.trim().to_string()),
273 },
274 Self::RetryCompensate {
275 tool,
276 input,
277 max_retries,
278 compensate_tool,
279 compensate_input,
280 next,
281 on_compensated,
282 } => Self::RetryCompensate {
283 tool: tool.trim().to_string(),
284 input,
285 max_retries,
286 compensate_tool: compensate_tool.trim().to_string(),
287 compensate_input,
288 next: next.trim().to_string(),
289 on_compensated: on_compensated.map(|edge| edge.trim().to_string()),
290 },
291 Self::HumanInTheLoop {
292 decision_path,
293 response_path,
294 on_approve,
295 on_reject,
296 } => Self::HumanInTheLoop {
297 decision_path: decision_path.trim().to_string(),
298 response_path: response_path.map(|path| path.trim().to_string()),
299 on_approve: on_approve.trim().to_string(),
300 on_reject: on_reject.trim().to_string(),
301 },
302 Self::CacheWrite {
303 key_path,
304 value_path,
305 next,
306 } => Self::CacheWrite {
307 key_path: key_path.trim().to_string(),
308 value_path: value_path.trim().to_string(),
309 next: next.trim().to_string(),
310 },
311 Self::CacheRead {
312 key_path,
313 next,
314 on_miss,
315 } => Self::CacheRead {
316 key_path: key_path.trim().to_string(),
317 next: next.trim().to_string(),
318 on_miss: on_miss.map(|edge| edge.trim().to_string()),
319 },
320 Self::EventTrigger {
321 event,
322 event_path,
323 next,
324 on_mismatch,
325 } => Self::EventTrigger {
326 event: event.trim().to_string(),
327 event_path: event_path.trim().to_string(),
328 next: next.trim().to_string(),
329 on_mismatch: on_mismatch.map(|edge| edge.trim().to_string()),
330 },
331 Self::Router { routes, default } => Self::Router {
332 routes: routes
333 .into_iter()
334 .map(|route| RouterRoute {
335 when: route.when.trim().to_string(),
336 next: route.next.trim().to_string(),
337 })
338 .collect(),
339 default: default.trim().to_string(),
340 },
341 Self::Transform { expression, next } => Self::Transform {
342 expression: expression.trim().to_string(),
343 next: next.trim().to_string(),
344 },
345 Self::Loop {
346 condition,
347 body,
348 next,
349 max_iterations,
350 } => Self::Loop {
351 condition: condition.trim().to_string(),
352 body: body.trim().to_string(),
353 next: next.trim().to_string(),
354 max_iterations,
355 },
356 Self::Subgraph { graph, next } => Self::Subgraph {
357 graph: graph.trim().to_string(),
358 next: next.map(|edge| edge.trim().to_string()),
359 },
360 Self::Batch { items_path, next } => Self::Batch {
361 items_path: items_path.trim().to_string(),
362 next: next.trim().to_string(),
363 },
364 Self::Filter {
365 items_path,
366 expression,
367 next,
368 } => Self::Filter {
369 items_path: items_path.trim().to_string(),
370 expression: expression.trim().to_string(),
371 next: next.trim().to_string(),
372 },
373 Self::Parallel {
374 branches,
375 next,
376 max_in_flight,
377 } => Self::Parallel {
378 branches: branches
379 .into_iter()
380 .map(|edge| edge.trim().to_string())
381 .collect(),
382 next: next.trim().to_string(),
383 max_in_flight,
384 },
385 Self::Merge {
386 sources,
387 policy,
388 quorum,
389 next,
390 } => Self::Merge {
391 sources: sources
392 .into_iter()
393 .map(|id| id.trim().to_string())
394 .collect(),
395 policy,
396 quorum,
397 next: next.trim().to_string(),
398 },
399 Self::Map {
400 tool,
401 items_path,
402 next,
403 max_in_flight,
404 } => Self::Map {
405 tool: tool.trim().to_string(),
406 items_path: items_path.trim().to_string(),
407 next: next.trim().to_string(),
408 max_in_flight,
409 },
410 Self::Reduce {
411 source,
412 operation,
413 next,
414 } => Self::Reduce {
415 source: source.trim().to_string(),
416 operation,
417 next: next.trim().to_string(),
418 },
419 Self::End => Self::End,
420 }
421 }
422
423 fn outgoing_edges(&self) -> Vec<&str> {
424 match self {
425 Self::Start { next } => vec![next.as_str()],
426 Self::Llm { next, .. } | Self::Tool { next, .. } => {
427 next.as_deref().map_or_else(Vec::new, |edge| vec![edge])
428 }
429 Self::Condition {
430 on_true, on_false, ..
431 } => vec![on_true.as_str(), on_false.as_str()],
432 Self::Debounce {
433 next,
434 on_suppressed,
435 ..
436 } => {
437 let mut edges = vec![next.as_str()];
438 if let Some(edge) = on_suppressed.as_deref() {
439 edges.push(edge);
440 }
441 edges
442 }
443 Self::Throttle {
444 next, on_throttled, ..
445 } => {
446 let mut edges = vec![next.as_str()];
447 if let Some(edge) = on_throttled.as_deref() {
448 edges.push(edge);
449 }
450 edges
451 }
452 Self::RetryCompensate {
453 next,
454 on_compensated,
455 ..
456 }
457 | Self::CacheRead {
458 next,
459 on_miss: on_compensated,
460 ..
461 }
462 | Self::EventTrigger {
463 next,
464 on_mismatch: on_compensated,
465 ..
466 } => {
467 let mut edges = vec![next.as_str()];
468 if let Some(edge) = on_compensated.as_deref() {
469 edges.push(edge);
470 }
471 edges
472 }
473 Self::HumanInTheLoop {
474 on_approve,
475 on_reject,
476 ..
477 } => vec![on_approve.as_str(), on_reject.as_str()],
478 Self::CacheWrite { next, .. } | Self::Transform { next, .. } => vec![next.as_str()],
479 Self::Router { routes, default } => {
480 let mut edges = routes
481 .iter()
482 .map(|route| route.next.as_str())
483 .collect::<Vec<_>>();
484 edges.push(default.as_str());
485 edges
486 }
487 Self::Loop { body, next, .. } => vec![body.as_str(), next.as_str()],
488 Self::Subgraph { next, .. } => next.as_deref().map_or_else(Vec::new, |edge| vec![edge]),
489 Self::Batch { next, .. } | Self::Filter { next, .. } => vec![next.as_str()],
490 Self::Parallel { branches, next, .. } => {
491 let mut edges = branches.iter().map(String::as_str).collect::<Vec<_>>();
492 edges.push(next.as_str());
493 edges
494 }
495 Self::Merge { next, .. } | Self::Map { next, .. } | Self::Reduce { next, .. } => {
496 vec![next.as_str()]
497 }
498 Self::End => Vec::new(),
499 }
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use serde_json::json;
506
507 use super::NodeKind;
508
509 #[test]
510 fn condition_deserializes_switch_alias() {
511 let kind: NodeKind = serde_json::from_value(json!({
512 "type": "switch",
513 "expression": "input.ok == true",
514 "on_true": "end_true",
515 "on_false": "end_false"
516 }))
517 .expect("switch alias should deserialize");
518
519 assert!(matches!(kind, NodeKind::Condition { .. }));
520 }
521
522 #[test]
523 fn condition_deserializes_if_alias() {
524 let kind: NodeKind = serde_json::from_value(json!({
525 "type": "if",
526 "expression": "input.ok == true",
527 "on_true": "end_true",
528 "on_false": "end_false"
529 }))
530 .expect("if alias should deserialize");
531
532 assert!(matches!(kind, NodeKind::Condition { .. }));
533 }
534}