1use crate::grpc::error::GrpcError;
2type Result<T> = std::result::Result<T, GrpcError>;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use tokio::sync::oneshot;
6use tokio::task::JoinHandle;
7use tonic::transport::Server;
8use tracing::{error, info};
9
10use crate::grpc::RuntimeAgentService;
11use steer_core::api::Client as ApiClient;
12use steer_core::app::domain::runtime::{RuntimeHandle, RuntimeService};
13use steer_core::app::domain::session::{SessionCatalog, SqliteEventStore};
14use steer_core::auth::storage::AuthStorage;
15use steer_core::catalog::CatalogConfig;
16use steer_core::tools::ToolSystemBuilder;
17use steer_proto::agent::v1::agent_service_server::AgentServiceServer;
18use steer_workspace::{LocalEnvironmentManager, LocalWorkspaceManager, RepoManager};
19
20#[derive(Clone)]
21pub struct ServiceHostConfig {
22 pub db_path: std::path::PathBuf,
23 pub bind_addr: SocketAddr,
24 pub auth_storage: Arc<dyn AuthStorage>,
25 pub catalog_config: CatalogConfig,
26 pub workspace_root: Option<std::path::PathBuf>,
27}
28
29impl std::fmt::Debug for ServiceHostConfig {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("ServiceHostConfig")
32 .field("db_path", &self.db_path)
33 .field("bind_addr", &self.bind_addr)
34 .field("auth_storage", &"Arc<dyn AuthStorage>")
35 .field("catalog_config", &self.catalog_config)
36 .field("workspace_root", &self.workspace_root)
37 .finish()
38 }
39}
40
41impl ServiceHostConfig {
42 pub fn new(db_path: std::path::PathBuf, bind_addr: SocketAddr) -> Result<Self> {
43 let auth_storage = Arc::new(
44 steer_core::auth::DefaultAuthStorage::new()
45 .map_err(|e| GrpcError::CoreError(e.into()))?,
46 );
47
48 Ok(Self {
49 db_path,
50 bind_addr,
51 auth_storage,
52 catalog_config: CatalogConfig::default(),
53 workspace_root: None,
54 })
55 }
56
57 pub fn with_catalog(
58 db_path: std::path::PathBuf,
59 bind_addr: SocketAddr,
60 catalog_config: CatalogConfig,
61 ) -> Result<Self> {
62 let auth_storage = Arc::new(
63 steer_core::auth::DefaultAuthStorage::new()
64 .map_err(|e| GrpcError::CoreError(e.into()))?,
65 );
66
67 Ok(Self {
68 db_path,
69 bind_addr,
70 auth_storage,
71 catalog_config,
72 workspace_root: None,
73 })
74 }
75}
76
77pub struct ServiceHost {
78 runtime_service: RuntimeService,
79 runtime_handle: RuntimeHandle,
80 catalog: Arc<dyn SessionCatalog>,
81 model_registry: Arc<steer_core::model_registry::ModelRegistry>,
82 provider_registry: Arc<steer_core::auth::ProviderRegistry>,
83 llm_config_provider: steer_core::config::LlmConfigProvider,
84 environment_root: std::path::PathBuf,
85 server_handle: Option<JoinHandle<Result<()>>>,
86 shutdown_tx: Option<oneshot::Sender<()>>,
87 config: ServiceHostConfig,
88}
89
90impl ServiceHost {
91 pub async fn new(config: ServiceHostConfig) -> Result<Self> {
92 let event_store = Arc::new(SqliteEventStore::new(&config.db_path).await.map_err(|e| {
93 GrpcError::InvalidSessionState {
94 reason: format!("Failed to create event store: {e}"),
95 }
96 })?);
97
98 let catalog: Arc<dyn SessionCatalog> = event_store.clone();
99
100 let model_registry = Arc::new(
101 steer_core::model_registry::ModelRegistry::load(&config.catalog_config.catalog_paths)
102 .map_err(|e| GrpcError::InvalidSessionState {
103 reason: format!("Failed to load model registry: {e}"),
104 })?,
105 );
106
107 let provider_registry = Arc::new(
108 steer_core::auth::ProviderRegistry::load(&config.catalog_config.catalog_paths)
109 .map_err(|e| GrpcError::InvalidSessionState {
110 reason: format!("Failed to load provider registry: {e}"),
111 })?,
112 );
113
114 let llm_config_provider =
115 steer_core::config::LlmConfigProvider::new(config.auth_storage.clone())
116 .map_err(GrpcError::CoreError)?;
117
118 let api_client = Arc::new(ApiClient::new_with_deps(
119 llm_config_provider.clone(),
120 provider_registry.clone(),
121 model_registry.clone(),
122 ));
123
124 let workspace_path = config.workspace_root.clone().unwrap_or_else(|| {
125 std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."))
126 });
127 let environment_root = steer_core::utils::paths::AppPaths::local_environment_root();
128 let workspace = steer_core::workspace::create_workspace(
129 &steer_core::workspace::WorkspaceConfig::Local {
130 path: workspace_path.clone(),
131 },
132 )
133 .await
134 .map_err(|e| GrpcError::InvalidSessionState {
135 reason: format!("Failed to create workspace: {e}"),
136 })?;
137 let workspace_manager = Arc::new(
138 LocalWorkspaceManager::new(environment_root.clone())
139 .await
140 .map_err(|e| GrpcError::InvalidSessionState {
141 reason: format!("Failed to create workspace manager: {e}"),
142 })?,
143 );
144 let repo_manager: Arc<dyn RepoManager> = workspace_manager.clone();
145
146 let tool_executor = ToolSystemBuilder::new(
147 workspace,
148 event_store.clone(),
149 api_client.clone(),
150 model_registry.clone(),
151 )
152 .with_workspace_manager(workspace_manager)
153 .with_repo_manager(repo_manager)
154 .build();
155
156 let runtime_service = RuntimeService::spawn(event_store, api_client, tool_executor);
157
158 let runtime_handle = runtime_service.handle();
159
160 info!(
161 "ServiceHost initialized with database at {:?}",
162 config.db_path
163 );
164
165 Ok(Self {
166 runtime_service,
167 runtime_handle,
168 catalog,
169 model_registry,
170 provider_registry,
171 llm_config_provider,
172 environment_root,
173 server_handle: None,
174 shutdown_tx: None,
175 config,
176 })
177 }
178
179 pub async fn start(&mut self) -> Result<()> {
180 if self.server_handle.is_some() {
181 return Err(GrpcError::InvalidSessionState {
182 reason: "Server is already running".to_string(),
183 });
184 }
185
186 let environment_root = self.environment_root.clone();
187 let workspace_manager = Arc::new(
188 LocalWorkspaceManager::new(environment_root.clone())
189 .await
190 .map_err(|e| GrpcError::InvalidSessionState {
191 reason: format!("Failed to create workspace manager: {e}"),
192 })?,
193 );
194 let repo_manager: Arc<dyn RepoManager> = workspace_manager.clone();
195 let environment_manager = Arc::new(LocalEnvironmentManager::new(environment_root));
196
197 let service = RuntimeAgentService::new(crate::grpc::RuntimeAgentDeps {
198 runtime: self.runtime_handle.clone(),
199 catalog: self.catalog.clone(),
200 llm_config_provider: self.llm_config_provider.clone(),
201 model_registry: self.model_registry.clone(),
202 provider_registry: self.provider_registry.clone(),
203 environment_manager,
204 workspace_manager,
205 repo_manager,
206 });
207
208 let (shutdown_tx, shutdown_rx) = oneshot::channel();
209 let addr = self.config.bind_addr;
210
211 info!("Starting gRPC server on {}", addr);
212
213 let server_handle = tokio::spawn(async move {
214 Server::builder()
215 .add_service(AgentServiceServer::new(service))
216 .serve_with_shutdown(addr, async {
217 shutdown_rx.await.ok();
218 info!("gRPC server shutdown signal received");
219 })
220 .await
221 .map_err(GrpcError::ConnectionFailed)
222 });
223
224 self.server_handle = Some(server_handle);
225 self.shutdown_tx = Some(shutdown_tx);
226
227 info!("gRPC server listening on {}", addr);
228 Ok(())
229 }
230
231 pub async fn shutdown(mut self) -> Result<()> {
232 info!("Initiating ServiceHost shutdown");
233
234 if let Some(shutdown_tx) = self.shutdown_tx.take() {
235 let _ = shutdown_tx.send(());
236 }
237
238 if let Some(server_handle) = self.server_handle.take() {
239 match server_handle.await {
240 Ok(Ok(())) => info!("gRPC server shut down successfully"),
241 Ok(Err(e)) => error!("gRPC server error during shutdown: {}", e),
242 Err(e) => error!("Failed to join server task: {}", e),
243 }
244 }
245
246 self.runtime_service.shutdown().await;
247
248 info!("ServiceHost shutdown complete");
249 Ok(())
250 }
251
252 pub fn runtime_handle(&self) -> &RuntimeHandle {
253 &self.runtime_handle
254 }
255
256 pub async fn wait(&mut self) -> Result<()> {
257 if let Some(server_handle) = &mut self.server_handle {
258 match server_handle.await {
259 Ok(result) => result,
260 Err(e) => Err(GrpcError::StreamError(format!("Server task panicked: {e}"))),
261 }
262 } else {
263 Err(GrpcError::InvalidSessionState {
264 reason: "Server is not running".to_string(),
265 })
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use tempfile::TempDir;
274
275 fn create_test_config() -> (ServiceHostConfig, TempDir) {
276 let temp_dir = TempDir::new().unwrap();
277 let db_path = temp_dir.path().join("test.db");
278
279 let config = ServiceHostConfig {
280 db_path,
281 bind_addr: "127.0.0.1:0".parse().unwrap(),
282 auth_storage: Arc::new(steer_core::test_utils::InMemoryAuthStorage::new()),
283 catalog_config: CatalogConfig::default(),
284 workspace_root: Some(temp_dir.path().to_path_buf()),
285 };
286
287 (config, temp_dir)
288 }
289
290 #[tokio::test]
291 async fn test_service_host_creation() {
292 let (config, _temp_dir) = create_test_config();
293
294 let host = ServiceHost::new(config).await.unwrap();
295
296 let sessions = host.runtime_handle.list_all_sessions().await.unwrap();
297 assert!(sessions.is_empty());
298 }
299
300 #[tokio::test]
301 async fn test_service_host_lifecycle() {
302 let (mut config, _temp_dir) = create_test_config();
303 config.bind_addr = "127.0.0.1:0".parse().unwrap();
304
305 let mut host = ServiceHost::new(config).await.unwrap();
306
307 host.start().await.unwrap();
308 assert!(host.server_handle.is_some());
309
310 host.shutdown().await.unwrap();
311 }
312}