Skip to main content

sh_layer0/
access_controller.rs

1//! 访问控制模块
2//!
3//! RBAC 权限管理系统。
4
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8
9/// 角色
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub struct Role {
12    pub name: String,
13    pub permissions: HashSet<Permission>,
14}
15
16/// 权限
17#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub struct Permission {
19    pub resource: String,
20    pub action: String,
21}
22
23impl Permission {
24    pub fn new(resource: impl Into<String>, action: impl Into<String>) -> Self {
25        Self {
26            resource: resource.into(),
27            action: action.into(),
28        }
29    }
30}
31
32/// 用户信息
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct User {
35    pub id: String,
36    pub roles: HashSet<String>,
37}
38
39/// 访问控制器
40pub struct AccessController {
41    /// 角色-权限映射
42    roles: RwLock<HashMap<String, Role>>,
43    /// 用户-角色映射
44    user_roles: RwLock<HashMap<String, HashSet<String>>>,
45}
46
47impl AccessController {
48    pub fn new() -> Self {
49        Self {
50            roles: RwLock::new(Self::default_roles()),
51            user_roles: RwLock::new(HashMap::new()),
52        }
53    }
54
55    /// 默认角色配置
56    fn default_roles() -> HashMap<String, Role> {
57        let mut roles = HashMap::new();
58
59        // Admin 角色 - 所有权限
60        let admin_perms: HashSet<Permission> = [Permission::new("*", "*")].into_iter().collect();
61        roles.insert(
62            "admin".to_string(),
63            Role {
64                name: "admin".to_string(),
65                permissions: admin_perms,
66            },
67        );
68
69        // User 角色 - 基本权限
70        let user_perms: HashSet<Permission> = [
71            Permission::new("session", "read"),
72            Permission::new("session", "write"),
73            Permission::new("tool", "execute"),
74            Permission::new("agent", "run"),
75        ]
76        .into_iter()
77        .collect();
78        roles.insert(
79            "user".to_string(),
80            Role {
81                name: "user".to_string(),
82                permissions: user_perms,
83            },
84        );
85
86        // Guest 角色 - 最小权限
87        let guest_perms: HashSet<Permission> =
88            [Permission::new("session", "read")].into_iter().collect();
89        roles.insert(
90            "guest".to_string(),
91            Role {
92                name: "guest".to_string(),
93                permissions: guest_perms,
94            },
95        );
96
97        roles
98    }
99
100    /// 检查权限
101    pub fn check(&self, user_id: &str, resource: &str, action: &str) -> bool {
102        let user_roles = self.user_roles.read();
103
104        // 获取用户角色
105        let roles = user_roles.get(user_id).cloned().unwrap_or_else(|| {
106            // 默认给 guest 角色
107            HashSet::from(["guest".to_string()])
108        });
109
110        let role_map = self.roles.read();
111
112        // 检查每个角色的权限
113        for role_name in roles {
114            if let Some(role) = role_map.get(&role_name) {
115                for perm in &role.permissions {
116                    // 检查通配符权限
117                    if (perm.resource == "*" || perm.resource == resource)
118                        && (perm.action == "*" || perm.action == action)
119                    {
120                        return true;
121                    }
122                }
123            }
124        }
125
126        false
127    }
128
129    /// 为用户添加角色
130    pub fn add_role(&self, user_id: &str, role_name: &str) {
131        let mut user_roles = self.user_roles.write();
132        user_roles
133            .entry(user_id.to_string())
134            .or_default()
135            .insert(role_name.to_string());
136    }
137
138    /// 为用户移除角色
139    pub fn remove_role(&self, user_id: &str, role_name: &str) {
140        let mut user_roles = self.user_roles.write();
141        if let Some(roles) = user_roles.get_mut(user_id) {
142            roles.remove(role_name);
143            // 如果角色集合为空,则删除该用户的记录,使其回退到 guest 默认权限
144            if roles.is_empty() {
145                user_roles.remove(user_id);
146            }
147        }
148    }
149
150    /// 创建自定义角色
151    pub fn create_role(&self, role: Role) {
152        let mut roles = self.roles.write();
153        roles.insert(role.name.clone(), role);
154    }
155
156    /// 获取用户所有权限
157    pub fn get_permissions(&self, user_id: &str) -> HashSet<Permission> {
158        let user_roles = self.user_roles.read();
159        let roles = user_roles.get(user_id).cloned().unwrap_or_default();
160        let role_map = self.roles.read();
161
162        let mut permissions = HashSet::new();
163        for role_name in roles {
164            if let Some(role) = role_map.get(&role_name) {
165                permissions.extend(role.permissions.clone());
166            }
167        }
168        permissions
169    }
170}
171
172impl Default for AccessController {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_default_roles() {
184        let controller = AccessController::new();
185
186        // Admin 应该有所有权限
187        controller.add_role("admin_user", "admin");
188        assert!(controller.check("admin_user", "any_resource", "any_action"));
189
190        // User 应该有基本权限
191        controller.add_role("normal_user", "user");
192        assert!(controller.check("normal_user", "tool", "execute"));
193        assert!(!controller.check("normal_user", "admin", "create"));
194    }
195
196    #[test]
197    fn test_guest_default() {
198        let controller = AccessController::new();
199
200        // 未设置角色的用户默认是 guest
201        assert!(controller.check("unknown_user", "session", "read"));
202        assert!(!controller.check("unknown_user", "session", "write"));
203    }
204
205    #[test]
206    fn test_remove_role() {
207        let controller = AccessController::new();
208
209        controller.add_role("test_user", "admin");
210        assert!(controller.check("test_user", "any", "any"));
211
212        controller.remove_role("test_user", "admin");
213        // 移除后应该回退到 guest 权限
214        assert!(!controller.check("test_user", "any", "any"));
215        assert!(controller.check("test_user", "session", "read"));
216    }
217
218    #[test]
219    fn test_create_custom_role() {
220        let controller = AccessController::new();
221
222        let custom_role = Role {
223            name: "custom".to_string(),
224            permissions: HashSet::from([
225                Permission::new("custom_resource", "read"),
226                Permission::new("custom_resource", "write"),
227            ]),
228        };
229        controller.create_role(custom_role);
230
231        controller.add_role("custom_user", "custom");
232        assert!(controller.check("custom_user", "custom_resource", "read"));
233        assert!(controller.check("custom_user", "custom_resource", "write"));
234        assert!(!controller.check("custom_user", "other_resource", "read"));
235    }
236
237    #[test]
238    fn test_get_permissions() {
239        let controller = AccessController::new();
240
241        controller.add_role("multi_user", "user");
242        controller.add_role("multi_user", "guest");
243
244        let permissions = controller.get_permissions("multi_user");
245        // 应该合并 user 和 guest 的权限
246        assert!(permissions.contains(&Permission::new("session", "read")));
247        assert!(permissions.contains(&Permission::new("session", "write")));
248        assert!(permissions.contains(&Permission::new("tool", "execute")));
249    }
250
251    #[test]
252    fn test_permission_new() {
253        let perm = Permission::new("resource", "action");
254        assert_eq!(perm.resource, "resource");
255        assert_eq!(perm.action, "action");
256    }
257
258    #[test]
259    fn test_multiple_roles_same_user() {
260        let controller = AccessController::new();
261
262        controller.add_role("power_user", "user");
263        controller.add_role("power_user", "admin");
264
265        // 有 admin 权限应该可以访问任何资源
266        assert!(controller.check("power_user", "super_secret", "delete"));
267    }
268
269    #[test]
270    fn test_role_serialization() {
271        let role = Role {
272            name: "test".to_string(),
273            permissions: HashSet::from([Permission::new("r", "a")]),
274        };
275        let json = serde_json::to_string(&role).unwrap();
276        assert!(json.contains("test"));
277    }
278
279    #[test]
280    fn test_permission_hash_equality() {
281        let p1 = Permission::new("resource", "action");
282        let p2 = Permission::new("resource", "action");
283        let set: HashSet<Permission> = HashSet::from([p1, p2]);
284        assert_eq!(set.len(), 1); // 相同权限应该只保留一个
285    }
286
287    #[test]
288    fn test_concurrent_access() {
289        use std::sync::Arc;
290        use std::thread;
291
292        let controller = Arc::new(AccessController::new());
293        controller.add_role("user1", "admin");
294
295        let mut handles = vec![];
296        for i in 0..10 {
297            let c = Arc::clone(&controller);
298            handles.push(thread::spawn(move || {
299                let user = format!("user{}", i);
300                c.check(&user, "session", "read")
301            }));
302        }
303
304        for handle in handles {
305            handle.join().unwrap();
306        }
307    }
308}