1use std::fmt;
5use std::path::PathBuf;
6use std::str::FromStr;
7
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10pub use zeph_config::FailureStrategy;
11use zeph_memory::store::graph_store::{GraphSummary, RawGraphStore};
12
13use super::error::OrchestrationError;
14use super::verify_predicate::{PredicateOutcome, VerifyPredicate};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
32pub struct TaskId(pub u32);
33
34impl TaskId {
35 #[must_use]
37 pub fn index(self) -> usize {
38 self.0 as usize
39 }
40
41 #[must_use]
43 pub fn as_u32(self) -> u32 {
44 self.0
45 }
46}
47
48impl fmt::Display for TaskId {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 write!(f, "{}", self.0)
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
72pub struct GraphId(Uuid);
73
74impl GraphId {
75 #[must_use]
77 pub fn new() -> Self {
78 Self(Uuid::new_v4())
79 }
80}
81
82impl Default for GraphId {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl fmt::Display for GraphId {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 write!(f, "{}", self.0)
91 }
92}
93
94impl FromStr for GraphId {
95 type Err = OrchestrationError;
96
97 fn from_str(s: &str) -> Result<Self, Self::Err> {
98 Uuid::parse_str(s)
99 .map(GraphId)
100 .map_err(|e| OrchestrationError::InvalidGraph(format!("invalid graph id '{s}': {e}")))
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
128#[serde(rename_all = "snake_case")]
129pub enum TaskStatus {
130 Pending,
132 Ready,
134 Running,
136 Completed,
138 Failed,
140 Skipped,
142 Canceled,
144}
145
146impl TaskStatus {
147 #[must_use]
149 pub fn is_terminal(self) -> bool {
150 matches!(
151 self,
152 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Skipped | TaskStatus::Canceled
153 )
154 }
155}
156
157impl fmt::Display for TaskStatus {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 match self {
160 TaskStatus::Pending => write!(f, "pending"),
161 TaskStatus::Ready => write!(f, "ready"),
162 TaskStatus::Running => write!(f, "running"),
163 TaskStatus::Completed => write!(f, "completed"),
164 TaskStatus::Failed => write!(f, "failed"),
165 TaskStatus::Skipped => write!(f, "skipped"),
166 TaskStatus::Canceled => write!(f, "canceled"),
167 }
168 }
169}
170
171#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
182#[serde(rename_all = "snake_case")]
183pub enum GraphStatus {
184 Created,
186 Running,
188 Completed,
190 Failed,
192 Canceled,
194 Paused,
196}
197
198impl fmt::Display for GraphStatus {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 match self {
201 GraphStatus::Created => write!(f, "created"),
202 GraphStatus::Running => write!(f, "running"),
203 GraphStatus::Completed => write!(f, "completed"),
204 GraphStatus::Failed => write!(f, "failed"),
205 GraphStatus::Canceled => write!(f, "canceled"),
206 GraphStatus::Paused => write!(f, "paused"),
207 }
208 }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct TaskResult {
219 pub output: String,
221 pub artifacts: Vec<PathBuf>,
223 pub duration_ms: u64,
225 pub agent_id: Option<String>,
227 pub agent_def: Option<String>,
229}
230
231#[derive(
249 Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, schemars::JsonSchema,
250)]
251#[serde(rename_all = "snake_case")]
252pub enum ExecutionMode {
253 #[default]
255 Parallel,
256 Sequential,
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct TaskNode {
280 pub id: TaskId,
282 pub title: String,
284 pub description: String,
286 pub agent_hint: Option<String>,
288 pub status: TaskStatus,
290 pub depends_on: Vec<TaskId>,
292 pub result: Option<TaskResult>,
294 pub assigned_agent: Option<String>,
296 pub retry_count: u32,
298 #[serde(default)]
300 pub predicate_rerun_count: u32,
301 pub failure_strategy: Option<FailureStrategy>,
303 pub max_retries: Option<u32>,
305 #[serde(default)]
308 pub execution_mode: ExecutionMode,
309 #[serde(default)]
316 pub verify_predicate: Option<VerifyPredicate>,
317 #[serde(default)]
323 pub predicate_outcome: Option<PredicateOutcome>,
324 #[serde(default, skip_serializing_if = "Option::is_none")]
327 pub execution_environment: Option<String>,
328
329 #[serde(default, skip_serializing_if = "Option::is_none")]
334 pub token_budget_cents: Option<f64>,
335}
336
337impl TaskNode {
338 #[must_use]
340 pub fn new(id: u32, title: impl Into<String>, description: impl Into<String>) -> Self {
341 Self {
342 id: TaskId(id),
343 title: title.into(),
344 description: description.into(),
345 agent_hint: None,
346 status: TaskStatus::Pending,
347 depends_on: Vec::new(),
348 result: None,
349 assigned_agent: None,
350 retry_count: 0,
351 predicate_rerun_count: 0,
352 failure_strategy: None,
353 max_retries: None,
354 execution_mode: ExecutionMode::default(),
355 verify_predicate: None,
356 predicate_outcome: None,
357 execution_environment: None,
358 token_budget_cents: None,
359 }
360 }
361}
362
363#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct TaskGraph {
384 pub id: GraphId,
386 pub goal: String,
388 pub tasks: Vec<TaskNode>,
390 pub status: GraphStatus,
392 pub default_failure_strategy: FailureStrategy,
394 pub default_max_retries: u32,
396 pub created_at: String,
398 pub finished_at: Option<String>,
400}
401
402impl TaskGraph {
403 #[must_use]
405 pub fn new(goal: impl Into<String>) -> Self {
406 Self {
407 id: GraphId::new(),
408 goal: goal.into(),
409 tasks: Vec::new(),
410 status: GraphStatus::Created,
411 default_failure_strategy: FailureStrategy::default(),
412 default_max_retries: 3,
413 created_at: chrono_now(),
414 finished_at: None,
415 }
416 }
417}
418
419pub(crate) fn chrono_now() -> String {
420 let secs = std::time::SystemTime::now()
423 .duration_since(std::time::UNIX_EPOCH)
424 .map_or(0, |d| d.as_secs());
425 let (y, mo, d, h, mi, s) = epoch_secs_to_datetime(secs);
428 format!("{y:04}-{mo:02}-{d:02}T{h:02}:{mi:02}:{s:02}Z")
429}
430
431fn epoch_secs_to_datetime(secs: u64) -> (u64, u8, u8, u8, u8, u8) {
433 let s = (secs % 60) as u8;
434 let mins = secs / 60;
435 let mi = (mins % 60) as u8;
436 let hours = mins / 60;
437 let h = (hours % 24) as u8;
438 let days = hours / 24; let (mut year, mut remaining_days) = {
443 let cycles = days / 146_097;
444 let rem = days % 146_097;
445 (1970 + cycles * 400, rem)
446 };
447 let centuries = (remaining_days / 36_524).min(3);
449 year += centuries * 100;
450 remaining_days -= centuries * 36_524;
451 let quads = remaining_days / 1_461;
453 year += quads * 4;
454 remaining_days -= quads * 1_461;
455 let extra_years = (remaining_days / 365).min(3);
457 year += extra_years;
458 remaining_days -= extra_years * 365;
459
460 let is_leap = (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0);
461 let days_in_month: [u64; 12] = if is_leap {
462 [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
463 } else {
464 [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
465 };
466
467 let mut month = 0u8;
468 for (i, &dim) in days_in_month.iter().enumerate() {
469 if remaining_days < dim {
470 month = u8::try_from(i + 1).unwrap_or(1);
472 break;
473 }
474 remaining_days -= dim;
475 }
476 let day = u8::try_from(remaining_days + 1).unwrap_or(1);
478
479 (year, month, day, h, mi, s)
480}
481
482const MAX_GOAL_LEN: usize = 1024;
484
485pub struct GraphPersistence<S: RawGraphStore> {
498 store: S,
499}
500
501impl<S: RawGraphStore> GraphPersistence<S> {
502 pub fn new(store: S) -> Self {
504 Self { store }
505 }
506
507 pub async fn save(&self, graph: &TaskGraph) -> Result<(), OrchestrationError> {
516 if graph.goal.len() > MAX_GOAL_LEN {
517 return Err(OrchestrationError::InvalidGraph(format!(
518 "goal exceeds {MAX_GOAL_LEN} character limit ({} chars)",
519 graph.goal.len()
520 )));
521 }
522 let json = serde_json::to_string(graph)
523 .map_err(|e| OrchestrationError::Persistence(e.to_string()))?;
524 self.store
525 .save_graph(
526 &graph.id.to_string(),
527 &graph.goal,
528 &graph.status.to_string(),
529 &json,
530 &graph.created_at,
531 graph.finished_at.as_deref(),
532 )
533 .await
534 .map_err(|e| OrchestrationError::Persistence(e.to_string()))
535 }
536
537 pub async fn load(&self, id: &GraphId) -> Result<Option<TaskGraph>, OrchestrationError> {
545 match self
546 .store
547 .load_graph(&id.to_string())
548 .await
549 .map_err(|e| OrchestrationError::Persistence(e.to_string()))?
550 {
551 Some(json) => {
552 let graph = serde_json::from_str(&json)
553 .map_err(|e| OrchestrationError::Persistence(e.to_string()))?;
554 Ok(Some(graph))
555 }
556 None => Ok(None),
557 }
558 }
559
560 pub async fn list(&self, limit: u32) -> Result<Vec<GraphSummary>, OrchestrationError> {
566 self.store
567 .list_graphs(limit)
568 .await
569 .map_err(|e| OrchestrationError::Persistence(e.to_string()))
570 }
571
572 pub async fn delete(&self, id: &GraphId) -> Result<bool, OrchestrationError> {
580 self.store
581 .delete_graph(&id.to_string())
582 .await
583 .map_err(|e| OrchestrationError::Persistence(e.to_string()))
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_taskid_display() {
593 assert_eq!(TaskId(3).to_string(), "3");
594 }
595
596 #[test]
597 fn test_graphid_display_and_new() {
598 let id = GraphId::new();
599 let s = id.to_string();
600 assert_eq!(s.len(), 36, "UUID string should be 36 chars");
601 let parsed: GraphId = s.parse().expect("should parse back");
602 assert_eq!(id, parsed);
603 }
604
605 #[test]
606 fn test_graphid_from_str_invalid() {
607 let err = "not-a-uuid".parse::<GraphId>();
608 assert!(err.is_err());
609 }
610
611 #[test]
612 fn test_task_status_is_terminal() {
613 assert!(TaskStatus::Completed.is_terminal());
614 assert!(TaskStatus::Failed.is_terminal());
615 assert!(TaskStatus::Skipped.is_terminal());
616 assert!(TaskStatus::Canceled.is_terminal());
617
618 assert!(!TaskStatus::Pending.is_terminal());
619 assert!(!TaskStatus::Ready.is_terminal());
620 assert!(!TaskStatus::Running.is_terminal());
621 }
622
623 #[test]
624 fn test_task_status_display() {
625 assert_eq!(TaskStatus::Pending.to_string(), "pending");
626 assert_eq!(TaskStatus::Ready.to_string(), "ready");
627 assert_eq!(TaskStatus::Running.to_string(), "running");
628 assert_eq!(TaskStatus::Completed.to_string(), "completed");
629 assert_eq!(TaskStatus::Failed.to_string(), "failed");
630 assert_eq!(TaskStatus::Skipped.to_string(), "skipped");
631 assert_eq!(TaskStatus::Canceled.to_string(), "canceled");
632 }
633
634 #[test]
635 fn test_failure_strategy_default() {
636 assert_eq!(FailureStrategy::default(), FailureStrategy::Abort);
637 }
638
639 #[test]
640 fn test_failure_strategy_display() {
641 assert_eq!(FailureStrategy::Abort.to_string(), "abort");
642 assert_eq!(FailureStrategy::Retry.to_string(), "retry");
643 assert_eq!(FailureStrategy::Skip.to_string(), "skip");
644 assert_eq!(FailureStrategy::Ask.to_string(), "ask");
645 }
646
647 #[test]
648 fn test_graph_status_display() {
649 assert_eq!(GraphStatus::Created.to_string(), "created");
650 assert_eq!(GraphStatus::Running.to_string(), "running");
651 assert_eq!(GraphStatus::Completed.to_string(), "completed");
652 assert_eq!(GraphStatus::Failed.to_string(), "failed");
653 assert_eq!(GraphStatus::Canceled.to_string(), "canceled");
654 assert_eq!(GraphStatus::Paused.to_string(), "paused");
655 }
656
657 #[test]
658 fn test_task_graph_serde_roundtrip() {
659 let mut graph = TaskGraph::new("test goal");
660 graph.tasks.push(TaskNode::new(0, "task 0", "do something"));
661 let json = serde_json::to_string(&graph).expect("serialize");
662 let restored: TaskGraph = serde_json::from_str(&json).expect("deserialize");
663 assert_eq!(graph.id, restored.id);
664 assert_eq!(graph.goal, restored.goal);
665 assert_eq!(graph.tasks.len(), restored.tasks.len());
666 }
667
668 #[test]
669 fn test_task_node_serde_roundtrip() {
670 let mut node = TaskNode::new(1, "compile", "run cargo build");
671 node.agent_hint = Some("rust-dev".to_string());
672 node.depends_on = vec![TaskId(0)];
673 let json = serde_json::to_string(&node).expect("serialize");
674 let restored: TaskNode = serde_json::from_str(&json).expect("deserialize");
675 assert_eq!(node.id, restored.id);
676 assert_eq!(node.title, restored.title);
677 assert_eq!(node.depends_on, restored.depends_on);
678 }
679
680 #[test]
681 fn test_task_result_serde_roundtrip() {
682 let result = TaskResult {
683 output: "ok".to_string(),
684 artifacts: vec![PathBuf::from("/tmp/out.bin")],
685 duration_ms: 500,
686 agent_id: Some("agent-1".to_string()),
687 agent_def: None,
688 };
689 let json = serde_json::to_string(&result).expect("serialize");
690 let restored: TaskResult = serde_json::from_str(&json).expect("deserialize");
691 assert_eq!(result.output, restored.output);
692 assert_eq!(result.duration_ms, restored.duration_ms);
693 assert_eq!(result.artifacts, restored.artifacts);
694 }
695
696 #[test]
697 fn test_failure_strategy_from_str() {
698 assert_eq!(
699 "abort".parse::<FailureStrategy>().unwrap(),
700 FailureStrategy::Abort
701 );
702 assert_eq!(
703 "retry".parse::<FailureStrategy>().unwrap(),
704 FailureStrategy::Retry
705 );
706 assert_eq!(
707 "skip".parse::<FailureStrategy>().unwrap(),
708 FailureStrategy::Skip
709 );
710 assert_eq!(
711 "ask".parse::<FailureStrategy>().unwrap(),
712 FailureStrategy::Ask
713 );
714 assert!("abort_all".parse::<FailureStrategy>().is_err());
715 assert!("".parse::<FailureStrategy>().is_err());
716 }
717
718 #[test]
719 fn test_chrono_now_iso8601_format() {
720 let ts = chrono_now();
721 assert_eq!(ts.len(), 20, "timestamp should be 20 chars: {ts}");
723 assert!(ts.ends_with('Z'), "should end with Z: {ts}");
724 assert!(ts.contains('T'), "should contain T: {ts}");
725 let year: u32 = ts[..4].parse().expect("year should be numeric");
727 assert!(year >= 2024, "year should be >= 2024: {year}");
728 }
729
730 #[test]
731 fn test_failure_strategy_serde_snake_case() {
732 assert_eq!(
733 serde_json::to_string(&FailureStrategy::Abort).unwrap(),
734 "\"abort\""
735 );
736 assert_eq!(
737 serde_json::to_string(&FailureStrategy::Retry).unwrap(),
738 "\"retry\""
739 );
740 assert_eq!(
741 serde_json::to_string(&FailureStrategy::Skip).unwrap(),
742 "\"skip\""
743 );
744 assert_eq!(
745 serde_json::to_string(&FailureStrategy::Ask).unwrap(),
746 "\"ask\""
747 );
748 }
749
750 #[test]
751 fn test_graph_persistence_save_rejects_long_goal() {
752 let long_goal = "x".repeat(MAX_GOAL_LEN + 1);
755 let mut graph = TaskGraph::new(long_goal);
756 graph.goal = "x".repeat(MAX_GOAL_LEN + 1);
757 assert!(
758 graph.goal.len() > MAX_GOAL_LEN,
759 "test setup: goal must exceed limit"
760 );
761 assert_eq!(MAX_GOAL_LEN, 1024);
764 }
765
766 #[test]
767 fn test_task_node_predicate_fields_default_to_none() {
768 let json = r#"{
771 "id": 0,
772 "title": "t",
773 "description": "d",
774 "agent_hint": null,
775 "status": "pending",
776 "depends_on": [],
777 "result": null,
778 "assigned_agent": null,
779 "retry_count": 0,
780 "failure_strategy": null,
781 "max_retries": null
782 }"#;
783 let node: TaskNode = serde_json::from_str(json).expect("should deserialize old JSON");
784 assert!(node.verify_predicate.is_none());
785 assert!(node.predicate_outcome.is_none());
786 }
787
788 #[test]
789 fn test_task_node_missing_execution_mode_deserializes_as_parallel() {
790 let json = r#"{
793 "id": 0,
794 "title": "t",
795 "description": "d",
796 "agent_hint": null,
797 "status": "pending",
798 "depends_on": [],
799 "result": null,
800 "assigned_agent": null,
801 "retry_count": 0,
802 "failure_strategy": null,
803 "max_retries": null
804 }"#;
805 let node: TaskNode = serde_json::from_str(json).expect("should deserialize old JSON");
806 assert_eq!(node.execution_mode, ExecutionMode::Parallel);
807 }
808
809 #[test]
810 fn test_execution_mode_serde_snake_case() {
811 assert_eq!(
812 serde_json::to_string(&ExecutionMode::Parallel).unwrap(),
813 "\"parallel\""
814 );
815 assert_eq!(
816 serde_json::to_string(&ExecutionMode::Sequential).unwrap(),
817 "\"sequential\""
818 );
819 let p: ExecutionMode = serde_json::from_str("\"parallel\"").unwrap();
820 assert_eq!(p, ExecutionMode::Parallel);
821 let s: ExecutionMode = serde_json::from_str("\"sequential\"").unwrap();
822 assert_eq!(s, ExecutionMode::Sequential);
823 }
824}