systemprompt_api/services/middleware/
cors.rs1use axum::http::Method;
2use systemprompt_models::Config;
3use thiserror::Error;
4use tower_http::cors::{AllowOrigin, CorsLayer};
5
6#[derive(Debug, Error)]
7pub enum CorsError {
8 #[error("Invalid origin '{origin}' in cors_allowed_origins: {reason}")]
9 InvalidOrigin { origin: String, reason: String },
10 #[error("cors_allowed_origins must contain at least one valid origin")]
11 EmptyOrigins,
12}
13
14#[derive(Debug, Clone, Copy)]
15pub struct CorsMiddleware;
16
17impl CorsMiddleware {
18 pub fn build_layer(config: &Config) -> Result<CorsLayer, CorsError> {
19 let mut origins = Vec::new();
20 for origin in &config.cors_allowed_origins {
21 let trimmed = origin.trim();
22 if trimmed.is_empty() {
23 continue;
24 }
25 let header_value =
26 trimmed
27 .parse::<http::HeaderValue>()
28 .map_err(|e| CorsError::InvalidOrigin {
29 origin: origin.clone(),
30 reason: e.to_string(),
31 })?;
32 origins.push(header_value);
33 }
34
35 if origins.is_empty() {
36 return Err(CorsError::EmptyOrigins);
37 }
38
39 Ok(CorsLayer::new()
40 .allow_origin(AllowOrigin::list(origins))
41 .allow_credentials(true)
42 .allow_methods([
43 Method::GET,
44 Method::POST,
45 Method::PUT,
46 Method::DELETE,
47 Method::OPTIONS,
48 ])
49 .allow_headers([
50 http::header::AUTHORIZATION,
51 http::header::CONTENT_TYPE,
52 http::header::ACCEPT,
53 http::header::ORIGIN,
54 http::header::ACCESS_CONTROL_REQUEST_METHOD,
55 http::header::ACCESS_CONTROL_REQUEST_HEADERS,
56 http::HeaderName::from_static("mcp-protocol-version"),
57 http::HeaderName::from_static("x-context-id"),
58 http::HeaderName::from_static("x-trace-id"),
59 http::HeaderName::from_static("x-call-source"),
60 ])
61 .expose_headers([http::header::WWW_AUTHENTICATE]))
62 }
63}