1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use serde_json::Value;
9use synaptic_core::{ChatModel, ChatRequest, Message, SynapticError, Tool, ToolDefinition};
10use synaptic_macros::traceable;
11use synaptic_middleware::{AgentMiddleware, BaseChatModelCaller, MiddlewareChain, ModelRequest};
12use synaptic_store::Store;
13use synaptic_tools::SerialToolExecutor;
14
15use crate::builder::StateGraph;
16use crate::checkpoint::Checkpointer;
17use crate::command::NodeOutput;
18use crate::compiled::CompiledGraph;
19use crate::node::Node;
20use crate::state::MessageState;
21use crate::tool_node::ToolNode;
22use crate::END;
23
24pub type PreModelHook = Arc<
30 dyn Fn(
31 &mut MessageState,
32 ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
33 + Send
34 + Sync,
35>;
36
37pub type PostModelHook = Arc<
39 dyn Fn(
40 &mut MessageState,
41 ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
42 + Send
43 + Sync,
44>;
45
46struct ChatModelNode {
51 model: Arc<dyn ChatModel>,
52 tool_defs: Vec<ToolDefinition>,
53 system_prompt: Option<String>,
54 middleware: Arc<MiddlewareChain>,
55 is_first_call: AtomicBool,
56 pre_model_hook: Option<PreModelHook>,
57 post_model_hook: Option<PostModelHook>,
58 response_format: Option<Value>,
61}
62
63#[async_trait]
64impl Node<MessageState> for ChatModelNode {
65 async fn process(
66 &self,
67 mut state: MessageState,
68 ) -> Result<NodeOutput<MessageState>, SynapticError> {
69 if self.is_first_call.swap(false, Ordering::SeqCst) {
71 self.middleware
72 .run_before_agent(&mut state.messages)
73 .await?;
74 }
75
76 if let Some(ref hook) = self.pre_model_hook {
78 hook(&mut state).await?;
79 }
80
81 let request = ModelRequest {
82 messages: state.messages.clone(),
83 tools: self.tool_defs.clone(),
84 tool_choice: None,
85 system_prompt: self.system_prompt.clone(),
86 };
87
88 let base_caller = BaseChatModelCaller::new(self.model.clone());
89 let response = self.middleware.call_model(request, &base_caller).await?;
90
91 state.messages.push(response.message.clone());
92
93 if let Some(ref hook) = self.post_model_hook {
95 hook(&mut state).await?;
96 }
97
98 if response.message.tool_calls().is_empty() {
100 if let Some(ref schema) = self.response_format {
102 let instruction = format!(
103 "You MUST respond with valid JSON matching this schema:\n{}\n\n\
104 Do not include any text outside the JSON object. \
105 Do not use markdown code blocks.",
106 schema
107 );
108 let mut structured_messages = vec![Message::system(instruction)];
109 structured_messages.extend(state.messages.clone());
110
111 let structured_request = ChatRequest::new(structured_messages);
112 let structured_response = self.model.chat(structured_request).await?;
113 state.messages.pop();
115 state.messages.push(structured_response.message);
116 }
117
118 self.middleware.run_after_agent(&mut state.messages).await?;
119 }
120
121 Ok(state.into())
122 }
123}
124
125#[derive(Default)]
131pub struct ReactAgentOptions {
132 pub checkpointer: Option<Arc<dyn Checkpointer>>,
134 pub interrupt_before: Vec<String>,
136 pub interrupt_after: Vec<String>,
138 pub system_prompt: Option<String>,
140}
141
142pub fn create_react_agent(
144 model: Arc<dyn ChatModel>,
145 tools: Vec<Arc<dyn Tool>>,
146) -> Result<CompiledGraph<MessageState>, SynapticError> {
147 create_react_agent_with_options(model, tools, ReactAgentOptions::default())
148}
149
150pub fn create_react_agent_with_options(
152 model: Arc<dyn ChatModel>,
153 tools: Vec<Arc<dyn Tool>>,
154 options: ReactAgentOptions,
155) -> Result<CompiledGraph<MessageState>, SynapticError> {
156 create_agent(
157 model,
158 tools,
159 AgentOptions {
160 checkpointer: options.checkpointer,
161 interrupt_before: options.interrupt_before,
162 interrupt_after: options.interrupt_after,
163 system_prompt: options.system_prompt,
164 ..Default::default()
165 },
166 )
167}
168
169#[derive(Default)]
175pub struct AgentOptions {
176 pub checkpointer: Option<Arc<dyn Checkpointer>>,
177 pub interrupt_before: Vec<String>,
178 pub interrupt_after: Vec<String>,
179 pub system_prompt: Option<String>,
180 pub middleware: Vec<Arc<dyn AgentMiddleware>>,
181 pub store: Option<Arc<dyn Store>>,
182 pub name: Option<String>,
183 pub pre_model_hook: Option<PreModelHook>,
184 pub post_model_hook: Option<PostModelHook>,
185 pub response_format: Option<Value>,
187 pub parallel_tools: bool,
189}
190
191#[traceable(skip = "model,tools,options")]
193pub fn create_agent(
194 model: Arc<dyn ChatModel>,
195 tools: Vec<Arc<dyn Tool>>,
196 options: AgentOptions,
197) -> Result<CompiledGraph<MessageState>, SynapticError> {
198 let tool_defs: Vec<ToolDefinition> = tools.iter().map(|t| t.as_tool_definition()).collect();
199
200 let registry = synaptic_tools::ToolRegistry::new();
201 for tool in tools {
202 registry.register(tool)?;
203 }
204 let executor = SerialToolExecutor::new(registry);
205
206 let middleware_chain = Arc::new(MiddlewareChain::new(options.middleware));
207
208 let agent_node = ChatModelNode {
209 model,
210 tool_defs,
211 system_prompt: options.system_prompt,
212 middleware: middleware_chain.clone(),
213 is_first_call: AtomicBool::new(true),
214 pre_model_hook: options.pre_model_hook,
215 post_model_hook: options.post_model_hook,
216 response_format: options.response_format,
217 };
218
219 let mut tool_node =
220 ToolNode::with_middleware(executor, middleware_chain).with_parallel(options.parallel_tools);
221 if let Some(ref store) = options.store {
222 tool_node = tool_node.with_store(store.clone());
223 }
224
225 let mut builder = StateGraph::new()
226 .add_node("agent", agent_node)
227 .add_node("tools", tool_node)
228 .set_entry_point("agent")
229 .add_conditional_edges_with_path_map(
230 "agent",
231 |state: &MessageState| {
232 if let Some(last) = state.last_message() {
233 if !last.tool_calls().is_empty() {
234 return "tools".to_string();
235 }
236 }
237 END.to_string()
238 },
239 HashMap::from([
240 ("tools".to_string(), "tools".to_string()),
241 (END.to_string(), END.to_string()),
242 ]),
243 )
244 .add_edge("tools", "agent");
245
246 if !options.interrupt_before.is_empty() {
247 builder = builder.interrupt_before(options.interrupt_before);
248 }
249 if !options.interrupt_after.is_empty() {
250 builder = builder.interrupt_after(options.interrupt_after);
251 }
252
253 let mut graph = builder.compile()?;
254
255 let checkpointer: Option<Arc<dyn Checkpointer>> = match (&options.store, options.checkpointer) {
257 (_, Some(ckpt)) => Some(ckpt),
258 (Some(store), None) => Some(Arc::new(crate::StoreCheckpointer::new(store.clone()))),
259 (None, None) => None,
260 };
261
262 if let Some(checkpointer) = checkpointer {
263 graph = graph.with_checkpointer(checkpointer);
264 }
265
266 Ok(graph)
267}
268
269struct HandoffTool {
274 target_agent: String,
275 tool_description: String,
276}
277
278#[async_trait]
279impl Tool for HandoffTool {
280 fn name(&self) -> &'static str {
281 Box::leak(format!("transfer_to_{}", self.target_agent).into_boxed_str())
282 }
283
284 fn description(&self) -> &'static str {
285 Box::leak(self.tool_description.clone().into_boxed_str())
286 }
287
288 async fn call(&self, _args: Value) -> Result<Value, SynapticError> {
289 Ok(Value::String(format!(
290 "Transferring to agent '{}'.",
291 self.target_agent
292 )))
293 }
294}
295
296pub fn create_handoff_tool(agent_name: &str, description: &str) -> Arc<dyn Tool> {
298 Arc::new(HandoffTool {
299 target_agent: agent_name.to_string(),
300 tool_description: description.to_string(),
301 })
302}
303
304#[derive(Default)]
310pub struct SupervisorOptions {
311 pub checkpointer: Option<Arc<dyn Checkpointer>>,
312 pub store: Option<Arc<dyn Store>>,
313 pub system_prompt: Option<String>,
314}
315
316struct SubAgentNode {
318 graph: CompiledGraph<MessageState>,
319}
320
321#[async_trait]
322impl Node<MessageState> for SubAgentNode {
323 async fn process(
324 &self,
325 state: MessageState,
326 ) -> Result<NodeOutput<MessageState>, SynapticError> {
327 let result = self.graph.invoke(state).await?;
328 Ok(result.into_state().into())
329 }
330}
331
332#[traceable(skip = "model,agents,options")]
334pub fn create_supervisor(
335 model: Arc<dyn ChatModel>,
336 agents: Vec<(String, CompiledGraph<MessageState>)>,
337 options: SupervisorOptions,
338) -> Result<CompiledGraph<MessageState>, SynapticError> {
339 let agent_names: Vec<String> = agents.iter().map(|(name, _)| name.clone()).collect();
340
341 let handoff_tools: Vec<Arc<dyn Tool>> = agent_names
343 .iter()
344 .map(|name| {
345 create_handoff_tool(
346 name,
347 &format!("Transfer the conversation to the '{name}' agent."),
348 )
349 })
350 .collect();
351
352 let handoff_tool_defs: Vec<ToolDefinition> = handoff_tools
353 .iter()
354 .map(|t| ToolDefinition {
355 name: t.name().to_string(),
356 description: t.description().to_string(),
357 parameters: serde_json::json!({}),
358 extras: None,
359 })
360 .collect();
361
362 let default_prompt = format!(
363 "You are a supervisor managing these agents: {}. \
364 Use the transfer tools to delegate tasks to the appropriate agent. \
365 When the task is complete, respond directly to the user.",
366 agent_names.join(", ")
367 );
368 let system_prompt = options.system_prompt.unwrap_or(default_prompt);
369
370 let supervisor_node = ChatModelNode {
371 model,
372 tool_defs: handoff_tool_defs.clone(),
373 system_prompt: Some(system_prompt),
374 middleware: Arc::new(MiddlewareChain::new(vec![])),
375 is_first_call: AtomicBool::new(false),
376 pre_model_hook: None,
377 post_model_hook: None,
378 response_format: None,
379 };
380
381 let mut builder = StateGraph::new()
382 .add_node("supervisor", supervisor_node)
383 .set_entry_point("supervisor");
384
385 for (name, graph) in agents {
386 builder = builder
387 .add_node(&name, SubAgentNode { graph })
388 .add_edge(&name, "supervisor");
389 }
390
391 let agent_names_for_router = agent_names.clone();
392 builder = builder.add_conditional_edges("supervisor", move |state: &MessageState| {
393 if let Some(last) = state.last_message() {
394 for tc in last.tool_calls() {
395 for agent_name in &agent_names_for_router {
396 if tc.name == format!("transfer_to_{agent_name}") {
397 return agent_name.clone();
398 }
399 }
400 }
401 }
402 END.to_string()
403 });
404
405 let mut graph = builder.compile()?;
406
407 if let Some(checkpointer) = options.checkpointer {
408 graph = graph.with_checkpointer(checkpointer);
409 }
410
411 Ok(graph)
412}
413
414#[derive(Default)]
420pub struct SwarmOptions {
421 pub checkpointer: Option<Arc<dyn Checkpointer>>,
422 pub store: Option<Arc<dyn Store>>,
423}
424
425struct SwarmAgentNode {
427 model: Arc<dyn ChatModel>,
428 tool_defs: Vec<ToolDefinition>,
429 system_prompt: Option<String>,
430}
431
432#[async_trait]
433impl Node<MessageState> for SwarmAgentNode {
434 async fn process(
435 &self,
436 mut state: MessageState,
437 ) -> Result<NodeOutput<MessageState>, SynapticError> {
438 let mut messages = Vec::new();
439 if let Some(ref prompt) = self.system_prompt {
440 messages.push(Message::system(prompt));
441 }
442 messages.extend(state.messages.clone());
443
444 let request = ChatRequest::new(messages).with_tools(self.tool_defs.clone());
445 let response = self.model.chat(request).await?;
446 state.messages.push(response.message);
447 Ok(state.into())
448 }
449}
450
451struct SwarmToolNode {
453 executor: SerialToolExecutor,
454 handoff_tool_names: Vec<String>,
455}
456
457#[async_trait]
458impl Node<MessageState> for SwarmToolNode {
459 async fn process(
460 &self,
461 mut state: MessageState,
462 ) -> Result<NodeOutput<MessageState>, SynapticError> {
463 let last = state
464 .last_message()
465 .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
466
467 let tool_calls = last.tool_calls().to_vec();
468 for call in &tool_calls {
469 if self.handoff_tool_names.contains(&call.name) {
470 state.messages.push(Message::tool(
471 "Transferring to agent.".to_string(),
472 &call.id,
473 ));
474 } else {
475 let result = self
476 .executor
477 .execute(&call.name, call.arguments.clone())
478 .await?;
479 state
480 .messages
481 .push(Message::tool(result.to_string(), &call.id));
482 }
483 }
484
485 Ok(state.into())
486 }
487}
488
489pub struct SwarmAgent {
491 pub name: String,
492 pub model: Arc<dyn ChatModel>,
493 pub tools: Vec<Arc<dyn Tool>>,
494 pub system_prompt: Option<String>,
495}
496
497#[traceable(skip = "agents,options")]
499pub fn create_swarm(
500 agents: Vec<SwarmAgent>,
501 options: SwarmOptions,
502) -> Result<CompiledGraph<MessageState>, SynapticError> {
503 if agents.is_empty() {
504 return Err(SynapticError::Graph(
505 "swarm requires at least one agent".to_string(),
506 ));
507 }
508
509 let agent_names: Vec<String> = agents.iter().map(|a| a.name.clone()).collect();
510 let entry_agent = agent_names[0].clone();
511
512 let all_handoff_tools: HashMap<String, Arc<dyn Tool>> = agent_names
513 .iter()
514 .map(|name| {
515 (
516 name.clone(),
517 create_handoff_tool(
518 name,
519 &format!("Transfer the conversation to the '{name}' agent."),
520 ),
521 )
522 })
523 .collect();
524
525 let handoff_tool_names: Vec<String> = all_handoff_tools
526 .values()
527 .map(|t| t.name().to_string())
528 .collect();
529
530 let mut builder = StateGraph::new();
531
532 let global_registry = synaptic_tools::ToolRegistry::new();
533
534 for agent in agents {
535 let SwarmAgent {
536 name,
537 model,
538 tools,
539 system_prompt,
540 } = agent;
541
542 let mut tool_defs: Vec<ToolDefinition> = tools
543 .iter()
544 .map(|t| ToolDefinition {
545 name: t.name().to_string(),
546 description: t.description().to_string(),
547 parameters: serde_json::json!({}),
548 extras: None,
549 })
550 .collect();
551
552 for tool in &tools {
553 let _ = global_registry.register(tool.clone());
554 }
555
556 for other_name in &agent_names {
557 if other_name != &name {
558 if let Some(ht) = all_handoff_tools.get(other_name) {
559 tool_defs.push(ToolDefinition {
560 name: ht.name().to_string(),
561 description: ht.description().to_string(),
562 parameters: serde_json::json!({}),
563 extras: None,
564 });
565 }
566 }
567 }
568
569 let agent_node = SwarmAgentNode {
570 model,
571 tool_defs,
572 system_prompt,
573 };
574
575 builder = builder.add_node(&name, agent_node);
576 }
577
578 let executor = SerialToolExecutor::new(global_registry);
579 let swarm_tool_node = SwarmToolNode {
580 executor,
581 handoff_tool_names: handoff_tool_names.clone(),
582 };
583 builder = builder.add_node("tools", swarm_tool_node);
584
585 builder = builder.set_entry_point(&entry_agent);
586
587 for agent_name in &agent_names {
588 builder = builder.add_conditional_edges(agent_name, |state: &MessageState| {
589 if let Some(last) = state.last_message() {
590 if !last.tool_calls().is_empty() {
591 return "tools".to_string();
592 }
593 }
594 END.to_string()
595 });
596 }
597
598 let all_agent_names = agent_names.clone();
599 builder = builder.add_conditional_edges("tools", move |state: &MessageState| {
600 for msg in state.messages.iter().rev() {
601 if msg.is_ai() && !msg.tool_calls().is_empty() {
602 for tc in msg.tool_calls() {
603 for agent_name in &all_agent_names {
604 if tc.name == format!("transfer_to_{agent_name}") {
605 return agent_name.clone();
606 }
607 }
608 }
609 return all_agent_names[0].clone();
610 }
611 }
612 all_agent_names[0].clone()
613 });
614
615 let mut graph = builder.compile()?;
616
617 if let Some(checkpointer) = options.checkpointer {
618 graph = graph.with_checkpointer(checkpointer);
619 }
620
621 Ok(graph)
622}