vex_api/
server.rs

1//! VEX API Server with graceful shutdown
2
3use axum::{middleware, Router};
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::signal;
8use tower_http::compression::CompressionLayer;
9
10use crate::auth::JwtAuth;
11use crate::error::ApiError;
12use crate::middleware::{
13    auth_middleware, body_limit_layer, cors_layer, rate_limit_middleware, request_id_middleware,
14    timeout_layer, tracing_middleware,
15};
16use crate::routes::api_router;
17use vex_llm::{Metrics, RateLimitConfig, RateLimiter};
18// use vex_persist::StorageBackend; // Not dealing with trait directly here
19// use vex_queue::WorkerPool;
20
21/// TLS configuration for HTTPS
22#[derive(Debug, Clone)]
23pub struct TlsConfig {
24    /// Path to certificate file (PEM format)
25    pub cert_path: String,
26    /// Path to private key file (PEM format)
27    pub key_path: String,
28}
29
30impl TlsConfig {
31    /// Create TLS config from paths
32    pub fn new(cert_path: &str, key_path: &str) -> Self {
33        Self {
34            cert_path: cert_path.to_string(),
35            key_path: key_path.to_string(),
36        }
37    }
38
39    /// Create from environment variables VEX_TLS_CERT and VEX_TLS_KEY
40    pub fn from_env() -> Option<Self> {
41        let cert = std::env::var("VEX_TLS_CERT").ok()?;
42        let key = std::env::var("VEX_TLS_KEY").ok()?;
43        Some(Self::new(&cert, &key))
44    }
45}
46
47/// Server configuration
48#[derive(Debug, Clone)]
49pub struct ServerConfig {
50    /// Server address
51    pub addr: SocketAddr,
52    /// Request timeout
53    pub timeout: Duration,
54    /// Max request body size (bytes)
55    pub max_body_size: usize,
56    /// Enable compression
57    pub compression: bool,
58    /// Rate limit config
59    pub rate_limit: RateLimitConfig,
60    /// Optional TLS configuration for HTTPS
61    pub tls: Option<TlsConfig>,
62}
63
64impl Default for ServerConfig {
65    fn default() -> Self {
66        Self {
67            addr: "0.0.0.0:8080".parse().unwrap(),
68            timeout: Duration::from_secs(30),
69            max_body_size: 1024 * 1024, // 1MB
70            compression: true,
71            rate_limit: RateLimitConfig::default(),
72            tls: None,
73        }
74    }
75}
76
77impl ServerConfig {
78    /// Create from environment variables
79    pub fn from_env() -> Self {
80        let port: u16 = std::env::var("VEX_PORT")
81            .ok()
82            .and_then(|p| p.parse().ok())
83            .unwrap_or(8080);
84
85        let timeout_secs: u64 = std::env::var("VEX_TIMEOUT_SECS")
86            .ok()
87            .and_then(|t| t.parse().ok())
88            .unwrap_or(30);
89
90        Self {
91            addr: SocketAddr::from(([0, 0, 0, 0], port)),
92            timeout: Duration::from_secs(timeout_secs),
93            ..Default::default()
94        }
95    }
96}
97
98use crate::state::AppState;
99
100/// VEX API Server
101pub struct VexServer {
102    config: ServerConfig,
103    app_state: AppState,
104}
105
106impl VexServer {
107    /// Create a new server
108    pub async fn new(config: ServerConfig) -> Result<Self, ApiError> {
109        use crate::jobs::agent::{AgentExecutionJob, AgentJobPayload};
110        use vex_llm::{DeepSeekProvider, LlmProvider, MockProvider};
111        use vex_queue::{QueueBackend, WorkerConfig, WorkerPool};
112
113        let jwt_auth = JwtAuth::from_env()?;
114        let rate_limiter = Arc::new(RateLimiter::new(config.rate_limit.clone()));
115        let metrics = Arc::new(Metrics::new());
116
117        // Initialize Persistence (SQLite)
118        let db_url =
119            std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite::memory:".to_string());
120        let db = vex_persist::sqlite::SqliteBackend::new(&db_url)
121            .await
122            .map_err(|e| ApiError::Internal(format!("DB Init failed: {}", e)))?;
123
124        // Initialize Queue (Persistent SQLite)
125        let queue_backend = vex_persist::queue::SqliteQueueBackend::new(db.pool().clone());
126
127        // Use dynamic dispatch for the worker pool backend
128        let worker_pool = WorkerPool::new_with_arc(
129            Arc::new(queue_backend) as Arc<dyn QueueBackend>,
130            WorkerConfig::default(),
131        );
132
133        // Initialize Intelligence (LLM)
134        let llm: Arc<dyn LlmProvider> = if let Ok(key) = std::env::var("DEEPSEEK_API_KEY") {
135            tracing::info!("Initializing DeepSeek Provider");
136            Arc::new(DeepSeekProvider::chat(&key))
137        } else {
138            tracing::warn!("DEEPSEEK_API_KEY not found. Using Mock Provider.");
139            Arc::new(MockProvider::smart())
140        };
141
142        // Create shared result store for job results
143        let result_store = crate::jobs::new_result_store();
144
145        // Register Agent Job
146        let llm_clone = llm.clone();
147        let result_store_clone = result_store.clone();
148        worker_pool.register_job_factory("agent_execution", move |payload| {
149            let job_payload: AgentJobPayload =
150                serde_json::from_value(payload).unwrap_or_else(|_| AgentJobPayload {
151                    agent_id: "unknown".to_string(),
152                    prompt: "payload error".to_string(),
153                    context_id: None,
154                });
155            let job_id = uuid::Uuid::new_v4();
156            Box::new(AgentExecutionJob::new(
157                job_id,
158                job_payload,
159                llm_clone.clone(),
160                result_store_clone.clone(),
161            ))
162        });
163
164        let app_state = AppState::new(
165            jwt_auth,
166            rate_limiter,
167            metrics,
168            Arc::new(db),
169            Arc::new(worker_pool),
170        );
171
172        Ok(Self { config, app_state })
173    }
174
175    /// Build the complete router with all middleware
176    pub fn router(&self) -> Router {
177        let mut app = api_router(self.app_state.clone());
178
179        // Apply middleware layers (order matters - bottom to top execution)
180        app = app
181            // Compression (outermost - compresses response)
182            .layer(CompressionLayer::new())
183            // Body size limit
184            .layer(body_limit_layer(self.config.max_body_size))
185            // Timeout
186            .layer(timeout_layer(self.config.timeout))
187            // CORS
188            .layer(cors_layer())
189            // Request ID
190            .layer(middleware::from_fn(request_id_middleware))
191            // Tracing
192            .layer(middleware::from_fn_with_state(
193                self.app_state.clone(),
194                tracing_middleware,
195            ))
196            // Rate limiting
197            .layer(middleware::from_fn_with_state(
198                self.app_state.clone(),
199                rate_limit_middleware,
200            ))
201            // Authentication (innermost - runs first)
202            .layer(middleware::from_fn_with_state(
203                self.app_state.clone(),
204                auth_middleware,
205            ));
206
207        app
208    }
209
210    /// Run the server with graceful shutdown
211    pub async fn run(self) -> Result<(), ApiError> {
212        let app = self.router();
213        let addr = self.config.addr;
214
215        tracing::info!("Starting VEX API server on {}", addr);
216
217        // Start Worker Pool in background
218        let queue = self.app_state.queue();
219        tokio::spawn(async move {
220            queue.start().await;
221        });
222
223        let listener = tokio::net::TcpListener::bind(addr)
224            .await
225            .map_err(|e| ApiError::Internal(format!("Failed to bind: {}", e)))?;
226
227        axum::serve(listener, app)
228            .with_graceful_shutdown(shutdown_signal())
229            .await
230            .map_err(|e| ApiError::Internal(format!("Server error: {}", e)))?;
231
232        tracing::info!("Server shutdown complete");
233        Ok(())
234    }
235
236    /// Get server metrics
237    pub fn metrics(&self) -> Arc<Metrics> {
238        self.app_state.metrics()
239    }
240}
241
242/// Graceful shutdown signal handler
243async fn shutdown_signal() {
244    let ctrl_c = async {
245        signal::ctrl_c()
246            .await
247            .expect("Failed to install Ctrl+C handler");
248    };
249
250    #[cfg(unix)]
251    let terminate = async {
252        signal::unix::signal(signal::unix::SignalKind::terminate())
253            .expect("Failed to install SIGTERM handler")
254            .recv()
255            .await;
256    };
257
258    #[cfg(not(unix))]
259    let terminate = std::future::pending::<()>();
260
261    tokio::select! {
262        _ = ctrl_c => {
263            tracing::info!("Received Ctrl+C, starting graceful shutdown");
264        }
265        _ = terminate => {
266            tracing::info!("Received SIGTERM, starting graceful shutdown");
267        }
268    }
269}
270
271/// Initialize tracing subscriber
272pub fn init_tracing() {
273    use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
274
275    let filter = EnvFilter::try_from_default_env()
276        .unwrap_or_else(|_| EnvFilter::new("info,vex_api=debug,tower_http=debug"));
277
278    tracing_subscriber::registry()
279        .with(filter)
280        .with(tracing_subscriber::fmt::layer().with_target(true))
281        .init();
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_server_config_default() {
290        let config = ServerConfig::default();
291        assert_eq!(config.addr.port(), 8080);
292        assert_eq!(config.timeout, Duration::from_secs(30));
293    }
294}