Skip to main content

smcp_server_core/
auth.rs

1//! 认证接口抽象定义 / Authentication interface abstract definition
2
3use async_trait::async_trait;
4use http::HeaderMap;
5use thiserror::Error;
6
7/// 认证错误类型
8#[derive(Error, Debug, serde::Serialize)]
9pub enum AuthError {
10    #[error("Missing API key")]
11    MissingApiKey,
12    #[error("Invalid API key")]
13    InvalidApiKey,
14    #[error("Authentication failed: {0}")]
15    Failed(String),
16}
17
18/// 认证提供者抽象 trait
19/// Authentication provider abstract trait
20#[async_trait]
21pub trait AuthenticationProvider: Send + Sync + 'static + std::fmt::Debug {
22    /// 认证连接请求
23    /// Authenticate connection request
24    ///
25    /// # Arguments
26    /// * `headers` - HTTP 请求头 / HTTP request headers
27    /// * `auth` - 原始认证数据 / Raw authentication data
28    ///
29    /// # Returns
30    /// 认证是否成功 / Whether authentication succeeded
31    async fn authenticate(
32        &self,
33        headers: &HeaderMap,
34        auth: Option<&serde_json::Value>,
35    ) -> Result<(), AuthError>;
36}
37
38/// 默认认证提供者,提供基础的认证逻辑实现
39/// Default authentication provider, provides basic authentication logic implementation
40#[derive(Debug, Clone)]
41pub struct DefaultAuthenticationProvider {
42    /// 管理员密钥 / Admin secret
43    admin_secret: Option<String>,
44    /// API 密钥字段名 / API key field name
45    api_key_name: String,
46}
47
48impl DefaultAuthenticationProvider {
49    /// 创建新的默认认证提供者
50    /// Create new default authentication provider
51    ///
52    /// # Arguments
53    /// * `admin_secret` - 管理员密钥 / Admin secret
54    /// * `api_key_name` - API 密钥字段名,默认为 "x-api-key" / API key field name, defaults to "x-api-key"
55    pub fn new(admin_secret: Option<String>, api_key_name: Option<String>) -> Self {
56        Self {
57            admin_secret,
58            api_key_name: api_key_name.unwrap_or_else(|| "x-api-key".to_string()),
59        }
60    }
61}
62
63#[async_trait]
64impl AuthenticationProvider for DefaultAuthenticationProvider {
65    async fn authenticate(
66        &self,
67        headers: &HeaderMap,
68        _auth: Option<&serde_json::Value>,
69    ) -> Result<(), AuthError> {
70        // 从 headers 中提取 API 密钥
71        // Extract API key from headers
72        let api_key = headers
73            .get(self.api_key_name.as_str())
74            .and_then(|value| value.to_str().ok())
75            .map(|s| s.to_string());
76
77        let api_key = api_key.ok_or(AuthError::MissingApiKey)?;
78
79        // 检查管理员权限:与配置的管理员密钥比较
80        // Check admin permission: compare with configured admin secret
81        if let Some(ref admin_secret) = self.admin_secret {
82            if api_key.as_str() == admin_secret {
83                return Ok(());
84            }
85        }
86
87        // 这里可以添加其他认证逻辑,如数据库验证等
88        // Additional authentication logic can be added here, such as database validation
89        Err(AuthError::InvalidApiKey)
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use http::HeaderValue;
97
98    #[tokio::test]
99    async fn test_default_auth_success() {
100        let auth = DefaultAuthenticationProvider::new(Some("secret123".to_string()), None);
101        let mut headers = HeaderMap::new();
102        headers.insert("x-api-key", HeaderValue::from_static("secret123"));
103
104        let result = auth.authenticate(&headers, None).await;
105        assert!(result.is_ok());
106    }
107
108    #[tokio::test]
109    async fn test_default_auth_missing_key() {
110        let auth = DefaultAuthenticationProvider::new(Some("secret123".to_string()), None);
111        let headers = HeaderMap::new();
112
113        let result = auth.authenticate(&headers, None).await;
114        assert!(matches!(result, Err(AuthError::MissingApiKey)));
115    }
116
117    #[tokio::test]
118    async fn test_default_auth_invalid_key() {
119        let auth = DefaultAuthenticationProvider::new(Some("secret123".to_string()), None);
120        let mut headers = HeaderMap::new();
121        headers.insert("x-api-key", HeaderValue::from_static("wrong"));
122
123        let result = auth.authenticate(&headers, None).await;
124        assert!(matches!(result, Err(AuthError::InvalidApiKey)));
125    }
126
127    #[tokio::test]
128    async fn test_default_auth_no_admin_secret() {
129        let auth = DefaultAuthenticationProvider::new(None, None);
130        let mut headers = HeaderMap::new();
131        headers.insert("x-api-key", HeaderValue::from_static("anykey"));
132
133        let result = auth.authenticate(&headers, None).await;
134        assert!(matches!(result, Err(AuthError::InvalidApiKey)));
135    }
136}