Skip to main content

serdes_ai_graph/
executor.rs

1//! Graph execution engine.
2
3use crate::error::{GraphError, GraphResult};
4use crate::graph::Graph;
5use crate::persistence::StatePersistence;
6use crate::state::{generate_run_id, GraphRunResult, GraphState};
7use std::sync::Arc;
8use tracing::{info, span, Level};
9
10/// Graph executor with optional persistence and instrumentation.
11pub struct GraphExecutor<State, Deps, End, P = NoPersistence>
12where
13    State: GraphState,
14{
15    _persistence_type: std::marker::PhantomData<P>,
16    graph: Arc<Graph<State, Deps, End>>,
17    persistence: Option<Arc<P>>,
18    auto_save: bool,
19    max_steps: u32,
20    instrumentation: bool,
21}
22
23/// Marker type for no persistence.
24#[derive(Debug, Clone, Copy)]
25pub struct NoPersistence;
26
27impl<State, Deps, End> GraphExecutor<State, Deps, End, NoPersistence>
28where
29    State: GraphState,
30    Deps: Clone + Send + Sync + 'static,
31    End: Clone + Send + Sync + 'static,
32{
33    /// Create a new executor without persistence.
34    pub fn new(graph: Graph<State, Deps, End>) -> Self {
35        Self {
36            _persistence_type: std::marker::PhantomData,
37            graph: Arc::new(graph),
38            persistence: None,
39            auto_save: false,
40            max_steps: 100,
41            instrumentation: true,
42        }
43    }
44}
45
46impl<State, Deps, End, P> GraphExecutor<State, Deps, End, P>
47where
48    State: GraphState,
49    Deps: Clone + Send + Sync + 'static,
50    End: Clone + Send + Sync + 'static,
51    P: StatePersistence<State, End> + 'static,
52{
53    /// Create an executor with persistence.
54    pub fn with_persistence(graph: Graph<State, Deps, End>, persistence: P) -> Self {
55        Self {
56            _persistence_type: std::marker::PhantomData,
57            graph: Arc::new(graph),
58            persistence: Some(Arc::new(persistence)),
59            auto_save: true,
60            max_steps: 100,
61            instrumentation: true,
62        }
63    }
64
65    /// Set whether to automatically save state.
66    pub fn auto_save(mut self, enabled: bool) -> Self {
67        self.auto_save = enabled;
68        self
69    }
70
71    /// Set maximum steps.
72    pub fn max_steps(mut self, max: u32) -> Self {
73        self.max_steps = max;
74        self
75    }
76
77    /// Disable instrumentation.
78    pub fn without_instrumentation(mut self) -> Self {
79        self.instrumentation = false;
80        self
81    }
82
83    /// Get a reference to the graph.
84    pub fn graph(&self) -> &Graph<State, Deps, End> {
85        &self.graph
86    }
87
88    /// Run the graph.
89    pub async fn run(&self, state: State, deps: Deps) -> GraphResult<GraphRunResult<State, End>> {
90        let options = ExecutionOptions::new()
91            .max_steps(self.max_steps)
92            .tracing(self.instrumentation);
93        self.run_with_options(state, deps, options).await
94    }
95
96    /// Run the graph with options.
97    pub async fn run_with_options(
98        &self,
99        state: State,
100        deps: Deps,
101        mut options: ExecutionOptions,
102    ) -> GraphResult<GraphRunResult<State, End>> {
103        let run_id = options.run_id.clone().unwrap_or_else(generate_run_id);
104        options.run_id = Some(run_id.clone());
105
106        if options.tracing {
107            let _span = span!(Level::INFO, "graph_run", run_id = %run_id).entered();
108            info!("Starting graph execution");
109        }
110
111        self.graph.run_with_options(state, deps, options).await
112    }
113
114    /// Resume a previous run.
115    pub async fn resume(
116        &self,
117        run_id: &str,
118        deps: Deps,
119    ) -> GraphResult<Option<GraphRunResult<State, End>>> {
120        let Some(ref persistence) = self.persistence else {
121            return Err(GraphError::persistence("No persistence configured"));
122        };
123
124        let Some((state, _step)) = persistence.load_state(run_id).await? else {
125            return Ok(None);
126        };
127
128        // Resume from the loaded state
129        let options = ExecutionOptions::new()
130            .max_steps(self.max_steps)
131            .tracing(self.instrumentation)
132            .run_id(run_id.to_string());
133        let result = self.graph.run_with_options(state, deps, options).await?;
134
135        // Save final result
136        if self.auto_save {
137            persistence.save_result(run_id, &result.result).await?;
138        }
139
140        Ok(Some(result))
141    }
142
143    /// Get a saved result.
144    pub async fn get_result(&self, run_id: &str) -> GraphResult<Option<End>> {
145        let Some(ref persistence) = self.persistence else {
146            return Err(GraphError::persistence("No persistence configured"));
147        };
148
149        Ok(persistence.load_result(run_id).await?)
150    }
151
152    /// List all saved runs.
153    pub async fn list_runs(&self) -> GraphResult<Vec<String>> {
154        let Some(ref persistence) = self.persistence else {
155            return Err(GraphError::persistence("No persistence configured"));
156        };
157
158        Ok(persistence.list_runs().await?)
159    }
160}
161
162/// Execution options.
163#[derive(Debug, Clone)]
164pub struct ExecutionOptions {
165    /// Maximum steps.
166    pub max_steps: u32,
167    /// Enable tracing.
168    pub tracing: bool,
169    /// Checkpoint interval.
170    pub checkpoint_interval: Option<u32>,
171    /// Custom run ID.
172    pub run_id: Option<String>,
173}
174
175impl Default for ExecutionOptions {
176    fn default() -> Self {
177        Self {
178            max_steps: 100,
179            tracing: true,
180            checkpoint_interval: None,
181            run_id: None,
182        }
183    }
184}
185
186impl ExecutionOptions {
187    /// Create new options.
188    pub fn new() -> Self {
189        Self::default()
190    }
191
192    /// Set max steps.
193    pub fn max_steps(mut self, max: u32) -> Self {
194        self.max_steps = max;
195        self
196    }
197
198    /// Enable or disable tracing.
199    pub fn tracing(mut self, enabled: bool) -> Self {
200        self.tracing = enabled;
201        self
202    }
203
204    /// Set checkpoint interval.
205    pub fn checkpoint_every(mut self, steps: u32) -> Self {
206        self.checkpoint_interval = Some(steps);
207        self
208    }
209
210    /// Set custom run ID.
211    pub fn run_id(mut self, id: impl Into<String>) -> Self {
212        self.run_id = Some(id.into());
213        self
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_execution_options() {
223        let opts = ExecutionOptions::new()
224            .max_steps(50)
225            .tracing(false)
226            .checkpoint_every(10)
227            .run_id("custom-run");
228
229        assert_eq!(opts.max_steps, 50);
230        assert!(!opts.tracing);
231        assert_eq!(opts.checkpoint_interval, Some(10));
232        assert_eq!(opts.run_id, Some("custom-run".to_string()));
233    }
234}