1use axum::{middleware, Router};
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::signal;
8use tower_http::compression::CompressionLayer;
9
10use crate::auth::JwtAuth;
11use crate::error::ApiError;
12use crate::middleware::{
13 auth_middleware, body_limit_layer, cors_layer, rate_limit_middleware, request_id_middleware,
14 timeout_layer, tracing_middleware,
15};
16use crate::routes::api_router;
17use vex_llm::{Metrics, RateLimitConfig, RateLimiter};
18#[derive(Debug, Clone)]
23pub struct TlsConfig {
24 pub cert_path: String,
26 pub key_path: String,
28}
29
30impl TlsConfig {
31 pub fn new(cert_path: &str, key_path: &str) -> Self {
33 Self {
34 cert_path: cert_path.to_string(),
35 key_path: key_path.to_string(),
36 }
37 }
38
39 pub fn from_env() -> Option<Self> {
41 let cert = std::env::var("VEX_TLS_CERT").ok()?;
42 let key = std::env::var("VEX_TLS_KEY").ok()?;
43 Some(Self::new(&cert, &key))
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct ServerConfig {
50 pub addr: SocketAddr,
52 pub timeout: Duration,
54 pub max_body_size: usize,
56 pub compression: bool,
58 pub rate_limit: RateLimitConfig,
60 pub tls: Option<TlsConfig>,
62}
63
64impl Default for ServerConfig {
65 fn default() -> Self {
66 Self {
67 addr: "0.0.0.0:8080".parse().unwrap(),
68 timeout: Duration::from_secs(30),
69 max_body_size: 1024 * 1024, compression: true,
71 rate_limit: RateLimitConfig::default(),
72 tls: None,
73 }
74 }
75}
76
77impl ServerConfig {
78 pub fn from_env() -> Self {
80 let port: u16 = std::env::var("VEX_PORT")
81 .ok()
82 .and_then(|p| p.parse().ok())
83 .unwrap_or(8080);
84
85 let timeout_secs: u64 = std::env::var("VEX_TIMEOUT_SECS")
86 .ok()
87 .and_then(|t| t.parse().ok())
88 .unwrap_or(30);
89
90 Self {
91 addr: SocketAddr::from(([0, 0, 0, 0], port)),
92 timeout: Duration::from_secs(timeout_secs),
93 ..Default::default()
94 }
95 }
96}
97
98use crate::state::AppState;
99
100pub struct VexServer {
102 config: ServerConfig,
103 app_state: AppState,
104}
105
106impl VexServer {
107 pub async fn new(config: ServerConfig) -> Result<Self, ApiError> {
109 use crate::jobs::agent::{AgentExecutionJob, AgentJobPayload};
110 use vex_llm::{DeepSeekProvider, LlmProvider, MockProvider};
111 use vex_queue::{QueueBackend, WorkerConfig, WorkerPool};
112
113 let jwt_auth = JwtAuth::from_env()?;
114 let rate_limiter = Arc::new(RateLimiter::new(config.rate_limit.clone()));
115 let metrics = Arc::new(Metrics::new());
116
117 let db_url =
119 std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite::memory:".to_string());
120 let db = vex_persist::sqlite::SqliteBackend::new(&db_url)
121 .await
122 .map_err(|e| ApiError::Internal(format!("DB Init failed: {}", e)))?;
123
124 let queue_backend = vex_persist::queue::SqliteQueueBackend::new(db.pool().clone());
126
127 let worker_pool = WorkerPool::new_with_arc(
129 Arc::new(queue_backend) as Arc<dyn QueueBackend>,
130 WorkerConfig::default(),
131 );
132
133 let llm: Arc<dyn LlmProvider> = if let Ok(key) = std::env::var("DEEPSEEK_API_KEY") {
135 tracing::info!("Initializing DeepSeek Provider");
136 Arc::new(DeepSeekProvider::chat(&key))
137 } else {
138 tracing::warn!("DEEPSEEK_API_KEY not found. Using Mock Provider.");
139 Arc::new(MockProvider::smart())
140 };
141
142 let result_store = crate::jobs::new_result_store();
144
145 let llm_clone = llm.clone();
147 let result_store_clone = result_store.clone();
148 worker_pool.register_job_factory("agent_execution", move |payload| {
149 let job_payload: AgentJobPayload =
150 serde_json::from_value(payload).unwrap_or_else(|_| AgentJobPayload {
151 agent_id: "unknown".to_string(),
152 prompt: "payload error".to_string(),
153 context_id: None,
154 });
155 let job_id = uuid::Uuid::new_v4();
156 Box::new(AgentExecutionJob::new(
157 job_id,
158 job_payload,
159 llm_clone.clone(),
160 result_store_clone.clone(),
161 ))
162 });
163
164 let app_state = AppState::new(
165 jwt_auth,
166 rate_limiter,
167 metrics,
168 Arc::new(db),
169 Arc::new(worker_pool),
170 );
171
172 Ok(Self { config, app_state })
173 }
174
175 pub fn router(&self) -> Router {
177 let mut app = api_router(self.app_state.clone());
178
179 app = app
181 .layer(CompressionLayer::new())
183 .layer(body_limit_layer(self.config.max_body_size))
185 .layer(timeout_layer(self.config.timeout))
187 .layer(cors_layer())
189 .layer(middleware::from_fn(request_id_middleware))
191 .layer(middleware::from_fn_with_state(
193 self.app_state.clone(),
194 tracing_middleware,
195 ))
196 .layer(middleware::from_fn_with_state(
198 self.app_state.clone(),
199 rate_limit_middleware,
200 ))
201 .layer(middleware::from_fn_with_state(
203 self.app_state.clone(),
204 auth_middleware,
205 ));
206
207 app
208 }
209
210 pub async fn run(self) -> Result<(), ApiError> {
212 let app = self.router();
213 let addr = self.config.addr;
214
215 tracing::info!("Starting VEX API server on {}", addr);
216
217 let queue = self.app_state.queue();
219 tokio::spawn(async move {
220 queue.start().await;
221 });
222
223 let listener = tokio::net::TcpListener::bind(addr)
224 .await
225 .map_err(|e| ApiError::Internal(format!("Failed to bind: {}", e)))?;
226
227 axum::serve(listener, app)
228 .with_graceful_shutdown(shutdown_signal())
229 .await
230 .map_err(|e| ApiError::Internal(format!("Server error: {}", e)))?;
231
232 tracing::info!("Server shutdown complete");
233 Ok(())
234 }
235
236 pub fn metrics(&self) -> Arc<Metrics> {
238 self.app_state.metrics()
239 }
240}
241
242async fn shutdown_signal() {
244 let ctrl_c = async {
245 signal::ctrl_c()
246 .await
247 .expect("Failed to install Ctrl+C handler");
248 };
249
250 #[cfg(unix)]
251 let terminate = async {
252 signal::unix::signal(signal::unix::SignalKind::terminate())
253 .expect("Failed to install SIGTERM handler")
254 .recv()
255 .await;
256 };
257
258 #[cfg(not(unix))]
259 let terminate = std::future::pending::<()>();
260
261 tokio::select! {
262 _ = ctrl_c => {
263 tracing::info!("Received Ctrl+C, starting graceful shutdown");
264 }
265 _ = terminate => {
266 tracing::info!("Received SIGTERM, starting graceful shutdown");
267 }
268 }
269}
270
271pub fn init_tracing() {
273 use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
274
275 let filter = EnvFilter::try_from_default_env()
276 .unwrap_or_else(|_| EnvFilter::new("info,vex_api=debug,tower_http=debug"));
277
278 tracing_subscriber::registry()
279 .with(filter)
280 .with(tracing_subscriber::fmt::layer().with_target(true))
281 .init();
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_server_config_default() {
290 let config = ServerConfig::default();
291 assert_eq!(config.addr.port(), 8080);
292 assert_eq!(config.timeout, Duration::from_secs(30));
293 }
294}