Skip to main content

synaptic_graph/
prebuilt.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use serde_json::Value;
9use synaptic_core::{ChatModel, ChatRequest, Message, SynapticError, Tool, ToolDefinition};
10use synaptic_macros::traceable;
11use synaptic_middleware::{AgentMiddleware, BaseChatModelCaller, MiddlewareChain, ModelRequest};
12use synaptic_store::Store;
13use synaptic_tools::SerialToolExecutor;
14
15use crate::builder::StateGraph;
16use crate::checkpoint::Checkpointer;
17use crate::command::NodeOutput;
18use crate::compiled::CompiledGraph;
19use crate::node::Node;
20use crate::state::MessageState;
21use crate::tool_node::ToolNode;
22use crate::END;
23
24// ---------------------------------------------------------------------------
25// Hook types
26// ---------------------------------------------------------------------------
27
28/// A hook called before each model invocation. Can modify the state.
29pub type PreModelHook = Arc<
30    dyn Fn(
31            &mut MessageState,
32        ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
33        + Send
34        + Sync,
35>;
36
37/// A hook called after each model invocation. Can modify the state.
38pub type PostModelHook = Arc<
39    dyn Fn(
40            &mut MessageState,
41        ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
42        + Send
43        + Sync,
44>;
45
46// ---------------------------------------------------------------------------
47// ChatModelNode — prebuilt node that calls a ChatModel through middleware
48// ---------------------------------------------------------------------------
49
50struct ChatModelNode {
51    model: Arc<dyn ChatModel>,
52    tool_defs: Vec<ToolDefinition>,
53    system_prompt: Option<String>,
54    middleware: Arc<MiddlewareChain>,
55    is_first_call: AtomicBool,
56    pre_model_hook: Option<PreModelHook>,
57    post_model_hook: Option<PostModelHook>,
58    /// When set, the final response (no tool calls) is re-called with
59    /// structured output instructions matching this JSON schema.
60    response_format: Option<Value>,
61}
62
63#[async_trait]
64impl Node<MessageState> for ChatModelNode {
65    async fn process(
66        &self,
67        mut state: MessageState,
68    ) -> Result<NodeOutput<MessageState>, SynapticError> {
69        // On first call, run before_agent middleware hooks
70        if self.is_first_call.swap(false, Ordering::SeqCst) {
71            self.middleware
72                .run_before_agent(&mut state.messages)
73                .await?;
74        }
75
76        // Run pre_model_hook
77        if let Some(ref hook) = self.pre_model_hook {
78            hook(&mut state).await?;
79        }
80
81        let request = ModelRequest {
82            messages: state.messages.clone(),
83            tools: self.tool_defs.clone(),
84            tool_choice: None,
85            system_prompt: self.system_prompt.clone(),
86        };
87
88        let base_caller = BaseChatModelCaller::new(self.model.clone());
89        let response = self.middleware.call_model(request, &base_caller).await?;
90
91        state.messages.push(response.message.clone());
92
93        // Run post_model_hook
94        if let Some(ref hook) = self.post_model_hook {
95            hook(&mut state).await?;
96        }
97
98        // If no tool calls, this is the final answer
99        if response.message.tool_calls().is_empty() {
100            // If response_format is set, re-call with structured output instructions
101            if let Some(ref schema) = self.response_format {
102                let instruction = format!(
103                    "You MUST respond with valid JSON matching this schema:\n{}\n\n\
104                     Do not include any text outside the JSON object. \
105                     Do not use markdown code blocks.",
106                    schema
107                );
108                let mut structured_messages = vec![Message::system(instruction)];
109                structured_messages.extend(state.messages.clone());
110
111                let structured_request = ChatRequest::new(structured_messages);
112                let structured_response = self.model.chat(structured_request).await?;
113                // Replace the last message with the structured response
114                state.messages.pop();
115                state.messages.push(structured_response.message);
116            }
117
118            self.middleware.run_after_agent(&mut state.messages).await?;
119        }
120
121        Ok(state.into())
122    }
123}
124
125// ---------------------------------------------------------------------------
126// ReactAgentOptions (legacy)
127// ---------------------------------------------------------------------------
128
129/// Options for creating a ReAct agent with `create_react_agent_with_options`.
130#[derive(Default)]
131pub struct ReactAgentOptions {
132    /// Optional checkpointer for state persistence across invocations.
133    pub checkpointer: Option<Arc<dyn Checkpointer>>,
134    /// Node names that should interrupt BEFORE execution (human-in-the-loop).
135    pub interrupt_before: Vec<String>,
136    /// Node names that should interrupt AFTER execution (human-in-the-loop).
137    pub interrupt_after: Vec<String>,
138    /// Optional system prompt to prepend to messages before calling the model.
139    pub system_prompt: Option<String>,
140}
141
142/// Create a prebuilt ReAct agent graph.
143pub fn create_react_agent(
144    model: Arc<dyn ChatModel>,
145    tools: Vec<Arc<dyn Tool>>,
146) -> Result<CompiledGraph<MessageState>, SynapticError> {
147    create_react_agent_with_options(model, tools, ReactAgentOptions::default())
148}
149
150/// Create a prebuilt ReAct agent graph with additional configuration options.
151pub fn create_react_agent_with_options(
152    model: Arc<dyn ChatModel>,
153    tools: Vec<Arc<dyn Tool>>,
154    options: ReactAgentOptions,
155) -> Result<CompiledGraph<MessageState>, SynapticError> {
156    create_agent(
157        model,
158        tools,
159        AgentOptions {
160            checkpointer: options.checkpointer,
161            interrupt_before: options.interrupt_before,
162            interrupt_after: options.interrupt_after,
163            system_prompt: options.system_prompt,
164            ..Default::default()
165        },
166    )
167}
168
169// ---------------------------------------------------------------------------
170// AgentOptions — new unified options for create_agent
171// ---------------------------------------------------------------------------
172
173/// Options for creating an agent with `create_agent`.
174#[derive(Default)]
175pub struct AgentOptions {
176    pub checkpointer: Option<Arc<dyn Checkpointer>>,
177    pub interrupt_before: Vec<String>,
178    pub interrupt_after: Vec<String>,
179    pub system_prompt: Option<String>,
180    pub middleware: Vec<Arc<dyn AgentMiddleware>>,
181    pub store: Option<Arc<dyn Store>>,
182    pub name: Option<String>,
183    pub pre_model_hook: Option<PreModelHook>,
184    pub post_model_hook: Option<PostModelHook>,
185    /// Optional JSON schema for structured output on the final model call.
186    pub response_format: Option<Value>,
187}
188
189/// Create a prebuilt agent graph with full middleware and store support.
190#[traceable(skip = "model,tools,options")]
191pub fn create_agent(
192    model: Arc<dyn ChatModel>,
193    tools: Vec<Arc<dyn Tool>>,
194    options: AgentOptions,
195) -> Result<CompiledGraph<MessageState>, SynapticError> {
196    let tool_defs: Vec<ToolDefinition> = tools.iter().map(|t| t.as_tool_definition()).collect();
197
198    let registry = synaptic_tools::ToolRegistry::new();
199    for tool in tools {
200        registry.register(tool)?;
201    }
202    let executor = SerialToolExecutor::new(registry);
203
204    let middleware_chain = Arc::new(MiddlewareChain::new(options.middleware));
205
206    let agent_node = ChatModelNode {
207        model,
208        tool_defs,
209        system_prompt: options.system_prompt,
210        middleware: middleware_chain.clone(),
211        is_first_call: AtomicBool::new(true),
212        pre_model_hook: options.pre_model_hook,
213        post_model_hook: options.post_model_hook,
214        response_format: options.response_format,
215    };
216
217    let mut tool_node = ToolNode::with_middleware(executor, middleware_chain);
218    if let Some(ref store) = options.store {
219        tool_node = tool_node.with_store(store.clone());
220    }
221
222    let mut builder = StateGraph::new()
223        .add_node("agent", agent_node)
224        .add_node("tools", tool_node)
225        .set_entry_point("agent")
226        .add_conditional_edges_with_path_map(
227            "agent",
228            |state: &MessageState| {
229                if let Some(last) = state.last_message() {
230                    if !last.tool_calls().is_empty() {
231                        return "tools".to_string();
232                    }
233                }
234                END.to_string()
235            },
236            HashMap::from([
237                ("tools".to_string(), "tools".to_string()),
238                (END.to_string(), END.to_string()),
239            ]),
240        )
241        .add_edge("tools", "agent");
242
243    if !options.interrupt_before.is_empty() {
244        builder = builder.interrupt_before(options.interrupt_before);
245    }
246    if !options.interrupt_after.is_empty() {
247        builder = builder.interrupt_after(options.interrupt_after);
248    }
249
250    let mut graph = builder.compile()?;
251
252    if let Some(checkpointer) = options.checkpointer {
253        graph = graph.with_checkpointer(checkpointer);
254    }
255
256    Ok(graph)
257}
258
259// ---------------------------------------------------------------------------
260// Handoff tool — for multi-agent collaboration
261// ---------------------------------------------------------------------------
262
263struct HandoffTool {
264    target_agent: String,
265    tool_description: String,
266}
267
268#[async_trait]
269impl Tool for HandoffTool {
270    fn name(&self) -> &'static str {
271        Box::leak(format!("transfer_to_{}", self.target_agent).into_boxed_str())
272    }
273
274    fn description(&self) -> &'static str {
275        Box::leak(self.tool_description.clone().into_boxed_str())
276    }
277
278    async fn call(&self, _args: Value) -> Result<Value, SynapticError> {
279        Ok(Value::String(format!(
280            "Transferring to agent '{}'.",
281            self.target_agent
282        )))
283    }
284}
285
286/// Create a handoff tool that signals transfer to another agent.
287pub fn create_handoff_tool(agent_name: &str, description: &str) -> Arc<dyn Tool> {
288    Arc::new(HandoffTool {
289        target_agent: agent_name.to_string(),
290        tool_description: description.to_string(),
291    })
292}
293
294// ---------------------------------------------------------------------------
295// Supervisor — centralized multi-agent orchestration
296// ---------------------------------------------------------------------------
297
298/// Options for the supervisor multi-agent pattern.
299#[derive(Default)]
300pub struct SupervisorOptions {
301    pub checkpointer: Option<Arc<dyn Checkpointer>>,
302    pub store: Option<Arc<dyn Store>>,
303    pub system_prompt: Option<String>,
304}
305
306/// A sub-agent node that invokes a compiled agent graph as a node.
307struct SubAgentNode {
308    graph: CompiledGraph<MessageState>,
309}
310
311#[async_trait]
312impl Node<MessageState> for SubAgentNode {
313    async fn process(
314        &self,
315        state: MessageState,
316    ) -> Result<NodeOutput<MessageState>, SynapticError> {
317        let result = self.graph.invoke(state).await?;
318        Ok(result.into_state().into())
319    }
320}
321
322/// Create a supervisor multi-agent graph.
323#[traceable(skip = "model,agents,options")]
324pub fn create_supervisor(
325    model: Arc<dyn ChatModel>,
326    agents: Vec<(String, CompiledGraph<MessageState>)>,
327    options: SupervisorOptions,
328) -> Result<CompiledGraph<MessageState>, SynapticError> {
329    let agent_names: Vec<String> = agents.iter().map(|(name, _)| name.clone()).collect();
330
331    // Create handoff tools for each agent
332    let handoff_tools: Vec<Arc<dyn Tool>> = agent_names
333        .iter()
334        .map(|name| {
335            create_handoff_tool(
336                name,
337                &format!("Transfer the conversation to the '{name}' agent."),
338            )
339        })
340        .collect();
341
342    let handoff_tool_defs: Vec<ToolDefinition> = handoff_tools
343        .iter()
344        .map(|t| ToolDefinition {
345            name: t.name().to_string(),
346            description: t.description().to_string(),
347            parameters: serde_json::json!({}),
348            extras: None,
349        })
350        .collect();
351
352    let default_prompt = format!(
353        "You are a supervisor managing these agents: {}. \
354         Use the transfer tools to delegate tasks to the appropriate agent. \
355         When the task is complete, respond directly to the user.",
356        agent_names.join(", ")
357    );
358    let system_prompt = options.system_prompt.unwrap_or(default_prompt);
359
360    let supervisor_node = ChatModelNode {
361        model,
362        tool_defs: handoff_tool_defs.clone(),
363        system_prompt: Some(system_prompt),
364        middleware: Arc::new(MiddlewareChain::new(vec![])),
365        is_first_call: AtomicBool::new(false),
366        pre_model_hook: None,
367        post_model_hook: None,
368        response_format: None,
369    };
370
371    let mut builder = StateGraph::new()
372        .add_node("supervisor", supervisor_node)
373        .set_entry_point("supervisor");
374
375    for (name, graph) in agents {
376        builder = builder
377            .add_node(&name, SubAgentNode { graph })
378            .add_edge(&name, "supervisor");
379    }
380
381    let agent_names_for_router = agent_names.clone();
382    builder = builder.add_conditional_edges("supervisor", move |state: &MessageState| {
383        if let Some(last) = state.last_message() {
384            for tc in last.tool_calls() {
385                for agent_name in &agent_names_for_router {
386                    if tc.name == format!("transfer_to_{agent_name}") {
387                        return agent_name.clone();
388                    }
389                }
390            }
391        }
392        END.to_string()
393    });
394
395    let mut graph = builder.compile()?;
396
397    if let Some(checkpointer) = options.checkpointer {
398        graph = graph.with_checkpointer(checkpointer);
399    }
400
401    Ok(graph)
402}
403
404// ---------------------------------------------------------------------------
405// Swarm — decentralized multi-agent collaboration
406// ---------------------------------------------------------------------------
407
408/// Options for the swarm multi-agent pattern.
409#[derive(Default)]
410pub struct SwarmOptions {
411    pub checkpointer: Option<Arc<dyn Checkpointer>>,
412    pub store: Option<Arc<dyn Store>>,
413}
414
415/// A swarm agent node: calls a model with its own tools + handoff tools.
416struct SwarmAgentNode {
417    model: Arc<dyn ChatModel>,
418    tool_defs: Vec<ToolDefinition>,
419    system_prompt: Option<String>,
420}
421
422#[async_trait]
423impl Node<MessageState> for SwarmAgentNode {
424    async fn process(
425        &self,
426        mut state: MessageState,
427    ) -> Result<NodeOutput<MessageState>, SynapticError> {
428        let mut messages = Vec::new();
429        if let Some(ref prompt) = self.system_prompt {
430            messages.push(Message::system(prompt));
431        }
432        messages.extend(state.messages.clone());
433
434        let request = ChatRequest::new(messages).with_tools(self.tool_defs.clone());
435        let response = self.model.chat(request).await?;
436        state.messages.push(response.message);
437        Ok(state.into())
438    }
439}
440
441/// Swarm tool node: executes tool calls, but skips handoff tools.
442struct SwarmToolNode {
443    executor: SerialToolExecutor,
444    handoff_tool_names: Vec<String>,
445}
446
447#[async_trait]
448impl Node<MessageState> for SwarmToolNode {
449    async fn process(
450        &self,
451        mut state: MessageState,
452    ) -> Result<NodeOutput<MessageState>, SynapticError> {
453        let last = state
454            .last_message()
455            .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
456
457        let tool_calls = last.tool_calls().to_vec();
458        for call in &tool_calls {
459            if self.handoff_tool_names.contains(&call.name) {
460                state.messages.push(Message::tool(
461                    "Transferring to agent.".to_string(),
462                    &call.id,
463                ));
464            } else {
465                let result = self
466                    .executor
467                    .execute(&call.name, call.arguments.clone())
468                    .await?;
469                state
470                    .messages
471                    .push(Message::tool(result.to_string(), &call.id));
472            }
473        }
474
475        Ok(state.into())
476    }
477}
478
479/// A swarm agent definition.
480pub struct SwarmAgent {
481    pub name: String,
482    pub model: Arc<dyn ChatModel>,
483    pub tools: Vec<Arc<dyn Tool>>,
484    pub system_prompt: Option<String>,
485}
486
487/// Create a swarm multi-agent graph.
488#[traceable(skip = "agents,options")]
489pub fn create_swarm(
490    agents: Vec<SwarmAgent>,
491    options: SwarmOptions,
492) -> Result<CompiledGraph<MessageState>, SynapticError> {
493    if agents.is_empty() {
494        return Err(SynapticError::Graph(
495            "swarm requires at least one agent".to_string(),
496        ));
497    }
498
499    let agent_names: Vec<String> = agents.iter().map(|a| a.name.clone()).collect();
500    let entry_agent = agent_names[0].clone();
501
502    let all_handoff_tools: HashMap<String, Arc<dyn Tool>> = agent_names
503        .iter()
504        .map(|name| {
505            (
506                name.clone(),
507                create_handoff_tool(
508                    name,
509                    &format!("Transfer the conversation to the '{name}' agent."),
510                ),
511            )
512        })
513        .collect();
514
515    let handoff_tool_names: Vec<String> = all_handoff_tools
516        .values()
517        .map(|t| t.name().to_string())
518        .collect();
519
520    let mut builder = StateGraph::new();
521
522    let global_registry = synaptic_tools::ToolRegistry::new();
523
524    for agent in agents {
525        let SwarmAgent {
526            name,
527            model,
528            tools,
529            system_prompt,
530        } = agent;
531
532        let mut tool_defs: Vec<ToolDefinition> = tools
533            .iter()
534            .map(|t| ToolDefinition {
535                name: t.name().to_string(),
536                description: t.description().to_string(),
537                parameters: serde_json::json!({}),
538                extras: None,
539            })
540            .collect();
541
542        for tool in &tools {
543            let _ = global_registry.register(tool.clone());
544        }
545
546        for other_name in &agent_names {
547            if other_name != &name {
548                if let Some(ht) = all_handoff_tools.get(other_name) {
549                    tool_defs.push(ToolDefinition {
550                        name: ht.name().to_string(),
551                        description: ht.description().to_string(),
552                        parameters: serde_json::json!({}),
553                        extras: None,
554                    });
555                }
556            }
557        }
558
559        let agent_node = SwarmAgentNode {
560            model,
561            tool_defs,
562            system_prompt,
563        };
564
565        builder = builder.add_node(&name, agent_node);
566    }
567
568    let executor = SerialToolExecutor::new(global_registry);
569    let swarm_tool_node = SwarmToolNode {
570        executor,
571        handoff_tool_names: handoff_tool_names.clone(),
572    };
573    builder = builder.add_node("tools", swarm_tool_node);
574
575    builder = builder.set_entry_point(&entry_agent);
576
577    for agent_name in &agent_names {
578        builder = builder.add_conditional_edges(agent_name, |state: &MessageState| {
579            if let Some(last) = state.last_message() {
580                if !last.tool_calls().is_empty() {
581                    return "tools".to_string();
582                }
583            }
584            END.to_string()
585        });
586    }
587
588    let all_agent_names = agent_names.clone();
589    builder = builder.add_conditional_edges("tools", move |state: &MessageState| {
590        for msg in state.messages.iter().rev() {
591            if msg.is_ai() && !msg.tool_calls().is_empty() {
592                for tc in msg.tool_calls() {
593                    for agent_name in &all_agent_names {
594                        if tc.name == format!("transfer_to_{agent_name}") {
595                            return agent_name.clone();
596                        }
597                    }
598                }
599                return all_agent_names[0].clone();
600            }
601        }
602        all_agent_names[0].clone()
603    });
604
605    let mut graph = builder.compile()?;
606
607    if let Some(checkpointer) = options.checkpointer {
608        graph = graph.with_checkpointer(checkpointer);
609    }
610
611    Ok(graph)
612}