steer_grpc/
service_host.rs1use crate::grpc::error::GrpcError;
2type Result<T> = std::result::Result<T, GrpcError>;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use tokio::sync::oneshot;
6use tokio::task::JoinHandle;
7use tonic::transport::Server;
8use tracing::{error, info};
9
10use crate::grpc::server::AgentServiceImpl;
11use steer_core::auth::storage::AuthStorage;
12use steer_core::session::{SessionManager, SessionManagerConfig, SessionStore};
13use steer_proto::agent::v1::agent_service_server::AgentServiceServer;
14
15#[derive(Clone)]
17pub struct ServiceHostConfig {
18 pub db_path: std::path::PathBuf,
20 pub session_manager_config: SessionManagerConfig,
22 pub bind_addr: SocketAddr,
24 pub auth_storage: Arc<dyn AuthStorage>,
26}
27
28impl std::fmt::Debug for ServiceHostConfig {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("ServiceHostConfig")
31 .field("db_path", &self.db_path)
32 .field("session_manager_config", &self.session_manager_config)
33 .field("bind_addr", &self.bind_addr)
34 .field("auth_storage", &"Arc<dyn AuthStorage>")
35 .finish()
36 }
37}
38
39impl ServiceHostConfig {
40 pub fn new(
42 db_path: std::path::PathBuf,
43 session_manager_config: SessionManagerConfig,
44 bind_addr: SocketAddr,
45 ) -> Result<Self> {
46 let auth_storage = Arc::new(
47 steer_core::auth::DefaultAuthStorage::new()
48 .map_err(|e| GrpcError::CoreError(e.into()))?,
49 );
50
51 Ok(Self {
52 db_path,
53 session_manager_config,
54 bind_addr,
55 auth_storage,
56 })
57 }
58}
59
60pub struct ServiceHost {
63 session_manager: Arc<SessionManager>,
64 server_handle: Option<JoinHandle<Result<()>>>,
65 cleanup_handle: Option<JoinHandle<()>>,
66 shutdown_tx: Option<oneshot::Sender<()>>,
67 config: ServiceHostConfig,
68}
69
70impl ServiceHost {
71 pub async fn new(config: ServiceHostConfig) -> Result<Self> {
73 let store = create_session_store(&config.db_path).await?;
75
76 let session_manager = Arc::new(SessionManager::new(
78 store,
79 config.session_manager_config.clone(),
80 ));
81
82 info!(
83 "ServiceHost initialized with database at {:?}",
84 config.db_path
85 );
86
87 Ok(Self {
88 session_manager,
89 server_handle: None,
90 cleanup_handle: None,
91 shutdown_tx: None,
92 config,
93 })
94 }
95
96 pub async fn start(&mut self) -> Result<()> {
98 if self.server_handle.is_some() {
99 return Err(GrpcError::InvalidSessionState {
100 reason: "Server is already running".to_string(),
101 });
102 }
103
104 let llm_config_provider =
106 steer_core::config::LlmConfigProvider::new(self.config.auth_storage.clone());
107
108 let service = AgentServiceImpl::new(self.session_manager.clone(), llm_config_provider);
109 let (shutdown_tx, shutdown_rx) = oneshot::channel();
110
111 let addr = self.config.bind_addr;
112
113 info!("Starting gRPC server on {}", addr);
114
115 let server_handle = tokio::spawn(async move {
116 Server::builder()
117 .add_service(AgentServiceServer::new(service))
118 .serve_with_shutdown(addr, async {
119 shutdown_rx.await.ok();
120 info!("gRPC server shutdown signal received");
121 })
122 .await
123 .map_err(GrpcError::ConnectionFailed)
124 });
125
126 let session_manager = self.session_manager.clone();
128 let cleanup_handle = tokio::spawn(async move {
129 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(300)); loop {
131 interval.tick().await;
132
133 let idle_duration = chrono::Duration::minutes(30);
135 match session_manager
136 .cleanup_inactive_sessions(idle_duration)
137 .await
138 {
139 0 => {} count => info!("Cleaned up {} inactive sessions", count),
141 }
142 }
143 });
144
145 self.server_handle = Some(server_handle);
146 self.cleanup_handle = Some(cleanup_handle);
147 self.shutdown_tx = Some(shutdown_tx);
148
149 info!("gRPC server listening on {}", addr);
150 Ok(())
151 }
152
153 pub async fn shutdown(mut self) -> Result<()> {
155 info!("Initiating ServiceHost shutdown");
156
157 if let Some(shutdown_tx) = self.shutdown_tx.take() {
159 let _ = shutdown_tx.send(());
160 }
161
162 if let Some(cleanup_handle) = self.cleanup_handle.take() {
164 cleanup_handle.abort();
165 }
166
167 if let Some(server_handle) = self.server_handle.take() {
169 match server_handle.await {
170 Ok(Ok(())) => info!("gRPC server shut down successfully"),
171 Ok(Err(e)) => error!("gRPC server error during shutdown: {}", e),
172 Err(e) => error!("Failed to join server task: {}", e),
173 }
174 }
175
176 let active_sessions = self.session_manager.get_active_sessions().await;
178 for session_id in active_sessions {
179 if let Err(e) = self.session_manager.suspend_session(&session_id).await {
180 error!(
181 "Failed to suspend session {} during shutdown: {}",
182 session_id, e
183 );
184 }
185 }
186
187 info!("ServiceHost shutdown complete");
188 Ok(())
189 }
190
191 pub fn session_manager(&self) -> &Arc<SessionManager> {
193 &self.session_manager
194 }
195
196 pub async fn wait(&mut self) -> Result<()> {
198 if let Some(server_handle) = &mut self.server_handle {
199 match server_handle.await {
200 Ok(result) => result,
201 Err(e) => Err(GrpcError::StreamError(format!("Server task panicked: {e}"))),
202 }
203 } else {
204 Err(GrpcError::InvalidSessionState {
205 reason: "Server is not running".to_string(),
206 })
207 }
208 }
209}
210
211async fn create_session_store(db_path: &std::path::Path) -> Result<Arc<dyn SessionStore>> {
213 use steer_core::session::SessionStoreConfig;
214 use steer_core::utils::session::create_session_store_with_config;
215
216 let config = SessionStoreConfig::sqlite(db_path.to_path_buf());
217 create_session_store_with_config(config)
218 .await
219 .map_err(|e| GrpcError::InvalidSessionState {
220 reason: format!("Failed to create session store: {e}"),
221 })
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use steer_core::api::Model;
228 use tempfile::TempDir;
229
230 fn create_test_config() -> (ServiceHostConfig, TempDir) {
231 let temp_dir = TempDir::new().unwrap();
232 let db_path = temp_dir.path().join("test.db");
233
234 let config = ServiceHostConfig {
235 db_path,
236 session_manager_config: SessionManagerConfig {
237 max_concurrent_sessions: 10,
238 default_model: Model::ClaudeSonnet4_20250514,
239 auto_persist: true,
240 },
241 bind_addr: "127.0.0.1:0".parse().unwrap(), auth_storage: Arc::new(steer_core::test_utils::InMemoryAuthStorage::new()),
243 };
244
245 (config, temp_dir)
246 }
247
248 #[tokio::test]
249 async fn test_service_host_creation() {
250 let (config, _temp_dir) = create_test_config();
251
252 let host = ServiceHost::new(config).await.unwrap();
253
254 assert_eq!(host.session_manager().get_active_sessions().await.len(), 0);
256 }
257
258 #[tokio::test]
259 async fn test_service_host_lifecycle() {
260 let (mut config, _temp_dir) = create_test_config();
261 config.bind_addr = "127.0.0.1:0".parse().unwrap(); let mut host = ServiceHost::new(config).await.unwrap();
264
265 host.start().await.unwrap();
267
268 assert!(host.server_handle.is_some());
270
271 host.shutdown().await.unwrap();
273 }
274}