1use super::{Connection, ConnectionMode, EngineMode, Node, Pipeline};
13use indexmap::IndexMap;
14use serde::Deserialize;
15
16#[derive(Debug, Deserialize)]
18pub struct Step {
19 pub kind: String,
20 pub params: Option<serde_json::Value>,
21}
22
23#[derive(Debug, Deserialize)]
25pub struct UserNode {
26 pub kind: String,
27 pub params: Option<serde_json::Value>,
28 #[serde(default)]
29 pub needs: Needs,
30}
31
32#[derive(Debug, Deserialize)]
34#[serde(untagged)]
35pub enum NeedsDependency {
36 Simple(String),
38 WithMode {
40 node: String,
41 #[serde(default)]
42 mode: ConnectionMode,
43 },
44}
45
46impl NeedsDependency {
47 fn node(&self) -> &str {
48 match self {
49 Self::Simple(s) => s,
50 Self::WithMode { node, .. } => node,
51 }
52 }
53
54 fn mode(&self) -> ConnectionMode {
55 match self {
56 Self::Simple(_) => ConnectionMode::default(),
57 Self::WithMode { mode, .. } => *mode,
58 }
59 }
60}
61
62#[derive(Debug, Deserialize, Default)]
64#[serde(untagged)]
65pub enum Needs {
66 #[default]
67 None,
68 Single(NeedsDependency),
69 Multiple(Vec<NeedsDependency>),
70}
71
72#[derive(Debug, Deserialize)]
76#[serde(untagged)]
77pub enum UserPipeline {
78 Steps {
79 #[serde(skip_serializing_if = "Option::is_none")]
80 name: Option<String>,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 description: Option<String>,
83 #[serde(default)]
84 mode: EngineMode,
85 steps: Vec<Step>,
86 },
87 Dag {
88 #[serde(skip_serializing_if = "Option::is_none")]
89 name: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 description: Option<String>,
92 #[serde(default)]
93 mode: EngineMode,
94 nodes: IndexMap<String, UserNode>,
95 },
96}
97
98pub fn compile(pipeline: UserPipeline) -> Result<Pipeline, String> {
104 match pipeline {
105 UserPipeline::Steps { name, description, mode, steps } => {
106 Ok(compile_steps(name, description, mode, steps))
107 },
108 UserPipeline::Dag { name, description, mode, nodes } => {
109 compile_dag(name, description, mode, nodes)
110 },
111 }
112}
113
114fn compile_steps(
116 name: Option<String>,
117 description: Option<String>,
118 mode: EngineMode,
119 steps: Vec<Step>,
120) -> Pipeline {
121 let mut nodes = IndexMap::new();
122 let mut connections = Vec::new();
123
124 for (i, step) in steps.into_iter().enumerate() {
125 let node_name = format!("step_{i}");
126
127 if i > 0 {
129 connections.push(Connection {
130 from_node: format!("step_{}", i - 1),
131 from_pin: "out".to_string(),
132 to_node: node_name.clone(),
133 to_pin: "in".to_string(),
134 mode: ConnectionMode::default(),
135 });
136 }
137
138 nodes.insert(node_name, Node { kind: step.kind, params: step.params, state: None });
139 }
140
141 Pipeline { name, description, mode, nodes, connections }
142}
143
144const BIDIRECTIONAL_NODE_KINDS: &[&str] = &["transport::moq::peer"];
148
149fn is_bidirectional_kind(kind: &str) -> bool {
151 BIDIRECTIONAL_NODE_KINDS.contains(&kind)
152}
153
154fn detect_cycles(user_nodes: &IndexMap<String, UserNode>) -> Result<(), String> {
160 use std::collections::HashSet;
161
162 fn dfs<'a>(
165 node: &'a String,
166 adjacency: &IndexMap<&'a String, Vec<&'a String>>,
167 visited: &mut HashSet<&'a String>,
168 rec_stack: &mut HashSet<&'a String>,
169 cycle_path: &mut Vec<&'a String>,
170 ) -> Option<(Vec<&'a String>, String)> {
171 visited.insert(node);
172 rec_stack.insert(node);
173 cycle_path.push(node);
174
175 if let Some(neighbors) = adjacency.get(node) {
176 for neighbor in neighbors {
177 if !visited.contains(neighbor) {
178 if let Some(cycle) = dfs(neighbor, adjacency, visited, rec_stack, cycle_path) {
179 rec_stack.remove(node);
181 cycle_path.pop();
182 return Some(cycle);
183 }
184 } else if rec_stack.contains(neighbor) {
185 let cycle_start_idx =
187 cycle_path.iter().position(|&n| n == *neighbor).unwrap_or(0);
188 let cycle_nodes: Vec<&'a String> = cycle_path[cycle_start_idx..].to_vec();
189 let cycle_strs: Vec<&str> = cycle_nodes.iter().map(|s| s.as_str()).collect();
190 let description = format!(
191 "Circular dependency detected: {} -> {}",
192 cycle_strs.join(" -> "),
193 neighbor
194 );
195 rec_stack.remove(node);
197 cycle_path.pop();
198 return Some((cycle_nodes, description));
199 }
200 }
201 }
202
203 rec_stack.remove(node);
204 cycle_path.pop();
205 None
206 }
207
208 let mut adjacency: IndexMap<&String, Vec<&String>> = IndexMap::new();
212
213 for (node_name, node_def) in user_nodes {
214 adjacency.entry(node_name).or_default();
215
216 let dependencies: Vec<&str> = match &node_def.needs {
217 Needs::None => vec![],
218 Needs::Single(dep) => vec![dep.node()],
219 Needs::Multiple(deps) => deps.iter().map(NeedsDependency::node).collect(),
220 };
221
222 for dep_name in dependencies {
223 if let Some((key, _)) = user_nodes.get_key_value(dep_name) {
226 adjacency.entry(key).or_default().push(node_name);
227 }
228 }
229 }
230
231 let mut visited: HashSet<&String> = HashSet::new();
233 let mut rec_stack: HashSet<&String> = HashSet::new();
234 let mut cycle_path: Vec<&String> = Vec::new();
235
236 for node_name in user_nodes.keys() {
237 if !visited.contains(node_name) {
238 if let Some((cycle_nodes, cycle_error)) =
239 dfs(node_name, &adjacency, &mut visited, &mut rec_stack, &mut cycle_path)
240 {
241 let has_bidirectional = cycle_nodes.iter().any(|node_name| {
243 user_nodes.get(*node_name).is_some_and(|node| is_bidirectional_kind(&node.kind))
244 });
245
246 if !has_bidirectional {
248 return Err(cycle_error);
249 }
250 }
251 }
252 }
253
254 Ok(())
255}
256
257fn compile_dag(
259 name: Option<String>,
260 description: Option<String>,
261 mode: EngineMode,
262 user_nodes: IndexMap<String, UserNode>,
263) -> Result<Pipeline, String> {
264 detect_cycles(&user_nodes)?;
266
267 let mut connections = Vec::new();
268
269 for (node_name, node_def) in &user_nodes {
270 let dependencies: Vec<&NeedsDependency> = match &node_def.needs {
271 Needs::None => vec![],
272 Needs::Single(dep) => vec![dep],
273 Needs::Multiple(deps) => deps.iter().collect(),
274 };
275
276 for (idx, dep) in dependencies.iter().enumerate() {
277 let dep_name = dep.node();
278
279 if !user_nodes.contains_key(dep_name) {
281 return Err(format!(
282 "Node '{node_name}' references non-existent node '{dep_name}' in 'needs' field"
283 ));
284 }
285
286 let to_pin =
288 if dependencies.len() > 1 { format!("in_{idx}") } else { "in".to_string() };
289
290 connections.push(Connection {
291 from_node: dep_name.to_string(),
292 from_pin: "out".to_string(),
293 to_node: node_name.clone(),
294 to_pin,
295 mode: dep.mode(),
296 });
297 }
298 }
299
300 let mut incoming_counts: IndexMap<String, usize> = IndexMap::new();
302 for conn in &connections {
303 *incoming_counts.entry(conn.to_node.clone()).or_insert(0) += 1;
304 }
305
306 let nodes = user_nodes
307 .into_iter()
308 .map(|(name, def)| {
309 let mut params = def.params;
310
311 if def.kind == "audio::mixer" && mode != EngineMode::Dynamic {
314 if let Some(count) = incoming_counts.get(&name) {
315 if *count > 1 {
316 if let Some(serde_json::Value::Object(ref mut map)) = params {
318 let should_inject = matches!(
319 map.get("num_inputs"),
320 Some(serde_json::Value::Null) | None
321 );
322 if should_inject {
323 map.insert(
324 "num_inputs".to_string(),
325 serde_json::Value::Number((*count).into()),
326 );
327 }
328 } else if params.is_none() {
329 let mut map = serde_json::Map::new();
331 map.insert(
332 "num_inputs".to_string(),
333 serde_json::Value::Number((*count).into()),
334 );
335 params = Some(serde_json::Value::Object(map));
336 }
337 }
338 }
339 }
340
341 (name, Node { kind: def.kind, params, state: None })
342 })
343 .collect();
344
345 Ok(Pipeline { name, description, mode, nodes, connections })
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 #[allow(clippy::unwrap_used)]
354 fn test_self_reference_needs_rejected() {
355 let yaml = r"
356mode: dynamic
357nodes:
358 peer:
359 kind: test_node
360 params: {}
361 needs: peer
362";
363
364 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
365 let result = compile(user_pipeline);
366
367 assert!(result.is_err());
369 let err = result.unwrap_err();
370 assert!(
371 err.contains("Circular dependency"),
372 "Error should mention circular dependency: {err}"
373 );
374 assert!(err.contains("peer"), "Error should mention the node name: {err}");
375 }
376
377 #[test]
378 #[allow(clippy::unwrap_used)]
379 fn test_circular_needs_rejected() {
380 let yaml = r"
381mode: dynamic
382nodes:
383 node_a:
384 kind: test_node
385 needs: node_b
386 node_b:
387 kind: test_node
388 needs: node_a
389";
390
391 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
392 let result = compile(user_pipeline);
393
394 assert!(result.is_err());
396 let err = result.unwrap_err();
397 assert!(
398 err.contains("Circular dependency"),
399 "Error should mention circular dependency: {err}"
400 );
401 }
402
403 #[test]
404 #[allow(clippy::unwrap_used)]
405 fn test_invalid_needs_reference() {
406 let yaml = r"
407mode: dynamic
408nodes:
409 node_a:
410 kind: test_node
411 needs: non_existent_node
412";
413
414 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
415 let result = compile(user_pipeline);
416
417 assert!(result.is_err());
419 let err = result.unwrap_err();
420 assert!(err.contains("node_a"));
421 assert!(err.contains("non_existent_node"));
422 assert!(err.contains("needs"));
423 }
424
425 #[test]
426 #[allow(clippy::unwrap_used)]
427 fn test_bidirectional_transport_not_flagged_as_cycle() {
428 let yaml = r"
433mode: dynamic
434nodes:
435 file_reader:
436 kind: core::file_reader
437 params:
438 path: /tmp/test.opus
439 ogg_demuxer:
440 kind: containers::ogg::demuxer
441 needs: file_reader
442 pacer:
443 kind: core::pacer
444 needs: ogg_demuxer
445 moq_publisher:
446 kind: transport::moq::publisher
447 params:
448 broadcast: input
449 needs: pacer
450 moq_peer:
451 kind: transport::moq::peer
452 params:
453 input_broadcast: input
454 output_broadcast: output
455 ogg_muxer:
456 kind: containers::ogg::muxer
457 needs: moq_peer
458 file_writer:
459 kind: core::file_writer
460 params:
461 path: /tmp/output.opus
462 needs: ogg_muxer
463";
464
465 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
466 let result = compile(user_pipeline);
467
468 assert!(
470 result.is_ok(),
471 "Bidirectional transport pattern should not be flagged as a cycle: {:?}",
472 result.err()
473 );
474 }
475
476 #[test]
477 #[allow(clippy::unwrap_used)]
478 fn test_bidirectional_cycle_allowed() {
479 let yaml = r"
484mode: dynamic
485nodes:
486 decoder:
487 kind: audio::opus::decoder
488 needs: moq_peer
489 encoder:
490 kind: audio::opus::encoder
491 needs: mixer
492 gain:
493 kind: audio::gain
494 needs: decoder
495 mixer:
496 kind: audio::mixer
497 needs: gain
498 moq_peer:
499 kind: transport::moq::peer
500 params:
501 input_broadcast: input
502 output_broadcast: output
503 needs: encoder
504";
505
506 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
507 let result = compile(user_pipeline);
508
509 assert!(
511 result.is_ok(),
512 "Cycle with bidirectional node should be allowed: {:?}",
513 result.err()
514 );
515 }
516
517 #[test]
518 fn test_sample_moq_mixing_compiles() {
519 let yaml = include_str!("../../../samples/pipelines/dynamic/moq_mixing.yml");
520 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
521 let result = compile(user_pipeline);
522
523 assert!(
524 result.is_ok(),
525 "Sample pipeline moq_mixing.yml should compile: {:?}",
526 result.err()
527 );
528 }
529
530 #[test]
531 #[allow(clippy::unwrap_used, clippy::expect_used)]
532 fn test_multiple_inputs_numbered_pins() {
533 let yaml = r"
534mode: dynamic
535nodes:
536 input_a:
537 kind: test_source
538 input_b:
539 kind: test_source
540 mixer:
541 kind: audio::mixer
542 needs:
543 - input_a
544 - input_b
545";
546
547 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
548 let pipeline = compile(user_pipeline).unwrap();
549
550 assert_eq!(pipeline.nodes.len(), 3);
552
553 assert_eq!(pipeline.connections.len(), 2);
555
556 let conn_a = pipeline
558 .connections
559 .iter()
560 .find(|c| c.from_node == "input_a")
561 .expect("Should have connection from input_a");
562 assert_eq!(conn_a.to_node, "mixer");
563 assert_eq!(conn_a.from_pin, "out");
564 assert_eq!(conn_a.to_pin, "in_0");
565
566 let conn_b = pipeline
568 .connections
569 .iter()
570 .find(|c| c.from_node == "input_b")
571 .expect("Should have connection from input_b");
572 assert_eq!(conn_b.to_node, "mixer");
573 assert_eq!(conn_b.from_pin, "out");
574 assert_eq!(conn_b.to_pin, "in_1");
575 }
576
577 #[test]
578 #[allow(clippy::unwrap_used)]
579 fn test_single_input_uses_in_pin() {
580 let yaml = r"
581mode: dynamic
582nodes:
583 source:
584 kind: test_source
585 sink:
586 kind: test_sink
587 needs: source
588";
589
590 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
591 let pipeline = compile(user_pipeline).unwrap();
592
593 assert_eq!(pipeline.nodes.len(), 2);
595
596 assert_eq!(pipeline.connections.len(), 1);
598
599 let conn = &pipeline.connections[0];
601 assert_eq!(conn.from_node, "source");
602 assert_eq!(conn.to_node, "sink");
603 assert_eq!(conn.from_pin, "out");
604 assert_eq!(conn.to_pin, "in");
605 }
606
607 #[test]
608 #[allow(clippy::unwrap_used, clippy::expect_used)]
609 fn test_mixer_auto_configures_num_inputs() {
610 let yaml = r"
611mode: oneshot
612nodes:
613 input_a:
614 kind: test_source
615 input_b:
616 kind: test_source
617 mixer:
618 kind: audio::mixer
619 params:
620 # num_inputs intentionally omitted
621 needs:
622 - input_a
623 - input_b
624";
625
626 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
627 let pipeline = compile(user_pipeline).unwrap();
628
629 let mixer_node = pipeline.nodes.get("mixer").expect("mixer node should exist");
631 assert_eq!(mixer_node.kind, "audio::mixer");
632
633 if let Some(serde_json::Value::Object(ref map)) = mixer_node.params {
635 let num_inputs_value = map.get("num_inputs").expect("num_inputs should be set");
636 if let serde_json::Value::Number(n) = num_inputs_value {
637 assert_eq!(n.as_u64(), Some(2));
638 } else {
639 panic!("num_inputs should be a number");
640 }
641 } else {
642 panic!("mixer params should be an object");
643 }
644 }
645
646 #[test]
647 #[allow(clippy::unwrap_used)]
648 fn test_steps_format_compilation() {
649 let yaml = r"
650mode: oneshot
651steps:
652 - kind: streamkit::http_input
653 - kind: audio::gain
654 params:
655 gain: 2.0
656 - kind: streamkit::http_output
657";
658
659 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
660 let pipeline = compile(user_pipeline).unwrap();
661
662 assert_eq!(pipeline.nodes.len(), 3);
664 assert!(pipeline.nodes.contains_key("step_0"));
665 assert!(pipeline.nodes.contains_key("step_1"));
666 assert!(pipeline.nodes.contains_key("step_2"));
667
668 assert_eq!(pipeline.connections.len(), 2);
670
671 let conn0 = &pipeline.connections[0];
673 assert_eq!(conn0.from_node, "step_0");
674 assert_eq!(conn0.to_node, "step_1");
675 assert_eq!(conn0.from_pin, "out");
676 assert_eq!(conn0.to_pin, "in");
677
678 let conn1 = &pipeline.connections[1];
680 assert_eq!(conn1.from_node, "step_1");
681 assert_eq!(conn1.to_node, "step_2");
682
683 let gain_node = pipeline.nodes.get("step_1").unwrap();
685 assert!(gain_node.params.is_some());
686 }
687
688 #[test]
689 #[allow(clippy::unwrap_used)]
690 fn test_mode_preservation() {
691 let yaml_oneshot = r"
693mode: oneshot
694steps:
695 - kind: streamkit::http_input
696 - kind: streamkit::http_output
697";
698 let pipeline: UserPipeline = serde_saphyr::from_str(yaml_oneshot).unwrap();
699 let compiled = compile(pipeline).unwrap();
700 assert_eq!(compiled.mode, EngineMode::OneShot);
701
702 let yaml_dynamic = r"
704mode: dynamic
705steps:
706 - kind: core::passthrough
707";
708 let pipeline: UserPipeline = serde_saphyr::from_str(yaml_dynamic).unwrap();
709 let compiled = compile(pipeline).unwrap();
710 assert_eq!(compiled.mode, EngineMode::Dynamic);
711 }
712
713 #[test]
714 #[allow(clippy::unwrap_used)]
715 fn test_default_mode_is_dynamic() {
716 let yaml = r"
717# mode not specified
718steps:
719 - kind: core::passthrough
720";
721 let pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
722 let compiled = compile(pipeline).unwrap();
723 assert_eq!(compiled.mode, EngineMode::Dynamic);
724 }
725
726 #[test]
727 #[allow(clippy::unwrap_used)]
728 fn test_name_and_description_preservation() {
729 let yaml = r"
730name: Test Pipeline
731description: A test pipeline for validation
732mode: dynamic
733steps:
734 - kind: core::passthrough
735";
736 let pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
737 let compiled = compile(pipeline).unwrap();
738
739 assert_eq!(compiled.name, Some("Test Pipeline".to_string()));
740 assert_eq!(compiled.description, Some("A test pipeline for validation".to_string()));
741 }
742
743 #[test]
744 #[allow(clippy::unwrap_used, clippy::expect_used)]
745 fn test_connection_mode_in_needs() {
746 let yaml = r"
747mode: dynamic
748nodes:
749 source:
750 kind: test_source
751 main_sink:
752 kind: test_sink
753 needs: source
754 metrics:
755 kind: test_metrics
756 needs:
757 node: source
758 mode: best_effort
759";
760
761 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
762 let pipeline = compile(user_pipeline).unwrap();
763
764 assert_eq!(pipeline.nodes.len(), 3);
766
767 assert_eq!(pipeline.connections.len(), 2);
769
770 let main_conn = pipeline
772 .connections
773 .iter()
774 .find(|c| c.to_node == "main_sink")
775 .expect("Should have connection to main_sink");
776 assert_eq!(main_conn.mode, ConnectionMode::Reliable);
777
778 let metrics_conn = pipeline
780 .connections
781 .iter()
782 .find(|c| c.to_node == "metrics")
783 .expect("Should have connection to metrics");
784 assert_eq!(metrics_conn.mode, ConnectionMode::BestEffort);
785 }
786
787 #[test]
788 #[allow(clippy::unwrap_used, clippy::expect_used)]
789 fn test_connection_mode_in_needs_list() {
790 let yaml = r"
791mode: dynamic
792nodes:
793 input_a:
794 kind: test_source
795 input_b:
796 kind: test_source
797 mixer:
798 kind: audio::mixer
799 needs:
800 - input_a
801 - node: input_b
802 mode: best_effort
803";
804
805 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
806 let pipeline = compile(user_pipeline).unwrap();
807
808 assert_eq!(pipeline.nodes.len(), 3);
810
811 assert_eq!(pipeline.connections.len(), 2);
813
814 let conn_a = pipeline
816 .connections
817 .iter()
818 .find(|c| c.from_node == "input_a")
819 .expect("Should have connection from input_a");
820 assert_eq!(conn_a.mode, ConnectionMode::Reliable);
821 assert_eq!(conn_a.to_pin, "in_0");
822
823 let conn_b = pipeline
825 .connections
826 .iter()
827 .find(|c| c.from_node == "input_b")
828 .expect("Should have connection from input_b");
829 assert_eq!(conn_b.mode, ConnectionMode::BestEffort);
830 assert_eq!(conn_b.to_pin, "in_1");
831 }
832}