steer_grpc/
service_host.rs

1use crate::grpc::error::GrpcError;
2type Result<T> = std::result::Result<T, GrpcError>;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use tokio::sync::oneshot;
6use tokio::task::JoinHandle;
7use tonic::transport::Server;
8use tracing::{error, info};
9
10use crate::grpc::server::AgentServiceImpl;
11use steer_core::auth::storage::AuthStorage;
12use steer_core::session::{SessionManager, SessionManagerConfig, SessionStore};
13use steer_proto::agent::v1::agent_service_server::AgentServiceServer;
14
15/// Configuration for the ServiceHost
16#[derive(Clone)]
17pub struct ServiceHostConfig {
18    /// Path to the session database
19    pub db_path: std::path::PathBuf,
20    /// Session manager configuration
21    pub session_manager_config: SessionManagerConfig,
22    /// gRPC server bind address
23    pub bind_addr: SocketAddr,
24    /// Auth storage
25    pub auth_storage: Arc<dyn AuthStorage>,
26}
27
28impl std::fmt::Debug for ServiceHostConfig {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("ServiceHostConfig")
31            .field("db_path", &self.db_path)
32            .field("session_manager_config", &self.session_manager_config)
33            .field("bind_addr", &self.bind_addr)
34            .field("auth_storage", &"Arc<dyn AuthStorage>")
35            .finish()
36    }
37}
38
39impl ServiceHostConfig {
40    /// Create a new ServiceHostConfig with default auth storage
41    pub fn new(
42        db_path: std::path::PathBuf,
43        session_manager_config: SessionManagerConfig,
44        bind_addr: SocketAddr,
45    ) -> Result<Self> {
46        let auth_storage = Arc::new(
47            steer_core::auth::DefaultAuthStorage::new()
48                .map_err(|e| GrpcError::CoreError(e.into()))?,
49        );
50
51        Ok(Self {
52            db_path,
53            session_manager_config,
54            bind_addr,
55            auth_storage,
56        })
57    }
58}
59
60/// Main orchestrator for the service host system
61/// Manages the gRPC server, SessionManager, and component lifecycle
62pub struct ServiceHost {
63    session_manager: Arc<SessionManager>,
64    server_handle: Option<JoinHandle<Result<()>>>,
65    cleanup_handle: Option<JoinHandle<()>>,
66    shutdown_tx: Option<oneshot::Sender<()>>,
67    config: ServiceHostConfig,
68}
69
70impl ServiceHost {
71    /// Create a new ServiceHost with the given configuration
72    pub async fn new(config: ServiceHostConfig) -> Result<Self> {
73        // Initialize session store
74        let store = create_session_store(&config.db_path).await?;
75
76        // Create session manager
77        let session_manager = Arc::new(SessionManager::new(
78            store,
79            config.session_manager_config.clone(),
80        ));
81
82        info!(
83            "ServiceHost initialized with database at {:?}",
84            config.db_path
85        );
86
87        Ok(Self {
88            session_manager,
89            server_handle: None,
90            cleanup_handle: None,
91            shutdown_tx: None,
92            config,
93        })
94    }
95
96    /// Start the gRPC server
97    pub async fn start(&mut self) -> Result<()> {
98        if self.server_handle.is_some() {
99            return Err(GrpcError::InvalidSessionState {
100                reason: "Server is already running".to_string(),
101            });
102        }
103
104        // Use auth storage from config
105        let llm_config_provider =
106            steer_core::config::LlmConfigProvider::new(self.config.auth_storage.clone());
107
108        let service = AgentServiceImpl::new(self.session_manager.clone(), llm_config_provider);
109        let (shutdown_tx, shutdown_rx) = oneshot::channel();
110
111        let addr = self.config.bind_addr;
112
113        info!("Starting gRPC server on {}", addr);
114
115        let server_handle = tokio::spawn(async move {
116            Server::builder()
117                .add_service(AgentServiceServer::new(service))
118                .serve_with_shutdown(addr, async {
119                    shutdown_rx.await.ok();
120                    info!("gRPC server shutdown signal received");
121                })
122                .await
123                .map_err(GrpcError::ConnectionFailed)
124        });
125
126        // Start periodic cleanup task
127        let session_manager = self.session_manager.clone();
128        let cleanup_handle = tokio::spawn(async move {
129            let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(300)); // 5 minutes
130            loop {
131                interval.tick().await;
132
133                // Clean up sessions that have been idle for more than 30 minutes
134                let idle_duration = chrono::Duration::minutes(30);
135                match session_manager
136                    .cleanup_inactive_sessions(idle_duration)
137                    .await
138                {
139                    0 => {} // No sessions cleaned, don't log
140                    count => info!("Cleaned up {} inactive sessions", count),
141                }
142            }
143        });
144
145        self.server_handle = Some(server_handle);
146        self.cleanup_handle = Some(cleanup_handle);
147        self.shutdown_tx = Some(shutdown_tx);
148
149        info!("gRPC server listening on {}", addr);
150        Ok(())
151    }
152
153    /// Shutdown the server gracefully
154    pub async fn shutdown(mut self) -> Result<()> {
155        info!("Initiating ServiceHost shutdown");
156
157        // Send shutdown signal to server
158        if let Some(shutdown_tx) = self.shutdown_tx.take() {
159            let _ = shutdown_tx.send(());
160        }
161
162        // Abort cleanup task
163        if let Some(cleanup_handle) = self.cleanup_handle.take() {
164            cleanup_handle.abort();
165        }
166
167        // Wait for server to finish
168        if let Some(server_handle) = self.server_handle.take() {
169            match server_handle.await {
170                Ok(Ok(())) => info!("gRPC server shut down successfully"),
171                Ok(Err(e)) => error!("gRPC server error during shutdown: {}", e),
172                Err(e) => error!("Failed to join server task: {}", e),
173            }
174        }
175
176        // Clean up active sessions
177        let active_sessions = self.session_manager.get_active_sessions().await;
178        for session_id in active_sessions {
179            if let Err(e) = self.session_manager.suspend_session(&session_id).await {
180                error!(
181                    "Failed to suspend session {} during shutdown: {}",
182                    session_id, e
183                );
184            }
185        }
186
187        info!("ServiceHost shutdown complete");
188        Ok(())
189    }
190
191    /// Get a reference to the session manager
192    pub fn session_manager(&self) -> &Arc<SessionManager> {
193        &self.session_manager
194    }
195
196    /// Wait for the server to finish (blocks until shutdown)
197    pub async fn wait(&mut self) -> Result<()> {
198        if let Some(server_handle) = &mut self.server_handle {
199            match server_handle.await {
200                Ok(result) => result,
201                Err(e) => Err(GrpcError::StreamError(format!("Server task panicked: {e}"))),
202            }
203        } else {
204            Err(GrpcError::InvalidSessionState {
205                reason: "Server is not running".to_string(),
206            })
207        }
208    }
209}
210
211/// Create a session store from the given database path
212async fn create_session_store(db_path: &std::path::Path) -> Result<Arc<dyn SessionStore>> {
213    use steer_core::session::SessionStoreConfig;
214    use steer_core::utils::session::create_session_store_with_config;
215
216    let config = SessionStoreConfig::sqlite(db_path.to_path_buf());
217    create_session_store_with_config(config)
218        .await
219        .map_err(|e| GrpcError::InvalidSessionState {
220            reason: format!("Failed to create session store: {e}"),
221        })
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use steer_core::api::Model;
228    use tempfile::TempDir;
229
230    fn create_test_config() -> (ServiceHostConfig, TempDir) {
231        let temp_dir = TempDir::new().unwrap();
232        let db_path = temp_dir.path().join("test.db");
233
234        let config = ServiceHostConfig {
235            db_path,
236            session_manager_config: SessionManagerConfig {
237                max_concurrent_sessions: 10,
238                default_model: Model::ClaudeSonnet4_20250514,
239                auto_persist: true,
240            },
241            bind_addr: "127.0.0.1:0".parse().unwrap(), // Use port 0 for testing
242            auth_storage: Arc::new(steer_core::test_utils::InMemoryAuthStorage::new()),
243        };
244
245        (config, temp_dir)
246    }
247
248    #[tokio::test]
249    async fn test_service_host_creation() {
250        let (config, _temp_dir) = create_test_config();
251
252        let host = ServiceHost::new(config).await.unwrap();
253
254        // Verify session manager was created
255        assert_eq!(host.session_manager().get_active_sessions().await.len(), 0);
256    }
257
258    #[tokio::test]
259    async fn test_service_host_lifecycle() {
260        let (mut config, _temp_dir) = create_test_config();
261        config.bind_addr = "127.0.0.1:0".parse().unwrap(); // Use any available port
262
263        let mut host = ServiceHost::new(config).await.unwrap();
264
265        // Start server
266        host.start().await.unwrap();
267
268        // Verify it's running
269        assert!(host.server_handle.is_some());
270
271        // Shutdown
272        host.shutdown().await.unwrap();
273    }
274}