1use std::fmt;
5use std::path::PathBuf;
6use std::str::FromStr;
7
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10use zeph_memory::store::graph_store::{GraphSummary, RawGraphStore};
11
12use super::error::OrchestrationError;
13use super::verify_predicate::{PredicateOutcome, VerifyPredicate};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
31pub struct TaskId(pub u32);
32
33impl TaskId {
34 #[must_use]
36 pub fn index(self) -> usize {
37 self.0 as usize
38 }
39
40 #[must_use]
42 pub fn as_u32(self) -> u32 {
43 self.0
44 }
45}
46
47impl fmt::Display for TaskId {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 write!(f, "{}", self.0)
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
71pub struct GraphId(Uuid);
72
73impl GraphId {
74 #[must_use]
76 pub fn new() -> Self {
77 Self(Uuid::new_v4())
78 }
79}
80
81impl Default for GraphId {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl fmt::Display for GraphId {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 write!(f, "{}", self.0)
90 }
91}
92
93impl FromStr for GraphId {
94 type Err = OrchestrationError;
95
96 fn from_str(s: &str) -> Result<Self, Self::Err> {
97 Uuid::parse_str(s)
98 .map(GraphId)
99 .map_err(|e| OrchestrationError::InvalidGraph(format!("invalid graph id '{s}': {e}")))
100 }
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
127#[serde(rename_all = "snake_case")]
128pub enum TaskStatus {
129 Pending,
131 Ready,
133 Running,
135 Completed,
137 Failed,
139 Skipped,
141 Canceled,
143}
144
145impl TaskStatus {
146 #[must_use]
148 pub fn is_terminal(self) -> bool {
149 matches!(
150 self,
151 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Skipped | TaskStatus::Canceled
152 )
153 }
154}
155
156impl fmt::Display for TaskStatus {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 match self {
159 TaskStatus::Pending => write!(f, "pending"),
160 TaskStatus::Ready => write!(f, "ready"),
161 TaskStatus::Running => write!(f, "running"),
162 TaskStatus::Completed => write!(f, "completed"),
163 TaskStatus::Failed => write!(f, "failed"),
164 TaskStatus::Skipped => write!(f, "skipped"),
165 TaskStatus::Canceled => write!(f, "canceled"),
166 }
167 }
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
181#[serde(rename_all = "snake_case")]
182pub enum GraphStatus {
183 Created,
185 Running,
187 Completed,
189 Failed,
191 Canceled,
193 Paused,
195}
196
197impl fmt::Display for GraphStatus {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 match self {
200 GraphStatus::Created => write!(f, "created"),
201 GraphStatus::Running => write!(f, "running"),
202 GraphStatus::Completed => write!(f, "completed"),
203 GraphStatus::Failed => write!(f, "failed"),
204 GraphStatus::Canceled => write!(f, "canceled"),
205 GraphStatus::Paused => write!(f, "paused"),
206 }
207 }
208}
209
210#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
226#[serde(rename_all = "snake_case")]
227pub enum FailureStrategy {
228 #[default]
230 Abort,
231 Retry,
233 Skip,
235 Ask,
237}
238
239impl fmt::Display for FailureStrategy {
240 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241 match self {
242 FailureStrategy::Abort => write!(f, "abort"),
243 FailureStrategy::Retry => write!(f, "retry"),
244 FailureStrategy::Skip => write!(f, "skip"),
245 FailureStrategy::Ask => write!(f, "ask"),
246 }
247 }
248}
249
250impl FromStr for FailureStrategy {
251 type Err = OrchestrationError;
252
253 fn from_str(s: &str) -> Result<Self, Self::Err> {
254 match s {
255 "abort" => Ok(FailureStrategy::Abort),
256 "retry" => Ok(FailureStrategy::Retry),
257 "skip" => Ok(FailureStrategy::Skip),
258 "ask" => Ok(FailureStrategy::Ask),
259 other => Err(OrchestrationError::InvalidGraph(format!(
260 "unknown failure strategy '{other}': expected one of abort, retry, skip, ask"
261 ))),
262 }
263 }
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct TaskResult {
274 pub output: String,
276 pub artifacts: Vec<PathBuf>,
278 pub duration_ms: u64,
280 pub agent_id: Option<String>,
282 pub agent_def: Option<String>,
284}
285
286#[derive(
304 Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, schemars::JsonSchema,
305)]
306#[serde(rename_all = "snake_case")]
307pub enum ExecutionMode {
308 #[default]
310 Parallel,
311 Sequential,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct TaskNode {
335 pub id: TaskId,
337 pub title: String,
339 pub description: String,
341 pub agent_hint: Option<String>,
343 pub status: TaskStatus,
345 pub depends_on: Vec<TaskId>,
347 pub result: Option<TaskResult>,
349 pub assigned_agent: Option<String>,
351 pub retry_count: u32,
353 #[serde(default)]
355 pub predicate_rerun_count: u32,
356 pub failure_strategy: Option<FailureStrategy>,
358 pub max_retries: Option<u32>,
360 #[serde(default)]
363 pub execution_mode: ExecutionMode,
364 #[serde(default)]
371 pub verify_predicate: Option<VerifyPredicate>,
372 #[serde(default)]
378 pub predicate_outcome: Option<PredicateOutcome>,
379 #[serde(default, skip_serializing_if = "Option::is_none")]
382 pub execution_environment: Option<String>,
383
384 #[serde(default, skip_serializing_if = "Option::is_none")]
389 pub token_budget_cents: Option<f64>,
390}
391
392impl TaskNode {
393 #[must_use]
395 pub fn new(id: u32, title: impl Into<String>, description: impl Into<String>) -> Self {
396 Self {
397 id: TaskId(id),
398 title: title.into(),
399 description: description.into(),
400 agent_hint: None,
401 status: TaskStatus::Pending,
402 depends_on: Vec::new(),
403 result: None,
404 assigned_agent: None,
405 retry_count: 0,
406 predicate_rerun_count: 0,
407 failure_strategy: None,
408 max_retries: None,
409 execution_mode: ExecutionMode::default(),
410 verify_predicate: None,
411 predicate_outcome: None,
412 execution_environment: None,
413 token_budget_cents: None,
414 }
415 }
416}
417
418#[derive(Debug, Clone, Serialize, Deserialize)]
438pub struct TaskGraph {
439 pub id: GraphId,
441 pub goal: String,
443 pub tasks: Vec<TaskNode>,
445 pub status: GraphStatus,
447 pub default_failure_strategy: FailureStrategy,
449 pub default_max_retries: u32,
451 pub created_at: String,
453 pub finished_at: Option<String>,
455}
456
457impl TaskGraph {
458 #[must_use]
460 pub fn new(goal: impl Into<String>) -> Self {
461 Self {
462 id: GraphId::new(),
463 goal: goal.into(),
464 tasks: Vec::new(),
465 status: GraphStatus::Created,
466 default_failure_strategy: FailureStrategy::default(),
467 default_max_retries: 3,
468 created_at: chrono_now(),
469 finished_at: None,
470 }
471 }
472}
473
474pub(crate) fn chrono_now() -> String {
475 let secs = std::time::SystemTime::now()
478 .duration_since(std::time::UNIX_EPOCH)
479 .map_or(0, |d| d.as_secs());
480 let (y, mo, d, h, mi, s) = epoch_secs_to_datetime(secs);
483 format!("{y:04}-{mo:02}-{d:02}T{h:02}:{mi:02}:{s:02}Z")
484}
485
486fn epoch_secs_to_datetime(secs: u64) -> (u64, u8, u8, u8, u8, u8) {
488 let s = (secs % 60) as u8;
489 let mins = secs / 60;
490 let mi = (mins % 60) as u8;
491 let hours = mins / 60;
492 let h = (hours % 24) as u8;
493 let days = hours / 24; let (mut year, mut remaining_days) = {
498 let cycles = days / 146_097;
499 let rem = days % 146_097;
500 (1970 + cycles * 400, rem)
501 };
502 let centuries = (remaining_days / 36_524).min(3);
504 year += centuries * 100;
505 remaining_days -= centuries * 36_524;
506 let quads = remaining_days / 1_461;
508 year += quads * 4;
509 remaining_days -= quads * 1_461;
510 let extra_years = (remaining_days / 365).min(3);
512 year += extra_years;
513 remaining_days -= extra_years * 365;
514
515 let is_leap = (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0);
516 let days_in_month: [u64; 12] = if is_leap {
517 [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
518 } else {
519 [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
520 };
521
522 let mut month = 0u8;
523 for (i, &dim) in days_in_month.iter().enumerate() {
524 if remaining_days < dim {
525 month = u8::try_from(i + 1).unwrap_or(1);
527 break;
528 }
529 remaining_days -= dim;
530 }
531 let day = u8::try_from(remaining_days + 1).unwrap_or(1);
533
534 (year, month, day, h, mi, s)
535}
536
537const MAX_GOAL_LEN: usize = 1024;
539
540pub struct GraphPersistence<S: RawGraphStore> {
553 store: S,
554}
555
556impl<S: RawGraphStore> GraphPersistence<S> {
557 pub fn new(store: S) -> Self {
559 Self { store }
560 }
561
562 pub async fn save(&self, graph: &TaskGraph) -> Result<(), OrchestrationError> {
571 if graph.goal.len() > MAX_GOAL_LEN {
572 return Err(OrchestrationError::InvalidGraph(format!(
573 "goal exceeds {MAX_GOAL_LEN} character limit ({} chars)",
574 graph.goal.len()
575 )));
576 }
577 let json = serde_json::to_string(graph)
578 .map_err(|e| OrchestrationError::Persistence(e.to_string()))?;
579 self.store
580 .save_graph(
581 &graph.id.to_string(),
582 &graph.goal,
583 &graph.status.to_string(),
584 &json,
585 &graph.created_at,
586 graph.finished_at.as_deref(),
587 )
588 .await
589 .map_err(|e| OrchestrationError::Persistence(e.to_string()))
590 }
591
592 pub async fn load(&self, id: &GraphId) -> Result<Option<TaskGraph>, OrchestrationError> {
600 match self
601 .store
602 .load_graph(&id.to_string())
603 .await
604 .map_err(|e| OrchestrationError::Persistence(e.to_string()))?
605 {
606 Some(json) => {
607 let graph = serde_json::from_str(&json)
608 .map_err(|e| OrchestrationError::Persistence(e.to_string()))?;
609 Ok(Some(graph))
610 }
611 None => Ok(None),
612 }
613 }
614
615 pub async fn list(&self, limit: u32) -> Result<Vec<GraphSummary>, OrchestrationError> {
621 self.store
622 .list_graphs(limit)
623 .await
624 .map_err(|e| OrchestrationError::Persistence(e.to_string()))
625 }
626
627 pub async fn delete(&self, id: &GraphId) -> Result<bool, OrchestrationError> {
635 self.store
636 .delete_graph(&id.to_string())
637 .await
638 .map_err(|e| OrchestrationError::Persistence(e.to_string()))
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645
646 #[test]
647 fn test_taskid_display() {
648 assert_eq!(TaskId(3).to_string(), "3");
649 }
650
651 #[test]
652 fn test_graphid_display_and_new() {
653 let id = GraphId::new();
654 let s = id.to_string();
655 assert_eq!(s.len(), 36, "UUID string should be 36 chars");
656 let parsed: GraphId = s.parse().expect("should parse back");
657 assert_eq!(id, parsed);
658 }
659
660 #[test]
661 fn test_graphid_from_str_invalid() {
662 let err = "not-a-uuid".parse::<GraphId>();
663 assert!(err.is_err());
664 }
665
666 #[test]
667 fn test_task_status_is_terminal() {
668 assert!(TaskStatus::Completed.is_terminal());
669 assert!(TaskStatus::Failed.is_terminal());
670 assert!(TaskStatus::Skipped.is_terminal());
671 assert!(TaskStatus::Canceled.is_terminal());
672
673 assert!(!TaskStatus::Pending.is_terminal());
674 assert!(!TaskStatus::Ready.is_terminal());
675 assert!(!TaskStatus::Running.is_terminal());
676 }
677
678 #[test]
679 fn test_task_status_display() {
680 assert_eq!(TaskStatus::Pending.to_string(), "pending");
681 assert_eq!(TaskStatus::Ready.to_string(), "ready");
682 assert_eq!(TaskStatus::Running.to_string(), "running");
683 assert_eq!(TaskStatus::Completed.to_string(), "completed");
684 assert_eq!(TaskStatus::Failed.to_string(), "failed");
685 assert_eq!(TaskStatus::Skipped.to_string(), "skipped");
686 assert_eq!(TaskStatus::Canceled.to_string(), "canceled");
687 }
688
689 #[test]
690 fn test_failure_strategy_default() {
691 assert_eq!(FailureStrategy::default(), FailureStrategy::Abort);
692 }
693
694 #[test]
695 fn test_failure_strategy_display() {
696 assert_eq!(FailureStrategy::Abort.to_string(), "abort");
697 assert_eq!(FailureStrategy::Retry.to_string(), "retry");
698 assert_eq!(FailureStrategy::Skip.to_string(), "skip");
699 assert_eq!(FailureStrategy::Ask.to_string(), "ask");
700 }
701
702 #[test]
703 fn test_graph_status_display() {
704 assert_eq!(GraphStatus::Created.to_string(), "created");
705 assert_eq!(GraphStatus::Running.to_string(), "running");
706 assert_eq!(GraphStatus::Completed.to_string(), "completed");
707 assert_eq!(GraphStatus::Failed.to_string(), "failed");
708 assert_eq!(GraphStatus::Canceled.to_string(), "canceled");
709 assert_eq!(GraphStatus::Paused.to_string(), "paused");
710 }
711
712 #[test]
713 fn test_task_graph_serde_roundtrip() {
714 let mut graph = TaskGraph::new("test goal");
715 graph.tasks.push(TaskNode::new(0, "task 0", "do something"));
716 let json = serde_json::to_string(&graph).expect("serialize");
717 let restored: TaskGraph = serde_json::from_str(&json).expect("deserialize");
718 assert_eq!(graph.id, restored.id);
719 assert_eq!(graph.goal, restored.goal);
720 assert_eq!(graph.tasks.len(), restored.tasks.len());
721 }
722
723 #[test]
724 fn test_task_node_serde_roundtrip() {
725 let mut node = TaskNode::new(1, "compile", "run cargo build");
726 node.agent_hint = Some("rust-dev".to_string());
727 node.depends_on = vec![TaskId(0)];
728 let json = serde_json::to_string(&node).expect("serialize");
729 let restored: TaskNode = serde_json::from_str(&json).expect("deserialize");
730 assert_eq!(node.id, restored.id);
731 assert_eq!(node.title, restored.title);
732 assert_eq!(node.depends_on, restored.depends_on);
733 }
734
735 #[test]
736 fn test_task_result_serde_roundtrip() {
737 let result = TaskResult {
738 output: "ok".to_string(),
739 artifacts: vec![PathBuf::from("/tmp/out.bin")],
740 duration_ms: 500,
741 agent_id: Some("agent-1".to_string()),
742 agent_def: None,
743 };
744 let json = serde_json::to_string(&result).expect("serialize");
745 let restored: TaskResult = serde_json::from_str(&json).expect("deserialize");
746 assert_eq!(result.output, restored.output);
747 assert_eq!(result.duration_ms, restored.duration_ms);
748 assert_eq!(result.artifacts, restored.artifacts);
749 }
750
751 #[test]
752 fn test_failure_strategy_from_str() {
753 assert_eq!(
754 "abort".parse::<FailureStrategy>().unwrap(),
755 FailureStrategy::Abort
756 );
757 assert_eq!(
758 "retry".parse::<FailureStrategy>().unwrap(),
759 FailureStrategy::Retry
760 );
761 assert_eq!(
762 "skip".parse::<FailureStrategy>().unwrap(),
763 FailureStrategy::Skip
764 );
765 assert_eq!(
766 "ask".parse::<FailureStrategy>().unwrap(),
767 FailureStrategy::Ask
768 );
769 assert!("abort_all".parse::<FailureStrategy>().is_err());
770 assert!("".parse::<FailureStrategy>().is_err());
771 }
772
773 #[test]
774 fn test_chrono_now_iso8601_format() {
775 let ts = chrono_now();
776 assert_eq!(ts.len(), 20, "timestamp should be 20 chars: {ts}");
778 assert!(ts.ends_with('Z'), "should end with Z: {ts}");
779 assert!(ts.contains('T'), "should contain T: {ts}");
780 let year: u32 = ts[..4].parse().expect("year should be numeric");
782 assert!(year >= 2024, "year should be >= 2024: {year}");
783 }
784
785 #[test]
786 fn test_failure_strategy_serde_snake_case() {
787 assert_eq!(
788 serde_json::to_string(&FailureStrategy::Abort).unwrap(),
789 "\"abort\""
790 );
791 assert_eq!(
792 serde_json::to_string(&FailureStrategy::Retry).unwrap(),
793 "\"retry\""
794 );
795 assert_eq!(
796 serde_json::to_string(&FailureStrategy::Skip).unwrap(),
797 "\"skip\""
798 );
799 assert_eq!(
800 serde_json::to_string(&FailureStrategy::Ask).unwrap(),
801 "\"ask\""
802 );
803 }
804
805 #[test]
806 fn test_graph_persistence_save_rejects_long_goal() {
807 let long_goal = "x".repeat(MAX_GOAL_LEN + 1);
810 let mut graph = TaskGraph::new(long_goal);
811 graph.goal = "x".repeat(MAX_GOAL_LEN + 1);
812 assert!(
813 graph.goal.len() > MAX_GOAL_LEN,
814 "test setup: goal must exceed limit"
815 );
816 assert_eq!(MAX_GOAL_LEN, 1024);
819 }
820
821 #[test]
822 fn test_task_node_predicate_fields_default_to_none() {
823 let json = r#"{
826 "id": 0,
827 "title": "t",
828 "description": "d",
829 "agent_hint": null,
830 "status": "pending",
831 "depends_on": [],
832 "result": null,
833 "assigned_agent": null,
834 "retry_count": 0,
835 "failure_strategy": null,
836 "max_retries": null
837 }"#;
838 let node: TaskNode = serde_json::from_str(json).expect("should deserialize old JSON");
839 assert!(node.verify_predicate.is_none());
840 assert!(node.predicate_outcome.is_none());
841 }
842
843 #[test]
844 fn test_task_node_missing_execution_mode_deserializes_as_parallel() {
845 let json = r#"{
848 "id": 0,
849 "title": "t",
850 "description": "d",
851 "agent_hint": null,
852 "status": "pending",
853 "depends_on": [],
854 "result": null,
855 "assigned_agent": null,
856 "retry_count": 0,
857 "failure_strategy": null,
858 "max_retries": null
859 }"#;
860 let node: TaskNode = serde_json::from_str(json).expect("should deserialize old JSON");
861 assert_eq!(node.execution_mode, ExecutionMode::Parallel);
862 }
863
864 #[test]
865 fn test_execution_mode_serde_snake_case() {
866 assert_eq!(
867 serde_json::to_string(&ExecutionMode::Parallel).unwrap(),
868 "\"parallel\""
869 );
870 assert_eq!(
871 serde_json::to_string(&ExecutionMode::Sequential).unwrap(),
872 "\"sequential\""
873 );
874 let p: ExecutionMode = serde_json::from_str("\"parallel\"").unwrap();
875 assert_eq!(p, ExecutionMode::Parallel);
876 let s: ExecutionMode = serde_json::from_str("\"sequential\"").unwrap();
877 assert_eq!(s, ExecutionMode::Sequential);
878 }
879}