Skip to main content

zeph_orchestration/
graph.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt;
5use std::path::PathBuf;
6use std::str::FromStr;
7
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10use zeph_memory::store::graph_store::{GraphSummary, RawGraphStore};
11
12use super::error::OrchestrationError;
13
14/// Index of a task within a `TaskGraph.tasks` Vec.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub struct TaskId(pub(crate) u32);
17
18impl TaskId {
19    /// Returns the index for Vec access.
20    #[must_use]
21    pub fn index(self) -> usize {
22        self.0 as usize
23    }
24
25    /// Returns the raw `u32` value.
26    #[must_use]
27    pub fn as_u32(self) -> u32 {
28        self.0
29    }
30}
31
32impl fmt::Display for TaskId {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        write!(f, "{}", self.0)
35    }
36}
37
38/// Unique identifier for a `TaskGraph`.
39#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
40pub struct GraphId(Uuid);
41
42impl GraphId {
43    /// Generate a new random v4 `GraphId`.
44    #[must_use]
45    pub fn new() -> Self {
46        Self(Uuid::new_v4())
47    }
48}
49
50impl Default for GraphId {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl fmt::Display for GraphId {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        write!(f, "{}", self.0)
59    }
60}
61
62impl FromStr for GraphId {
63    type Err = OrchestrationError;
64
65    fn from_str(s: &str) -> Result<Self, Self::Err> {
66        Uuid::parse_str(s)
67            .map(GraphId)
68            .map_err(|e| OrchestrationError::InvalidGraph(format!("invalid graph id '{s}': {e}")))
69    }
70}
71
72/// Lifecycle status of a single task node.
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
74#[serde(rename_all = "snake_case")]
75pub enum TaskStatus {
76    Pending,
77    Ready,
78    Running,
79    Completed,
80    Failed,
81    Skipped,
82    Canceled,
83}
84
85impl TaskStatus {
86    /// Returns `true` if the status is a terminal state.
87    #[must_use]
88    pub fn is_terminal(self) -> bool {
89        matches!(
90            self,
91            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Skipped | TaskStatus::Canceled
92        )
93    }
94}
95
96impl fmt::Display for TaskStatus {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        match self {
99            TaskStatus::Pending => write!(f, "pending"),
100            TaskStatus::Ready => write!(f, "ready"),
101            TaskStatus::Running => write!(f, "running"),
102            TaskStatus::Completed => write!(f, "completed"),
103            TaskStatus::Failed => write!(f, "failed"),
104            TaskStatus::Skipped => write!(f, "skipped"),
105            TaskStatus::Canceled => write!(f, "canceled"),
106        }
107    }
108}
109
110/// Lifecycle status of a `TaskGraph`.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
112#[serde(rename_all = "snake_case")]
113pub enum GraphStatus {
114    Created,
115    Running,
116    Completed,
117    Failed,
118    Canceled,
119    Paused,
120}
121
122impl fmt::Display for GraphStatus {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        match self {
125            GraphStatus::Created => write!(f, "created"),
126            GraphStatus::Running => write!(f, "running"),
127            GraphStatus::Completed => write!(f, "completed"),
128            GraphStatus::Failed => write!(f, "failed"),
129            GraphStatus::Canceled => write!(f, "canceled"),
130            GraphStatus::Paused => write!(f, "paused"),
131        }
132    }
133}
134
135/// What to do when a task fails.
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
137#[serde(rename_all = "snake_case")]
138pub enum FailureStrategy {
139    /// Abort the entire graph.
140    #[default]
141    Abort,
142    /// Retry the task up to `max_retries` times.
143    Retry,
144    /// Skip the task and its dependents.
145    Skip,
146    /// Pause the graph and ask the user.
147    Ask,
148}
149
150impl fmt::Display for FailureStrategy {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        match self {
153            FailureStrategy::Abort => write!(f, "abort"),
154            FailureStrategy::Retry => write!(f, "retry"),
155            FailureStrategy::Skip => write!(f, "skip"),
156            FailureStrategy::Ask => write!(f, "ask"),
157        }
158    }
159}
160
161impl FromStr for FailureStrategy {
162    type Err = OrchestrationError;
163
164    fn from_str(s: &str) -> Result<Self, Self::Err> {
165        match s {
166            "abort" => Ok(FailureStrategy::Abort),
167            "retry" => Ok(FailureStrategy::Retry),
168            "skip" => Ok(FailureStrategy::Skip),
169            "ask" => Ok(FailureStrategy::Ask),
170            other => Err(OrchestrationError::InvalidGraph(format!(
171                "unknown failure strategy '{other}': expected one of abort, retry, skip, ask"
172            ))),
173        }
174    }
175}
176
177/// Output produced by a completed task.
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct TaskResult {
180    pub output: String,
181    pub artifacts: Vec<PathBuf>,
182    pub duration_ms: u64,
183    pub agent_id: Option<String>,
184    pub agent_def: Option<String>,
185}
186
187/// Execution mode annotation from the LLM planner.
188#[derive(
189    Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, schemars::JsonSchema,
190)]
191#[serde(rename_all = "snake_case")]
192pub enum ExecutionMode {
193    /// Task can run in parallel with others at the same DAG level.
194    #[default]
195    Parallel,
196    /// Task is globally serialized: at most one `Sequential` task runs at a time across
197    /// the entire graph (e.g. deploy, exclusive-resource access, shared-state mutation).
198    Sequential,
199}
200
201/// A single node in the task DAG.
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct TaskNode {
204    pub id: TaskId,
205    pub title: String,
206    pub description: String,
207    pub agent_hint: Option<String>,
208    pub status: TaskStatus,
209    /// Indices of tasks this node depends on.
210    pub depends_on: Vec<TaskId>,
211    pub result: Option<TaskResult>,
212    pub assigned_agent: Option<String>,
213    pub retry_count: u32,
214    /// Per-task override; `None` means use graph default.
215    pub failure_strategy: Option<FailureStrategy>,
216    pub max_retries: Option<u32>,
217    /// LLM planner annotation. Old SQLite-stored JSON without this field
218    /// deserializes to the default (`Parallel`).
219    #[serde(default)]
220    pub execution_mode: ExecutionMode,
221}
222
223impl TaskNode {
224    /// Create a new pending task with the given index.
225    #[must_use]
226    pub fn new(id: u32, title: impl Into<String>, description: impl Into<String>) -> Self {
227        Self {
228            id: TaskId(id),
229            title: title.into(),
230            description: description.into(),
231            agent_hint: None,
232            status: TaskStatus::Pending,
233            depends_on: Vec::new(),
234            result: None,
235            assigned_agent: None,
236            retry_count: 0,
237            failure_strategy: None,
238            max_retries: None,
239            execution_mode: ExecutionMode::default(),
240        }
241    }
242}
243
244/// A directed acyclic graph of tasks to be executed by the orchestrator.
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct TaskGraph {
247    pub id: GraphId,
248    pub goal: String,
249    pub tasks: Vec<TaskNode>,
250    pub status: GraphStatus,
251    pub default_failure_strategy: FailureStrategy,
252    pub default_max_retries: u32,
253    pub created_at: String,
254    pub finished_at: Option<String>,
255}
256
257impl TaskGraph {
258    /// Create a new graph with `Created` status.
259    #[must_use]
260    pub fn new(goal: impl Into<String>) -> Self {
261        Self {
262            id: GraphId::new(),
263            goal: goal.into(),
264            tasks: Vec::new(),
265            status: GraphStatus::Created,
266            default_failure_strategy: FailureStrategy::default(),
267            default_max_retries: 3,
268            created_at: chrono_now(),
269            finished_at: None,
270        }
271    }
272}
273
274pub(crate) fn chrono_now() -> String {
275    // ISO-8601 UTC timestamp, consistent with the rest of the codebase.
276    // Format: "2026-03-05T22:04:41Z"
277    let secs = std::time::SystemTime::now()
278        .duration_since(std::time::UNIX_EPOCH)
279        .map_or(0, |d| d.as_secs());
280    // Manual formatting: seconds since epoch → ISO-8601 UTC
281    // Days since epoch, then decompose into year/month/day
282    let (y, mo, d, h, mi, s) = epoch_secs_to_datetime(secs);
283    format!("{y:04}-{mo:02}-{d:02}T{h:02}:{mi:02}:{s:02}Z")
284}
285
286/// Convert Unix epoch seconds to (year, month, day, hour, min, sec) UTC.
287fn epoch_secs_to_datetime(secs: u64) -> (u64, u8, u8, u8, u8, u8) {
288    let s = (secs % 60) as u8;
289    let mins = secs / 60;
290    let mi = (mins % 60) as u8;
291    let hours = mins / 60;
292    let h = (hours % 24) as u8;
293    let days = hours / 24; // days since 1970-01-01
294
295    // Gregorian calendar decomposition
296    // 400-year cycle = 146097 days
297    let (mut year, mut remaining_days) = {
298        let cycles = days / 146_097;
299        let rem = days % 146_097;
300        (1970 + cycles * 400, rem)
301    };
302    // 100-year century (36524 days, no leap on century unless /400)
303    let centuries = (remaining_days / 36_524).min(3);
304    year += centuries * 100;
305    remaining_days -= centuries * 36_524;
306    // 4-year cycle (1461 days)
307    let quads = remaining_days / 1_461;
308    year += quads * 4;
309    remaining_days -= quads * 1_461;
310    // remaining years
311    let extra_years = (remaining_days / 365).min(3);
312    year += extra_years;
313    remaining_days -= extra_years * 365;
314
315    let is_leap = (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0);
316    let days_in_month: [u64; 12] = if is_leap {
317        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
318    } else {
319        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
320    };
321
322    let mut month = 0u8;
323    for (i, &dim) in days_in_month.iter().enumerate() {
324        if remaining_days < dim {
325            // i is in 0..12, so i+1 fits in u8
326            month = u8::try_from(i + 1).unwrap_or(1);
327            break;
328        }
329        remaining_days -= dim;
330    }
331    // remaining_days is in 0..30, so +1 fits in u8
332    let day = u8::try_from(remaining_days + 1).unwrap_or(1);
333
334    (year, month, day, h, mi, s)
335}
336
337/// Maximum allowed length for a `TaskGraph` goal string.
338const MAX_GOAL_LEN: usize = 1024;
339
340/// Type-safe wrapper around `RawGraphStore` that handles `TaskGraph` serialization.
341///
342/// Consumers in `zeph-core` use this instead of `RawGraphStore` directly, so they
343/// never need to deal with JSON strings.
344///
345/// # Storage layout
346///
347/// The `task_graphs` table stores both metadata columns (`goal`, `status`,
348/// `created_at`, `finished_at`) and the full `graph_json` blob. The metadata
349/// columns are summary/index data used for listing and filtering; `graph_json`
350/// is the authoritative source for full graph reconstruction. On `load`, only
351/// `graph_json` is deserialized — the columns are not consulted.
352pub struct GraphPersistence<S: RawGraphStore> {
353    store: S,
354}
355
356impl<S: RawGraphStore> GraphPersistence<S> {
357    /// Create a new `GraphPersistence` wrapping the given store.
358    pub fn new(store: S) -> Self {
359        Self { store }
360    }
361
362    /// Persist a `TaskGraph` (upsert).
363    ///
364    /// Returns `OrchestrationError::InvalidGraph` if `graph.goal` exceeds
365    /// `MAX_GOAL_LEN` (1024) characters.
366    ///
367    /// # Errors
368    ///
369    /// Returns `OrchestrationError::Persistence` on serialization or database failure.
370    pub async fn save(&self, graph: &TaskGraph) -> Result<(), OrchestrationError> {
371        if graph.goal.len() > MAX_GOAL_LEN {
372            return Err(OrchestrationError::InvalidGraph(format!(
373                "goal exceeds {MAX_GOAL_LEN} character limit ({} chars)",
374                graph.goal.len()
375            )));
376        }
377        let json = serde_json::to_string(graph)
378            .map_err(|e| OrchestrationError::Persistence(e.to_string()))?;
379        self.store
380            .save_graph(
381                &graph.id.to_string(),
382                &graph.goal,
383                &graph.status.to_string(),
384                &json,
385                &graph.created_at,
386                graph.finished_at.as_deref(),
387            )
388            .await
389            .map_err(|e| OrchestrationError::Persistence(e.to_string()))
390    }
391
392    /// Load a `TaskGraph` by its `GraphId`.
393    ///
394    /// Returns `None` if not found.
395    ///
396    /// # Errors
397    ///
398    /// Returns `OrchestrationError::Persistence` on database or deserialization failure.
399    pub async fn load(&self, id: &GraphId) -> Result<Option<TaskGraph>, OrchestrationError> {
400        match self
401            .store
402            .load_graph(&id.to_string())
403            .await
404            .map_err(|e| OrchestrationError::Persistence(e.to_string()))?
405        {
406            Some(json) => {
407                let graph = serde_json::from_str(&json)
408                    .map_err(|e| OrchestrationError::Persistence(e.to_string()))?;
409                Ok(Some(graph))
410            }
411            None => Ok(None),
412        }
413    }
414
415    /// List stored graphs (newest first).
416    ///
417    /// # Errors
418    ///
419    /// Returns `OrchestrationError::Persistence` on database failure.
420    pub async fn list(&self, limit: u32) -> Result<Vec<GraphSummary>, OrchestrationError> {
421        self.store
422            .list_graphs(limit)
423            .await
424            .map_err(|e| OrchestrationError::Persistence(e.to_string()))
425    }
426
427    /// Delete a graph by its `GraphId`.
428    ///
429    /// Returns `true` if a row was deleted.
430    ///
431    /// # Errors
432    ///
433    /// Returns `OrchestrationError::Persistence` on database failure.
434    pub async fn delete(&self, id: &GraphId) -> Result<bool, OrchestrationError> {
435        self.store
436            .delete_graph(&id.to_string())
437            .await
438            .map_err(|e| OrchestrationError::Persistence(e.to_string()))
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_taskid_display() {
448        assert_eq!(TaskId(3).to_string(), "3");
449    }
450
451    #[test]
452    fn test_graphid_display_and_new() {
453        let id = GraphId::new();
454        let s = id.to_string();
455        assert_eq!(s.len(), 36, "UUID string should be 36 chars");
456        let parsed: GraphId = s.parse().expect("should parse back");
457        assert_eq!(id, parsed);
458    }
459
460    #[test]
461    fn test_graphid_from_str_invalid() {
462        let err = "not-a-uuid".parse::<GraphId>();
463        assert!(err.is_err());
464    }
465
466    #[test]
467    fn test_task_status_is_terminal() {
468        assert!(TaskStatus::Completed.is_terminal());
469        assert!(TaskStatus::Failed.is_terminal());
470        assert!(TaskStatus::Skipped.is_terminal());
471        assert!(TaskStatus::Canceled.is_terminal());
472
473        assert!(!TaskStatus::Pending.is_terminal());
474        assert!(!TaskStatus::Ready.is_terminal());
475        assert!(!TaskStatus::Running.is_terminal());
476    }
477
478    #[test]
479    fn test_task_status_display() {
480        assert_eq!(TaskStatus::Pending.to_string(), "pending");
481        assert_eq!(TaskStatus::Ready.to_string(), "ready");
482        assert_eq!(TaskStatus::Running.to_string(), "running");
483        assert_eq!(TaskStatus::Completed.to_string(), "completed");
484        assert_eq!(TaskStatus::Failed.to_string(), "failed");
485        assert_eq!(TaskStatus::Skipped.to_string(), "skipped");
486        assert_eq!(TaskStatus::Canceled.to_string(), "canceled");
487    }
488
489    #[test]
490    fn test_failure_strategy_default() {
491        assert_eq!(FailureStrategy::default(), FailureStrategy::Abort);
492    }
493
494    #[test]
495    fn test_failure_strategy_display() {
496        assert_eq!(FailureStrategy::Abort.to_string(), "abort");
497        assert_eq!(FailureStrategy::Retry.to_string(), "retry");
498        assert_eq!(FailureStrategy::Skip.to_string(), "skip");
499        assert_eq!(FailureStrategy::Ask.to_string(), "ask");
500    }
501
502    #[test]
503    fn test_graph_status_display() {
504        assert_eq!(GraphStatus::Created.to_string(), "created");
505        assert_eq!(GraphStatus::Running.to_string(), "running");
506        assert_eq!(GraphStatus::Completed.to_string(), "completed");
507        assert_eq!(GraphStatus::Failed.to_string(), "failed");
508        assert_eq!(GraphStatus::Canceled.to_string(), "canceled");
509        assert_eq!(GraphStatus::Paused.to_string(), "paused");
510    }
511
512    #[test]
513    fn test_task_graph_serde_roundtrip() {
514        let mut graph = TaskGraph::new("test goal");
515        graph.tasks.push(TaskNode::new(0, "task 0", "do something"));
516        let json = serde_json::to_string(&graph).expect("serialize");
517        let restored: TaskGraph = serde_json::from_str(&json).expect("deserialize");
518        assert_eq!(graph.id, restored.id);
519        assert_eq!(graph.goal, restored.goal);
520        assert_eq!(graph.tasks.len(), restored.tasks.len());
521    }
522
523    #[test]
524    fn test_task_node_serde_roundtrip() {
525        let mut node = TaskNode::new(1, "compile", "run cargo build");
526        node.agent_hint = Some("rust-dev".to_string());
527        node.depends_on = vec![TaskId(0)];
528        let json = serde_json::to_string(&node).expect("serialize");
529        let restored: TaskNode = serde_json::from_str(&json).expect("deserialize");
530        assert_eq!(node.id, restored.id);
531        assert_eq!(node.title, restored.title);
532        assert_eq!(node.depends_on, restored.depends_on);
533    }
534
535    #[test]
536    fn test_task_result_serde_roundtrip() {
537        let result = TaskResult {
538            output: "ok".to_string(),
539            artifacts: vec![PathBuf::from("/tmp/out.bin")],
540            duration_ms: 500,
541            agent_id: Some("agent-1".to_string()),
542            agent_def: None,
543        };
544        let json = serde_json::to_string(&result).expect("serialize");
545        let restored: TaskResult = serde_json::from_str(&json).expect("deserialize");
546        assert_eq!(result.output, restored.output);
547        assert_eq!(result.duration_ms, restored.duration_ms);
548        assert_eq!(result.artifacts, restored.artifacts);
549    }
550
551    #[test]
552    fn test_failure_strategy_from_str() {
553        assert_eq!(
554            "abort".parse::<FailureStrategy>().unwrap(),
555            FailureStrategy::Abort
556        );
557        assert_eq!(
558            "retry".parse::<FailureStrategy>().unwrap(),
559            FailureStrategy::Retry
560        );
561        assert_eq!(
562            "skip".parse::<FailureStrategy>().unwrap(),
563            FailureStrategy::Skip
564        );
565        assert_eq!(
566            "ask".parse::<FailureStrategy>().unwrap(),
567            FailureStrategy::Ask
568        );
569        assert!("abort_all".parse::<FailureStrategy>().is_err());
570        assert!("".parse::<FailureStrategy>().is_err());
571    }
572
573    #[test]
574    fn test_chrono_now_iso8601_format() {
575        let ts = chrono_now();
576        // Format: "YYYY-MM-DDTHH:MM:SSZ" — 20 chars
577        assert_eq!(ts.len(), 20, "timestamp should be 20 chars: {ts}");
578        assert!(ts.ends_with('Z'), "should end with Z: {ts}");
579        assert!(ts.contains('T'), "should contain T: {ts}");
580        // Year should be >= 2024
581        let year: u32 = ts[..4].parse().expect("year should be numeric");
582        assert!(year >= 2024, "year should be >= 2024: {year}");
583    }
584
585    #[test]
586    fn test_failure_strategy_serde_snake_case() {
587        assert_eq!(
588            serde_json::to_string(&FailureStrategy::Abort).unwrap(),
589            "\"abort\""
590        );
591        assert_eq!(
592            serde_json::to_string(&FailureStrategy::Retry).unwrap(),
593            "\"retry\""
594        );
595        assert_eq!(
596            serde_json::to_string(&FailureStrategy::Skip).unwrap(),
597            "\"skip\""
598        );
599        assert_eq!(
600            serde_json::to_string(&FailureStrategy::Ask).unwrap(),
601            "\"ask\""
602        );
603    }
604
605    #[test]
606    fn test_graph_persistence_save_rejects_long_goal() {
607        // GraphPersistence::save() is async and requires a real store;
608        // we verify the goal-length guard directly via the const.
609        let long_goal = "x".repeat(MAX_GOAL_LEN + 1);
610        let mut graph = TaskGraph::new(long_goal);
611        graph.goal = "x".repeat(MAX_GOAL_LEN + 1);
612        assert!(
613            graph.goal.len() > MAX_GOAL_LEN,
614            "test setup: goal must exceed limit"
615        );
616        // The check itself lives in GraphPersistence::save(), exercised by
617        // the async persistence tests in zeph-memory; here we verify the constant.
618        assert_eq!(MAX_GOAL_LEN, 1024);
619    }
620
621    #[test]
622    fn test_task_node_missing_execution_mode_deserializes_as_parallel() {
623        // Old SQLite-stored JSON blobs lack the execution_mode field.
624        // #[serde(default)] must make them deserialize to Parallel without error.
625        let json = r#"{
626            "id": 0,
627            "title": "t",
628            "description": "d",
629            "agent_hint": null,
630            "status": "pending",
631            "depends_on": [],
632            "result": null,
633            "assigned_agent": null,
634            "retry_count": 0,
635            "failure_strategy": null,
636            "max_retries": null
637        }"#;
638        let node: TaskNode = serde_json::from_str(json).expect("should deserialize old JSON");
639        assert_eq!(node.execution_mode, ExecutionMode::Parallel);
640    }
641
642    #[test]
643    fn test_execution_mode_serde_snake_case() {
644        assert_eq!(
645            serde_json::to_string(&ExecutionMode::Parallel).unwrap(),
646            "\"parallel\""
647        );
648        assert_eq!(
649            serde_json::to_string(&ExecutionMode::Sequential).unwrap(),
650            "\"sequential\""
651        );
652        let p: ExecutionMode = serde_json::from_str("\"parallel\"").unwrap();
653        assert_eq!(p, ExecutionMode::Parallel);
654        let s: ExecutionMode = serde_json::from_str("\"sequential\"").unwrap();
655        assert_eq!(s, ExecutionMode::Sequential);
656    }
657}