Skip to main content

systemprompt_models/auth/
types.rs

1use serde::{Deserialize, Serialize};
2use uuid::Uuid;
3
4use super::enums::UserType;
5use super::permission::Permission;
6
7pub const BEARER_PREFIX: &str = "Bearer ";
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
10pub struct AuthenticatedUser {
11    pub id: Uuid,
12    pub username: String,
13    pub email: String,
14    pub permissions: Vec<Permission>,
15    #[serde(default)]
16    pub roles: Vec<String>,
17}
18
19impl AuthenticatedUser {
20    pub const fn new(
21        id: Uuid,
22        username: String,
23        email: String,
24        permissions: Vec<Permission>,
25    ) -> Self {
26        Self {
27            id,
28            username,
29            email,
30            permissions,
31            roles: Vec::new(),
32        }
33    }
34
35    pub const fn new_with_roles(
36        id: Uuid,
37        username: String,
38        email: String,
39        permissions: Vec<Permission>,
40        roles: Vec<String>,
41    ) -> Self {
42        Self {
43            id,
44            username,
45            email,
46            permissions,
47            roles,
48        }
49    }
50
51    pub fn has_permission(&self, permission: Permission) -> bool {
52        self.permissions.contains(&permission)
53            || self.permissions.iter().any(|p| p.implies(&permission))
54    }
55
56    pub fn is_admin(&self) -> bool {
57        self.has_permission(Permission::Admin)
58    }
59
60    pub fn permissions(&self) -> &[Permission] {
61        &self.permissions
62    }
63
64    pub fn has_role(&self, role: &str) -> bool {
65        self.roles.iter().any(|r| r == role)
66    }
67
68    pub fn roles(&self) -> &[String] {
69        &self.roles
70    }
71
72    pub fn user_type(&self) -> UserType {
73        if self.has_permission(Permission::Admin) {
74            UserType::Admin
75        } else if self.has_permission(Permission::User) {
76            UserType::User
77        } else if self.has_permission(Permission::A2a) {
78            UserType::A2a
79        } else if self.has_permission(Permission::Mcp) {
80            UserType::Mcp
81        } else if self.has_permission(Permission::Service) {
82            UserType::Service
83        } else {
84            UserType::Anon
85        }
86    }
87}
88
89#[derive(Debug, thiserror::Error)]
90pub enum AuthError {
91    #[error("Invalid token format")]
92    InvalidTokenFormat,
93
94    #[error("Token expired")]
95    TokenExpired,
96
97    #[error("Token signature invalid")]
98    InvalidSignature,
99
100    #[error("User not found")]
101    UserNotFound,
102
103    #[error("Insufficient permissions")]
104    InsufficientPermissions,
105
106    #[error("Authentication failed: {message}")]
107    AuthenticationFailed { message: String },
108
109    #[error("Invalid OAuth request: {reason}")]
110    InvalidRequest { reason: String },
111
112    #[error("CSRF token (state) is required")]
113    MissingState,
114
115    #[error("Redirect URI is required and must be registered")]
116    InvalidRedirectUri,
117
118    #[error("PKCE code_challenge is required")]
119    MissingCodeChallenge,
120
121    #[error("PKCE method '{method}' not allowed (must be S256)")]
122    WeakPkceMethod { method: String },
123
124    #[error("Client ID {client_id} not found")]
125    ClientNotFound { client_id: String },
126
127    #[error("Scope '{scope}' is invalid")]
128    InvalidScope { scope: String },
129
130    #[error("Token revocation requires authenticated user")]
131    UnauthenticatedRevocation,
132
133    #[error("WebAuthn RP ID could not be determined")]
134    InvalidRpId,
135
136    #[error("Client registration validation failed: {reason}")]
137    RegistrationFailed { reason: String },
138
139    #[error("Internal error: {0}")]
140    Internal(#[from] anyhow::Error),
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq)]
144pub enum PkceMethod {
145    S256,
146}
147
148impl std::str::FromStr for PkceMethod {
149    type Err = AuthError;
150
151    fn from_str(s: &str) -> Result<Self, Self::Err> {
152        match s {
153            "S256" => Ok(Self::S256),
154            "plain" => Err(AuthError::WeakPkceMethod {
155                method: s.to_string(),
156            }),
157            _ => Err(AuthError::InvalidRequest {
158                reason: format!("Unknown PKCE method: {s}"),
159            }),
160        }
161    }
162}
163
164impl std::fmt::Display for PkceMethod {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        match self {
167            Self::S256 => write!(f, "S256"),
168        }
169    }
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173pub enum GrantType {
174    AuthorizationCode,
175    RefreshToken,
176    ClientCredentials,
177}
178
179impl std::str::FromStr for GrantType {
180    type Err = AuthError;
181
182    fn from_str(s: &str) -> Result<Self, Self::Err> {
183        match s {
184            "authorization_code" => Ok(Self::AuthorizationCode),
185            "refresh_token" => Ok(Self::RefreshToken),
186            "client_credentials" => Ok(Self::ClientCredentials),
187            _ => Err(AuthError::InvalidRequest {
188                reason: format!("Unknown grant type: {s}"),
189            }),
190        }
191    }
192}
193
194impl std::fmt::Display for GrantType {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        match self {
197            Self::AuthorizationCode => write!(f, "authorization_code"),
198            Self::RefreshToken => write!(f, "refresh_token"),
199            Self::ClientCredentials => write!(f, "client_credentials"),
200        }
201    }
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq)]
205pub enum ResponseType {
206    Code,
207    Token,
208}
209
210impl std::str::FromStr for ResponseType {
211    type Err = AuthError;
212
213    fn from_str(s: &str) -> Result<Self, Self::Err> {
214        match s {
215            "code" => Ok(Self::Code),
216            "token" => Ok(Self::Token),
217            _ => Err(AuthError::InvalidRequest {
218                reason: format!("Unknown response type: {s}"),
219            }),
220        }
221    }
222}
223
224impl std::fmt::Display for ResponseType {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        match self {
227            Self::Code => write!(f, "code"),
228            Self::Token => write!(f, "token"),
229        }
230    }
231}