1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use serde_json::Value;
9use synaptic_core::{ChatModel, ChatRequest, Message, SynapticError, Tool, ToolDefinition};
10use synaptic_macros::traceable;
11use synaptic_middleware::{AgentMiddleware, BaseChatModelCaller, MiddlewareChain, ModelRequest};
12use synaptic_store::Store;
13use synaptic_tools::SerialToolExecutor;
14
15use crate::builder::StateGraph;
16use crate::checkpoint::Checkpointer;
17use crate::command::NodeOutput;
18use crate::compiled::CompiledGraph;
19use crate::node::Node;
20use crate::state::MessageState;
21use crate::tool_node::ToolNode;
22use crate::END;
23
24pub type PreModelHook = Arc<
30 dyn Fn(
31 &mut MessageState,
32 ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
33 + Send
34 + Sync,
35>;
36
37pub type PostModelHook = Arc<
39 dyn Fn(
40 &mut MessageState,
41 ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
42 + Send
43 + Sync,
44>;
45
46struct ChatModelNode {
51 model: Arc<dyn ChatModel>,
52 tool_defs: Vec<ToolDefinition>,
53 system_prompt: Option<String>,
54 middleware: Arc<MiddlewareChain>,
55 is_first_call: AtomicBool,
56 pre_model_hook: Option<PreModelHook>,
57 post_model_hook: Option<PostModelHook>,
58 response_format: Option<Value>,
61}
62
63#[async_trait]
64impl Node<MessageState> for ChatModelNode {
65 async fn process(
66 &self,
67 mut state: MessageState,
68 ) -> Result<NodeOutput<MessageState>, SynapticError> {
69 if self.is_first_call.swap(false, Ordering::SeqCst) {
71 self.middleware
72 .run_before_agent(&mut state.messages)
73 .await?;
74 }
75
76 if let Some(ref hook) = self.pre_model_hook {
78 hook(&mut state).await?;
79 }
80
81 let request = ModelRequest {
82 messages: state.messages.clone(),
83 tools: self.tool_defs.clone(),
84 tool_choice: None,
85 system_prompt: self.system_prompt.clone(),
86 };
87
88 let base_caller = BaseChatModelCaller::new(self.model.clone());
89 let response = self.middleware.call_model(request, &base_caller).await?;
90
91 state.messages.push(response.message.clone());
92
93 if let Some(ref hook) = self.post_model_hook {
95 hook(&mut state).await?;
96 }
97
98 if response.message.tool_calls().is_empty() {
100 if let Some(ref schema) = self.response_format {
102 let instruction = format!(
103 "You MUST respond with valid JSON matching this schema:\n{}\n\n\
104 Do not include any text outside the JSON object. \
105 Do not use markdown code blocks.",
106 schema
107 );
108 let mut structured_messages = vec![Message::system(instruction)];
109 structured_messages.extend(state.messages.clone());
110
111 let structured_request = ChatRequest::new(structured_messages);
112 let structured_response = self.model.chat(structured_request).await?;
113 state.messages.pop();
115 state.messages.push(structured_response.message);
116 }
117
118 self.middleware.run_after_agent(&mut state.messages).await?;
119 }
120
121 Ok(state.into())
122 }
123}
124
125#[derive(Default)]
131pub struct ReactAgentOptions {
132 pub checkpointer: Option<Arc<dyn Checkpointer>>,
134 pub interrupt_before: Vec<String>,
136 pub interrupt_after: Vec<String>,
138 pub system_prompt: Option<String>,
140}
141
142pub fn create_react_agent(
144 model: Arc<dyn ChatModel>,
145 tools: Vec<Arc<dyn Tool>>,
146) -> Result<CompiledGraph<MessageState>, SynapticError> {
147 create_react_agent_with_options(model, tools, ReactAgentOptions::default())
148}
149
150pub fn create_react_agent_with_options(
152 model: Arc<dyn ChatModel>,
153 tools: Vec<Arc<dyn Tool>>,
154 options: ReactAgentOptions,
155) -> Result<CompiledGraph<MessageState>, SynapticError> {
156 create_agent(
157 model,
158 tools,
159 AgentOptions {
160 checkpointer: options.checkpointer,
161 interrupt_before: options.interrupt_before,
162 interrupt_after: options.interrupt_after,
163 system_prompt: options.system_prompt,
164 ..Default::default()
165 },
166 )
167}
168
169#[derive(Default)]
175pub struct AgentOptions {
176 pub checkpointer: Option<Arc<dyn Checkpointer>>,
177 pub interrupt_before: Vec<String>,
178 pub interrupt_after: Vec<String>,
179 pub system_prompt: Option<String>,
180 pub middleware: Vec<Arc<dyn AgentMiddleware>>,
181 pub store: Option<Arc<dyn Store>>,
182 pub name: Option<String>,
183 pub pre_model_hook: Option<PreModelHook>,
184 pub post_model_hook: Option<PostModelHook>,
185 pub response_format: Option<Value>,
187}
188
189#[traceable(skip = "model,tools,options")]
191pub fn create_agent(
192 model: Arc<dyn ChatModel>,
193 tools: Vec<Arc<dyn Tool>>,
194 options: AgentOptions,
195) -> Result<CompiledGraph<MessageState>, SynapticError> {
196 let tool_defs: Vec<ToolDefinition> = tools.iter().map(|t| t.as_tool_definition()).collect();
197
198 let registry = synaptic_tools::ToolRegistry::new();
199 for tool in tools {
200 registry.register(tool)?;
201 }
202 let executor = SerialToolExecutor::new(registry);
203
204 let middleware_chain = Arc::new(MiddlewareChain::new(options.middleware));
205
206 let agent_node = ChatModelNode {
207 model,
208 tool_defs,
209 system_prompt: options.system_prompt,
210 middleware: middleware_chain.clone(),
211 is_first_call: AtomicBool::new(true),
212 pre_model_hook: options.pre_model_hook,
213 post_model_hook: options.post_model_hook,
214 response_format: options.response_format,
215 };
216
217 let mut tool_node = ToolNode::with_middleware(executor, middleware_chain);
218 if let Some(ref store) = options.store {
219 tool_node = tool_node.with_store(store.clone());
220 }
221
222 let mut builder = StateGraph::new()
223 .add_node("agent", agent_node)
224 .add_node("tools", tool_node)
225 .set_entry_point("agent")
226 .add_conditional_edges_with_path_map(
227 "agent",
228 |state: &MessageState| {
229 if let Some(last) = state.last_message() {
230 if !last.tool_calls().is_empty() {
231 return "tools".to_string();
232 }
233 }
234 END.to_string()
235 },
236 HashMap::from([
237 ("tools".to_string(), "tools".to_string()),
238 (END.to_string(), END.to_string()),
239 ]),
240 )
241 .add_edge("tools", "agent");
242
243 if !options.interrupt_before.is_empty() {
244 builder = builder.interrupt_before(options.interrupt_before);
245 }
246 if !options.interrupt_after.is_empty() {
247 builder = builder.interrupt_after(options.interrupt_after);
248 }
249
250 let mut graph = builder.compile()?;
251
252 if let Some(checkpointer) = options.checkpointer {
253 graph = graph.with_checkpointer(checkpointer);
254 }
255
256 Ok(graph)
257}
258
259struct HandoffTool {
264 target_agent: String,
265 tool_description: String,
266}
267
268#[async_trait]
269impl Tool for HandoffTool {
270 fn name(&self) -> &'static str {
271 Box::leak(format!("transfer_to_{}", self.target_agent).into_boxed_str())
272 }
273
274 fn description(&self) -> &'static str {
275 Box::leak(self.tool_description.clone().into_boxed_str())
276 }
277
278 async fn call(&self, _args: Value) -> Result<Value, SynapticError> {
279 Ok(Value::String(format!(
280 "Transferring to agent '{}'.",
281 self.target_agent
282 )))
283 }
284}
285
286pub fn create_handoff_tool(agent_name: &str, description: &str) -> Arc<dyn Tool> {
288 Arc::new(HandoffTool {
289 target_agent: agent_name.to_string(),
290 tool_description: description.to_string(),
291 })
292}
293
294#[derive(Default)]
300pub struct SupervisorOptions {
301 pub checkpointer: Option<Arc<dyn Checkpointer>>,
302 pub store: Option<Arc<dyn Store>>,
303 pub system_prompt: Option<String>,
304}
305
306struct SubAgentNode {
308 graph: CompiledGraph<MessageState>,
309}
310
311#[async_trait]
312impl Node<MessageState> for SubAgentNode {
313 async fn process(
314 &self,
315 state: MessageState,
316 ) -> Result<NodeOutput<MessageState>, SynapticError> {
317 let result = self.graph.invoke(state).await?;
318 Ok(result.into_state().into())
319 }
320}
321
322#[traceable(skip = "model,agents,options")]
324pub fn create_supervisor(
325 model: Arc<dyn ChatModel>,
326 agents: Vec<(String, CompiledGraph<MessageState>)>,
327 options: SupervisorOptions,
328) -> Result<CompiledGraph<MessageState>, SynapticError> {
329 let agent_names: Vec<String> = agents.iter().map(|(name, _)| name.clone()).collect();
330
331 let handoff_tools: Vec<Arc<dyn Tool>> = agent_names
333 .iter()
334 .map(|name| {
335 create_handoff_tool(
336 name,
337 &format!("Transfer the conversation to the '{name}' agent."),
338 )
339 })
340 .collect();
341
342 let handoff_tool_defs: Vec<ToolDefinition> = handoff_tools
343 .iter()
344 .map(|t| ToolDefinition {
345 name: t.name().to_string(),
346 description: t.description().to_string(),
347 parameters: serde_json::json!({}),
348 extras: None,
349 })
350 .collect();
351
352 let default_prompt = format!(
353 "You are a supervisor managing these agents: {}. \
354 Use the transfer tools to delegate tasks to the appropriate agent. \
355 When the task is complete, respond directly to the user.",
356 agent_names.join(", ")
357 );
358 let system_prompt = options.system_prompt.unwrap_or(default_prompt);
359
360 let supervisor_node = ChatModelNode {
361 model,
362 tool_defs: handoff_tool_defs.clone(),
363 system_prompt: Some(system_prompt),
364 middleware: Arc::new(MiddlewareChain::new(vec![])),
365 is_first_call: AtomicBool::new(false),
366 pre_model_hook: None,
367 post_model_hook: None,
368 response_format: None,
369 };
370
371 let mut builder = StateGraph::new()
372 .add_node("supervisor", supervisor_node)
373 .set_entry_point("supervisor");
374
375 for (name, graph) in agents {
376 builder = builder
377 .add_node(&name, SubAgentNode { graph })
378 .add_edge(&name, "supervisor");
379 }
380
381 let agent_names_for_router = agent_names.clone();
382 builder = builder.add_conditional_edges("supervisor", move |state: &MessageState| {
383 if let Some(last) = state.last_message() {
384 for tc in last.tool_calls() {
385 for agent_name in &agent_names_for_router {
386 if tc.name == format!("transfer_to_{agent_name}") {
387 return agent_name.clone();
388 }
389 }
390 }
391 }
392 END.to_string()
393 });
394
395 let mut graph = builder.compile()?;
396
397 if let Some(checkpointer) = options.checkpointer {
398 graph = graph.with_checkpointer(checkpointer);
399 }
400
401 Ok(graph)
402}
403
404#[derive(Default)]
410pub struct SwarmOptions {
411 pub checkpointer: Option<Arc<dyn Checkpointer>>,
412 pub store: Option<Arc<dyn Store>>,
413}
414
415struct SwarmAgentNode {
417 model: Arc<dyn ChatModel>,
418 tool_defs: Vec<ToolDefinition>,
419 system_prompt: Option<String>,
420}
421
422#[async_trait]
423impl Node<MessageState> for SwarmAgentNode {
424 async fn process(
425 &self,
426 mut state: MessageState,
427 ) -> Result<NodeOutput<MessageState>, SynapticError> {
428 let mut messages = Vec::new();
429 if let Some(ref prompt) = self.system_prompt {
430 messages.push(Message::system(prompt));
431 }
432 messages.extend(state.messages.clone());
433
434 let request = ChatRequest::new(messages).with_tools(self.tool_defs.clone());
435 let response = self.model.chat(request).await?;
436 state.messages.push(response.message);
437 Ok(state.into())
438 }
439}
440
441struct SwarmToolNode {
443 executor: SerialToolExecutor,
444 handoff_tool_names: Vec<String>,
445}
446
447#[async_trait]
448impl Node<MessageState> for SwarmToolNode {
449 async fn process(
450 &self,
451 mut state: MessageState,
452 ) -> Result<NodeOutput<MessageState>, SynapticError> {
453 let last = state
454 .last_message()
455 .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
456
457 let tool_calls = last.tool_calls().to_vec();
458 for call in &tool_calls {
459 if self.handoff_tool_names.contains(&call.name) {
460 state.messages.push(Message::tool(
461 "Transferring to agent.".to_string(),
462 &call.id,
463 ));
464 } else {
465 let result = self
466 .executor
467 .execute(&call.name, call.arguments.clone())
468 .await?;
469 state
470 .messages
471 .push(Message::tool(result.to_string(), &call.id));
472 }
473 }
474
475 Ok(state.into())
476 }
477}
478
479pub struct SwarmAgent {
481 pub name: String,
482 pub model: Arc<dyn ChatModel>,
483 pub tools: Vec<Arc<dyn Tool>>,
484 pub system_prompt: Option<String>,
485}
486
487#[traceable(skip = "agents,options")]
489pub fn create_swarm(
490 agents: Vec<SwarmAgent>,
491 options: SwarmOptions,
492) -> Result<CompiledGraph<MessageState>, SynapticError> {
493 if agents.is_empty() {
494 return Err(SynapticError::Graph(
495 "swarm requires at least one agent".to_string(),
496 ));
497 }
498
499 let agent_names: Vec<String> = agents.iter().map(|a| a.name.clone()).collect();
500 let entry_agent = agent_names[0].clone();
501
502 let all_handoff_tools: HashMap<String, Arc<dyn Tool>> = agent_names
503 .iter()
504 .map(|name| {
505 (
506 name.clone(),
507 create_handoff_tool(
508 name,
509 &format!("Transfer the conversation to the '{name}' agent."),
510 ),
511 )
512 })
513 .collect();
514
515 let handoff_tool_names: Vec<String> = all_handoff_tools
516 .values()
517 .map(|t| t.name().to_string())
518 .collect();
519
520 let mut builder = StateGraph::new();
521
522 let global_registry = synaptic_tools::ToolRegistry::new();
523
524 for agent in agents {
525 let SwarmAgent {
526 name,
527 model,
528 tools,
529 system_prompt,
530 } = agent;
531
532 let mut tool_defs: Vec<ToolDefinition> = tools
533 .iter()
534 .map(|t| ToolDefinition {
535 name: t.name().to_string(),
536 description: t.description().to_string(),
537 parameters: serde_json::json!({}),
538 extras: None,
539 })
540 .collect();
541
542 for tool in &tools {
543 let _ = global_registry.register(tool.clone());
544 }
545
546 for other_name in &agent_names {
547 if other_name != &name {
548 if let Some(ht) = all_handoff_tools.get(other_name) {
549 tool_defs.push(ToolDefinition {
550 name: ht.name().to_string(),
551 description: ht.description().to_string(),
552 parameters: serde_json::json!({}),
553 extras: None,
554 });
555 }
556 }
557 }
558
559 let agent_node = SwarmAgentNode {
560 model,
561 tool_defs,
562 system_prompt,
563 };
564
565 builder = builder.add_node(&name, agent_node);
566 }
567
568 let executor = SerialToolExecutor::new(global_registry);
569 let swarm_tool_node = SwarmToolNode {
570 executor,
571 handoff_tool_names: handoff_tool_names.clone(),
572 };
573 builder = builder.add_node("tools", swarm_tool_node);
574
575 builder = builder.set_entry_point(&entry_agent);
576
577 for agent_name in &agent_names {
578 let end_str = END.to_string();
579 builder = builder.add_conditional_edges(agent_name, move |state: &MessageState| {
580 if let Some(last) = state.last_message() {
581 if !last.tool_calls().is_empty() {
582 return "tools".to_string();
583 }
584 }
585 end_str.clone()
586 });
587 }
588
589 let all_agent_names = agent_names.clone();
590 builder = builder.add_conditional_edges("tools", move |state: &MessageState| {
591 for msg in state.messages.iter().rev() {
592 if msg.is_ai() && !msg.tool_calls().is_empty() {
593 for tc in msg.tool_calls() {
594 for agent_name in &all_agent_names {
595 if tc.name == format!("transfer_to_{agent_name}") {
596 return agent_name.clone();
597 }
598 }
599 }
600 return all_agent_names[0].clone();
601 }
602 }
603 all_agent_names[0].clone()
604 });
605
606 let mut graph = builder.compile()?;
607
608 if let Some(checkpointer) = options.checkpointer {
609 graph = graph.with_checkpointer(checkpointer);
610 }
611
612 Ok(graph)
613}