Skip to main content

systemprompt_models/auth/
types.rs

1use serde::{Deserialize, Serialize};
2use systemprompt_identifiers::ClientId;
3use uuid::Uuid;
4
5use super::enums::UserType;
6use super::permission::Permission;
7
8pub const BEARER_PREFIX: &str = "Bearer ";
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct AuthenticatedUser {
12    pub id: Uuid,
13    pub username: String,
14    pub email: String,
15    pub permissions: Vec<Permission>,
16    #[serde(default)]
17    pub roles: Vec<String>,
18    #[serde(default)]
19    pub department: Option<String>,
20}
21
22impl AuthenticatedUser {
23    pub const fn new(
24        id: Uuid,
25        username: String,
26        email: String,
27        permissions: Vec<Permission>,
28    ) -> Self {
29        Self {
30            id,
31            username,
32            email,
33            permissions,
34            roles: Vec::new(),
35            department: None,
36        }
37    }
38
39    pub const fn new_with_roles(
40        id: Uuid,
41        username: String,
42        email: String,
43        permissions: Vec<Permission>,
44        roles: Vec<String>,
45    ) -> Self {
46        Self {
47            id,
48            username,
49            email,
50            permissions,
51            roles,
52            department: None,
53        }
54    }
55
56    #[must_use]
57    pub fn with_department(mut self, department: Option<String>) -> Self {
58        self.department = department;
59        self
60    }
61
62    #[must_use]
63    pub fn department(&self) -> Option<&str> {
64        self.department.as_deref()
65    }
66
67    pub fn has_permission(&self, permission: Permission) -> bool {
68        self.permissions.contains(&permission)
69            || self.permissions.iter().any(|p| p.implies(&permission))
70    }
71
72    pub fn is_admin(&self) -> bool {
73        self.has_permission(Permission::Admin)
74    }
75
76    pub fn permissions(&self) -> &[Permission] {
77        &self.permissions
78    }
79
80    pub fn has_role(&self, role: &str) -> bool {
81        self.roles.iter().any(|r| r == role)
82    }
83
84    pub fn roles(&self) -> &[String] {
85        &self.roles
86    }
87
88    pub fn user_type(&self) -> UserType {
89        if self.has_permission(Permission::Admin) {
90            UserType::Admin
91        } else if self.has_permission(Permission::User) {
92            UserType::User
93        } else if self.has_permission(Permission::A2a) {
94            UserType::A2a
95        } else if self.has_permission(Permission::Mcp) {
96            UserType::Mcp
97        } else if self.has_permission(Permission::Service) {
98            UserType::Service
99        } else {
100            UserType::Anon
101        }
102    }
103}
104
105#[derive(Debug, thiserror::Error)]
106pub enum AuthError {
107    #[error("Invalid token format")]
108    InvalidTokenFormat,
109
110    #[error("Token expired")]
111    TokenExpired,
112
113    #[error("Token signature invalid")]
114    InvalidSignature,
115
116    #[error("User not found")]
117    UserNotFound,
118
119    #[error("Insufficient permissions")]
120    InsufficientPermissions,
121
122    #[error("Authentication failed: {message}")]
123    AuthenticationFailed { message: String },
124
125    #[error("Invalid OAuth request: {reason}")]
126    InvalidRequest { reason: String },
127
128    #[error("CSRF token (state) is required")]
129    MissingState,
130
131    #[error("Redirect URI is required and must be registered")]
132    InvalidRedirectUri,
133
134    #[error("PKCE code_challenge is required")]
135    MissingCodeChallenge,
136
137    #[error("PKCE method '{method}' not allowed (must be S256)")]
138    WeakPkceMethod { method: String },
139
140    #[error("Client ID {client_id} not found")]
141    ClientNotFound { client_id: ClientId },
142
143    #[error("Scope '{scope}' is invalid")]
144    InvalidScope { scope: String },
145
146    #[error("Token revocation requires authenticated user")]
147    UnauthenticatedRevocation,
148
149    #[error("WebAuthn RP ID could not be determined")]
150    InvalidRpId,
151
152    #[error("Client registration validation failed: {reason}")]
153    RegistrationFailed { reason: String },
154
155    #[error("Internal error: {0}")]
156    Internal(String),
157}
158
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum PkceMethod {
161    S256,
162}
163
164impl std::str::FromStr for PkceMethod {
165    type Err = AuthError;
166
167    fn from_str(s: &str) -> Result<Self, Self::Err> {
168        match s {
169            "S256" => Ok(Self::S256),
170            "plain" => Err(AuthError::WeakPkceMethod {
171                method: s.to_string(),
172            }),
173            _ => Err(AuthError::InvalidRequest {
174                reason: format!("Unknown PKCE method: {s}"),
175            }),
176        }
177    }
178}
179
180impl std::fmt::Display for PkceMethod {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        match self {
183            Self::S256 => write!(f, "S256"),
184        }
185    }
186}
187
188#[derive(Debug, Clone, Copy, PartialEq, Eq)]
189pub enum GrantType {
190    AuthorizationCode,
191    RefreshToken,
192    ClientCredentials,
193}
194
195impl std::str::FromStr for GrantType {
196    type Err = AuthError;
197
198    fn from_str(s: &str) -> Result<Self, Self::Err> {
199        match s {
200            "authorization_code" => Ok(Self::AuthorizationCode),
201            "refresh_token" => Ok(Self::RefreshToken),
202            "client_credentials" => Ok(Self::ClientCredentials),
203            _ => Err(AuthError::InvalidRequest {
204                reason: format!("Unknown grant type: {s}"),
205            }),
206        }
207    }
208}
209
210impl std::fmt::Display for GrantType {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        match self {
213            Self::AuthorizationCode => write!(f, "authorization_code"),
214            Self::RefreshToken => write!(f, "refresh_token"),
215            Self::ClientCredentials => write!(f, "client_credentials"),
216        }
217    }
218}
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
221pub enum ResponseType {
222    Code,
223    Token,
224}
225
226impl std::str::FromStr for ResponseType {
227    type Err = AuthError;
228
229    fn from_str(s: &str) -> Result<Self, Self::Err> {
230        match s {
231            "code" => Ok(Self::Code),
232            "token" => Ok(Self::Token),
233            _ => Err(AuthError::InvalidRequest {
234                reason: format!("Unknown response type: {s}"),
235            }),
236        }
237    }
238}
239
240impl std::fmt::Display for ResponseType {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        match self {
243            Self::Code => write!(f, "code"),
244            Self::Token => write!(f, "token"),
245        }
246    }
247}