Skip to main content

sa_token_core/
router.rs

1// Author: 金书记
2//
3// Path-based authentication router module
4// 基于路径的鉴权路由模块
5
6use std::sync::Arc;
7
8use sa_token_adapter::context::SaRequest;
9use sa_token_adapter::utils::extract_bearer_or_value;
10
11type LoginIdValidator = Arc<dyn Fn(&str) -> bool + Send + Sync>;
12
13/// Match a path against a pattern (Ant-style wildcard)
14/// 匹配路径与模式(Ant 风格通配符)
15///
16/// # Arguments
17/// - `path`: The request path to match
18/// - `pattern`: The pattern to match against
19///
20/// # Patterns Supported
21/// - `/**`: Match all paths
22/// - `/api/**`: Match all paths starting with `/api/`
23/// - `/api/*`: Match single-level paths under `/api/`
24/// - `*.html`: Match paths ending with `.html`
25/// - `/exact`: Exact match
26///
27/// # Examples
28/// ```
29/// use sa_token_core::router::match_path;
30/// assert!(match_path("/api/user", "/api/**"));
31/// assert!(match_path("/api/user", "/api/*"));
32/// assert!(!match_path("/api/user/profile", "/api/*"));
33/// ```
34pub fn match_path(path: &str, pattern: &str) -> bool {
35    if pattern == "/**" {
36        return true;
37    }
38    if let Some(prefix) = pattern.strip_suffix("/**") {
39        return path.starts_with(prefix);
40    }
41    if let Some(suffix) = pattern.strip_prefix("*") {
42        return path.ends_with(suffix);
43    }
44    // `/*`: single path segment after prefix (e.g. `/api/*` matches `/api/user`, not `/api/a/b`).
45    // `/*`:前缀后仅一层路径段(如 `/api/*` 匹配 `/api/user`,不匹配 `/api/a/b`)。
46    if let Some(prefix) = pattern.strip_suffix("/*") {
47        if !path.starts_with(prefix) {
48            return false;
49        }
50        let rest = &path[prefix.len()..];
51        if rest.is_empty() || rest == "/" {
52            return true;
53        }
54        let rest = rest.trim_start_matches('/');
55        return !rest.contains('/');
56    }
57    path == pattern
58}
59
60/// Check if path matches any pattern in the list
61/// 检查路径是否匹配列表中的任意模式
62pub fn match_any(path: &str, patterns: &[&str]) -> bool {
63    patterns.iter().any(|p| match_path(path, p))
64}
65
66/// Determine if authentication is needed for a path
67/// 判断路径是否需要鉴权
68///
69/// Returns `true` if path matches include patterns but not exclude patterns
70/// 如果路径匹配包含模式但不匹配排除模式,返回 `true`
71pub fn need_auth(path: &str, include: &[&str], exclude: &[&str]) -> bool {
72    match_any(path, include) && !match_any(path, exclude)
73}
74
75/// Path-based authentication configuration
76/// 基于路径的鉴权配置
77///
78/// Configure which paths require authentication and which are excluded
79/// 配置哪些路径需要鉴权,哪些路径被排除
80#[derive(Clone)]
81pub struct PathAuthConfig {
82    /// Paths that require authentication (include patterns)
83    /// 需要鉴权的路径(包含模式)
84    include: Vec<String>,
85    /// Paths excluded from authentication (exclude patterns)
86    /// 排除鉴权的路径(排除模式)
87    exclude: Vec<String>,
88    /// Optional login ID validator function
89    /// 可选的登录ID验证函数
90    validator: Option<LoginIdValidator>,
91}
92
93impl PathAuthConfig {
94    /// Create a new path authentication configuration
95    /// 创建新的路径鉴权配置
96    pub fn new() -> Self {
97        Self {
98            include: Vec::new(),
99            exclude: Vec::new(),
100            validator: None,
101        }
102    }
103
104    /// Set paths that require authentication
105    /// 设置需要鉴权的路径
106    pub fn include(mut self, patterns: Vec<String>) -> Self {
107        self.include = patterns;
108        self
109    }
110
111    /// Set paths excluded from authentication
112    /// 设置排除鉴权的路径
113    pub fn exclude(mut self, patterns: Vec<String>) -> Self {
114        self.exclude = patterns;
115        self
116    }
117
118    /// Set a custom login ID validator function
119    /// 设置自定义的登录ID验证函数
120    pub fn validator<F>(mut self, f: F) -> Self
121    where
122        F: Fn(&str) -> bool + Send + Sync + 'static,
123    {
124        self.validator = Some(Arc::new(f));
125        self
126    }
127
128    /// Check if a path requires authentication
129    /// 检查路径是否需要鉴权
130    pub fn check(&self, path: &str) -> bool {
131        let inc: Vec<&str> = self.include.iter().map(|s| s.as_str()).collect();
132        let exc: Vec<&str> = self.exclude.iter().map(|s| s.as_str()).collect();
133        need_auth(path, &inc, &exc)
134    }
135
136    /// Validate a login ID using the configured validator
137    /// 使用配置的验证器验证登录ID
138    pub fn validate_login_id(&self, login_id: &str) -> bool {
139        self.validator.as_ref().is_none_or(|v| v(login_id))
140    }
141}
142
143impl Default for PathAuthConfig {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149use crate::{SaTokenManager, TokenValue, SaTokenContext, token::TokenInfo};
150
151/// Authentication result after processing
152/// 处理后的鉴权结果
153pub struct AuthResult {
154    /// Whether authentication is required for this path
155    /// 此路径是否需要鉴权
156    pub need_auth: bool,
157    /// Extracted token value
158    /// 提取的token值
159    pub token: Option<TokenValue>,
160    /// Token information if valid
161    /// 如果有效则包含token信息
162    pub token_info: Option<TokenInfo>,
163    /// Whether the token is valid
164    /// token是否有效
165    pub is_valid: bool,
166}
167
168impl AuthResult {
169    /// Check if the request should be rejected
170    /// 检查请求是否应该被拒绝
171    pub fn should_reject(&self) -> bool {
172        self.need_auth && (!self.is_valid || self.token.is_none())
173    }
174
175    /// Get the login ID from token info
176    /// 从token信息中获取登录ID
177    pub fn login_id(&self) -> Option<&str> {
178        self.token_info.as_ref().map(|t| t.login_id.as_str())
179    }
180}
181
182/// Process authentication for a request path
183/// 处理请求路径的鉴权
184///
185/// This function checks if the path requires authentication, validates the token,
186/// and returns an AuthResult with all relevant information.
187/// 此函数检查路径是否需要鉴权,验证token,并返回包含所有相关信息的AuthResult。
188///
189/// # Arguments
190/// - `path`: The request path
191/// - `token_str`: Optional token string from request
192/// - `config`: Path authentication configuration
193/// - `manager`: SaTokenManager instance
194pub async fn process_auth(
195    path: &str,
196    token_str: Option<String>,
197    config: &PathAuthConfig,
198    manager: &SaTokenManager,
199) -> AuthResult {
200    let need_auth = config.check(path);
201    
202    let token = token_str.map(TokenValue::new);
203    
204    let (is_valid, token_info) = if let Some(ref t) = token {
205        let valid = manager.is_valid(t).await;
206        let info = if valid {
207            manager.get_token_info(t).await.ok()
208        } else {
209            None
210        };
211        (valid, info)
212    } else {
213        (false, None)
214    };
215
216    let is_valid = is_valid && if need_auth {
217        token_info.as_ref().is_some_and(|info| config.validate_login_id(&info.login_id))
218    } else {
219        true
220    };
221
222    AuthResult {
223        need_auth,
224        token,
225        token_info,
226        is_valid,
227    }
228}
229
230/// Create SaTokenContext from authentication result
231/// 从鉴权结果创建SaTokenContext
232pub fn create_context(result: &AuthResult) -> SaTokenContext {
233    let mut ctx = SaTokenContext::new();
234    if let (Some(token), Some(info)) = (&result.token, &result.token_info) {
235        ctx.token = Some(token.clone());
236        ctx.token_info = Some(Arc::new(info.clone()));
237        ctx.login_id = Some(info.login_id.clone());
238    }
239    ctx
240}
241
242/// Generic token extraction from any [`SaRequest`] implementation.
243/// 从任意 [`SaRequest`] 实现中按统一顺序提取 Token。
244///
245/// Order | 顺序:
246/// 1. Header `[token_name]` (Bearer semantics via [`extract_bearer_or_value`]).
247/// 2. `Authorization` header if `token_name` is not already Authorization (case-insensitive match on read side is adapter-specific).
248/// 3. Cookie `[token_name]`.
249/// 4. Query parameter `[token_name]`.
250///
251/// Empty strings are skipped. Returns `None` if nothing found.
252/// 空字符串跳过;均未命中则返回 `None`。
253pub fn extract_token<R: SaRequest>(req: &R, token_name: &str) -> Option<String> {
254    if let Some(v) = req.get_header(token_name) {
255        let s = extract_bearer_or_value(&v);
256        if !s.is_empty() {
257            return Some(s);
258        }
259    }
260    if !token_name.eq_ignore_ascii_case("authorization")
261        && let Some(v) = req.get_header("Authorization") {
262            let s = extract_bearer_or_value(&v);
263            if !s.is_empty() {
264                return Some(s);
265            }
266        }
267    if let Some(v) = req.get_cookie(token_name) {
268        let s = v.trim().to_string();
269        if !s.is_empty() {
270            return Some(s);
271        }
272    }
273    if let Some(v) = req.get_param(token_name) {
274        let s = v.trim().to_string();
275        if !s.is_empty() {
276            return Some(s);
277        }
278    }
279    None
280}
281
282/// Outcome of [`run_auth_flow`]; bindings copy token/login_id/context into framework-specific storage (extensions, depot, etc.).
283/// [`run_auth_flow`] 的返回结果;各框架绑定把 token / login_id / context 写入自身存储(extensions、Depot 等)。
284pub struct AuthFlowResult {
285    /// Path rules + validation summary. | 路径规则与校验摘要。
286    pub auth: AuthResult,
287    /// Login id when token is valid. | 登录 id(token 有效时)。
288    pub login_id: Option<String>,
289    /// Parsed token value when present. | 解析后的 token(若有)。
290    pub token: Option<TokenValue>,
291    /// Request-scoped context for `StpUtil` / handlers. | 请求级上下文,供 `StpUtil` / 处理器使用。
292    pub context: SaTokenContext,
293}
294
295impl AuthFlowResult {
296    /// `true` if the binding should respond **401** (path requires auth but token missing or invalid).
297    /// 若路径要求鉴权但 token 缺失或无效,绑定层应返回 **401**,则返回 `true`。
298    pub fn should_reject(&self) -> bool {
299        self.auth.should_reject()
300    }
301
302    /// Run `fut` with [`SaTokenContext::scope`] using this flow's [`AuthFlowResult::context`] (await-safe).
303    /// 用本流的 [`AuthFlowResult::context`] 调用 [`SaTokenContext::scope`] 执行 `fut`(可跨 await)。
304    pub async fn run<F, R>(self, fut: F) -> R
305    where
306        F: std::future::Future<Output = R>,
307    {
308        SaTokenContext::scope(self.context, fut).await
309    }
310}
311
312/// Full auth pipeline: [`extract_token`] → optional [`PathAuthConfig`] via [`process_auth`], else default check → [`create_context`].
313/// 完整鉴权流水线:[`extract_token`] → 若有 [`PathAuthConfig`] 则 [`process_auth`],否则默认校验 → [`create_context`]。
314///
315/// Pass `path_config: None` for “validate token if present, no path-based reject”.
316/// `path_config` 为 `None` 时表示:有 token 则校验并填上下文,不按路径规则拒绝。
317pub async fn run_auth_flow<R: SaRequest>(
318    req: &R,
319    manager: &SaTokenManager,
320    path_config: Option<&PathAuthConfig>,
321) -> AuthFlowResult {
322    let token_name = manager.config.token_name.as_str();
323    let token_str = extract_token(req, token_name);
324    let path = req.get_path();
325
326    let (auth, ctx) = match path_config {
327        Some(cfg) => {
328            // Path-based rules: may set need_auth / should_reject.
329            // 基于路径的规则:可产生 need_auth / should_reject。
330            let auth = process_auth(path.as_str(), token_str.clone(), cfg, manager).await;
331            let ctx = create_context(&auth);
332            (auth, ctx)
333        }
334        None => {
335            // No path config: only validate token when present.
336            // 无路径配置:仅在有 token 时做有效性校验。
337            let token = token_str.map(TokenValue::new);
338            let (is_valid, token_info) = if let Some(ref t) = token {
339                let valid = manager.is_valid(t).await;
340                let info = if valid {
341                    manager.get_token_info(t).await.ok()
342                } else {
343                    None
344                };
345                (valid, info)
346            } else {
347                (false, None)
348            };
349            let auth = AuthResult {
350                need_auth: false,
351                token: token.clone(),
352                token_info,
353                is_valid,
354            };
355            let ctx = create_context(&auth);
356            (auth, ctx)
357        }
358    };
359
360    let login_id = auth.login_id().map(str::to_string);
361    let token = auth.token.clone();
362    AuthFlowResult {
363        auth,
364        login_id,
365        token,
366        context: ctx,
367    }
368}
369