1use std::fmt;
3use std::sync::Arc;
4
5use casbin::{CoreApi, DefaultModel, Enforcer, MgmtApi};
6use serde::{Deserialize, Serialize};
7use tokio::sync::RwLock;
8
9use super::user::Repository;
10use crate::errors::ServiceError;
11use crate::models::user::{UserCompact, UserId};
12
13#[derive(Debug, PartialEq, Serialize, Deserialize, Hash)]
14#[serde(rename_all = "lowercase")]
15enum UserRole {
16 Admin,
17 Registered,
18 Guest,
19}
20
21impl fmt::Display for UserRole {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 let role_str = match self {
24 UserRole::Admin => "admin",
25 UserRole::Registered => "registered",
26 UserRole::Guest => "guest",
27 };
28 write!(f, "{role_str}")
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, Hash)]
33pub enum ACTION {
34 GetAboutPage,
35 GetLicensePage,
36 AddCategory,
37 DeleteCategory,
38 GetCategories,
39 GetImageByUrl,
40 GetSettings,
41 GetSettingsSecret,
42 GetPublicSettings,
43 GetSiteName,
44 AddTag,
45 DeleteTag,
46 GetTags,
47 AddTorrent,
48 GetTorrent,
49 DeleteTorrent,
50 GetTorrentInfo,
51 GenerateTorrentInfoListing,
52 GetCanonicalInfoHash,
53 ChangePassword,
54 BanUser,
55}
56
57pub struct Service {
58 user_repository: Arc<Box<dyn Repository>>,
59 casbin_enforcer: Arc<CasbinEnforcer>,
60}
61
62impl Service {
63 #[must_use]
64 pub fn new(user_repository: Arc<Box<dyn Repository>>, casbin_enforcer: Arc<CasbinEnforcer>) -> Self {
65 Self {
66 user_repository,
67 casbin_enforcer,
68 }
69 }
70
71 pub async fn authorize(&self, action: ACTION, maybe_user_id: Option<UserId>) -> std::result::Result<(), ServiceError> {
78 let role = self.get_role(maybe_user_id).await;
79
80 let enforcer = self.casbin_enforcer.enforcer.read().await;
81
82 let authorize = enforcer
83 .enforce((&role, action))
84 .map_err(|_| ServiceError::UnauthorizedAction)?;
85
86 if authorize {
87 Ok(())
88 } else if role == UserRole::Guest {
89 Err(ServiceError::UnauthorizedActionForGuests)
90 } else {
91 Err(ServiceError::UnauthorizedAction)
92 }
93 }
94
95 async fn get_user(&self, user_id: UserId) -> std::result::Result<UserCompact, ServiceError> {
101 self.user_repository.get_compact(&user_id).await
102 }
103
104 async fn get_role(&self, maybe_user_id: Option<UserId>) -> UserRole {
107 match maybe_user_id {
108 Some(user_id) => {
109 let user_guard = self.get_user(user_id).await;
111
112 match user_guard {
113 Ok(user_data) => {
114 if user_data.administrator {
115 UserRole::Admin
116 } else {
117 UserRole::Registered
118 }
119 }
120 Err(_) => UserRole::Guest,
121 }
122 }
123 None => UserRole::Guest,
124 }
125 }
126}
127
128pub struct CasbinEnforcer {
129 enforcer: Arc<RwLock<Enforcer>>,
130}
131
132impl CasbinEnforcer {
133 pub async fn with_default_configuration() -> Self {
140 let casbin_configuration = CasbinConfiguration::default();
141
142 let mut enforcer = Enforcer::new(casbin_configuration.default_model().await, ())
143 .await
144 .expect("Error creating the enforcer");
145
146 enforcer
147 .add_policies(casbin_configuration.policy_lines())
148 .await
149 .expect("Error loading the policy");
150
151 let enforcer = Arc::new(RwLock::new(enforcer));
152
153 Self { enforcer }
154 }
155
156 pub async fn with_configuration(casbin_configuration: CasbinConfiguration) -> Self {
163 let mut enforcer = Enforcer::new(casbin_configuration.default_model().await, ())
164 .await
165 .expect("Error creating the enforcer");
166
167 enforcer
168 .add_policies(casbin_configuration.policy_lines())
169 .await
170 .expect("Error loading the policy");
171
172 let enforcer = Arc::new(RwLock::new(enforcer));
173
174 Self { enforcer }
175 }
176}
177
178#[allow(dead_code)]
179pub struct CasbinConfiguration {
180 model: String,
181 policy: String,
182}
183
184impl CasbinConfiguration {
185 #[must_use]
186 pub fn new(model: &str, policy: &str) -> Self {
187 Self {
188 model: model.to_owned(),
189 policy: policy.to_owned(),
190 }
191 }
192
193 async fn default_model(&self) -> DefaultModel {
197 DefaultModel::from_str(&self.model).await.expect("Error loading the model")
198 }
199
200 fn policy_lines(&self) -> Vec<Vec<String>> {
202 self.policy
203 .lines()
204 .filter(|line| !line.trim().is_empty())
205 .map(|line| line.split(',').map(|s| s.trim().to_owned()).collect::<Vec<String>>())
206 .collect()
207 }
208}
209
210impl Default for CasbinConfiguration {
211 fn default() -> Self {
212 Self {
213 model: String::from(
214 "
215 [request_definition]
216 r = role, action
217
218 [policy_definition]
219 p = role, action
220
221 [policy_effect]
222 e = some(where (p.eft == allow))
223
224 [matchers]
225 m = r.role == p.role && r.action == p.action
226 ",
227 ),
228 policy: String::from(
229 "
230 admin, GetAboutPage
231 admin, GetLicensePage
232 admin, AddCategory
233 admin, DeleteCategory
234 admin, GetCategories
235 admin, GetImageByUrl
236 admin, GetSettings
237 admin, GetSettingsSecret
238 admin, GetPublicSettings
239 admin, GetSiteName
240 admin, AddTag
241 admin, DeleteTag
242 admin, GetTags
243 admin, AddTorrent
244 admin, GetTorrent
245 admin, DeleteTorrent
246 admin, GetTorrentInfo
247 admin, GenerateTorrentInfoListing
248 admin, GetCanonicalInfoHash
249 admin, ChangePassword
250 admin, BanUser
251 registered, GetAboutPage
252 registered, GetLicensePage
253 registered, GetCategories
254 registered, GetImageByUrl
255 registered, GetPublicSettings
256 registered, GetSiteName
257 registered, GetTags
258 registered, AddTorrent
259 registered, GetTorrent
260 registered, GetTorrentInfo
261 registered, GenerateTorrentInfoListing
262 registered, GetCanonicalInfoHash
263 registered, ChangePassword
264 guest, GetAboutPage
265 guest, GetLicensePage
266 guest, GetCategories
267 guest, GetPublicSettings
268 guest, GetSiteName
269 guest, GetTags
270 guest, GetTorrent
271 guest, GetTorrentInfo
272 guest, GenerateTorrentInfoListing
273 guest, GetCanonicalInfoHash
274 ",
275 ),
276 }
277 }
278}