Skip to main content

synaptic_deep/middleware/
subagent.rs

1use async_trait::async_trait;
2use serde_json::{json, Value};
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Arc;
5use synaptic_core::{ChatModel, Message, SynapticError, Tool};
6use synaptic_graph::MessageState;
7use synaptic_middleware::AgentMiddleware;
8
9use crate::backend::Backend;
10
11/// Definition of a custom sub-agent type available to the task tool.
12#[derive(Clone)]
13pub struct SubAgentDef {
14    pub name: String,
15    pub description: String,
16    pub system_prompt: String,
17    pub tools: Vec<Arc<dyn Tool>>,
18}
19
20/// Middleware that provides a `task` tool for spawning child agents.
21///
22/// The `task` tool creates a child deep agent and invokes it with the given description.
23/// Recursion is bounded by `max_depth`.
24pub struct SubAgentMiddleware {
25    backend: Arc<dyn Backend>,
26    model: Arc<dyn ChatModel>,
27    max_depth: usize,
28    current_depth: Arc<AtomicUsize>,
29    custom_agents: Vec<SubAgentDef>,
30}
31
32impl SubAgentMiddleware {
33    pub fn new(
34        backend: Arc<dyn Backend>,
35        model: Arc<dyn ChatModel>,
36        max_depth: usize,
37        custom_agents: Vec<SubAgentDef>,
38    ) -> Self {
39        Self {
40            backend,
41            model,
42            max_depth,
43            current_depth: Arc::new(AtomicUsize::new(0)),
44            custom_agents,
45        }
46    }
47
48    /// Create the `task` tool that spawns sub-agents.
49    pub fn create_task_tool(&self) -> Arc<dyn Tool> {
50        Arc::new(TaskTool {
51            backend: self.backend.clone(),
52            model: self.model.clone(),
53            max_depth: self.max_depth,
54            current_depth: self.current_depth.clone(),
55            custom_agents: self.custom_agents.clone(),
56        })
57    }
58}
59
60#[async_trait]
61impl AgentMiddleware for SubAgentMiddleware {}
62
63// ---------------------------------------------------------------------------
64
65struct TaskTool {
66    backend: Arc<dyn Backend>,
67    model: Arc<dyn ChatModel>,
68    max_depth: usize,
69    current_depth: Arc<AtomicUsize>,
70    custom_agents: Vec<SubAgentDef>,
71}
72
73#[async_trait]
74impl Tool for TaskTool {
75    fn name(&self) -> &'static str {
76        "task"
77    }
78
79    fn description(&self) -> &'static str {
80        "Spawn a sub-agent to handle a complex, multi-step task autonomously"
81    }
82
83    fn parameters(&self) -> Option<Value> {
84        Some(json!({
85            "type": "object",
86            "properties": {
87                "description": {
88                    "type": "string",
89                    "description": "A detailed description of the task for the sub-agent"
90                },
91                "agent_type": {
92                    "type": "string",
93                    "description": "Type of agent to spawn (default: general-purpose)"
94                }
95            },
96            "required": ["description"]
97        }))
98    }
99
100    async fn call(&self, args: Value) -> Result<Value, SynapticError> {
101        let depth = self.current_depth.load(Ordering::Relaxed);
102        if depth >= self.max_depth {
103            return Err(SynapticError::Tool(format!(
104                "max subagent depth ({}) exceeded",
105                self.max_depth
106            )));
107        }
108
109        let description = args
110            .get("description")
111            .and_then(|v| v.as_str())
112            .ok_or_else(|| SynapticError::Tool("missing 'description' parameter".into()))?;
113
114        let agent_type = args
115            .get("agent_type")
116            .and_then(|v| v.as_str())
117            .unwrap_or("general-purpose");
118
119        self.current_depth.fetch_add(1, Ordering::Relaxed);
120        let result = self.run_subagent(description, agent_type).await;
121        self.current_depth.fetch_sub(1, Ordering::Relaxed);
122
123        result
124    }
125}
126
127impl TaskTool {
128    async fn run_subagent(
129        &self,
130        description: &str,
131        agent_type: &str,
132    ) -> Result<Value, SynapticError> {
133        let custom = self.custom_agents.iter().find(|a| a.name == agent_type);
134
135        let mut options = crate::DeepAgentOptions::new(self.backend.clone());
136        options.enable_subagents = self.current_depth.load(Ordering::Relaxed) < self.max_depth;
137        options.max_subagent_depth = self.max_depth;
138
139        if let Some(def) = custom {
140            options.system_prompt = Some(def.system_prompt.clone());
141            options.tools = def.tools.clone();
142        }
143
144        let agent = crate::create_deep_agent(self.model.clone(), options)?;
145
146        let state = MessageState::with_messages(vec![Message::human(description)]);
147        let result = agent.invoke(state).await?;
148        let final_state = result.into_state();
149
150        let response = final_state
151            .last_message()
152            .map(|m| m.content().to_string())
153            .unwrap_or_else(|| "Sub-agent completed with no response".to_string());
154
155        Ok(Value::String(response))
156    }
157}