steer_grpc/
service_host.rs

1use crate::grpc::error::GrpcError;
2type Result<T> = std::result::Result<T, GrpcError>;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use tokio::sync::oneshot;
6use tokio::task::JoinHandle;
7use tonic::transport::Server;
8use tracing::{error, info};
9
10use crate::grpc::server::AgentServiceImpl;
11use steer_core::auth::storage::AuthStorage;
12use steer_core::catalog::CatalogConfig;
13use steer_core::session::{SessionManager, SessionManagerConfig, SessionStore};
14use steer_proto::agent::v1::agent_service_server::AgentServiceServer;
15
16/// Configuration for the ServiceHost
17#[derive(Clone)]
18pub struct ServiceHostConfig {
19    /// Path to the session database
20    pub db_path: std::path::PathBuf,
21    /// Session manager configuration
22    pub session_manager_config: SessionManagerConfig,
23    /// gRPC server bind address
24    pub bind_addr: SocketAddr,
25    /// Auth storage
26    pub auth_storage: Arc<dyn AuthStorage>,
27    /// Catalog configuration for additional models/providers
28    pub catalog_config: CatalogConfig,
29}
30
31impl std::fmt::Debug for ServiceHostConfig {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("ServiceHostConfig")
34            .field("db_path", &self.db_path)
35            .field("session_manager_config", &self.session_manager_config)
36            .field("bind_addr", &self.bind_addr)
37            .field("auth_storage", &"Arc<dyn AuthStorage>")
38            .field("catalog_config", &self.catalog_config)
39            .finish()
40    }
41}
42
43impl ServiceHostConfig {
44    /// Create a new ServiceHostConfig with default auth storage
45    pub fn new(
46        db_path: std::path::PathBuf,
47        session_manager_config: SessionManagerConfig,
48        bind_addr: SocketAddr,
49    ) -> Result<Self> {
50        let auth_storage = Arc::new(
51            steer_core::auth::DefaultAuthStorage::new()
52                .map_err(|e| GrpcError::CoreError(e.into()))?,
53        );
54
55        Ok(Self {
56            db_path,
57            session_manager_config,
58            bind_addr,
59            auth_storage,
60            catalog_config: CatalogConfig::default(),
61        })
62    }
63
64    /// Create a new ServiceHostConfig with custom catalog configuration
65    pub fn with_catalog(
66        db_path: std::path::PathBuf,
67        session_manager_config: SessionManagerConfig,
68        bind_addr: SocketAddr,
69        catalog_config: CatalogConfig,
70    ) -> Result<Self> {
71        let auth_storage = Arc::new(
72            steer_core::auth::DefaultAuthStorage::new()
73                .map_err(|e| GrpcError::CoreError(e.into()))?,
74        );
75
76        Ok(Self {
77            db_path,
78            session_manager_config,
79            bind_addr,
80            auth_storage,
81            catalog_config,
82        })
83    }
84}
85
86/// Main orchestrator for the service host system
87/// Manages the gRPC server, SessionManager, and component lifecycle
88pub struct ServiceHost {
89    session_manager: Arc<SessionManager>,
90    model_registry: Arc<steer_core::model_registry::ModelRegistry>,
91    provider_registry: Arc<steer_core::auth::ProviderRegistry>,
92    server_handle: Option<JoinHandle<Result<()>>>,
93    cleanup_handle: Option<JoinHandle<()>>,
94    shutdown_tx: Option<oneshot::Sender<()>>,
95    config: ServiceHostConfig,
96}
97
98impl ServiceHost {
99    /// Create a new ServiceHost with the given configuration
100    pub async fn new(config: ServiceHostConfig) -> Result<Self> {
101        // Initialize session store
102        let store = create_session_store(&config.db_path).await?;
103
104        // Load model registry once at startup
105        let model_registry = Arc::new(
106            steer_core::model_registry::ModelRegistry::load(&config.catalog_config.catalog_paths)
107                .map_err(|e| GrpcError::InvalidSessionState {
108                reason: format!("Failed to load model registry: {e}"),
109            })?,
110        );
111
112        // Load provider registry once at startup
113        let provider_registry = Arc::new(
114            steer_core::auth::ProviderRegistry::load(&config.catalog_config.catalog_paths)
115                .map_err(|e| GrpcError::InvalidSessionState {
116                    reason: format!("Failed to load provider registry: {e}"),
117                })?,
118        );
119
120        // Create session manager
121        let session_manager = Arc::new(SessionManager::new(
122            store,
123            config.session_manager_config.clone(),
124        ));
125
126        info!(
127            "ServiceHost initialized with database at {:?}",
128            config.db_path
129        );
130
131        Ok(Self {
132            session_manager,
133            model_registry,
134            provider_registry,
135            server_handle: None,
136            cleanup_handle: None,
137            shutdown_tx: None,
138            config,
139        })
140    }
141
142    /// Start the gRPC server
143    pub async fn start(&mut self) -> Result<()> {
144        if self.server_handle.is_some() {
145            return Err(GrpcError::InvalidSessionState {
146                reason: "Server is already running".to_string(),
147            });
148        }
149
150        // Use auth storage from config
151        let llm_config_provider =
152            steer_core::config::LlmConfigProvider::new(self.config.auth_storage.clone());
153
154        let service = AgentServiceImpl::new(
155            self.session_manager.clone(),
156            llm_config_provider,
157            self.model_registry.clone(),
158            self.provider_registry.clone(),
159        );
160        let (shutdown_tx, shutdown_rx) = oneshot::channel();
161
162        let addr = self.config.bind_addr;
163
164        info!("Starting gRPC server on {}", addr);
165
166        let server_handle = tokio::spawn(async move {
167            Server::builder()
168                .add_service(AgentServiceServer::new(service))
169                .serve_with_shutdown(addr, async {
170                    shutdown_rx.await.ok();
171                    info!("gRPC server shutdown signal received");
172                })
173                .await
174                .map_err(GrpcError::ConnectionFailed)
175        });
176
177        // Start periodic cleanup task
178        let session_manager = self.session_manager.clone();
179        let cleanup_handle = tokio::spawn(async move {
180            let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(300)); // 5 minutes
181            loop {
182                interval.tick().await;
183
184                // Clean up sessions that have been idle for more than 30 minutes
185                let idle_duration = chrono::Duration::minutes(30);
186                match session_manager
187                    .cleanup_inactive_sessions(idle_duration)
188                    .await
189                {
190                    0 => {} // No sessions cleaned, don't log
191                    count => info!("Cleaned up {} inactive sessions", count),
192                }
193            }
194        });
195
196        self.server_handle = Some(server_handle);
197        self.cleanup_handle = Some(cleanup_handle);
198        self.shutdown_tx = Some(shutdown_tx);
199
200        info!("gRPC server listening on {}", addr);
201        Ok(())
202    }
203
204    /// Shutdown the server gracefully
205    pub async fn shutdown(mut self) -> Result<()> {
206        info!("Initiating ServiceHost shutdown");
207
208        // Send shutdown signal to server
209        if let Some(shutdown_tx) = self.shutdown_tx.take() {
210            let _ = shutdown_tx.send(());
211        }
212
213        // Abort cleanup task
214        if let Some(cleanup_handle) = self.cleanup_handle.take() {
215            cleanup_handle.abort();
216        }
217
218        // Wait for server to finish
219        if let Some(server_handle) = self.server_handle.take() {
220            match server_handle.await {
221                Ok(Ok(())) => info!("gRPC server shut down successfully"),
222                Ok(Err(e)) => error!("gRPC server error during shutdown: {}", e),
223                Err(e) => error!("Failed to join server task: {}", e),
224            }
225        }
226
227        // Clean up active sessions
228        let active_sessions = self.session_manager.get_active_sessions().await;
229        for session_id in active_sessions {
230            if let Err(e) = self.session_manager.suspend_session(&session_id).await {
231                error!(
232                    "Failed to suspend session {} during shutdown: {}",
233                    session_id, e
234                );
235            }
236        }
237
238        info!("ServiceHost shutdown complete");
239        Ok(())
240    }
241
242    /// Get a reference to the session manager
243    pub fn session_manager(&self) -> &Arc<SessionManager> {
244        &self.session_manager
245    }
246
247    /// Wait for the server to finish (blocks until shutdown)
248    pub async fn wait(&mut self) -> Result<()> {
249        if let Some(server_handle) = &mut self.server_handle {
250            match server_handle.await {
251                Ok(result) => result,
252                Err(e) => Err(GrpcError::StreamError(format!("Server task panicked: {e}"))),
253            }
254        } else {
255            Err(GrpcError::InvalidSessionState {
256                reason: "Server is not running".to_string(),
257            })
258        }
259    }
260}
261
262/// Create a session store from the given database path
263async fn create_session_store(db_path: &std::path::Path) -> Result<Arc<dyn SessionStore>> {
264    use steer_core::session::SessionStoreConfig;
265    use steer_core::utils::session::create_session_store_with_config;
266
267    let config = SessionStoreConfig::sqlite(db_path.to_path_buf());
268    create_session_store_with_config(config)
269        .await
270        .map_err(|e| GrpcError::InvalidSessionState {
271            reason: format!("Failed to create session store: {e}"),
272        })
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    use tempfile::TempDir;
280
281    fn create_test_config() -> (ServiceHostConfig, TempDir) {
282        let temp_dir = TempDir::new().unwrap();
283        let db_path = temp_dir.path().join("test.db");
284
285        let config = ServiceHostConfig {
286            db_path,
287            session_manager_config: SessionManagerConfig {
288                max_concurrent_sessions: 10,
289                default_model: steer_core::config::model::builtin::claude_3_7_sonnet_20250219(),
290                auto_persist: true,
291            },
292            bind_addr: "127.0.0.1:0".parse().unwrap(), // Use port 0 for testing
293            auth_storage: Arc::new(steer_core::test_utils::InMemoryAuthStorage::new()),
294            catalog_config: CatalogConfig::default(),
295        };
296
297        (config, temp_dir)
298    }
299
300    #[tokio::test]
301    async fn test_service_host_creation() {
302        let (config, _temp_dir) = create_test_config();
303
304        let host = ServiceHost::new(config).await.unwrap();
305
306        // Verify session manager was created
307        assert_eq!(host.session_manager().get_active_sessions().await.len(), 0);
308    }
309
310    #[tokio::test]
311    async fn test_service_host_lifecycle() {
312        let (mut config, _temp_dir) = create_test_config();
313        config.bind_addr = "127.0.0.1:0".parse().unwrap(); // Use any available port
314
315        let mut host = ServiceHost::new(config).await.unwrap();
316
317        // Start server
318        host.start().await.unwrap();
319
320        // Verify it's running
321        assert!(host.server_handle.is_some());
322
323        // Shutdown
324        host.shutdown().await.unwrap();
325    }
326}