Skip to main content

synaptic_graph/
prebuilt.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use serde_json::Value;
9use synaptic_core::{ChatModel, ChatRequest, Message, SynapticError, Tool, ToolDefinition};
10use synaptic_macros::traceable;
11use synaptic_middleware::{AgentMiddleware, BaseChatModelCaller, MiddlewareChain, ModelRequest};
12use synaptic_store::Store;
13use synaptic_tools::SerialToolExecutor;
14
15use crate::builder::StateGraph;
16use crate::checkpoint::Checkpointer;
17use crate::command::NodeOutput;
18use crate::compiled::CompiledGraph;
19use crate::node::Node;
20use crate::state::MessageState;
21use crate::tool_node::ToolNode;
22use crate::END;
23
24// ---------------------------------------------------------------------------
25// Hook types
26// ---------------------------------------------------------------------------
27
28/// A hook called before each model invocation. Can modify the state.
29pub type PreModelHook = Arc<
30    dyn Fn(
31            &mut MessageState,
32        ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
33        + Send
34        + Sync,
35>;
36
37/// A hook called after each model invocation. Can modify the state.
38pub type PostModelHook = Arc<
39    dyn Fn(
40            &mut MessageState,
41        ) -> Pin<Box<dyn Future<Output = Result<(), SynapticError>> + Send + '_>>
42        + Send
43        + Sync,
44>;
45
46// ---------------------------------------------------------------------------
47// ChatModelNode — prebuilt node that calls a ChatModel through middleware
48// ---------------------------------------------------------------------------
49
50struct ChatModelNode {
51    model: Arc<dyn ChatModel>,
52    tool_defs: Vec<ToolDefinition>,
53    system_prompt: Option<String>,
54    middleware: Arc<MiddlewareChain>,
55    is_first_call: AtomicBool,
56    pre_model_hook: Option<PreModelHook>,
57    post_model_hook: Option<PostModelHook>,
58    /// When set, the final response (no tool calls) is re-called with
59    /// structured output instructions matching this JSON schema.
60    response_format: Option<Value>,
61}
62
63#[async_trait]
64impl Node<MessageState> for ChatModelNode {
65    async fn process(
66        &self,
67        mut state: MessageState,
68    ) -> Result<NodeOutput<MessageState>, SynapticError> {
69        // On first call, run before_agent middleware hooks
70        if self.is_first_call.swap(false, Ordering::SeqCst) {
71            self.middleware
72                .run_before_agent(&mut state.messages)
73                .await?;
74        }
75
76        // Run pre_model_hook
77        if let Some(ref hook) = self.pre_model_hook {
78            hook(&mut state).await?;
79        }
80
81        let request = ModelRequest {
82            messages: state.messages.clone(),
83            tools: self.tool_defs.clone(),
84            tool_choice: None,
85            system_prompt: self.system_prompt.clone(),
86        };
87
88        let base_caller = BaseChatModelCaller::new(self.model.clone());
89        let response = self.middleware.call_model(request, &base_caller).await?;
90
91        state.messages.push(response.message.clone());
92
93        // Run post_model_hook
94        if let Some(ref hook) = self.post_model_hook {
95            hook(&mut state).await?;
96        }
97
98        // If no tool calls, this is the final answer
99        if response.message.tool_calls().is_empty() {
100            // If response_format is set, re-call with structured output instructions
101            if let Some(ref schema) = self.response_format {
102                let instruction = format!(
103                    "You MUST respond with valid JSON matching this schema:\n{}\n\n\
104                     Do not include any text outside the JSON object. \
105                     Do not use markdown code blocks.",
106                    schema
107                );
108                let mut structured_messages = vec![Message::system(instruction)];
109                structured_messages.extend(state.messages.clone());
110
111                let structured_request = ChatRequest::new(structured_messages);
112                let structured_response = self.model.chat(structured_request).await?;
113                // Replace the last message with the structured response
114                state.messages.pop();
115                state.messages.push(structured_response.message);
116            }
117
118            self.middleware.run_after_agent(&mut state.messages).await?;
119        }
120
121        Ok(state.into())
122    }
123}
124
125// ---------------------------------------------------------------------------
126// ReactAgentOptions (legacy)
127// ---------------------------------------------------------------------------
128
129/// Options for creating a ReAct agent with `create_react_agent_with_options`.
130#[derive(Default)]
131pub struct ReactAgentOptions {
132    /// Optional checkpointer for state persistence across invocations.
133    pub checkpointer: Option<Arc<dyn Checkpointer>>,
134    /// Node names that should interrupt BEFORE execution (human-in-the-loop).
135    pub interrupt_before: Vec<String>,
136    /// Node names that should interrupt AFTER execution (human-in-the-loop).
137    pub interrupt_after: Vec<String>,
138    /// Optional system prompt to prepend to messages before calling the model.
139    pub system_prompt: Option<String>,
140}
141
142/// Create a prebuilt ReAct agent graph.
143pub fn create_react_agent(
144    model: Arc<dyn ChatModel>,
145    tools: Vec<Arc<dyn Tool>>,
146) -> Result<CompiledGraph<MessageState>, SynapticError> {
147    create_react_agent_with_options(model, tools, ReactAgentOptions::default())
148}
149
150/// Create a prebuilt ReAct agent graph with additional configuration options.
151pub fn create_react_agent_with_options(
152    model: Arc<dyn ChatModel>,
153    tools: Vec<Arc<dyn Tool>>,
154    options: ReactAgentOptions,
155) -> Result<CompiledGraph<MessageState>, SynapticError> {
156    create_agent(
157        model,
158        tools,
159        AgentOptions {
160            checkpointer: options.checkpointer,
161            interrupt_before: options.interrupt_before,
162            interrupt_after: options.interrupt_after,
163            system_prompt: options.system_prompt,
164            ..Default::default()
165        },
166    )
167}
168
169// ---------------------------------------------------------------------------
170// AgentOptions — new unified options for create_agent
171// ---------------------------------------------------------------------------
172
173/// Options for creating an agent with `create_agent`.
174#[derive(Default)]
175pub struct AgentOptions {
176    pub checkpointer: Option<Arc<dyn Checkpointer>>,
177    pub interrupt_before: Vec<String>,
178    pub interrupt_after: Vec<String>,
179    pub system_prompt: Option<String>,
180    pub middleware: Vec<Arc<dyn AgentMiddleware>>,
181    pub store: Option<Arc<dyn Store>>,
182    pub name: Option<String>,
183    pub pre_model_hook: Option<PreModelHook>,
184    pub post_model_hook: Option<PostModelHook>,
185    /// Optional JSON schema for structured output on the final model call.
186    pub response_format: Option<Value>,
187    /// Enable parallel tool execution in ToolNode (default false).
188    pub parallel_tools: bool,
189}
190
191/// Create a prebuilt agent graph with full middleware and store support.
192#[traceable(skip = "model,tools,options")]
193pub fn create_agent(
194    model: Arc<dyn ChatModel>,
195    tools: Vec<Arc<dyn Tool>>,
196    options: AgentOptions,
197) -> Result<CompiledGraph<MessageState>, SynapticError> {
198    let tool_defs: Vec<ToolDefinition> = tools.iter().map(|t| t.as_tool_definition()).collect();
199
200    let registry = synaptic_tools::ToolRegistry::new();
201    for tool in tools {
202        registry.register(tool)?;
203    }
204    let executor = SerialToolExecutor::new(registry);
205
206    let middleware_chain = Arc::new(MiddlewareChain::new(options.middleware));
207
208    let agent_node = ChatModelNode {
209        model,
210        tool_defs,
211        system_prompt: options.system_prompt,
212        middleware: middleware_chain.clone(),
213        is_first_call: AtomicBool::new(true),
214        pre_model_hook: options.pre_model_hook,
215        post_model_hook: options.post_model_hook,
216        response_format: options.response_format,
217    };
218
219    let mut tool_node =
220        ToolNode::with_middleware(executor, middleware_chain).with_parallel(options.parallel_tools);
221    if let Some(ref store) = options.store {
222        tool_node = tool_node.with_store(store.clone());
223    }
224
225    let mut builder = StateGraph::new()
226        .add_node("agent", agent_node)
227        .add_node("tools", tool_node)
228        .set_entry_point("agent")
229        .add_conditional_edges_with_path_map(
230            "agent",
231            |state: &MessageState| {
232                if let Some(last) = state.last_message() {
233                    if !last.tool_calls().is_empty() {
234                        return "tools".to_string();
235                    }
236                }
237                END.to_string()
238            },
239            HashMap::from([
240                ("tools".to_string(), "tools".to_string()),
241                (END.to_string(), END.to_string()),
242            ]),
243        )
244        .add_edge("tools", "agent");
245
246    if !options.interrupt_before.is_empty() {
247        builder = builder.interrupt_before(options.interrupt_before);
248    }
249    if !options.interrupt_after.is_empty() {
250        builder = builder.interrupt_after(options.interrupt_after);
251    }
252
253    let mut graph = builder.compile()?;
254
255    // Auto-wire: explicit checkpointer > store-backed checkpointer > none
256    let checkpointer: Option<Arc<dyn Checkpointer>> = match (&options.store, options.checkpointer) {
257        (_, Some(ckpt)) => Some(ckpt),
258        (Some(store), None) => Some(Arc::new(crate::StoreCheckpointer::new(store.clone()))),
259        (None, None) => None,
260    };
261
262    if let Some(checkpointer) = checkpointer {
263        graph = graph.with_checkpointer(checkpointer);
264    }
265
266    Ok(graph)
267}
268
269// ---------------------------------------------------------------------------
270// Handoff tool — for multi-agent collaboration
271// ---------------------------------------------------------------------------
272
273struct HandoffTool {
274    target_agent: String,
275    tool_description: String,
276}
277
278#[async_trait]
279impl Tool for HandoffTool {
280    fn name(&self) -> &'static str {
281        Box::leak(format!("transfer_to_{}", self.target_agent).into_boxed_str())
282    }
283
284    fn description(&self) -> &'static str {
285        Box::leak(self.tool_description.clone().into_boxed_str())
286    }
287
288    async fn call(&self, _args: Value) -> Result<Value, SynapticError> {
289        Ok(Value::String(format!(
290            "Transferring to agent '{}'.",
291            self.target_agent
292        )))
293    }
294}
295
296/// Create a handoff tool that signals transfer to another agent.
297pub fn create_handoff_tool(agent_name: &str, description: &str) -> Arc<dyn Tool> {
298    Arc::new(HandoffTool {
299        target_agent: agent_name.to_string(),
300        tool_description: description.to_string(),
301    })
302}
303
304// ---------------------------------------------------------------------------
305// Supervisor — centralized multi-agent orchestration
306// ---------------------------------------------------------------------------
307
308/// Options for the supervisor multi-agent pattern.
309#[derive(Default)]
310pub struct SupervisorOptions {
311    pub checkpointer: Option<Arc<dyn Checkpointer>>,
312    pub store: Option<Arc<dyn Store>>,
313    pub system_prompt: Option<String>,
314}
315
316/// A sub-agent node that invokes a compiled agent graph as a node.
317struct SubAgentNode {
318    graph: CompiledGraph<MessageState>,
319}
320
321#[async_trait]
322impl Node<MessageState> for SubAgentNode {
323    async fn process(
324        &self,
325        state: MessageState,
326    ) -> Result<NodeOutput<MessageState>, SynapticError> {
327        let result = self.graph.invoke(state).await?;
328        Ok(result.into_state().into())
329    }
330}
331
332/// Create a supervisor multi-agent graph.
333#[traceable(skip = "model,agents,options")]
334pub fn create_supervisor(
335    model: Arc<dyn ChatModel>,
336    agents: Vec<(String, CompiledGraph<MessageState>)>,
337    options: SupervisorOptions,
338) -> Result<CompiledGraph<MessageState>, SynapticError> {
339    let agent_names: Vec<String> = agents.iter().map(|(name, _)| name.clone()).collect();
340
341    // Create handoff tools for each agent
342    let handoff_tools: Vec<Arc<dyn Tool>> = agent_names
343        .iter()
344        .map(|name| {
345            create_handoff_tool(
346                name,
347                &format!("Transfer the conversation to the '{name}' agent."),
348            )
349        })
350        .collect();
351
352    let handoff_tool_defs: Vec<ToolDefinition> = handoff_tools
353        .iter()
354        .map(|t| ToolDefinition {
355            name: t.name().to_string(),
356            description: t.description().to_string(),
357            parameters: serde_json::json!({}),
358            extras: None,
359        })
360        .collect();
361
362    let default_prompt = format!(
363        "You are a supervisor managing these agents: {}. \
364         Use the transfer tools to delegate tasks to the appropriate agent. \
365         When the task is complete, respond directly to the user.",
366        agent_names.join(", ")
367    );
368    let system_prompt = options.system_prompt.unwrap_or(default_prompt);
369
370    let supervisor_node = ChatModelNode {
371        model,
372        tool_defs: handoff_tool_defs.clone(),
373        system_prompt: Some(system_prompt),
374        middleware: Arc::new(MiddlewareChain::new(vec![])),
375        is_first_call: AtomicBool::new(false),
376        pre_model_hook: None,
377        post_model_hook: None,
378        response_format: None,
379    };
380
381    let mut builder = StateGraph::new()
382        .add_node("supervisor", supervisor_node)
383        .set_entry_point("supervisor");
384
385    for (name, graph) in agents {
386        builder = builder
387            .add_node(&name, SubAgentNode { graph })
388            .add_edge(&name, "supervisor");
389    }
390
391    let agent_names_for_router = agent_names.clone();
392    builder = builder.add_conditional_edges("supervisor", move |state: &MessageState| {
393        if let Some(last) = state.last_message() {
394            for tc in last.tool_calls() {
395                for agent_name in &agent_names_for_router {
396                    if tc.name == format!("transfer_to_{agent_name}") {
397                        return agent_name.clone();
398                    }
399                }
400            }
401        }
402        END.to_string()
403    });
404
405    let mut graph = builder.compile()?;
406
407    if let Some(checkpointer) = options.checkpointer {
408        graph = graph.with_checkpointer(checkpointer);
409    }
410
411    Ok(graph)
412}
413
414// ---------------------------------------------------------------------------
415// Swarm — decentralized multi-agent collaboration
416// ---------------------------------------------------------------------------
417
418/// Options for the swarm multi-agent pattern.
419#[derive(Default)]
420pub struct SwarmOptions {
421    pub checkpointer: Option<Arc<dyn Checkpointer>>,
422    pub store: Option<Arc<dyn Store>>,
423}
424
425/// A swarm agent node: calls a model with its own tools + handoff tools.
426struct SwarmAgentNode {
427    model: Arc<dyn ChatModel>,
428    tool_defs: Vec<ToolDefinition>,
429    system_prompt: Option<String>,
430}
431
432#[async_trait]
433impl Node<MessageState> for SwarmAgentNode {
434    async fn process(
435        &self,
436        mut state: MessageState,
437    ) -> Result<NodeOutput<MessageState>, SynapticError> {
438        let mut messages = Vec::new();
439        if let Some(ref prompt) = self.system_prompt {
440            messages.push(Message::system(prompt));
441        }
442        messages.extend(state.messages.clone());
443
444        let request = ChatRequest::new(messages).with_tools(self.tool_defs.clone());
445        let response = self.model.chat(request).await?;
446        state.messages.push(response.message);
447        Ok(state.into())
448    }
449}
450
451/// Swarm tool node: executes tool calls, but skips handoff tools.
452struct SwarmToolNode {
453    executor: SerialToolExecutor,
454    handoff_tool_names: Vec<String>,
455}
456
457#[async_trait]
458impl Node<MessageState> for SwarmToolNode {
459    async fn process(
460        &self,
461        mut state: MessageState,
462    ) -> Result<NodeOutput<MessageState>, SynapticError> {
463        let last = state
464            .last_message()
465            .ok_or_else(|| SynapticError::Graph("no messages in state".to_string()))?;
466
467        let tool_calls = last.tool_calls().to_vec();
468        for call in &tool_calls {
469            if self.handoff_tool_names.contains(&call.name) {
470                state.messages.push(Message::tool(
471                    "Transferring to agent.".to_string(),
472                    &call.id,
473                ));
474            } else {
475                let result = self
476                    .executor
477                    .execute(&call.name, call.arguments.clone())
478                    .await?;
479                state
480                    .messages
481                    .push(Message::tool(result.to_string(), &call.id));
482            }
483        }
484
485        Ok(state.into())
486    }
487}
488
489/// A swarm agent definition.
490pub struct SwarmAgent {
491    pub name: String,
492    pub model: Arc<dyn ChatModel>,
493    pub tools: Vec<Arc<dyn Tool>>,
494    pub system_prompt: Option<String>,
495}
496
497/// Create a swarm multi-agent graph.
498#[traceable(skip = "agents,options")]
499pub fn create_swarm(
500    agents: Vec<SwarmAgent>,
501    options: SwarmOptions,
502) -> Result<CompiledGraph<MessageState>, SynapticError> {
503    if agents.is_empty() {
504        return Err(SynapticError::Graph(
505            "swarm requires at least one agent".to_string(),
506        ));
507    }
508
509    let agent_names: Vec<String> = agents.iter().map(|a| a.name.clone()).collect();
510    let entry_agent = agent_names[0].clone();
511
512    let all_handoff_tools: HashMap<String, Arc<dyn Tool>> = agent_names
513        .iter()
514        .map(|name| {
515            (
516                name.clone(),
517                create_handoff_tool(
518                    name,
519                    &format!("Transfer the conversation to the '{name}' agent."),
520                ),
521            )
522        })
523        .collect();
524
525    let handoff_tool_names: Vec<String> = all_handoff_tools
526        .values()
527        .map(|t| t.name().to_string())
528        .collect();
529
530    let mut builder = StateGraph::new();
531
532    let global_registry = synaptic_tools::ToolRegistry::new();
533
534    for agent in agents {
535        let SwarmAgent {
536            name,
537            model,
538            tools,
539            system_prompt,
540        } = agent;
541
542        let mut tool_defs: Vec<ToolDefinition> = tools
543            .iter()
544            .map(|t| ToolDefinition {
545                name: t.name().to_string(),
546                description: t.description().to_string(),
547                parameters: serde_json::json!({}),
548                extras: None,
549            })
550            .collect();
551
552        for tool in &tools {
553            let _ = global_registry.register(tool.clone());
554        }
555
556        for other_name in &agent_names {
557            if other_name != &name {
558                if let Some(ht) = all_handoff_tools.get(other_name) {
559                    tool_defs.push(ToolDefinition {
560                        name: ht.name().to_string(),
561                        description: ht.description().to_string(),
562                        parameters: serde_json::json!({}),
563                        extras: None,
564                    });
565                }
566            }
567        }
568
569        let agent_node = SwarmAgentNode {
570            model,
571            tool_defs,
572            system_prompt,
573        };
574
575        builder = builder.add_node(&name, agent_node);
576    }
577
578    let executor = SerialToolExecutor::new(global_registry);
579    let swarm_tool_node = SwarmToolNode {
580        executor,
581        handoff_tool_names: handoff_tool_names.clone(),
582    };
583    builder = builder.add_node("tools", swarm_tool_node);
584
585    builder = builder.set_entry_point(&entry_agent);
586
587    for agent_name in &agent_names {
588        builder = builder.add_conditional_edges(agent_name, |state: &MessageState| {
589            if let Some(last) = state.last_message() {
590                if !last.tool_calls().is_empty() {
591                    return "tools".to_string();
592                }
593            }
594            END.to_string()
595        });
596    }
597
598    let all_agent_names = agent_names.clone();
599    builder = builder.add_conditional_edges("tools", move |state: &MessageState| {
600        for msg in state.messages.iter().rev() {
601            if msg.is_ai() && !msg.tool_calls().is_empty() {
602                for tc in msg.tool_calls() {
603                    for agent_name in &all_agent_names {
604                        if tc.name == format!("transfer_to_{agent_name}") {
605                            return agent_name.clone();
606                        }
607                    }
608                }
609                return all_agent_names[0].clone();
610            }
611        }
612        all_agent_names[0].clone()
613    });
614
615    let mut graph = builder.compile()?;
616
617    if let Some(checkpointer) = options.checkpointer {
618        graph = graph.with_checkpointer(checkpointer);
619    }
620
621    Ok(graph)
622}