1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::{Map, Value};
7
8use crate::agent::Agent;
9use crate::{LanguageModel, Result};
10
11#[derive(Debug, Clone, Default)]
13pub struct WorkflowContext {
14 pub state: Map<String, Value>,
15 pub logs: Vec<String>,
16}
17
18impl WorkflowContext {
19 pub fn insert(&mut self, key: impl Into<String>, value: Value) {
20 self.state.insert(key.into(), value);
21 }
22
23 pub fn get(&self, key: &str) -> Option<&Value> {
24 self.state.get(key)
25 }
26}
27
28#[async_trait]
29pub trait WorkflowTask: Send + Sync {
30 async fn run(&self, ctx: &mut WorkflowContext) -> Result<Value>;
31}
32
33type TaskFuture<'a> = Pin<Box<dyn Future<Output = Result<Value>> + Send + 'a>>;
34
35pub struct FunctionTask<F>
37where
38 F: for<'a> Fn(&'a mut WorkflowContext) -> TaskFuture<'a> + Send + Sync,
39{
40 func: F,
41}
42
43impl<F> FunctionTask<F>
44where
45 F: for<'a> Fn(&'a mut WorkflowContext) -> TaskFuture<'a> + Send + Sync,
46{
47 pub fn new(func: F) -> Self {
48 Self { func }
49 }
50}
51
52#[async_trait]
53impl<F> WorkflowTask for FunctionTask<F>
54where
55 F: for<'a> Fn(&'a mut WorkflowContext) -> TaskFuture<'a> + Send + Sync,
56{
57 async fn run(&self, ctx: &mut WorkflowContext) -> Result<Value> {
58 (self.func)(ctx).await
59 }
60}
61
62pub struct AgentTask<M: LanguageModel> {
64 agent: Arc<tokio::sync::Mutex<Agent<M>>>,
65 prompt_key: Option<String>,
66 store_under: Option<String>,
67 fallback_prompt: String,
68}
69
70impl<M: LanguageModel> AgentTask<M> {
71 pub fn new(
72 agent: Arc<tokio::sync::Mutex<Agent<M>>>,
73 prompt_key: Option<String>,
74 store_under: Option<String>,
75 fallback_prompt: impl Into<String>,
76 ) -> Self {
77 Self {
78 agent,
79 prompt_key,
80 store_under,
81 fallback_prompt: fallback_prompt.into(),
82 }
83 }
84}
85
86#[async_trait]
87impl<M: LanguageModel> WorkflowTask for AgentTask<M> {
88 async fn run(&self, ctx: &mut WorkflowContext) -> Result<Value> {
89 let prompt = self
90 .prompt_key
91 .as_ref()
92 .and_then(|k| ctx.get(k))
93 .and_then(|v| v.as_str())
94 .unwrap_or(&self.fallback_prompt)
95 .to_string();
96 let mut agent = self.agent.lock().await;
97 let reply = agent.respond(prompt).await?;
98 let value = Value::String(reply.clone());
99 if let Some(key) = &self.store_under {
100 ctx.insert(key.clone(), value.clone());
101 }
102 Ok(value)
103 }
104}
105
106pub type Condition = Arc<dyn Fn(&WorkflowContext) -> bool + Send + Sync>;
107
108#[derive(Clone)]
109pub enum WorkflowNode {
110 Task(Arc<dyn WorkflowTask>),
111 Sequence(Vec<WorkflowNode>),
112 Parallel(Vec<WorkflowNode>),
113 Conditional {
114 condition: Condition,
115 then_branch: Box<WorkflowNode>,
116 else_branch: Option<Box<WorkflowNode>>,
117 },
118 Loop {
119 condition: Condition,
120 body: Box<WorkflowNode>,
121 max_iterations: usize,
122 },
123}
124
125impl WorkflowNode {
126 fn execute<'a>(
127 &'a self,
128 ctx: &'a mut WorkflowContext,
129 ) -> std::pin::Pin<Box<dyn Future<Output = Result<Value>> + Send + 'a>> {
130 Box::pin(async move {
131 match self {
132 WorkflowNode::Task(task) => task.run(ctx).await,
133 WorkflowNode::Sequence(steps) => {
134 let mut last = Value::Null;
135 for step in steps {
136 last = step.execute(ctx).await?;
137 }
138 Ok(last)
139 }
140 WorkflowNode::Parallel(steps) => {
141 let mut combined = Vec::new();
142 for step in steps {
143 combined.push(step.execute(ctx).await?);
144 }
145 Ok(Value::Array(combined))
146 }
147 WorkflowNode::Conditional {
148 condition,
149 then_branch,
150 else_branch,
151 } => {
152 if condition(ctx) {
153 then_branch.execute(ctx).await
154 } else if let Some(other) = else_branch {
155 other.execute(ctx).await
156 } else {
157 Ok(Value::Null)
158 }
159 }
160 WorkflowNode::Loop {
161 condition,
162 body,
163 max_iterations,
164 } => {
165 let mut last = Value::Null;
166 for _ in 0..*max_iterations {
167 if !(condition)(ctx) {
168 break;
169 }
170 last = body.execute(ctx).await?;
171 }
172 Ok(last)
173 }
174 }
175 })
176 }
177}
178
179#[derive(Clone)]
180pub struct Workflow {
181 pub name: String,
182 pub root: WorkflowNode,
183}
184
185impl Workflow {
186 pub fn new(name: impl Into<String>, root: WorkflowNode) -> Self {
187 Self {
188 name: name.into(),
189 root,
190 }
191 }
192
193 pub async fn run(&self, ctx: &mut WorkflowContext) -> Result<Value> {
194 self.root.execute(ctx).await
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use serde_json::json;
202
203 #[tokio::test]
204 async fn executes_sequential_and_parallel_nodes() {
205 let task_a = FunctionTask::new(|ctx: &mut WorkflowContext| {
206 Box::pin(async move {
207 ctx.insert("a", json!(1));
208 Ok(json!("done"))
209 })
210 });
211 let task_b = FunctionTask::new(|ctx: &mut WorkflowContext| {
212 Box::pin(async move {
213 let current = ctx.get("a").and_then(|v| v.as_i64()).unwrap_or(0);
214 ctx.insert("b", json!(current + 1));
215 Ok(json!("b"))
216 })
217 });
218
219 let flow = Workflow::new(
220 "demo",
221 WorkflowNode::Sequence(vec![
222 WorkflowNode::Task(Arc::new(task_a)),
223 WorkflowNode::Parallel(vec![
224 WorkflowNode::Task(Arc::new(task_b)),
225 WorkflowNode::Task(Arc::new(FunctionTask::new(|ctx: &mut WorkflowContext| {
226 Box::pin(async move {
227 ctx.insert("c", json!(true));
228 Ok(json!("c"))
229 })
230 }))),
231 ]),
232 ]),
233 );
234
235 let mut ctx = WorkflowContext::default();
236 let result = flow.run(&mut ctx).await.unwrap();
237 assert!(result.is_array());
238 assert_eq!(ctx.get("a").unwrap(), &json!(1));
239 assert_eq!(ctx.get("b").unwrap(), &json!(2));
240 assert_eq!(ctx.get("c").unwrap(), &json!(true));
241 }
242
243 #[tokio::test]
244 async fn executes_conditional_loop() {
245 let body = WorkflowNode::Task(Arc::new(FunctionTask::new(|ctx: &mut WorkflowContext| {
246 Box::pin(async move {
247 let next = ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0) + 1;
248 ctx.insert("count", json!(next));
249 Ok(json!(next))
250 })
251 })));
252
253 let condition: Condition = Arc::new(|ctx: &WorkflowContext| {
254 ctx.get("count").and_then(|v| v.as_i64()).unwrap_or(0) < 3
255 });
256
257 let flow = Workflow::new(
258 "looping",
259 WorkflowNode::Loop {
260 condition,
261 body: Box::new(body),
262 max_iterations: 10,
263 },
264 );
265
266 let mut ctx = WorkflowContext::default();
267 ctx.insert("count", json!(0));
268 flow.run(&mut ctx).await.unwrap();
269 assert_eq!(ctx.get("count").unwrap(), &json!(3));
270 }
271}