Skip to main content

sa_token_core/
sso.rs

1//! # SSO 单点登录模块 | SSO Single Sign-On Module
2//!
3//! 提供完整的单点登录功能实现,支持票据认证和统一登出。
4//! Provides complete Single Sign-On functionality with ticket-based authentication and unified logout.
5//!
6//! ## 代码流程逻辑 | Code Flow Logic
7//!
8//! ### 1. 核心组件 | Core Components
9//!
10//! ```text
11//! SsoServer(SSO 服务端)
12//!   ├── 票据管理 | Ticket Management
13//!   │   ├── 生成票据 create_ticket()
14//!   │   ├── 验证票据 validate_ticket()
15//!   │   └── 清理过期票据 cleanup_expired_tickets()
16//!   ├── 会话管理 | Session Management
17//!   │   ├── 创建会话 login()
18//!   │   ├── 获取会话 get_session()
19//!   │   └── 删除会话 logout()
20//!   └── 客户端追踪 | Client Tracking
21//!       └── 获取活跃客户端 get_active_clients()
22//!
23//! SsoClient(SSO 客户端)
24//!   ├── URL 生成 | URL Generation
25//!   │   ├── 登录 URL get_login_url()
26//!   │   └── 登出 URL get_logout_url()
27//!   ├── 本地会话 | Local Session
28//!   │   ├── 检查登录 check_local_login()
29//!   │   └── 票据登录 login_by_ticket()
30//!   └── 登出处理 | Logout Handling
31//!       └── 处理登出 handle_logout()
32//! ```
33//!
34//! ### 2. 登录流程 | Login Flow
35//!
36//! ```text
37//! 步骤 1: 用户访问应用 → 重定向到 SSO Server
38//! Step 1: User accesses app → Redirect to SSO Server
39//!
40//! 步骤 2: SSO Server 验证凭证
41//! Step 2: SSO Server validates credentials
42//!   └─> login(login_id, service) 
43//!       ├─> 创建 Token
44//!       ├─> 创建或更新 SsoSession
45//!       └─> 生成 SsoTicket
46//!
47//! 步骤 3: 客户端应用验证票据
48//! Step 3: Client app validates ticket
49//!   └─> validate_ticket(ticket_id, service)
50//!       ├─> 检查票据存在
51//!       ├─> 验证票据有效性(未过期、未使用)
52//!       ├─> 验证服务 URL 匹配
53//!       ├─> 标记票据为已使用
54//!       └─> 返回 login_id
55//!
56//! 步骤 4: 创建本地会话
57//! Step 4: Create local session
58//!   └─> client.login_by_ticket(login_id)
59//!       └─> manager.login(login_id) → 创建本地 Token
60//! ```
61//!
62//! ### 3. SSO 无缝登录流程 | SSO Seamless Login Flow
63//!
64//! ```text
65//! 用户已在应用1登录,访问应用2:
66//! User logged in App1, accessing App2:
67//!
68//! 应用2 → SSO Server: 请求认证
69//! App2 → SSO Server: Request authentication
70//!   └─> is_logged_in(login_id) → true
71//!       └─> create_ticket(login_id, app2_url)
72//!           └─> 直接返回票据(无需再次登录)
73//!               Return ticket (no re-login required)
74//!
75//! 应用2 → 验证票据 → 创建本地会话 → 访问授权
76//! App2 → Validate ticket → Create local session → Access granted
77//! ```
78//!
79//! ### 4. 统一登出流程 | Unified Logout Flow
80//!
81//! ```text
82//! 用户从任一应用登出:
83//! User logs out from any app:
84//!
85//! logout(login_id)
86//!   ├─> 获取 SsoSession
87//!   ├─> 获取所有已登录客户端列表
88//!   ├─> 删除 SsoSession
89//!   ├─> 删除用户的所有 Token
90//!   └─> 返回客户端列表
91//!
92//! 通知所有客户端:
93//! Notify all clients:
94//!   └─> for each client_url
95//!       └─> client.handle_logout(login_id)
96//!           └─> 清除本地会话 | Clear local session
97//! ```
98//!
99//! ### 5. 票据生命周期 | Ticket Lifecycle
100//!
101//! ```text
102//! 创建 | Create: ticket.create_time = now
103//!   └─> 设置过期时间 | Set expiration: expire_time = now + timeout
104//!   └─> 状态 | Status: used = false
105//!
106//! 验证 | Validate:
107//!   ├─> 检查过期 | Check expiration: now > expire_time?
108//!   ├─> 检查使用状态 | Check usage: used == true?
109//!   └─> 验证服务 | Verify service: service == expected?
110//!
111//! 使用 | Use: 验证成功后 | After validation
112//!   └─> ticket.used = true(一次性使用 | One-time use)
113//!
114//! 清理 | Cleanup: cleanup_expired_tickets()
115//!   └─> 删除所有过期或已使用的票据
116//!       Remove all expired or used tickets
117//! ```
118//!
119//! ### 6. 安全机制 | Security Mechanisms
120//!
121//! ```text
122//! 1. 票据一次性使用 | One-time ticket usage
123//!    └─> validate_ticket() 后立即设置 used = true
124//!
125//! 2. 服务 URL 匹配 | Service URL matching
126//!    └─> ticket.service 必须与请求的 service 完全匹配
127//!
128//! 3. 票据过期 | Ticket expiration
129//!    └─> 默认 5 分钟过期,可配置
130//!
131//! 4. 跨域保护 | Cross-domain protection
132//!    └─> SsoConfig.allowed_origins 白名单机制
133//!
134//! 5. UUID 票据 ID | UUID ticket ID
135//!    └─> 使用 UUID 防止票据 ID 被猜测
136//! ```
137
138use std::sync::Arc;
139use std::collections::HashMap;
140use chrono::{DateTime, Utc, Duration as ChronoDuration};
141use serde::{Serialize, Deserialize};
142use tokio::sync::RwLock;
143use crate::{SaTokenError, SaTokenResult, SaTokenManager};
144
145type LogoutCallback = Arc<dyn Fn(&str) -> bool + Send + Sync>;
146
147/// SSO 票据结构 | SSO Ticket Structure
148///
149/// 票据是一个短期、一次性使用的认证令牌
150/// A ticket is a short-lived, one-time use authentication token
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct SsoTicket {
153    /// 票据唯一标识符(UUID)| Unique ticket identifier (UUID)
154    pub ticket_id: String,
155    /// 目标服务 URL | Target service URL
156    pub service: String,
157    /// 用户登录 ID | User login ID
158    pub login_id: String,
159    /// 票据创建时间 | Ticket creation time
160    pub create_time: DateTime<Utc>,
161    /// 票据过期时间 | Ticket expiration time
162    pub expire_time: DateTime<Utc>,
163    /// 是否已使用(一次性使用)| Whether used (one-time use)
164    pub used: bool,
165}
166
167impl SsoTicket {
168    /// 创建新票据 | Create a new ticket
169    ///
170    /// # 参数 | Parameters
171    /// * `login_id` - 用户登录 ID | User login ID
172    /// * `service` - 目标服务 URL | Target service URL
173    /// * `timeout_seconds` - 票据有效期(秒)| Ticket validity period (seconds)
174    pub fn new(login_id: String, service: String, timeout_seconds: i64) -> Self {
175        let now = Utc::now();
176        Self {
177            ticket_id: uuid::Uuid::new_v4().to_string(),
178            service,
179            login_id,
180            create_time: now,
181            expire_time: now + ChronoDuration::seconds(timeout_seconds),
182            used: false,
183        }
184    }
185
186    /// 检查票据是否过期 | Check if ticket is expired
187    pub fn is_expired(&self) -> bool {
188        Utc::now() > self.expire_time
189    }
190
191    /// 检查票据是否有效(未使用且未过期)| Check if ticket is valid (not used and not expired)
192    pub fn is_valid(&self) -> bool {
193        !self.used && !self.is_expired()
194    }
195}
196
197/// SSO 全局会话 | SSO Global Session
198///
199/// 跟踪用户在所有应用中的登录状态
200/// Tracks user's login status across all applications
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct SsoSession {
203    /// 用户登录 ID | User login ID
204    pub login_id: String,
205    /// 已登录的客户端列表 | List of logged-in clients
206    pub clients: Vec<String>,
207    /// 会话创建时间 | Session creation time
208    pub create_time: DateTime<Utc>,
209    /// 最后活动时间 | Last activity time
210    pub last_active_time: DateTime<Utc>,
211}
212
213impl SsoSession {
214    /// 创建新会话 | Create a new session
215    pub fn new(login_id: String) -> Self {
216        let now = Utc::now();
217        Self {
218            login_id,
219            clients: Vec::new(),
220            create_time: now,
221            last_active_time: now,
222        }
223    }
224
225    /// 添加客户端到会话 | Add client to session
226    ///
227    /// 如果客户端不在列表中,则添加
228    /// Adds client if not already in the list
229    pub fn add_client(&mut self, service: String) {
230        if !self.clients.contains(&service) {
231            self.clients.push(service);
232        }
233        self.last_active_time = Utc::now();
234    }
235
236    /// 从会话中移除客户端 | Remove client from session
237    pub fn remove_client(&mut self, service: &str) {
238        self.clients.retain(|c| c != service);
239        self.last_active_time = Utc::now();
240    }
241}
242
243/// SSO 服务端 | SSO Server
244///
245/// 中央认证服务,负责票据生成、验证和会话管理
246/// Central authentication service responsible for ticket generation, validation, and session management
247pub struct SsoServer {
248    manager: Arc<SaTokenManager>,
249    tickets: Arc<RwLock<HashMap<String, SsoTicket>>>,
250    sessions: Arc<RwLock<HashMap<String, SsoSession>>>,
251    ticket_timeout: i64,
252}
253
254impl SsoServer {
255    /// 创建新的 SSO 服务端 | Create a new SSO Server
256    ///
257    /// # 参数 | Parameters
258    /// * `manager` - SaTokenManager 实例 | SaTokenManager instance
259    pub fn new(manager: Arc<SaTokenManager>) -> Self {
260        Self {
261            manager,
262            tickets: Arc::new(RwLock::new(HashMap::new())),
263            sessions: Arc::new(RwLock::new(HashMap::new())),
264            ticket_timeout: 300, // 默认 5 分钟 | Default 5 minutes
265        }
266    }
267
268    /// 设置票据超时时间 | Set ticket timeout
269    ///
270    /// # 参数 | Parameters
271    /// * `timeout` - 超时时间(秒)| Timeout in seconds
272    pub fn with_ticket_timeout(mut self, timeout: i64) -> Self {
273        self.ticket_timeout = timeout;
274        self
275    }
276
277    /// 检查用户是否已登录 | Check if user is logged in
278    ///
279    /// 通过检查 SSO 会话是否存在来判断
280    /// Determined by checking if SSO session exists
281    pub async fn is_logged_in(&self, login_id: &str) -> bool {
282        let sessions = self.sessions.read().await;
283        let has_session = sessions.contains_key(login_id);
284        drop(sessions);
285        
286        // 如果会话存在,进一步验证 Token 是否有效
287        if has_session {
288            let key = format!("sa:login:token:{}:sso", login_id);
289            matches!(self.manager.storage.get(&key).await, Ok(Some(_)))
290        } else {
291            false
292        }
293    }
294
295    /// 创建票据 | Create ticket
296    ///
297    /// 为已登录用户创建访问特定服务的票据
298    /// Creates a ticket for logged-in user to access specific service
299    ///
300    /// # 参数 | Parameters
301    /// * `login_id` - 用户登录 ID | User login ID
302    /// * `service` - 目标服务 URL | Target service URL
303    ///
304    /// # 返回 | Returns
305    /// 新创建的票据 | Newly created ticket
306    pub async fn create_ticket(&self, login_id: String, service: String) -> SaTokenResult<SsoTicket> {
307        // 生成票据 | Generate ticket
308        let ticket = SsoTicket::new(login_id.clone(), service.clone(), self.ticket_timeout);
309        
310        // 存储票据 | Store ticket
311        let mut tickets = self.tickets.write().await;
312        tickets.insert(ticket.ticket_id.clone(), ticket.clone());
313
314        // 更新会话,添加客户端 | Update session, add client
315        let mut sessions = self.sessions.write().await;
316        sessions.entry(login_id.clone())
317            .or_insert_with(|| SsoSession::new(login_id))
318            .add_client(service);
319
320        Ok(ticket)
321    }
322
323    /// 验证票据 | Validate ticket
324    ///
325    /// 验证票据的有效性并将其标记为已使用(一次性使用)
326    /// Validates ticket and marks it as used (one-time use)
327    ///
328    /// # 参数 | Parameters
329    /// * `ticket_id` - 票据 ID | Ticket ID
330    /// * `service` - 请求的服务 URL | Requested service URL
331    ///
332    /// # 返回 | Returns
333    /// 用户登录 ID | User login ID
334    ///
335    /// # 错误 | Errors
336    /// * `InvalidTicket` - 票据不存在 | Ticket not found
337    /// * `TicketExpired` - 票据已过期或已使用 | Ticket expired or used
338    /// * `ServiceMismatch` - 服务 URL 不匹配 | Service URL mismatch
339    pub async fn validate_ticket(&self, ticket_id: &str, service: &str) -> SaTokenResult<String> {
340        let mut tickets = self.tickets.write().await;
341        
342        // 1. 检查票据是否存在 | Check if ticket exists
343        let ticket = tickets.get_mut(ticket_id)
344            .ok_or(SaTokenError::InvalidTicket)?;
345
346        // 2. 验证票据有效性(未过期、未使用)| Validate ticket (not expired, not used)
347        if !ticket.is_valid() {
348            return Err(SaTokenError::TicketExpired);
349        }
350
351        // 3. 验证服务 URL 匹配 | Verify service URL matches
352        if ticket.service != service {
353            return Err(SaTokenError::ServiceMismatch);
354        }
355
356        // 4. 标记票据为已使用(一次性使用)| Mark ticket as used (one-time use)
357        ticket.used = true;
358        let login_id = ticket.login_id.clone();
359
360        Ok(login_id)
361    }
362
363    /// 用户登录 | User login
364    ///
365    /// 完整的登录流程:创建 Token、会话和票据
366    /// Complete login flow: create Token, session, and ticket
367    ///
368    /// # 参数 | Parameters
369    /// * `login_id` - 用户登录 ID | User login ID
370    /// * `service` - 目标服务 URL | Target service URL
371    ///
372    /// # 返回 | Returns
373    /// 生成的票据 | Generated ticket
374    pub async fn login(&self, login_id: String, service: String) -> SaTokenResult<SsoTicket> {
375        // 使用 login_with_options 创建 SSO 类型的 Token
376        let _token = self.manager.login_with_options(
377            &login_id,
378            Some("sso".to_string()), // 设置 login_type 为 "sso"
379            None,
380            Some(serde_json::json!({
381                "sso_mode": true,
382                "service": service.clone()
383            })),
384            None,
385            None,
386        ).await?;
387        
388        // 更新会话
389        let mut sessions = self.sessions.write().await;
390        sessions.entry(login_id.clone())
391            .or_insert_with(|| SsoSession::new(login_id.clone()))
392            .add_client(service.clone());
393
394        drop(sessions);
395
396        // 创建并返回票据
397        self.create_ticket(login_id, service).await
398    }
399
400    /// 统一登出 | Unified logout
401    ///
402    /// 从 SSO 服务端登出,并返回需要通知的客户端列表
403    /// Logout from SSO Server and return list of clients to notify
404    ///
405    /// # 参数 | Parameters
406    /// * `login_id` - 用户登录 ID | User login ID
407    ///
408    /// # 返回 | Returns
409    /// 需要清除会话的客户端 URL 列表 | List of client URLs to clear sessions
410    pub async fn logout(&self, login_id: &str) -> SaTokenResult<Vec<String>> {
411        // 1. 获取并删除 SSO 会话 | Get and remove SSO session
412        let mut sessions = self.sessions.write().await;
413        let session = sessions.remove(login_id);
414        
415        // 2. 提取客户端列表 | Extract client list
416        let clients = session.map(|s| s.clients).unwrap_or_default();
417
418        drop(sessions);
419
420        // 3. 从 Token 管理器中登出(登出所有类型的 Token)| Logout from Token manager (all token types)
421        // 3.1 登出 SSO 服务端 Token
422        let sso_key = format!("sa:login:token:{}:sso", login_id);
423        let _ = self.manager.storage.delete(&sso_key).await;
424        
425        // 3.2 登出默认类型 Token
426        self.manager.logout_by_login_id(login_id).await?;
427
428        // 4. 返回客户端列表供通知 | Return client list for notification
429        Ok(clients)
430    }
431
432    /// 获取 SSO 会话 | Get SSO session
433    ///
434    /// # 参数 | Parameters
435    /// * `login_id` - 用户登录 ID | User login ID
436    ///
437    /// # 返回 | Returns
438    /// SSO 会话信息(如果存在)| SSO session info (if exists)
439    pub async fn get_session(&self, login_id: &str) -> Option<SsoSession> {
440        let sessions = self.sessions.read().await;
441        sessions.get(login_id).cloned()
442    }
443
444    /// 检查会话是否存在 | Check if session exists
445    ///
446    /// # 参数 | Parameters
447    /// * `login_id` - 用户登录 ID | User login ID
448    ///
449    /// # 返回 | Returns
450    /// 会话是否存在 | Whether session exists
451    pub async fn check_session(&self, login_id: &str) -> bool {
452        let sessions = self.sessions.read().await;
453        sessions.contains_key(login_id)
454    }
455
456    /// 清理过期票据 | Cleanup expired tickets
457    ///
458    /// 删除所有过期或已使用的票据
459    /// Removes all expired or used tickets
460    pub async fn cleanup_expired_tickets(&self) {
461        let mut tickets = self.tickets.write().await;
462        tickets.retain(|_, ticket| ticket.is_valid());
463    }
464
465    /// 获取活跃客户端列表 | Get active clients list
466    ///
467    /// # 参数 | Parameters
468    /// * `login_id` - 用户登录 ID | User login ID
469    ///
470    /// # 返回 | Returns
471    /// 客户端 URL 列表 | List of client URLs
472    pub async fn get_active_clients(&self, login_id: &str) -> Vec<String> {
473        let sessions = self.sessions.read().await;
474        sessions.get(login_id)
475            .map(|s| s.clients.clone())
476            .unwrap_or_default()
477    }
478}
479
480/// SSO 客户端 | SSO Client
481///
482/// 每个应用作为 SSO 客户端,处理本地会话和票据验证
483/// Each application acts as SSO Client, handling local sessions and ticket validation
484pub struct SsoClient {
485    /// Token 管理器 | Token manager
486    manager: Arc<SaTokenManager>,
487    /// SSO 服务端 URL | SSO Server URL
488    server_url: String,
489    /// 当前服务 URL | Current service URL
490    service_url: String,
491    /// 登出回调函数 | Logout callback function
492    logout_callback: Option<LogoutCallback>,
493}
494
495impl SsoClient {
496    /// 创建新的 SSO 客户端 | Create a new SSO Client
497    ///
498    /// # 参数 | Parameters
499    /// * `manager` - SaTokenManager 实例 | SaTokenManager instance
500    /// * `server_url` - SSO 服务端 URL | SSO Server URL
501    /// * `service_url` - 当前服务 URL | Current service URL
502    pub fn new(
503        manager: Arc<SaTokenManager>,
504        server_url: String,
505        service_url: String,
506    ) -> Self {
507        Self {
508            manager,
509            server_url,
510            service_url,
511            logout_callback: None,
512        }
513    }
514
515    /// 设置登出回调函数 | Set logout callback
516    ///
517    /// # 参数 | Parameters
518    /// * `callback` - 登出时执行的回调函数 | Callback function to execute on logout
519    pub fn with_logout_callback<F>(mut self, callback: F) -> Self
520    where
521        F: Fn(&str) -> bool + Send + Sync + 'static,
522    {
523        self.logout_callback = Some(Arc::new(callback));
524        self
525    }
526
527    /// 生成登录 URL | Generate login URL
528    ///
529    /// # 返回 | Returns
530    /// SSO 服务端登录 URL,包含当前服务的回调地址
531    /// SSO Server login URL with current service callback
532    pub fn get_login_url(&self) -> String {
533        format!("{}?service={}", self.server_url, urlencoding::encode(&self.service_url))
534    }
535
536    /// 生成登出 URL | Generate logout URL
537    ///
538    /// # 返回 | Returns
539    /// SSO 服务端登出 URL,包含当前服务的回调地址
540    /// SSO Server logout URL with current service callback
541    pub fn get_logout_url(&self) -> String {
542        format!("{}/logout?service={}", self.server_url, urlencoding::encode(&self.service_url))
543    }
544
545    /// 检查本地是否已登录 | Check if locally logged in
546    ///
547    /// # 参数 | Parameters
548    /// * `login_id` - 用户登录 ID | User login ID
549    ///
550    /// # 返回 | Returns
551    /// 是否已登录 | Whether logged in
552    pub async fn check_local_login(&self, login_id: &str) -> bool {
553        // 检查 SSO 客户端类型的登录
554        let key = format!("sa:login:token:{}:sso_client", login_id);
555        match self.manager.storage.get(&key).await {
556            Ok(Some(_)) => true,
557            _ => {
558                // 兼容旧的无类型登录
559                let key_default = format!("sa:login:token:{}", login_id);
560                matches!(self.manager.storage.get(&key_default).await, Ok(Some(_)))
561            }
562        }
563    }
564
565    /// 处理票据(验证票据合法性)| Process ticket (validate ticket)
566    ///
567    /// # 参数 | Parameters
568    /// * `ticket` - 票据 ID | Ticket ID
569    /// * `service` - 服务 URL | Service URL
570    ///
571    /// # 返回 | Returns
572    /// 处理后的票据信息 | Processed ticket info
573    ///
574    /// # 错误 | Errors
575    /// * `ServiceMismatch` - 服务 URL 不匹配 | Service URL mismatch
576    pub async fn process_ticket(&self, ticket: &str, service: &str) -> SaTokenResult<String> {
577        // 验证服务 URL 是否匹配
578        if service != self.service_url {
579            return Err(SaTokenError::ServiceMismatch);
580        }
581
582        Ok(ticket.to_string())
583    }
584
585    /// 通过票据登录(客户端本地登录)| Login by ticket (client-side local login)
586    ///
587    /// # 参数 | Parameters
588    /// * `login_id` - 用户登录 ID | User login ID
589    ///
590    /// # 返回 | Returns
591    /// 生成的本地 Token | Generated local token
592    pub async fn login_by_ticket(&self, login_id: String) -> SaTokenResult<String> {
593        // 使用 login_with_options 创建客户端 Token,标记为 SSO 客户端登录
594        let token = self.manager.login_with_options(
595            &login_id,
596            Some("sso_client".to_string()), // 标记为 SSO 客户端
597            None,
598            Some(serde_json::json!({
599                "sso_client": true,
600                "service_url": self.service_url.clone()
601            })),
602            None,
603            None,
604        ).await?;
605        Ok(token.to_string())
606    }
607
608    /// 处理登出(客户端)| Handle logout (client-side)
609    ///
610    /// # 参数 | Parameters
611    /// * `login_id` - 用户登录 ID | User login ID
612    pub async fn handle_logout(&self, login_id: &str) -> SaTokenResult<()> {
613        // 1. 执行登出回调 | Execute logout callback
614        if let Some(callback) = &self.logout_callback {
615            callback(login_id);
616        }
617        
618        // 2. 登出 SSO 客户端类型的 Token | Logout SSO client token
619        let sso_client_key = format!("sa:login:token:{}:sso_client", login_id);
620        let _ = self.manager.storage.delete(&sso_client_key).await;
621        
622        // 3. 登出默认类型的 Token(兼容)| Logout default token (compatibility)
623        self.manager.logout_by_login_id(login_id).await?;
624        
625        Ok(())
626    }
627
628    /// 获取 SSO 服务端 URL | Get SSO Server URL
629    pub fn server_url(&self) -> &str {
630        &self.server_url
631    }
632
633    /// 获取当前服务 URL | Get current service URL
634    pub fn service_url(&self) -> &str {
635        &self.service_url
636    }
637}
638
639#[derive(Debug, Clone, Serialize, Deserialize)]
640pub struct SsoConfig {
641    pub server_url: String,
642    pub ticket_timeout: i64,
643    pub allow_cross_domain: bool,
644    pub allowed_origins: Vec<String>,
645}
646
647impl Default for SsoConfig {
648    fn default() -> Self {
649        Self {
650            server_url: "http://localhost:8080/sso".to_string(),
651            ticket_timeout: 300,
652            allow_cross_domain: true,
653            allowed_origins: vec!["*".to_string()],
654        }
655    }
656}
657
658impl SsoConfig {
659    pub fn builder() -> SsoConfigBuilder {
660        SsoConfigBuilder::default()
661    }
662}
663
664#[derive(Default)]
665pub struct SsoConfigBuilder {
666    config: SsoConfig,
667}
668
669impl SsoConfigBuilder {
670    pub fn server_url(mut self, url: impl Into<String>) -> Self {
671        self.config.server_url = url.into();
672        self
673    }
674
675    pub fn ticket_timeout(mut self, timeout: i64) -> Self {
676        self.config.ticket_timeout = timeout;
677        self
678    }
679
680    pub fn allow_cross_domain(mut self, allow: bool) -> Self {
681        self.config.allow_cross_domain = allow;
682        self
683    }
684
685    pub fn allowed_origins(mut self, origins: Vec<String>) -> Self {
686        self.config.allowed_origins = origins;
687        self
688    }
689
690    pub fn add_allowed_origin(mut self, origin: String) -> Self {
691        if self.config.allowed_origins == vec!["*".to_string()] {
692            self.config.allowed_origins = vec![origin];
693        } else {
694            self.config.allowed_origins.push(origin);
695        }
696        self
697    }
698
699    pub fn build(self) -> SsoConfig {
700        self.config
701    }
702}
703
704pub struct SsoManager {
705    server: Option<Arc<SsoServer>>,
706    client: Option<Arc<SsoClient>>,
707    config: SsoConfig,
708}
709
710impl SsoManager {
711    pub fn new(config: SsoConfig) -> Self {
712        Self {
713            server: None,
714            client: None,
715            config,
716        }
717    }
718
719    pub fn with_server(mut self, server: Arc<SsoServer>) -> Self {
720        self.server = Some(server);
721        self
722    }
723
724    pub fn with_client(mut self, client: Arc<SsoClient>) -> Self {
725        self.client = Some(client);
726        self
727    }
728
729    pub fn server(&self) -> Option<&Arc<SsoServer>> {
730        self.server.as_ref()
731    }
732
733    pub fn client(&self) -> Option<&Arc<SsoClient>> {
734        self.client.as_ref()
735    }
736
737    pub fn config(&self) -> &SsoConfig {
738        &self.config
739    }
740
741    pub fn is_allowed_origin(&self, origin: &str) -> bool {
742        if !self.config.allow_cross_domain {
743            return false;
744        }
745
746        self.config.allowed_origins.contains(&"*".to_string()) ||
747        self.config.allowed_origins.contains(&origin.to_string())
748    }
749}
750