Skip to main content

systemprompt_users/services/user/
provider.rs

1use std::str::FromStr;
2
3use async_trait::async_trait;
4use systemprompt_identifiers::UserId;
5use systemprompt_traits::auth::{
6    AuthProviderError, AuthResult, AuthUser, FederatedIdentityClaims, RoleProvider, UserProvider,
7};
8
9use super::UserService;
10use crate::models::{User, UserRole};
11
12#[async_trait]
13impl UserProvider for UserService {
14    async fn find_by_id(&self, id: &UserId) -> AuthResult<Option<AuthUser>> {
15        self.find_by_id(id)
16            .await
17            .map(|opt| opt.map(|u| user_to_auth_user(&u)))
18            .map_err(|e| AuthProviderError::Internal(e.to_string()))
19    }
20
21    async fn find_by_email(&self, email: &str) -> AuthResult<Option<AuthUser>> {
22        Self::find_by_email(self, email)
23            .await
24            .map(|opt| opt.map(|u| user_to_auth_user(&u)))
25            .map_err(|e| AuthProviderError::Internal(e.to_string()))
26    }
27
28    async fn find_by_name(&self, name: &str) -> AuthResult<Option<AuthUser>> {
29        Self::find_by_name(self, name)
30            .await
31            .map(|opt| opt.map(|u| user_to_auth_user(&u)))
32            .map_err(|e| AuthProviderError::Internal(e.to_string()))
33    }
34
35    async fn create_user(
36        &self,
37        name: &str,
38        email: &str,
39        full_name: Option<&str>,
40    ) -> AuthResult<AuthUser> {
41        Self::create(self, name, email, full_name, full_name)
42            .await
43            .map(|u| user_to_auth_user(&u))
44            .map_err(|e| AuthProviderError::Internal(e.to_string()))
45    }
46
47    async fn create_anonymous(&self, fingerprint: &str) -> AuthResult<AuthUser> {
48        Self::create_anonymous(self, fingerprint)
49            .await
50            .map(|u| user_to_auth_user(&u))
51            .map_err(|e| AuthProviderError::Internal(e.to_string()))
52    }
53
54    async fn assign_roles(&self, user_id: &UserId, roles: &[String]) -> AuthResult<()> {
55        Self::assign_roles(self, user_id, roles)
56            .await
57            .map(|_| ())
58            .map_err(|e| AuthProviderError::Internal(e.to_string()))
59    }
60
61    async fn find_or_create_federated(
62        &self,
63        issuer: &str,
64        external_sub: &str,
65        claims: &FederatedIdentityClaims,
66    ) -> AuthResult<UserId> {
67        Self::find_or_create_federated(self, issuer, external_sub, claims)
68            .await
69            .map(|u| u.id)
70            .map_err(|e| AuthProviderError::Internal(e.to_string()))
71    }
72}
73
74fn user_to_auth_user(user: &User) -> AuthUser {
75    AuthUser {
76        id: user.id.clone(),
77        name: user.name.clone(),
78        email: user.email.clone(),
79        roles: user.roles.clone(),
80        is_active: user.is_active(),
81    }
82}
83
84#[async_trait]
85impl RoleProvider for UserService {
86    async fn get_roles(&self, user_id: &UserId) -> AuthResult<Vec<String>> {
87        match Self::find_by_id(self, user_id).await {
88            Ok(Some(user)) => Ok(user.roles),
89            Ok(None) => Err(AuthProviderError::UserNotFound),
90            Err(e) => Err(AuthProviderError::Internal(e.to_string())),
91        }
92    }
93
94    async fn assign_role(&self, user_id: &UserId, role: &str) -> AuthResult<()> {
95        let user = match Self::find_by_id(self, user_id).await {
96            Ok(Some(u)) => u,
97            Ok(None) => return Err(AuthProviderError::UserNotFound),
98            Err(e) => return Err(AuthProviderError::Internal(e.to_string())),
99        };
100
101        let mut roles = user.roles;
102        let role_str = role.to_owned();
103        if !roles.contains(&role_str) {
104            roles.push(role_str);
105        }
106
107        Self::assign_roles(self, user_id, &roles)
108            .await
109            .map(|_| ())
110            .map_err(|e| AuthProviderError::Internal(e.to_string()))
111    }
112
113    async fn revoke_role(&self, user_id: &UserId, role: &str) -> AuthResult<()> {
114        let user = match Self::find_by_id(self, user_id).await {
115            Ok(Some(u)) => u,
116            Ok(None) => return Err(AuthProviderError::UserNotFound),
117            Err(e) => return Err(AuthProviderError::Internal(e.to_string())),
118        };
119
120        let roles: Vec<String> = user.roles.into_iter().filter(|r| r != role).collect();
121
122        Self::assign_roles(self, user_id, &roles)
123            .await
124            .map(|_| ())
125            .map_err(|e| AuthProviderError::Internal(e.to_string()))
126    }
127
128    async fn list_users_by_role(&self, role: &str) -> AuthResult<Vec<AuthUser>> {
129        let Ok(user_role) = UserRole::from_str(role) else {
130            return Ok(vec![]);
131        };
132
133        Self::find_by_role(self, user_role)
134            .await
135            .map(|users| users.iter().map(user_to_auth_user).collect())
136            .map_err(|e| AuthProviderError::Internal(e.to_string()))
137    }
138}