Skip to main content

systemprompt_users/services/user/
provider.rs

1use std::str::FromStr;
2
3use async_trait::async_trait;
4use systemprompt_identifiers::UserId;
5use systemprompt_traits::auth::{
6    AuthProviderError, AuthResult, AuthUser, RoleProvider, UserProvider,
7};
8
9use super::UserService;
10use crate::models::{User, UserRole};
11
12#[async_trait]
13impl UserProvider for UserService {
14    async fn find_by_id(&self, id: &str) -> AuthResult<Option<AuthUser>> {
15        let user_id = UserId::new(id);
16        self.find_by_id(&user_id)
17            .await
18            .map(|opt| opt.map(|u| user_to_auth_user(&u)))
19            .map_err(|e| AuthProviderError::Internal(e.to_string()))
20    }
21
22    async fn find_by_email(&self, email: &str) -> AuthResult<Option<AuthUser>> {
23        Self::find_by_email(self, email)
24            .await
25            .map(|opt| opt.map(|u| user_to_auth_user(&u)))
26            .map_err(|e| AuthProviderError::Internal(e.to_string()))
27    }
28
29    async fn find_by_name(&self, name: &str) -> AuthResult<Option<AuthUser>> {
30        Self::find_by_name(self, name)
31            .await
32            .map(|opt| opt.map(|u| user_to_auth_user(&u)))
33            .map_err(|e| AuthProviderError::Internal(e.to_string()))
34    }
35
36    async fn create_user(
37        &self,
38        name: &str,
39        email: &str,
40        full_name: Option<&str>,
41    ) -> AuthResult<AuthUser> {
42        Self::create(self, name, email, full_name, full_name)
43            .await
44            .map(|u| user_to_auth_user(&u))
45            .map_err(|e| AuthProviderError::Internal(e.to_string()))
46    }
47
48    async fn create_anonymous(&self, fingerprint: &str) -> AuthResult<AuthUser> {
49        Self::create_anonymous(self, fingerprint)
50            .await
51            .map(|u| user_to_auth_user(&u))
52            .map_err(|e| AuthProviderError::Internal(e.to_string()))
53    }
54
55    async fn assign_roles(&self, user_id: &str, roles: &[String]) -> AuthResult<()> {
56        let id = UserId::new(user_id);
57        Self::assign_roles(self, &id, roles)
58            .await
59            .map(|_| ())
60            .map_err(|e| AuthProviderError::Internal(e.to_string()))
61    }
62}
63
64fn user_to_auth_user(user: &User) -> AuthUser {
65    AuthUser {
66        id: user.id.to_string(),
67        name: user.name.clone(),
68        email: user.email.clone(),
69        roles: user.roles.clone(),
70        is_active: user.is_active(),
71    }
72}
73
74#[async_trait]
75impl RoleProvider for UserService {
76    async fn get_roles(&self, user_id: &str) -> AuthResult<Vec<String>> {
77        let id = UserId::new(user_id);
78        match Self::find_by_id(self, &id).await {
79            Ok(Some(user)) => Ok(user.roles),
80            Ok(None) => Err(AuthProviderError::UserNotFound),
81            Err(e) => Err(AuthProviderError::Internal(e.to_string())),
82        }
83    }
84
85    async fn assign_role(&self, user_id: &str, role: &str) -> AuthResult<()> {
86        let id = UserId::new(user_id);
87        let user = match Self::find_by_id(self, &id).await {
88            Ok(Some(u)) => u,
89            Ok(None) => return Err(AuthProviderError::UserNotFound),
90            Err(e) => return Err(AuthProviderError::Internal(e.to_string())),
91        };
92
93        let mut roles = user.roles;
94        let role_str = role.to_string();
95        if !roles.contains(&role_str) {
96            roles.push(role_str);
97        }
98
99        Self::assign_roles(self, &id, &roles)
100            .await
101            .map(|_| ())
102            .map_err(|e| AuthProviderError::Internal(e.to_string()))
103    }
104
105    async fn revoke_role(&self, user_id: &str, role: &str) -> AuthResult<()> {
106        let id = UserId::new(user_id);
107        let user = match Self::find_by_id(self, &id).await {
108            Ok(Some(u)) => u,
109            Ok(None) => return Err(AuthProviderError::UserNotFound),
110            Err(e) => return Err(AuthProviderError::Internal(e.to_string())),
111        };
112
113        let roles: Vec<String> = user.roles.into_iter().filter(|r| r != role).collect();
114
115        Self::assign_roles(self, &id, &roles)
116            .await
117            .map(|_| ())
118            .map_err(|e| AuthProviderError::Internal(e.to_string()))
119    }
120
121    async fn list_users_by_role(&self, role: &str) -> AuthResult<Vec<AuthUser>> {
122        let Ok(user_role) = UserRole::from_str(role) else {
123            return Ok(vec![]);
124        };
125
126        Self::find_by_role(self, user_role)
127            .await
128            .map(|users| users.iter().map(user_to_auth_user).collect())
129            .map_err(|e| AuthProviderError::Internal(e.to_string()))
130    }
131}