Skip to main content

systemprompt_api/services/middleware/
cors.rs

1use axum::http::Method;
2use systemprompt_models::Config;
3use thiserror::Error;
4use tower_http::cors::{AllowOrigin, CorsLayer};
5
6#[derive(Debug, Error)]
7pub enum CorsError {
8    #[error("Invalid origin '{origin}' in cors_allowed_origins: {reason}")]
9    InvalidOrigin { origin: String, reason: String },
10    #[error("cors_allowed_origins must contain at least one valid origin")]
11    EmptyOrigins,
12}
13
14#[derive(Debug, Clone, Copy)]
15pub struct CorsMiddleware;
16
17impl CorsMiddleware {
18    pub fn build_layer(config: &Config) -> Result<CorsLayer, CorsError> {
19        let mut origins = Vec::new();
20        for origin in &config.cors_allowed_origins {
21            let trimmed = origin.trim();
22            if trimmed.is_empty() {
23                continue;
24            }
25            let header_value =
26                trimmed
27                    .parse::<http::HeaderValue>()
28                    .map_err(|e| CorsError::InvalidOrigin {
29                        origin: origin.clone(),
30                        reason: e.to_string(),
31                    })?;
32            origins.push(header_value);
33        }
34
35        if origins.is_empty() {
36            return Err(CorsError::EmptyOrigins);
37        }
38
39        Ok(CorsLayer::new()
40            .allow_origin(AllowOrigin::list(origins))
41            .allow_credentials(true)
42            .allow_methods([
43                Method::GET,
44                Method::POST,
45                Method::PUT,
46                Method::DELETE,
47                Method::OPTIONS,
48            ])
49            .allow_headers([
50                http::header::AUTHORIZATION,
51                http::header::CONTENT_TYPE,
52                http::header::ACCEPT,
53                http::header::ORIGIN,
54                http::header::ACCESS_CONTROL_REQUEST_METHOD,
55                http::header::ACCESS_CONTROL_REQUEST_HEADERS,
56                http::HeaderName::from_static("mcp-protocol-version"),
57                http::HeaderName::from_static("x-context-id"),
58                http::HeaderName::from_static("x-trace-id"),
59                http::HeaderName::from_static("x-call-source"),
60            ]))
61    }
62}