Skip to main content

steer_core/tools/
services.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use tokio_util::sync::CancellationToken;
5
6use crate::api::Client as ApiClient;
7use crate::app::SystemContext;
8use crate::app::conversation::Message;
9use crate::app::domain::session::EventStore;
10use crate::app::domain::types::SessionId;
11use crate::config::model::ModelId;
12use crate::session::state::BackendConfig;
13use crate::workspace::{
14    RepoManager, RepoRef, Workspace, WorkspaceId, WorkspaceManager, WorkspaceRef,
15};
16
17use super::capability::Capabilities;
18use steer_tools::ToolSchema;
19
20#[async_trait]
21pub trait AgentSpawner: Send + Sync {
22    async fn spawn(
23        &self,
24        config: SubAgentConfig,
25        cancel_token: CancellationToken,
26    ) -> Result<SubAgentResult, SubAgentError>;
27}
28
29#[derive(Debug, Clone)]
30pub struct SubAgentConfig {
31    pub parent_session_id: SessionId,
32    pub prompt: String,
33    pub allowed_tools: Vec<String>,
34    pub model: ModelId,
35    pub system_context: Option<SystemContext>,
36    pub workspace: Option<Arc<dyn Workspace>>,
37    pub workspace_ref: Option<WorkspaceRef>,
38    pub workspace_id: Option<WorkspaceId>,
39    pub repo_ref: Option<RepoRef>,
40    pub workspace_name: Option<String>,
41    pub mcp_backends: Vec<BackendConfig>,
42    pub allow_mcp_tools: bool,
43}
44
45#[derive(Debug, Clone)]
46pub struct SubAgentResult {
47    pub session_id: SessionId,
48    pub final_message: Message,
49}
50
51#[derive(Debug, Clone, thiserror::Error)]
52pub enum SubAgentError {
53    #[error("API error: {0}")]
54    Api(String),
55
56    #[error("Agent error: {0}")]
57    Agent(String),
58
59    #[error("Event store error: {0}")]
60    EventStore(String),
61
62    #[error("Cancelled")]
63    Cancelled,
64}
65
66#[async_trait]
67pub trait ModelCaller: Send + Sync {
68    async fn call(
69        &self,
70        model: &ModelId,
71        messages: Vec<Message>,
72        system_context: Option<SystemContext>,
73        cancel_token: CancellationToken,
74    ) -> Result<Message, ModelCallError>;
75}
76
77#[derive(Debug, Clone, thiserror::Error)]
78pub enum ModelCallError {
79    #[error("API error: {0}")]
80    Api(String),
81
82    #[error("Cancelled")]
83    Cancelled,
84}
85
86pub struct ToolServices {
87    pub workspace: Arc<dyn Workspace>,
88    pub event_store: Arc<dyn EventStore>,
89    pub api_client: Arc<ApiClient>,
90
91    agent_spawner: Option<Arc<dyn AgentSpawner>>,
92    model_caller: Option<Arc<dyn ModelCaller>>,
93    workspace_manager: Option<Arc<dyn WorkspaceManager>>,
94    repo_manager: Option<Arc<dyn RepoManager>>,
95
96    available_capabilities: Capabilities,
97}
98
99impl ToolServices {
100    pub fn new(
101        workspace: Arc<dyn Workspace>,
102        event_store: Arc<dyn EventStore>,
103        api_client: Arc<ApiClient>,
104    ) -> Self {
105        Self {
106            workspace,
107            event_store,
108            api_client,
109            agent_spawner: None,
110            model_caller: None,
111            workspace_manager: None,
112            repo_manager: None,
113            available_capabilities: Capabilities::WORKSPACE,
114        }
115    }
116
117    pub fn with_agent_spawner(mut self, spawner: Arc<dyn AgentSpawner>) -> Self {
118        self.agent_spawner = Some(spawner);
119        self.available_capabilities |= Capabilities::AGENT_SPAWNER;
120        self
121    }
122
123    pub fn with_model_caller(mut self, caller: Arc<dyn ModelCaller>) -> Self {
124        self.model_caller = Some(caller);
125        self.available_capabilities |= Capabilities::MODEL_CALLER;
126        self
127    }
128
129    pub fn with_workspace_manager(mut self, manager: Arc<dyn WorkspaceManager>) -> Self {
130        self.workspace_manager = Some(manager);
131        self
132    }
133
134    pub fn with_repo_manager(mut self, manager: Arc<dyn RepoManager>) -> Self {
135        self.repo_manager = Some(manager);
136        self
137    }
138
139    pub fn with_network(mut self) -> Self {
140        self.available_capabilities |= Capabilities::NETWORK;
141        self
142    }
143
144    pub fn capabilities(&self) -> Capabilities {
145        self.available_capabilities
146    }
147
148    pub fn agent_spawner(&self) -> Option<&Arc<dyn AgentSpawner>> {
149        self.agent_spawner.as_ref()
150    }
151
152    pub fn model_caller(&self) -> Option<&Arc<dyn ModelCaller>> {
153        self.model_caller.as_ref()
154    }
155
156    pub fn workspace_manager(&self) -> Option<&Arc<dyn WorkspaceManager>> {
157        self.workspace_manager.as_ref()
158    }
159
160    pub fn repo_manager(&self) -> Option<&Arc<dyn RepoManager>> {
161        self.repo_manager.as_ref()
162    }
163}
164
165impl std::fmt::Debug for ToolServices {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        f.debug_struct("ToolServices")
168            .field("capabilities", &self.available_capabilities)
169            .finish_non_exhaustive()
170    }
171}
172
173pub fn filter_schemas_by_capabilities<'a>(
174    schemas: impl Iterator<Item = (&'a ToolSchema, Capabilities)>,
175    available: Capabilities,
176) -> Vec<ToolSchema> {
177    schemas
178        .filter(|(_, required)| available.satisfies(*required))
179        .map(|(schema, _)| schema.clone())
180        .collect()
181}