Skip to main content

systemprompt_models/auth/
types.rs

1use serde::{Deserialize, Serialize};
2use systemprompt_identifiers::ClientId;
3use uuid::Uuid;
4
5use super::enums::UserType;
6use super::permission::Permission;
7
8pub const BEARER_PREFIX: &str = "Bearer ";
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct AuthenticatedUser {
12    pub id: Uuid,
13    pub username: String,
14    pub email: String,
15    pub permissions: Vec<Permission>,
16    #[serde(default)]
17    pub roles: Vec<String>,
18    #[serde(default)]
19    pub department: Option<String>,
20}
21
22impl AuthenticatedUser {
23    pub const fn new(
24        id: Uuid,
25        username: String,
26        email: String,
27        permissions: Vec<Permission>,
28    ) -> Self {
29        Self {
30            id,
31            username,
32            email,
33            permissions,
34            roles: Vec::new(),
35            department: None,
36        }
37    }
38
39    pub const fn new_with_roles(
40        id: Uuid,
41        username: String,
42        email: String,
43        permissions: Vec<Permission>,
44        roles: Vec<String>,
45    ) -> Self {
46        Self {
47            id,
48            username,
49            email,
50            permissions,
51            roles,
52            department: None,
53        }
54    }
55
56    #[must_use]
57    pub fn with_department(mut self, department: Option<String>) -> Self {
58        self.department = department;
59        self
60    }
61
62    #[must_use]
63    pub fn department(&self) -> Option<&str> {
64        self.department.as_deref()
65    }
66
67    pub fn has_permission(&self, permission: Permission) -> bool {
68        self.permissions.contains(&permission)
69            || self.permissions.iter().any(|p| p.implies(&permission))
70    }
71
72    pub fn is_admin(&self) -> bool {
73        self.has_permission(Permission::Admin)
74    }
75
76    pub fn permissions(&self) -> &[Permission] {
77        &self.permissions
78    }
79
80    pub fn has_role(&self, role: &str) -> bool {
81        self.roles.iter().any(|r| r == role)
82    }
83
84    pub fn roles(&self) -> &[String] {
85        &self.roles
86    }
87
88    pub fn user_type(&self) -> UserType {
89        UserType::from_permissions(&self.permissions)
90    }
91}
92
93#[derive(Debug, thiserror::Error)]
94pub enum AuthError {
95    #[error("Invalid token format")]
96    InvalidTokenFormat,
97
98    #[error("Token expired")]
99    TokenExpired,
100
101    #[error("Token signature invalid")]
102    InvalidSignature,
103
104    #[error("User not found")]
105    UserNotFound,
106
107    #[error("Insufficient permissions")]
108    InsufficientPermissions,
109
110    #[error("Authentication failed: {message}")]
111    AuthenticationFailed { message: String },
112
113    #[error("Invalid OAuth request: {reason}")]
114    InvalidRequest { reason: String },
115
116    #[error("CSRF token (state) is required")]
117    MissingState,
118
119    #[error("Redirect URI is required and must be registered")]
120    InvalidRedirectUri,
121
122    #[error("PKCE code_challenge is required")]
123    MissingCodeChallenge,
124
125    #[error("PKCE method '{method}' not allowed (must be S256)")]
126    WeakPkceMethod { method: String },
127
128    #[error("Client ID {client_id} not found")]
129    ClientNotFound { client_id: ClientId },
130
131    #[error("Scope '{scope}' is invalid")]
132    InvalidScope { scope: String },
133
134    #[error("Token revocation requires authenticated user")]
135    UnauthenticatedRevocation,
136
137    #[error("WebAuthn RP ID could not be determined")]
138    InvalidRpId,
139
140    #[error("Client registration validation failed: {reason}")]
141    RegistrationFailed { reason: String },
142
143    #[error("Internal error: {0}")]
144    Internal(String),
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum PkceMethod {
149    S256,
150}
151
152impl std::str::FromStr for PkceMethod {
153    type Err = AuthError;
154
155    fn from_str(s: &str) -> Result<Self, Self::Err> {
156        match s {
157            "S256" => Ok(Self::S256),
158            "plain" => Err(AuthError::WeakPkceMethod {
159                method: s.to_owned(),
160            }),
161            _ => Err(AuthError::InvalidRequest {
162                reason: format!("Unknown PKCE method: {s}"),
163            }),
164        }
165    }
166}
167
168impl std::fmt::Display for PkceMethod {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        match self {
171            Self::S256 => write!(f, "S256"),
172        }
173    }
174}
175
176#[derive(Debug, Clone, Copy, PartialEq, Eq)]
177pub enum ResponseType {
178    Code,
179    Token,
180}
181
182impl std::str::FromStr for ResponseType {
183    type Err = AuthError;
184
185    fn from_str(s: &str) -> Result<Self, Self::Err> {
186        match s {
187            "code" => Ok(Self::Code),
188            "token" => Ok(Self::Token),
189            _ => Err(AuthError::InvalidRequest {
190                reason: format!("Unknown response type: {s}"),
191            }),
192        }
193    }
194}
195
196impl std::fmt::Display for ResponseType {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        match self {
199            Self::Code => write!(f, "code"),
200            Self::Token => write!(f, "token"),
201        }
202    }
203}