Skip to main content

steer_grpc/
service_host.rs

1use crate::grpc::error::GrpcError;
2type Result<T> = std::result::Result<T, GrpcError>;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use tokio::sync::oneshot;
6use tokio::task::JoinHandle;
7use tonic::transport::Server;
8use tracing::{error, info};
9
10use crate::grpc::RuntimeAgentService;
11use steer_core::api::Client as ApiClient;
12use steer_core::app::domain::runtime::{RuntimeHandle, RuntimeService};
13use steer_core::app::domain::session::{SessionCatalog, SqliteEventStore};
14use steer_core::auth::storage::AuthStorage;
15use steer_core::catalog::CatalogConfig;
16use steer_core::tools::ToolSystemBuilder;
17use steer_proto::agent::v1::agent_service_server::AgentServiceServer;
18use steer_workspace::{LocalEnvironmentManager, LocalWorkspaceManager, RepoManager};
19
20#[derive(Clone)]
21pub struct ServiceHostConfig {
22    pub db_path: std::path::PathBuf,
23    pub bind_addr: SocketAddr,
24    pub auth_storage: Arc<dyn AuthStorage>,
25    pub catalog_config: CatalogConfig,
26    pub workspace_root: Option<std::path::PathBuf>,
27}
28
29impl std::fmt::Debug for ServiceHostConfig {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("ServiceHostConfig")
32            .field("db_path", &self.db_path)
33            .field("bind_addr", &self.bind_addr)
34            .field("auth_storage", &"Arc<dyn AuthStorage>")
35            .field("catalog_config", &self.catalog_config)
36            .field("workspace_root", &self.workspace_root)
37            .finish()
38    }
39}
40
41impl ServiceHostConfig {
42    pub fn new(db_path: std::path::PathBuf, bind_addr: SocketAddr) -> Result<Self> {
43        let auth_storage = Arc::new(
44            steer_core::auth::DefaultAuthStorage::new()
45                .map_err(|e| GrpcError::CoreError(e.into()))?,
46        );
47
48        Ok(Self {
49            db_path,
50            bind_addr,
51            auth_storage,
52            catalog_config: CatalogConfig::default(),
53            workspace_root: None,
54        })
55    }
56
57    pub fn with_catalog(
58        db_path: std::path::PathBuf,
59        bind_addr: SocketAddr,
60        catalog_config: CatalogConfig,
61    ) -> Result<Self> {
62        let auth_storage = Arc::new(
63            steer_core::auth::DefaultAuthStorage::new()
64                .map_err(|e| GrpcError::CoreError(e.into()))?,
65        );
66
67        Ok(Self {
68            db_path,
69            bind_addr,
70            auth_storage,
71            catalog_config,
72            workspace_root: None,
73        })
74    }
75}
76
77pub struct ServiceHost {
78    runtime_service: RuntimeService,
79    runtime_handle: RuntimeHandle,
80    catalog: Arc<dyn SessionCatalog>,
81    model_registry: Arc<steer_core::model_registry::ModelRegistry>,
82    provider_registry: Arc<steer_core::auth::ProviderRegistry>,
83    llm_config_provider: steer_core::config::LlmConfigProvider,
84    environment_root: std::path::PathBuf,
85    server_handle: Option<JoinHandle<Result<()>>>,
86    shutdown_tx: Option<oneshot::Sender<()>>,
87    config: ServiceHostConfig,
88}
89
90impl ServiceHost {
91    pub async fn new(config: ServiceHostConfig) -> Result<Self> {
92        let event_store = Arc::new(SqliteEventStore::new(&config.db_path).await.map_err(|e| {
93            GrpcError::InvalidSessionState {
94                reason: format!("Failed to create event store: {e}"),
95            }
96        })?);
97
98        let catalog: Arc<dyn SessionCatalog> = event_store.clone();
99
100        let model_registry = Arc::new(
101            steer_core::model_registry::ModelRegistry::load(&config.catalog_config.catalog_paths)
102                .map_err(|e| GrpcError::InvalidSessionState {
103                reason: format!("Failed to load model registry: {e}"),
104            })?,
105        );
106
107        let provider_registry = Arc::new(
108            steer_core::auth::ProviderRegistry::load(&config.catalog_config.catalog_paths)
109                .map_err(|e| GrpcError::InvalidSessionState {
110                    reason: format!("Failed to load provider registry: {e}"),
111                })?,
112        );
113
114        let llm_config_provider =
115            steer_core::config::LlmConfigProvider::new(config.auth_storage.clone())
116                .map_err(GrpcError::CoreError)?;
117
118        let api_client = Arc::new(ApiClient::new_with_deps(
119            llm_config_provider.clone(),
120            provider_registry.clone(),
121            model_registry.clone(),
122        ));
123
124        let workspace_path = config.workspace_root.clone().unwrap_or_else(|| {
125            std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."))
126        });
127        let environment_root = steer_core::utils::paths::AppPaths::local_environment_root();
128        let workspace = steer_core::workspace::create_workspace(
129            &steer_core::workspace::WorkspaceConfig::Local {
130                path: workspace_path.clone(),
131            },
132        )
133        .await
134        .map_err(|e| GrpcError::InvalidSessionState {
135            reason: format!("Failed to create workspace: {e}"),
136        })?;
137        let workspace_manager = Arc::new(
138            LocalWorkspaceManager::new(environment_root.clone())
139                .await
140                .map_err(|e| GrpcError::InvalidSessionState {
141                    reason: format!("Failed to create workspace manager: {e}"),
142                })?,
143        );
144        let repo_manager: Arc<dyn RepoManager> = workspace_manager.clone();
145
146        let tool_executor = ToolSystemBuilder::new(
147            workspace,
148            event_store.clone(),
149            api_client.clone(),
150            model_registry.clone(),
151        )
152        .with_workspace_manager(workspace_manager)
153        .with_repo_manager(repo_manager)
154        .build();
155
156        let runtime_service = RuntimeService::spawn(event_store, api_client, tool_executor);
157
158        let runtime_handle = runtime_service.handle();
159
160        info!(
161            "ServiceHost initialized with database at {:?}",
162            config.db_path
163        );
164
165        Ok(Self {
166            runtime_service,
167            runtime_handle,
168            catalog,
169            model_registry,
170            provider_registry,
171            llm_config_provider,
172            environment_root,
173            server_handle: None,
174            shutdown_tx: None,
175            config,
176        })
177    }
178
179    pub async fn start(&mut self) -> Result<()> {
180        if self.server_handle.is_some() {
181            return Err(GrpcError::InvalidSessionState {
182                reason: "Server is already running".to_string(),
183            });
184        }
185
186        let environment_root = self.environment_root.clone();
187        let workspace_manager = Arc::new(
188            LocalWorkspaceManager::new(environment_root.clone())
189                .await
190                .map_err(|e| GrpcError::InvalidSessionState {
191                    reason: format!("Failed to create workspace manager: {e}"),
192                })?,
193        );
194        let repo_manager: Arc<dyn RepoManager> = workspace_manager.clone();
195        let environment_manager = Arc::new(LocalEnvironmentManager::new(environment_root));
196
197        let service = RuntimeAgentService::new(crate::grpc::RuntimeAgentDeps {
198            runtime: self.runtime_handle.clone(),
199            catalog: self.catalog.clone(),
200            llm_config_provider: self.llm_config_provider.clone(),
201            model_registry: self.model_registry.clone(),
202            provider_registry: self.provider_registry.clone(),
203            environment_manager,
204            workspace_manager,
205            repo_manager,
206        });
207
208        let (shutdown_tx, shutdown_rx) = oneshot::channel();
209        let addr = self.config.bind_addr;
210
211        info!("Starting gRPC server on {}", addr);
212
213        let server_handle = tokio::spawn(async move {
214            Server::builder()
215                .add_service(AgentServiceServer::new(service))
216                .serve_with_shutdown(addr, async {
217                    shutdown_rx.await.ok();
218                    info!("gRPC server shutdown signal received");
219                })
220                .await
221                .map_err(GrpcError::ConnectionFailed)
222        });
223
224        self.server_handle = Some(server_handle);
225        self.shutdown_tx = Some(shutdown_tx);
226
227        info!("gRPC server listening on {}", addr);
228        Ok(())
229    }
230
231    pub async fn shutdown(mut self) -> Result<()> {
232        info!("Initiating ServiceHost shutdown");
233
234        if let Some(shutdown_tx) = self.shutdown_tx.take() {
235            let _ = shutdown_tx.send(());
236        }
237
238        if let Some(server_handle) = self.server_handle.take() {
239            match server_handle.await {
240                Ok(Ok(())) => info!("gRPC server shut down successfully"),
241                Ok(Err(e)) => error!("gRPC server error during shutdown: {}", e),
242                Err(e) => error!("Failed to join server task: {}", e),
243            }
244        }
245
246        self.runtime_service.shutdown().await;
247
248        info!("ServiceHost shutdown complete");
249        Ok(())
250    }
251
252    pub fn runtime_handle(&self) -> &RuntimeHandle {
253        &self.runtime_handle
254    }
255
256    pub async fn wait(&mut self) -> Result<()> {
257        if let Some(server_handle) = &mut self.server_handle {
258            match server_handle.await {
259                Ok(result) => result,
260                Err(e) => Err(GrpcError::StreamError(format!("Server task panicked: {e}"))),
261            }
262        } else {
263            Err(GrpcError::InvalidSessionState {
264                reason: "Server is not running".to_string(),
265            })
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use tempfile::TempDir;
274
275    fn create_test_config() -> (ServiceHostConfig, TempDir) {
276        let temp_dir = TempDir::new().unwrap();
277        let db_path = temp_dir.path().join("test.db");
278
279        let config = ServiceHostConfig {
280            db_path,
281            bind_addr: "127.0.0.1:0".parse().unwrap(),
282            auth_storage: Arc::new(steer_core::test_utils::InMemoryAuthStorage::new()),
283            catalog_config: CatalogConfig::default(),
284            workspace_root: Some(temp_dir.path().to_path_buf()),
285        };
286
287        (config, temp_dir)
288    }
289
290    #[tokio::test]
291    async fn test_service_host_creation() {
292        let (config, _temp_dir) = create_test_config();
293
294        let host = ServiceHost::new(config).await.unwrap();
295
296        let sessions = host.runtime_handle.list_all_sessions().await.unwrap();
297        assert!(sessions.is_empty());
298    }
299
300    #[tokio::test]
301    async fn test_service_host_lifecycle() {
302        let (mut config, _temp_dir) = create_test_config();
303        config.bind_addr = "127.0.0.1:0".parse().unwrap();
304
305        let mut host = ServiceHost::new(config).await.unwrap();
306
307        host.start().await.unwrap();
308        assert!(host.server_handle.is_some());
309
310        host.shutdown().await.unwrap();
311    }
312}