1use super::{Connection, ConnectionMode, EngineMode, Node, Pipeline};
13use indexmap::IndexMap;
14use serde::Deserialize;
15
16#[derive(Debug, Deserialize)]
18pub struct Step {
19 pub kind: String,
20 pub params: Option<serde_json::Value>,
21}
22
23#[derive(Debug, Deserialize)]
25pub struct UserNode {
26 pub kind: String,
27 pub params: Option<serde_json::Value>,
28 #[serde(default)]
29 pub needs: Needs,
30}
31
32#[derive(Debug, Deserialize)]
34#[serde(untagged)]
35pub enum NeedsDependency {
36 Simple(String),
38 WithMode {
40 node: String,
41 #[serde(default)]
42 mode: ConnectionMode,
43 },
44}
45
46impl NeedsDependency {
47 fn node(&self) -> &str {
48 match self {
49 Self::Simple(s) => s,
50 Self::WithMode { node, .. } => node,
51 }
52 }
53
54 fn mode(&self) -> ConnectionMode {
55 match self {
56 Self::Simple(_) => ConnectionMode::default(),
57 Self::WithMode { mode, .. } => *mode,
58 }
59 }
60}
61
62#[derive(Debug, Deserialize, Default)]
64#[serde(untagged)]
65pub enum Needs {
66 #[default]
67 None,
68 Single(NeedsDependency),
69 Multiple(Vec<NeedsDependency>),
70}
71
72#[derive(Debug, Deserialize)]
76#[serde(untagged)]
77pub enum UserPipeline {
78 Steps {
79 #[serde(skip_serializing_if = "Option::is_none")]
80 name: Option<String>,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 description: Option<String>,
83 #[serde(default)]
84 mode: EngineMode,
85 steps: Vec<Step>,
86 },
87 Dag {
88 #[serde(skip_serializing_if = "Option::is_none")]
89 name: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 description: Option<String>,
92 #[serde(default)]
93 mode: EngineMode,
94 nodes: IndexMap<String, UserNode>,
95 },
96}
97
98pub fn compile(pipeline: UserPipeline) -> Result<Pipeline, String> {
104 match pipeline {
105 UserPipeline::Steps { name, description, mode, steps } => {
106 Ok(compile_steps(name, description, mode, steps))
107 },
108 UserPipeline::Dag { name, description, mode, nodes } => {
109 compile_dag(name, description, mode, nodes)
110 },
111 }
112}
113
114fn compile_steps(
116 name: Option<String>,
117 description: Option<String>,
118 mode: EngineMode,
119 steps: Vec<Step>,
120) -> Pipeline {
121 let mut nodes = IndexMap::new();
122 let mut connections = Vec::new();
123
124 for (i, step) in steps.into_iter().enumerate() {
125 let node_name = format!("step_{i}");
126
127 if i > 0 {
129 connections.push(Connection {
130 from_node: format!("step_{}", i - 1),
131 from_pin: "out".to_string(),
132 to_node: node_name.clone(),
133 to_pin: "in".to_string(),
134 mode: ConnectionMode::default(),
135 });
136 }
137
138 nodes.insert(node_name, Node { kind: step.kind, params: step.params, state: None });
139 }
140
141 Pipeline { name, description, mode, nodes, connections }
142}
143
144const BIDIRECTIONAL_NODE_KINDS: &[&str] = &["transport::moq::peer"];
148
149fn is_bidirectional_kind(kind: &str) -> bool {
151 BIDIRECTIONAL_NODE_KINDS.contains(&kind)
152}
153
154fn detect_cycles(user_nodes: &IndexMap<String, UserNode>) -> Result<(), String> {
160 use std::collections::HashSet;
161
162 fn dfs<'a>(
165 node: &'a String,
166 adjacency: &IndexMap<&'a String, Vec<&'a String>>,
167 visited: &mut HashSet<&'a String>,
168 rec_stack: &mut HashSet<&'a String>,
169 cycle_path: &mut Vec<&'a String>,
170 ) -> Option<(Vec<&'a String>, String)> {
171 visited.insert(node);
172 rec_stack.insert(node);
173 cycle_path.push(node);
174
175 if let Some(neighbors) = adjacency.get(node) {
176 for neighbor in neighbors {
177 if !visited.contains(neighbor) {
178 if let Some(cycle) = dfs(neighbor, adjacency, visited, rec_stack, cycle_path) {
179 return Some(cycle);
180 }
181 } else if rec_stack.contains(neighbor) {
182 let cycle_start_idx =
184 cycle_path.iter().position(|&n| n == *neighbor).unwrap_or(0);
185 let cycle_nodes: Vec<&'a String> = cycle_path[cycle_start_idx..].to_vec();
186 let cycle_strs: Vec<&str> = cycle_nodes.iter().map(|s| s.as_str()).collect();
187 let description = format!(
188 "Circular dependency detected: {} -> {}",
189 cycle_strs.join(" -> "),
190 neighbor
191 );
192 return Some((cycle_nodes, description));
193 }
194 }
195 }
196
197 rec_stack.remove(node);
198 cycle_path.pop();
199 None
200 }
201
202 let mut adjacency: IndexMap<&String, Vec<&String>> = IndexMap::new();
206
207 for (node_name, node_def) in user_nodes {
208 adjacency.entry(node_name).or_default();
209
210 let dependencies: Vec<&str> = match &node_def.needs {
211 Needs::None => vec![],
212 Needs::Single(dep) => vec![dep.node()],
213 Needs::Multiple(deps) => deps.iter().map(NeedsDependency::node).collect(),
214 };
215
216 for dep_name in dependencies {
217 if let Some((key, _)) = user_nodes.get_key_value(dep_name) {
220 adjacency.entry(key).or_default().push(node_name);
221 }
222 }
223 }
224
225 let mut visited: HashSet<&String> = HashSet::new();
227 let mut rec_stack: HashSet<&String> = HashSet::new();
228 let mut cycle_path: Vec<&String> = Vec::new();
229
230 for node_name in user_nodes.keys() {
231 if !visited.contains(node_name) {
232 if let Some((cycle_nodes, cycle_error)) =
233 dfs(node_name, &adjacency, &mut visited, &mut rec_stack, &mut cycle_path)
234 {
235 let has_bidirectional = cycle_nodes.iter().any(|node_name| {
237 user_nodes.get(*node_name).is_some_and(|node| is_bidirectional_kind(&node.kind))
238 });
239
240 if !has_bidirectional {
242 return Err(cycle_error);
243 }
244 }
245 }
246 }
247
248 Ok(())
249}
250
251fn compile_dag(
253 name: Option<String>,
254 description: Option<String>,
255 mode: EngineMode,
256 user_nodes: IndexMap<String, UserNode>,
257) -> Result<Pipeline, String> {
258 detect_cycles(&user_nodes)?;
260
261 let mut connections = Vec::new();
262
263 for (node_name, node_def) in &user_nodes {
264 let dependencies: Vec<&NeedsDependency> = match &node_def.needs {
265 Needs::None => vec![],
266 Needs::Single(dep) => vec![dep],
267 Needs::Multiple(deps) => deps.iter().collect(),
268 };
269
270 for (idx, dep) in dependencies.iter().enumerate() {
271 let dep_name = dep.node();
272
273 if !user_nodes.contains_key(dep_name) {
275 return Err(format!(
276 "Node '{node_name}' references non-existent node '{dep_name}' in 'needs' field"
277 ));
278 }
279
280 let to_pin =
282 if dependencies.len() > 1 { format!("in_{idx}") } else { "in".to_string() };
283
284 connections.push(Connection {
285 from_node: dep_name.to_string(),
286 from_pin: "out".to_string(),
287 to_node: node_name.clone(),
288 to_pin,
289 mode: dep.mode(),
290 });
291 }
292 }
293
294 let mut incoming_counts: IndexMap<String, usize> = IndexMap::new();
296 for conn in &connections {
297 *incoming_counts.entry(conn.to_node.clone()).or_insert(0) += 1;
298 }
299
300 let nodes = user_nodes
301 .into_iter()
302 .map(|(name, def)| {
303 let mut params = def.params;
304
305 if def.kind == "audio::mixer" && mode != EngineMode::Dynamic {
308 if let Some(count) = incoming_counts.get(&name) {
309 if *count > 1 {
310 if let Some(serde_json::Value::Object(ref mut map)) = params {
312 let should_inject = matches!(
313 map.get("num_inputs"),
314 Some(serde_json::Value::Null) | None
315 );
316 if should_inject {
317 map.insert(
318 "num_inputs".to_string(),
319 serde_json::Value::Number((*count).into()),
320 );
321 }
322 } else if params.is_none() {
323 let mut map = serde_json::Map::new();
325 map.insert(
326 "num_inputs".to_string(),
327 serde_json::Value::Number((*count).into()),
328 );
329 params = Some(serde_json::Value::Object(map));
330 }
331 }
332 }
333 }
334
335 (name, Node { kind: def.kind, params, state: None })
336 })
337 .collect();
338
339 Ok(Pipeline { name, description, mode, nodes, connections })
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 #[allow(clippy::unwrap_used)]
348 fn test_self_reference_needs_rejected() {
349 let yaml = r"
350mode: dynamic
351nodes:
352 peer:
353 kind: test_node
354 params: {}
355 needs: peer
356";
357
358 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
359 let result = compile(user_pipeline);
360
361 assert!(result.is_err());
363 let err = result.unwrap_err();
364 assert!(
365 err.contains("Circular dependency"),
366 "Error should mention circular dependency: {err}"
367 );
368 assert!(err.contains("peer"), "Error should mention the node name: {err}");
369 }
370
371 #[test]
372 #[allow(clippy::unwrap_used)]
373 fn test_circular_needs_rejected() {
374 let yaml = r"
375mode: dynamic
376nodes:
377 node_a:
378 kind: test_node
379 needs: node_b
380 node_b:
381 kind: test_node
382 needs: node_a
383";
384
385 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
386 let result = compile(user_pipeline);
387
388 assert!(result.is_err());
390 let err = result.unwrap_err();
391 assert!(
392 err.contains("Circular dependency"),
393 "Error should mention circular dependency: {err}"
394 );
395 }
396
397 #[test]
398 #[allow(clippy::unwrap_used)]
399 fn test_invalid_needs_reference() {
400 let yaml = r"
401mode: dynamic
402nodes:
403 node_a:
404 kind: test_node
405 needs: non_existent_node
406";
407
408 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
409 let result = compile(user_pipeline);
410
411 assert!(result.is_err());
413 let err = result.unwrap_err();
414 assert!(err.contains("node_a"));
415 assert!(err.contains("non_existent_node"));
416 assert!(err.contains("needs"));
417 }
418
419 #[test]
420 #[allow(clippy::unwrap_used)]
421 fn test_bidirectional_transport_not_flagged_as_cycle() {
422 let yaml = r"
427mode: dynamic
428nodes:
429 file_reader:
430 kind: core::file_reader
431 params:
432 path: /tmp/test.opus
433 ogg_demuxer:
434 kind: containers::ogg::demuxer
435 needs: file_reader
436 pacer:
437 kind: core::pacer
438 needs: ogg_demuxer
439 moq_publisher:
440 kind: transport::moq::publisher
441 params:
442 broadcast: input
443 needs: pacer
444 moq_peer:
445 kind: transport::moq::peer
446 params:
447 input_broadcast: input
448 output_broadcast: output
449 ogg_muxer:
450 kind: containers::ogg::muxer
451 needs: moq_peer
452 file_writer:
453 kind: core::file_writer
454 params:
455 path: /tmp/output.opus
456 needs: ogg_muxer
457";
458
459 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
460 let result = compile(user_pipeline);
461
462 assert!(
464 result.is_ok(),
465 "Bidirectional transport pattern should not be flagged as a cycle: {:?}",
466 result.err()
467 );
468 }
469
470 #[test]
471 #[allow(clippy::unwrap_used)]
472 fn test_bidirectional_cycle_allowed() {
473 let yaml = r"
478mode: dynamic
479nodes:
480 decoder:
481 kind: audio::opus::decoder
482 needs: moq_peer
483 encoder:
484 kind: audio::opus::encoder
485 needs: mixer
486 gain:
487 kind: audio::gain
488 needs: decoder
489 mixer:
490 kind: audio::mixer
491 needs: gain
492 moq_peer:
493 kind: transport::moq::peer
494 params:
495 input_broadcast: input
496 output_broadcast: output
497 needs: encoder
498";
499
500 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
501 let result = compile(user_pipeline);
502
503 assert!(
505 result.is_ok(),
506 "Cycle with bidirectional node should be allowed: {:?}",
507 result.err()
508 );
509 }
510
511 #[test]
512 #[allow(clippy::unwrap_used, clippy::expect_used)]
513 fn test_multiple_inputs_numbered_pins() {
514 let yaml = r"
515mode: dynamic
516nodes:
517 input_a:
518 kind: test_source
519 input_b:
520 kind: test_source
521 mixer:
522 kind: audio::mixer
523 needs:
524 - input_a
525 - input_b
526";
527
528 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
529 let pipeline = compile(user_pipeline).unwrap();
530
531 assert_eq!(pipeline.nodes.len(), 3);
533
534 assert_eq!(pipeline.connections.len(), 2);
536
537 let conn_a = pipeline
539 .connections
540 .iter()
541 .find(|c| c.from_node == "input_a")
542 .expect("Should have connection from input_a");
543 assert_eq!(conn_a.to_node, "mixer");
544 assert_eq!(conn_a.from_pin, "out");
545 assert_eq!(conn_a.to_pin, "in_0");
546
547 let conn_b = pipeline
549 .connections
550 .iter()
551 .find(|c| c.from_node == "input_b")
552 .expect("Should have connection from input_b");
553 assert_eq!(conn_b.to_node, "mixer");
554 assert_eq!(conn_b.from_pin, "out");
555 assert_eq!(conn_b.to_pin, "in_1");
556 }
557
558 #[test]
559 #[allow(clippy::unwrap_used)]
560 fn test_single_input_uses_in_pin() {
561 let yaml = r"
562mode: dynamic
563nodes:
564 source:
565 kind: test_source
566 sink:
567 kind: test_sink
568 needs: source
569";
570
571 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
572 let pipeline = compile(user_pipeline).unwrap();
573
574 assert_eq!(pipeline.nodes.len(), 2);
576
577 assert_eq!(pipeline.connections.len(), 1);
579
580 let conn = &pipeline.connections[0];
582 assert_eq!(conn.from_node, "source");
583 assert_eq!(conn.to_node, "sink");
584 assert_eq!(conn.from_pin, "out");
585 assert_eq!(conn.to_pin, "in");
586 }
587
588 #[test]
589 #[allow(clippy::unwrap_used, clippy::expect_used)]
590 fn test_mixer_auto_configures_num_inputs() {
591 let yaml = r"
592mode: oneshot
593nodes:
594 input_a:
595 kind: test_source
596 input_b:
597 kind: test_source
598 mixer:
599 kind: audio::mixer
600 params:
601 # num_inputs intentionally omitted
602 needs:
603 - input_a
604 - input_b
605";
606
607 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
608 let pipeline = compile(user_pipeline).unwrap();
609
610 let mixer_node = pipeline.nodes.get("mixer").expect("mixer node should exist");
612 assert_eq!(mixer_node.kind, "audio::mixer");
613
614 if let Some(serde_json::Value::Object(ref map)) = mixer_node.params {
616 let num_inputs_value = map.get("num_inputs").expect("num_inputs should be set");
617 if let serde_json::Value::Number(n) = num_inputs_value {
618 assert_eq!(n.as_u64(), Some(2));
619 } else {
620 panic!("num_inputs should be a number");
621 }
622 } else {
623 panic!("mixer params should be an object");
624 }
625 }
626
627 #[test]
628 #[allow(clippy::unwrap_used)]
629 fn test_steps_format_compilation() {
630 let yaml = r"
631mode: oneshot
632steps:
633 - kind: streamkit::http_input
634 - kind: audio::gain
635 params:
636 gain: 2.0
637 - kind: streamkit::http_output
638";
639
640 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
641 let pipeline = compile(user_pipeline).unwrap();
642
643 assert_eq!(pipeline.nodes.len(), 3);
645 assert!(pipeline.nodes.contains_key("step_0"));
646 assert!(pipeline.nodes.contains_key("step_1"));
647 assert!(pipeline.nodes.contains_key("step_2"));
648
649 assert_eq!(pipeline.connections.len(), 2);
651
652 let conn0 = &pipeline.connections[0];
654 assert_eq!(conn0.from_node, "step_0");
655 assert_eq!(conn0.to_node, "step_1");
656 assert_eq!(conn0.from_pin, "out");
657 assert_eq!(conn0.to_pin, "in");
658
659 let conn1 = &pipeline.connections[1];
661 assert_eq!(conn1.from_node, "step_1");
662 assert_eq!(conn1.to_node, "step_2");
663
664 let gain_node = pipeline.nodes.get("step_1").unwrap();
666 assert!(gain_node.params.is_some());
667 }
668
669 #[test]
670 #[allow(clippy::unwrap_used)]
671 fn test_mode_preservation() {
672 let yaml_oneshot = r"
674mode: oneshot
675steps:
676 - kind: streamkit::http_input
677 - kind: streamkit::http_output
678";
679 let pipeline: UserPipeline = serde_saphyr::from_str(yaml_oneshot).unwrap();
680 let compiled = compile(pipeline).unwrap();
681 assert_eq!(compiled.mode, EngineMode::OneShot);
682
683 let yaml_dynamic = r"
685mode: dynamic
686steps:
687 - kind: core::passthrough
688";
689 let pipeline: UserPipeline = serde_saphyr::from_str(yaml_dynamic).unwrap();
690 let compiled = compile(pipeline).unwrap();
691 assert_eq!(compiled.mode, EngineMode::Dynamic);
692 }
693
694 #[test]
695 #[allow(clippy::unwrap_used)]
696 fn test_default_mode_is_dynamic() {
697 let yaml = r"
698# mode not specified
699steps:
700 - kind: core::passthrough
701";
702 let pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
703 let compiled = compile(pipeline).unwrap();
704 assert_eq!(compiled.mode, EngineMode::Dynamic);
705 }
706
707 #[test]
708 #[allow(clippy::unwrap_used)]
709 fn test_name_and_description_preservation() {
710 let yaml = r"
711name: Test Pipeline
712description: A test pipeline for validation
713mode: dynamic
714steps:
715 - kind: core::passthrough
716";
717 let pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
718 let compiled = compile(pipeline).unwrap();
719
720 assert_eq!(compiled.name, Some("Test Pipeline".to_string()));
721 assert_eq!(compiled.description, Some("A test pipeline for validation".to_string()));
722 }
723
724 #[test]
725 #[allow(clippy::unwrap_used, clippy::expect_used)]
726 fn test_connection_mode_in_needs() {
727 let yaml = r"
728mode: dynamic
729nodes:
730 source:
731 kind: test_source
732 main_sink:
733 kind: test_sink
734 needs: source
735 metrics:
736 kind: test_metrics
737 needs:
738 node: source
739 mode: best_effort
740";
741
742 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
743 let pipeline = compile(user_pipeline).unwrap();
744
745 assert_eq!(pipeline.nodes.len(), 3);
747
748 assert_eq!(pipeline.connections.len(), 2);
750
751 let main_conn = pipeline
753 .connections
754 .iter()
755 .find(|c| c.to_node == "main_sink")
756 .expect("Should have connection to main_sink");
757 assert_eq!(main_conn.mode, ConnectionMode::Reliable);
758
759 let metrics_conn = pipeline
761 .connections
762 .iter()
763 .find(|c| c.to_node == "metrics")
764 .expect("Should have connection to metrics");
765 assert_eq!(metrics_conn.mode, ConnectionMode::BestEffort);
766 }
767
768 #[test]
769 #[allow(clippy::unwrap_used, clippy::expect_used)]
770 fn test_connection_mode_in_needs_list() {
771 let yaml = r"
772mode: dynamic
773nodes:
774 input_a:
775 kind: test_source
776 input_b:
777 kind: test_source
778 mixer:
779 kind: audio::mixer
780 needs:
781 - input_a
782 - node: input_b
783 mode: best_effort
784";
785
786 let user_pipeline: UserPipeline = serde_saphyr::from_str(yaml).unwrap();
787 let pipeline = compile(user_pipeline).unwrap();
788
789 assert_eq!(pipeline.nodes.len(), 3);
791
792 assert_eq!(pipeline.connections.len(), 2);
794
795 let conn_a = pipeline
797 .connections
798 .iter()
799 .find(|c| c.from_node == "input_a")
800 .expect("Should have connection from input_a");
801 assert_eq!(conn_a.mode, ConnectionMode::Reliable);
802 assert_eq!(conn_a.to_pin, "in_0");
803
804 let conn_b = pipeline
806 .connections
807 .iter()
808 .find(|c| c.from_node == "input_b")
809 .expect("Should have connection from input_b");
810 assert_eq!(conn_b.mode, ConnectionMode::BestEffort);
811 assert_eq!(conn_b.to_pin, "in_1");
812 }
813}