Skip to main content

synaptic_graph/
prebuilt.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use serde_json::Value;
9use synaptic_core::{ChatModel, ChatRequest, Message, SynapticError, Tool, ToolDefinition};
10use synaptic_macros::traceable;
11use synaptic_middleware::{AgentMiddleware, BaseChatModelCaller, MiddlewareChain, ModelRequest};
12use synaptic_store::Store;
13use synaptic_tools::SerialToolExecutor;
14
15use crate::builder::StateGraph;
16use crate::checkpoint::Checkpointer;
17use crate::command::NodeOutput;
18use crate::compiled::CompiledGraph;
19use crate::node::Node;
20use crate::state::MessageState;
21use crate::tool_node::ToolNode;
22use crate::END;
23
24// ---------------------------------------------------------------------------
25// Hook types
26// ---------------------------------------------------------------------------
27
28/// A hook called before each model invocation. Can modify the state.
29pub type PreModelHook = Arc<
30    dyn Fn(
31            &mut MessageState,
32        ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
33        + Send
34        + Sync,
35>;
36
37/// A hook called after each model invocation. Can modify the state.
38pub type PostModelHook = Arc<
39    dyn Fn(
40            &mut MessageState,
41        ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
42        + Send
43        + Sync,
44>;
45
46// ---------------------------------------------------------------------------
47// ChatModelNode — prebuilt node that calls a ChatModel through middleware
48// ---------------------------------------------------------------------------
49
50struct ChatModelNode {
51    model: Arc<dyn ChatModel>,
52    tool_defs: Vec<ToolDefinition>,
53    system_prompt: Option<String>,
54    middleware: Arc<MiddlewareChain>,
55    is_first_call: AtomicBool,
56    pre_model_hook: Option<PreModelHook>,
57    post_model_hook: Option<PostModelHook>,
58    /// When set, the final response (no tool calls) is re-called with
59    /// structured output instructions matching this JSON schema.
60    response_format: Option<Value>,
61}
62
63#[async_trait]
64impl Node<MessageState> for ChatModelNode {
65    async fn process(
66        &self,
67        mut state: MessageState,
68    ) -> Result<NodeOutput<MessageState>, SynapticError> {
69        // On first call, run before_agent middleware hooks
70        if self.is_first_call.swap(false, Ordering::SeqCst) {
71            self.middleware
72                .run_before_agent(&mut state.messages)
73                .await?;
74        }
75
76        // Run pre_model_hook
77        if let Some(ref hook) = self.pre_model_hook {
78            hook(&mut state).await?;
79        }
80
81        let request = ModelRequest {
82            messages: state.messages.clone(),
83            tools: self.tool_defs.clone(),
84            tool_choice: None,
85            system_prompt: self.system_prompt.clone(),
86        };
87
88        let base_caller = BaseChatModelCaller::new(self.model.clone());
89        let response = self.middleware.call_model(request, &base_caller).await?;
90
91        state.messages.push(response.message.clone());
92
93        // Run post_model_hook
94        if let Some(ref hook) = self.post_model_hook {
95            hook(&mut state).await?;
96        }
97
98        // If no tool calls, this is the final answer
99        if response.message.tool_calls().is_empty() {
100            // If response_format is set, re-call with structured output instructions
101            if let Some(ref schema) = self.response_format {
102                let instruction = format!(
103                    "You MUST respond with valid JSON matching this schema:\n{}\n\n\
104                     Do not include any text outside the JSON object. \
105                     Do not use markdown code blocks.",
106                    schema
107                );
108                let mut structured_messages = vec![Message::system(instruction)];
109                structured_messages.extend(state.messages.clone());
110
111                let structured_request = ChatRequest::new(structured_messages);
112                let structured_response = self.model.chat(structured_request).await?;
113                // Replace the last message with the structured response
114                state.messages.pop();
115                state.messages.push(structured_response.message);
116            }
117
118            self.middleware.run_after_agent(&mut state.messages).await?;
119        }
120
121        Ok(state.into())
122    }
123}
124
125// ---------------------------------------------------------------------------
126// ReactAgentOptions (legacy)
127// ---------------------------------------------------------------------------
128
129/// Options for creating a ReAct agent with `create_react_agent_with_options`.
130#[derive(Default)]
131pub struct ReactAgentOptions {
132    /// Optional checkpointer for state persistence across invocations.
133    pub checkpointer: Option<Arc<dyn Checkpointer>>,
134    /// Node names that should interrupt BEFORE execution (human-in-the-loop).
135    pub interrupt_before: Vec<String>,
136    /// Node names that should interrupt AFTER execution (human-in-the-loop).
137    pub interrupt_after: Vec<String>,
138    /// Optional system prompt to prepend to messages before calling the model.
139    pub system_prompt: Option<String>,
140}
141
142/// Create a prebuilt ReAct agent graph.
143pub fn create_react_agent(
144    model: Arc<dyn ChatModel>,
145    tools: Vec<Arc<dyn Tool>>,
146) -> Result<CompiledGraph<MessageState>, SynapticError> {
147    create_react_agent_with_options(model, tools, ReactAgentOptions::default())
148}
149
150/// Create a prebuilt ReAct agent graph with additional configuration options.
151pub fn create_react_agent_with_options(
152    model: Arc<dyn ChatModel>,
153    tools: Vec<Arc<dyn Tool>>,
154    options: ReactAgentOptions,
155) -> Result<CompiledGraph<MessageState>, SynapticError> {
156    create_agent(
157        model,
158        tools,
159        AgentOptions {
160            checkpointer: options.checkpointer,
161            interrupt_before: options.interrupt_before,
162            interrupt_after: options.interrupt_after,
163            system_prompt: options.system_prompt,
164            ..Default::default()
165        },
166    )
167}
168
169// ---------------------------------------------------------------------------
170// AgentOptions — new unified options for create_agent
171// ---------------------------------------------------------------------------
172
173/// Options for creating an agent with `create_agent`.
174#[derive(Default)]
175pub struct AgentOptions {
176    pub checkpointer: Option<Arc<dyn Checkpointer>>,
177    pub interrupt_before: Vec<String>,
178    pub interrupt_after: Vec<String>,
179    pub system_prompt: Option<String>,
180    pub middleware: Vec<Arc<dyn AgentMiddleware>>,
181    pub store: Option<Arc<dyn Store>>,
182    pub name: Option<String>,
183    pub pre_model_hook: Option<PreModelHook>,
184    pub post_model_hook: Option<PostModelHook>,
185    /// Optional JSON schema for structured output on the final model call.
186    pub response_format: Option<Value>,
187}
188
189/// Create a prebuilt agent graph with full middleware and store support.
190#[traceable(skip = "model,tools,options")]
191pub fn create_agent(
192    model: Arc<dyn ChatModel>,
193    tools: Vec<Arc<dyn Tool>>,
194    options: AgentOptions,
195) -> Result<CompiledGraph<MessageState>, SynapticError> {
196    let tool_defs: Vec<ToolDefinition> = tools.iter().map(|t| t.as_tool_definition()).collect();
197
198    let registry = synaptic_tools::ToolRegistry::new();
199    for tool in tools {
200        registry.register(tool)?;
201    }
202    let executor = SerialToolExecutor::new(registry);
203
204    let middleware_chain = Arc::new(MiddlewareChain::new(options.middleware));
205
206    let agent_node = ChatModelNode {
207        model,
208        tool_defs,
209        system_prompt: options.system_prompt,
210        middleware: middleware_chain.clone(),
211        is_first_call: AtomicBool::new(true),
212        pre_model_hook: options.pre_model_hook,
213        post_model_hook: options.post_model_hook,
214        response_format: options.response_format,
215    };
216
217    let mut tool_node = ToolNode::with_middleware(executor, middleware_chain);
218    if let Some(ref store) = options.store {
219        tool_node = tool_node.with_store(store.clone());
220    }
221
222    let mut builder = StateGraph::new()
223        .add_node("agent", agent_node)
224        .add_node("tools", tool_node)
225        .set_entry_point("agent")
226        .add_conditional_edges_with_path_map(
227            "agent",
228            |state: &MessageState| {
229                if let Some(last) = state.last_message() {
230                    if !last.tool_calls().is_empty() {
231                        return "tools".to_string();
232                    }
233                }
234                END.to_string()
235            },
236            HashMap::from([
237                ("tools".to_string(), "tools".to_string()),
238                (END.to_string(), END.to_string()),
239            ]),
240        )
241        .add_edge("tools", "agent");
242
243    if !options.interrupt_before.is_empty() {
244        builder = builder.interrupt_before(options.interrupt_before);
245    }
246    if !options.interrupt_after.is_empty() {
247        builder = builder.interrupt_after(options.interrupt_after);
248    }
249
250    let mut graph = builder.compile()?;
251
252    // Auto-wire: explicit checkpointer > store-backed checkpointer > none
253    let checkpointer: Option<Arc<dyn Checkpointer>> = match (&options.store, options.checkpointer) {
254        (_, Some(ckpt)) => Some(ckpt),
255        (Some(store), None) => Some(Arc::new(crate::StoreCheckpointer::new(store.clone()))),
256        (None, None) => None,
257    };
258
259    if let Some(checkpointer) = checkpointer {
260        graph = graph.with_checkpointer(checkpointer);
261    }
262
263    Ok(graph)
264}
265
266// ---------------------------------------------------------------------------
267// Handoff tool — for multi-agent collaboration
268// ---------------------------------------------------------------------------
269
270struct HandoffTool {
271    target_agent: String,
272    tool_description: String,
273}
274
275#[async_trait]
276impl Tool for HandoffTool {
277    fn name(&self) -> &'static str {
278        Box::leak(format!("transfer_to_{}", self.target_agent).into_boxed_str())
279    }
280
281    fn description(&self) -> &'static str {
282        Box::leak(self.tool_description.clone().into_boxed_str())
283    }
284
285    async fn call(&self, _args: Value) -> Result<Value, SynapticError> {
286        Ok(Value::String(format!(
287            "Transferring to agent '{}'.",
288            self.target_agent
289        )))
290    }
291}
292
293/// Create a handoff tool that signals transfer to another agent.
294pub fn create_handoff_tool(agent_name: &str, description: &str) -> Arc<dyn Tool> {
295    Arc::new(HandoffTool {
296        target_agent: agent_name.to_string(),
297        tool_description: description.to_string(),
298    })
299}
300
301// ---------------------------------------------------------------------------
302// Supervisor — centralized multi-agent orchestration
303// ---------------------------------------------------------------------------
304
305/// Options for the supervisor multi-agent pattern.
306#[derive(Default)]
307pub struct SupervisorOptions {
308    pub checkpointer: Option<Arc<dyn Checkpointer>>,
309    pub store: Option<Arc<dyn Store>>,
310    pub system_prompt: Option<String>,
311}
312
313/// A sub-agent node that invokes a compiled agent graph as a node.
314struct SubAgentNode {
315    graph: CompiledGraph<MessageState>,
316}
317
318#[async_trait]
319impl Node<MessageState> for SubAgentNode {
320    async fn process(
321        &self,
322        state: MessageState,
323    ) -> Result<NodeOutput<MessageState>, SynapticError> {
324        let result = self.graph.invoke(state).await?;
325        Ok(result.into_state().into())
326    }
327}
328
329/// Create a supervisor multi-agent graph.
330#[traceable(skip = "model,agents,options")]
331pub fn create_supervisor(
332    model: Arc<dyn ChatModel>,
333    agents: Vec<(String, CompiledGraph<MessageState>)>,
334    options: SupervisorOptions,
335) -> Result<CompiledGraph<MessageState>, SynapticError> {
336    let agent_names: Vec<String> = agents.iter().map(|(name, _)| name.clone()).collect();
337
338    // Create handoff tools for each agent
339    let handoff_tools: Vec<Arc<dyn Tool>> = agent_names
340        .iter()
341        .map(|name| {
342            create_handoff_tool(
343                name,
344                &format!("Transfer the conversation to the '{name}' agent."),
345            )
346        })
347        .collect();
348
349    let handoff_tool_defs: Vec<ToolDefinition> = handoff_tools
350        .iter()
351        .map(|t| ToolDefinition {
352            name: t.name().to_string(),
353            description: t.description().to_string(),
354            parameters: serde_json::json!({}),
355            extras: None,
356        })
357        .collect();
358
359    let default_prompt = format!(
360        "You are a supervisor managing these agents: {}. \
361         Use the transfer tools to delegate tasks to the appropriate agent. \
362         When the task is complete, respond directly to the user.",
363        agent_names.join(", ")
364    );
365    let system_prompt = options.system_prompt.unwrap_or(default_prompt);
366
367    let supervisor_node = ChatModelNode {
368        model,
369        tool_defs: handoff_tool_defs.clone(),
370        system_prompt: Some(system_prompt),
371        middleware: Arc::new(MiddlewareChain::new(vec![])),
372        is_first_call: AtomicBool::new(false),
373        pre_model_hook: None,
374        post_model_hook: None,
375        response_format: None,
376    };
377
378    let mut builder = StateGraph::new()
379        .add_node("supervisor", supervisor_node)
380        .set_entry_point("supervisor");
381
382    for (name, graph) in agents {
383        builder = builder
384            .add_node(&name, SubAgentNode { graph })
385            .add_edge(&name, "supervisor");
386    }
387
388    let agent_names_for_router = agent_names.clone();
389    builder = builder.add_conditional_edges("supervisor", move |state: &MessageState| {
390        if let Some(last) = state.last_message() {
391            for tc in last.tool_calls() {
392                for agent_name in &agent_names_for_router {
393                    if tc.name == format!("transfer_to_{agent_name}") {
394                        return agent_name.clone();
395                    }
396                }
397            }
398        }
399        END.to_string()
400    });
401
402    let mut graph = builder.compile()?;
403
404    if let Some(checkpointer) = options.checkpointer {
405        graph = graph.with_checkpointer(checkpointer);
406    }
407
408    Ok(graph)
409}
410
411// ---------------------------------------------------------------------------
412// Swarm — decentralized multi-agent collaboration
413// ---------------------------------------------------------------------------
414
415/// Options for the swarm multi-agent pattern.
416#[derive(Default)]
417pub struct SwarmOptions {
418    pub checkpointer: Option<Arc<dyn Checkpointer>>,
419    pub store: Option<Arc<dyn Store>>,
420}
421
422/// A swarm agent node: calls a model with its own tools + handoff tools.
423struct SwarmAgentNode {
424    model: Arc<dyn ChatModel>,
425    tool_defs: Vec<ToolDefinition>,
426    system_prompt: Option<String>,
427}
428
429#[async_trait]
430impl Node<MessageState> for SwarmAgentNode {
431    async fn process(
432        &self,
433        mut state: MessageState,
434    ) -> Result<NodeOutput<MessageState>, SynapticError> {
435        let mut messages = Vec::new();
436        if let Some(ref prompt) = self.system_prompt {
437            messages.push(Message::system(prompt));
438        }
439        messages.extend(state.messages.clone());
440
441        let request = ChatRequest::new(messages).with_tools(self.tool_defs.clone());
442        let response = self.model.chat(request).await?;
443        state.messages.push(response.message);
444        Ok(state.into())
445    }
446}
447
448/// Swarm tool node: executes tool calls, but skips handoff tools.
449struct SwarmToolNode {
450    executor: SerialToolExecutor,
451    handoff_tool_names: Vec<String>,
452}
453
454#[async_trait]
455impl Node<MessageState> for SwarmToolNode {
456    async fn process(
457        &self,
458        mut state: MessageState,
459    ) -> Result<NodeOutput<MessageState>, SynapticError> {
460        let last = state
461            .last_message()
462            .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
463
464        let tool_calls = last.tool_calls().to_vec();
465        for call in &tool_calls {
466            if self.handoff_tool_names.contains(&call.name) {
467                state.messages.push(Message::tool(
468                    "Transferring to agent.".to_string(),
469                    &call.id,
470                ));
471            } else {
472                let result = self
473                    .executor
474                    .execute(&call.name, call.arguments.clone())
475                    .await?;
476                state
477                    .messages
478                    .push(Message::tool(result.to_string(), &call.id));
479            }
480        }
481
482        Ok(state.into())
483    }
484}
485
486/// A swarm agent definition.
487pub struct SwarmAgent {
488    pub name: String,
489    pub model: Arc<dyn ChatModel>,
490    pub tools: Vec<Arc<dyn Tool>>,
491    pub system_prompt: Option<String>,
492}
493
494/// Create a swarm multi-agent graph.
495#[traceable(skip = "agents,options")]
496pub fn create_swarm(
497    agents: Vec<SwarmAgent>,
498    options: SwarmOptions,
499) -> Result<CompiledGraph<MessageState>, SynapticError> {
500    if agents.is_empty() {
501        return Err(SynapticError::Graph(
502            "swarm requires at least one agent".to_string(),
503        ));
504    }
505
506    let agent_names: Vec<String> = agents.iter().map(|a| a.name.clone()).collect();
507    let entry_agent = agent_names[0].clone();
508
509    let all_handoff_tools: HashMap<String, Arc<dyn Tool>> = agent_names
510        .iter()
511        .map(|name| {
512            (
513                name.clone(),
514                create_handoff_tool(
515                    name,
516                    &format!("Transfer the conversation to the '{name}' agent."),
517                ),
518            )
519        })
520        .collect();
521
522    let handoff_tool_names: Vec<String> = all_handoff_tools
523        .values()
524        .map(|t| t.name().to_string())
525        .collect();
526
527    let mut builder = StateGraph::new();
528
529    let global_registry = synaptic_tools::ToolRegistry::new();
530
531    for agent in agents {
532        let SwarmAgent {
533            name,
534            model,
535            tools,
536            system_prompt,
537        } = agent;
538
539        let mut tool_defs: Vec<ToolDefinition> = tools
540            .iter()
541            .map(|t| ToolDefinition {
542                name: t.name().to_string(),
543                description: t.description().to_string(),
544                parameters: serde_json::json!({}),
545                extras: None,
546            })
547            .collect();
548
549        for tool in &tools {
550            let _ = global_registry.register(tool.clone());
551        }
552
553        for other_name in &agent_names {
554            if other_name != &name {
555                if let Some(ht) = all_handoff_tools.get(other_name) {
556                    tool_defs.push(ToolDefinition {
557                        name: ht.name().to_string(),
558                        description: ht.description().to_string(),
559                        parameters: serde_json::json!({}),
560                        extras: None,
561                    });
562                }
563            }
564        }
565
566        let agent_node = SwarmAgentNode {
567            model,
568            tool_defs,
569            system_prompt,
570        };
571
572        builder = builder.add_node(&name, agent_node);
573    }
574
575    let executor = SerialToolExecutor::new(global_registry);
576    let swarm_tool_node = SwarmToolNode {
577        executor,
578        handoff_tool_names: handoff_tool_names.clone(),
579    };
580    builder = builder.add_node("tools", swarm_tool_node);
581
582    builder = builder.set_entry_point(&entry_agent);
583
584    for agent_name in &agent_names {
585        builder = builder.add_conditional_edges(agent_name, |state: &MessageState| {
586            if let Some(last) = state.last_message() {
587                if !last.tool_calls().is_empty() {
588                    return "tools".to_string();
589                }
590            }
591            END.to_string()
592        });
593    }
594
595    let all_agent_names = agent_names.clone();
596    builder = builder.add_conditional_edges("tools", move |state: &MessageState| {
597        for msg in state.messages.iter().rev() {
598            if msg.is_ai() && !msg.tool_calls().is_empty() {
599                for tc in msg.tool_calls() {
600                    for agent_name in &all_agent_names {
601                        if tc.name == format!("transfer_to_{agent_name}") {
602                            return agent_name.clone();
603                        }
604                    }
605                }
606                return all_agent_names[0].clone();
607            }
608        }
609        all_agent_names[0].clone()
610    });
611
612    let mut graph = builder.compile()?;
613
614    if let Some(checkpointer) = options.checkpointer {
615        graph = graph.with_checkpointer(checkpointer);
616    }
617
618    Ok(graph)
619}