1use crate::error::{GraphError, GraphResult};
4use crate::graph::Graph;
5use crate::persistence::StatePersistence;
6use crate::state::{generate_run_id, GraphRunResult, GraphState};
7use std::sync::Arc;
8use tracing::{info, span, Level};
9
10pub struct GraphExecutor<State, Deps, End, P = NoPersistence>
12where
13 State: GraphState,
14{
15 _persistence_type: std::marker::PhantomData<P>,
16 graph: Arc<Graph<State, Deps, End>>,
17 persistence: Option<Arc<P>>,
18 auto_save: bool,
19 max_steps: u32,
20 instrumentation: bool,
21}
22
23#[derive(Debug, Clone, Copy)]
25pub struct NoPersistence;
26
27impl<State, Deps, End> GraphExecutor<State, Deps, End, NoPersistence>
28where
29 State: GraphState,
30 Deps: Clone + Send + Sync + 'static,
31 End: Clone + Send + Sync + 'static,
32{
33 pub fn new(graph: Graph<State, Deps, End>) -> Self {
35 Self {
36 _persistence_type: std::marker::PhantomData,
37 graph: Arc::new(graph),
38 persistence: None,
39 auto_save: false,
40 max_steps: 100,
41 instrumentation: true,
42 }
43 }
44}
45
46impl<State, Deps, End, P> GraphExecutor<State, Deps, End, P>
47where
48 State: GraphState,
49 Deps: Clone + Send + Sync + 'static,
50 End: Clone + Send + Sync + 'static,
51 P: StatePersistence<State, End> + 'static,
52{
53 pub fn with_persistence(graph: Graph<State, Deps, End>, persistence: P) -> Self {
55 Self {
56 _persistence_type: std::marker::PhantomData,
57 graph: Arc::new(graph),
58 persistence: Some(Arc::new(persistence)),
59 auto_save: true,
60 max_steps: 100,
61 instrumentation: true,
62 }
63 }
64
65 pub fn auto_save(mut self, enabled: bool) -> Self {
67 self.auto_save = enabled;
68 self
69 }
70
71 pub fn max_steps(mut self, max: u32) -> Self {
73 self.max_steps = max;
74 self
75 }
76
77 pub fn without_instrumentation(mut self) -> Self {
79 self.instrumentation = false;
80 self
81 }
82
83 pub fn graph(&self) -> &Graph<State, Deps, End> {
85 &self.graph
86 }
87
88 pub async fn run(&self, state: State, deps: Deps) -> GraphResult<GraphRunResult<State, End>> {
90 let options = ExecutionOptions::new()
91 .max_steps(self.max_steps)
92 .tracing(self.instrumentation);
93 self.run_with_options(state, deps, options).await
94 }
95
96 pub async fn run_with_options(
98 &self,
99 state: State,
100 deps: Deps,
101 mut options: ExecutionOptions,
102 ) -> GraphResult<GraphRunResult<State, End>> {
103 let run_id = options.run_id.clone().unwrap_or_else(generate_run_id);
104 options.run_id = Some(run_id.clone());
105
106 if options.tracing {
107 let _span = span!(Level::INFO, "graph_run", run_id = %run_id).entered();
108 info!("Starting graph execution");
109 }
110
111 self.graph.run_with_options(state, deps, options).await
112 }
113
114 pub async fn resume(
116 &self,
117 run_id: &str,
118 deps: Deps,
119 ) -> GraphResult<Option<GraphRunResult<State, End>>> {
120 let Some(ref persistence) = self.persistence else {
121 return Err(GraphError::persistence("No persistence configured"));
122 };
123
124 let Some((state, _step)) = persistence.load_state(run_id).await? else {
125 return Ok(None);
126 };
127
128 let options = ExecutionOptions::new()
130 .max_steps(self.max_steps)
131 .tracing(self.instrumentation)
132 .run_id(run_id.to_string());
133 let result = self.graph.run_with_options(state, deps, options).await?;
134
135 if self.auto_save {
137 persistence.save_result(run_id, &result.result).await?;
138 }
139
140 Ok(Some(result))
141 }
142
143 pub async fn get_result(&self, run_id: &str) -> GraphResult<Option<End>> {
145 let Some(ref persistence) = self.persistence else {
146 return Err(GraphError::persistence("No persistence configured"));
147 };
148
149 Ok(persistence.load_result(run_id).await?)
150 }
151
152 pub async fn list_runs(&self) -> GraphResult<Vec<String>> {
154 let Some(ref persistence) = self.persistence else {
155 return Err(GraphError::persistence("No persistence configured"));
156 };
157
158 Ok(persistence.list_runs().await?)
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct ExecutionOptions {
165 pub max_steps: u32,
167 pub tracing: bool,
169 pub checkpoint_interval: Option<u32>,
171 pub run_id: Option<String>,
173}
174
175impl Default for ExecutionOptions {
176 fn default() -> Self {
177 Self {
178 max_steps: 100,
179 tracing: true,
180 checkpoint_interval: None,
181 run_id: None,
182 }
183 }
184}
185
186impl ExecutionOptions {
187 pub fn new() -> Self {
189 Self::default()
190 }
191
192 pub fn max_steps(mut self, max: u32) -> Self {
194 self.max_steps = max;
195 self
196 }
197
198 pub fn tracing(mut self, enabled: bool) -> Self {
200 self.tracing = enabled;
201 self
202 }
203
204 pub fn checkpoint_every(mut self, steps: u32) -> Self {
206 self.checkpoint_interval = Some(steps);
207 self
208 }
209
210 pub fn run_id(mut self, id: impl Into<String>) -> Self {
212 self.run_id = Some(id.into());
213 self
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_execution_options() {
223 let opts = ExecutionOptions::new()
224 .max_steps(50)
225 .tracing(false)
226 .checkpoint_every(10)
227 .run_id("custom-run");
228
229 assert_eq!(opts.max_steps, 50);
230 assert!(!opts.tracing);
231 assert_eq!(opts.checkpoint_interval, Some(10));
232 assert_eq!(opts.run_id, Some("custom-run".to_string()));
233 }
234}