1use crate::error::GraphError;
4use crate::node::{BaseNode, NodeResult};
5use crate::state::{generate_run_id, GraphRunContext, GraphRunResult, GraphState};
6use std::marker::PhantomData;
7
8pub struct GraphIter<'a, State, Deps, End>
10where
11 State: GraphState,
12 Deps: Clone + Send + Sync + 'static,
13 End: Clone + Send + Sync + 'static,
14{
15 ctx: GraphRunContext<State, Deps>,
16 current: Option<Box<dyn BaseNode<State, Deps, End>>>,
17 finished: bool,
18 result: Option<End>,
19 history: Vec<String>,
20 _phantom: PhantomData<&'a ()>,
21}
22
23impl<'a, State, Deps, End> GraphIter<'a, State, Deps, End>
24where
25 State: GraphState,
26 Deps: Clone + Send + Sync + 'static,
27 End: Clone + Send + Sync + 'static,
28{
29 pub fn new<N: BaseNode<State, Deps, End> + Clone + 'static>(
31 start: N,
32 state: State,
33 deps: Deps,
34 ) -> Self {
35 let run_id = generate_run_id();
36 Self {
37 ctx: GraphRunContext::new(state, deps, run_id),
38 current: Some(Box::new(start)),
39 finished: false,
40 result: None,
41 history: Vec::new(),
42 _phantom: PhantomData,
43 }
44 }
45
46 pub async fn step(&mut self) -> Option<StepResult<State>> {
48 if self.finished {
49 return None;
50 }
51
52 let current = self.current.take()?;
53 self.ctx.increment_step();
54
55 let node_name = current.name().to_string();
56 self.history.push(node_name.clone());
57
58 match current.run(&mut self.ctx).await {
59 Ok(NodeResult::Next(next)) => {
60 self.current = Some(next);
61 Some(StepResult::Continue { node: node_name })
62 }
63 Ok(NodeResult::NextNamed(name)) => {
64 self.finished = true;
66 Some(StepResult::NamedTransition {
67 node: node_name,
68 next: name,
69 })
70 }
71 Ok(NodeResult::End(_end)) => {
72 self.finished = true;
73 Some(StepResult::Finished { node: node_name })
74 }
75 Err(e) => {
76 self.finished = true;
77 Some(StepResult::Error(e))
78 }
79 }
80 }
81
82 pub fn state(&self) -> &State {
84 &self.ctx.state
85 }
86
87 pub fn state_mut(&mut self) -> &mut State {
89 &mut self.ctx.state
90 }
91
92 pub fn step_count(&self) -> u32 {
94 self.ctx.step
95 }
96
97 pub fn is_finished(&self) -> bool {
99 self.finished
100 }
101
102 pub fn history(&self) -> &[String] {
104 &self.history
105 }
106
107 pub fn into_result(self) -> Option<GraphRunResult<State, End>> {
109 self.result.map(|r| {
110 GraphRunResult::new(r, self.ctx.state, self.ctx.step, self.ctx.run_id)
111 .with_history(self.history)
112 })
113 }
114}
115
116#[derive(Debug)]
118pub enum StepResult<State> {
119 Continue {
121 node: String,
123 },
124 NamedTransition {
126 node: String,
128 next: String,
130 },
131 Finished {
133 node: String,
135 },
136 Error(GraphError),
138 #[doc(hidden)]
140 _State(PhantomData<State>),
141}
142
143impl<State> StepResult<State> {
144 pub fn is_finished(&self) -> bool {
146 matches!(self, Self::Finished { .. })
147 }
148
149 pub fn is_error(&self) -> bool {
151 matches!(self, Self::Error(_))
152 }
153
154 pub fn node(&self) -> Option<&str> {
156 match self {
157 Self::Continue { node } => Some(node),
158 Self::NamedTransition { node, .. } => Some(node),
159 Self::Finished { node } => Some(node),
160 _ => None,
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[derive(Debug, Clone, Default)]
170 struct TestState {
171 _value: i32,
172 }
173
174 #[test]
175 fn test_step_result_is_finished() {
176 let result: StepResult<TestState> = StepResult::Finished {
177 node: "test".to_string(),
178 };
179 assert!(result.is_finished());
180 }
181
182 #[test]
183 fn test_step_result_is_error() {
184 let result: StepResult<TestState> = StepResult::Error(GraphError::NoEntryNode);
185 assert!(result.is_error());
186 }
187
188 #[test]
189 fn test_step_result_node() {
190 let result: StepResult<TestState> = StepResult::Continue {
191 node: "my_node".to_string(),
192 };
193 assert_eq!(result.node(), Some("my_node"));
194 }
195}