torrust_index/services/
authorization.rs

1//! Authorization service.
2use std::fmt;
3use std::sync::Arc;
4
5use casbin::{CoreApi, DefaultModel, Enforcer, MgmtApi};
6use serde::{Deserialize, Serialize};
7use tokio::sync::RwLock;
8
9use super::user::Repository;
10use crate::errors::ServiceError;
11use crate::models::user::{UserCompact, UserId};
12
13#[derive(Debug, PartialEq, Serialize, Deserialize, Hash)]
14#[serde(rename_all = "lowercase")]
15enum UserRole {
16    Admin,
17    Registered,
18    Guest,
19}
20
21impl fmt::Display for UserRole {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        let role_str = match self {
24            UserRole::Admin => "admin",
25            UserRole::Registered => "registered",
26            UserRole::Guest => "guest",
27        };
28        write!(f, "{role_str}")
29    }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, Hash)]
33pub enum ACTION {
34    GetAboutPage,
35    GetLicensePage,
36    AddCategory,
37    DeleteCategory,
38    GetCategories,
39    GetImageByUrl,
40    GetSettings,
41    GetSettingsSecret,
42    GetPublicSettings,
43    GetSiteName,
44    AddTag,
45    DeleteTag,
46    GetTags,
47    AddTorrent,
48    GetTorrent,
49    DeleteTorrent,
50    GetTorrentInfo,
51    GenerateTorrentInfoListing,
52    GetCanonicalInfoHash,
53    ChangePassword,
54    BanUser,
55}
56
57pub struct Service {
58    user_repository: Arc<Box<dyn Repository>>,
59    casbin_enforcer: Arc<CasbinEnforcer>,
60}
61
62impl Service {
63    #[must_use]
64    pub fn new(user_repository: Arc<Box<dyn Repository>>, casbin_enforcer: Arc<CasbinEnforcer>) -> Self {
65        Self {
66            user_repository,
67            casbin_enforcer,
68        }
69    }
70
71    ///Allows or denies an user to perform an action based on the user's privileges
72    ///
73    /// # Errors
74    ///
75    /// Will return an error if:
76    /// - The user is not authorized to perform the action.
77    pub async fn authorize(&self, action: ACTION, maybe_user_id: Option<UserId>) -> std::result::Result<(), ServiceError> {
78        let role = self.get_role(maybe_user_id).await;
79
80        let enforcer = self.casbin_enforcer.enforcer.read().await;
81
82        let authorize = enforcer
83            .enforce((&role, action))
84            .map_err(|_| ServiceError::UnauthorizedAction)?;
85
86        if authorize {
87            Ok(())
88        } else if role == UserRole::Guest {
89            Err(ServiceError::UnauthorizedActionForGuests)
90        } else {
91            Err(ServiceError::UnauthorizedAction)
92        }
93    }
94
95    /// It returns the compact user.
96    ///
97    /// # Errors
98    ///
99    /// It returns an error if there is a database error.
100    async fn get_user(&self, user_id: UserId) -> std::result::Result<UserCompact, ServiceError> {
101        self.user_repository.get_compact(&user_id).await
102    }
103
104    /// It returns the role of the user.
105    /// If the user found in the request does not exist in the database or there is no user id, a guest role is returned
106    async fn get_role(&self, maybe_user_id: Option<UserId>) -> UserRole {
107        match maybe_user_id {
108            Some(user_id) => {
109                // Checks if the user found in the request exists in the database
110                let user_guard = self.get_user(user_id).await;
111
112                match user_guard {
113                    Ok(user_data) => {
114                        if user_data.administrator {
115                            UserRole::Admin
116                        } else {
117                            UserRole::Registered
118                        }
119                    }
120                    Err(_) => UserRole::Guest,
121                }
122            }
123            None => UserRole::Guest,
124        }
125    }
126}
127
128pub struct CasbinEnforcer {
129    enforcer: Arc<RwLock<Enforcer>>,
130}
131
132impl CasbinEnforcer {
133    /// # Panics
134    ///
135    /// Will panic if:
136    ///
137    /// - The enforcer can't be created.
138    /// - The policies can't be loaded.
139    pub async fn with_default_configuration() -> Self {
140        let casbin_configuration = CasbinConfiguration::default();
141
142        let mut enforcer = Enforcer::new(casbin_configuration.default_model().await, ())
143            .await
144            .expect("Error creating the enforcer");
145
146        enforcer
147            .add_policies(casbin_configuration.policy_lines())
148            .await
149            .expect("Error loading the policy");
150
151        let enforcer = Arc::new(RwLock::new(enforcer));
152
153        Self { enforcer }
154    }
155
156    /// # Panics
157    ///
158    /// Will panic if:
159    ///
160    /// - The enforcer can't be created.
161    /// - The policies can't be loaded.
162    pub async fn with_configuration(casbin_configuration: CasbinConfiguration) -> Self {
163        let mut enforcer = Enforcer::new(casbin_configuration.default_model().await, ())
164            .await
165            .expect("Error creating the enforcer");
166
167        enforcer
168            .add_policies(casbin_configuration.policy_lines())
169            .await
170            .expect("Error loading the policy");
171
172        let enforcer = Arc::new(RwLock::new(enforcer));
173
174        Self { enforcer }
175    }
176}
177
178#[allow(dead_code)]
179pub struct CasbinConfiguration {
180    model: String,
181    policy: String,
182}
183
184impl CasbinConfiguration {
185    #[must_use]
186    pub fn new(model: &str, policy: &str) -> Self {
187        Self {
188            model: model.to_owned(),
189            policy: policy.to_owned(),
190        }
191    }
192
193    /// # Panics
194    ///
195    /// It panics if the model cannot be loaded.
196    async fn default_model(&self) -> DefaultModel {
197        DefaultModel::from_str(&self.model).await.expect("Error loading the model")
198    }
199
200    /// Converts the policy from a string type to a vector.
201    fn policy_lines(&self) -> Vec<Vec<String>> {
202        self.policy
203            .lines()
204            .filter(|line| !line.trim().is_empty())
205            .map(|line| line.split(',').map(|s| s.trim().to_owned()).collect::<Vec<String>>())
206            .collect()
207    }
208}
209
210impl Default for CasbinConfiguration {
211    fn default() -> Self {
212        Self {
213            model: String::from(
214                "
215                [request_definition]
216                r = role, action
217                
218                [policy_definition]
219                p = role, action
220                
221                [policy_effect]
222                e = some(where (p.eft == allow))
223                
224                [matchers]
225                m = r.role == p.role && r.action == p.action
226            ",
227            ),
228            policy: String::from(
229                "
230                admin, GetAboutPage
231                admin, GetLicensePage
232                admin, AddCategory
233                admin, DeleteCategory
234                admin, GetCategories
235                admin, GetImageByUrl
236                admin, GetSettings
237                admin, GetSettingsSecret
238                admin, GetPublicSettings
239                admin, GetSiteName
240                admin, AddTag
241                admin, DeleteTag
242                admin, GetTags
243                admin, AddTorrent
244                admin, GetTorrent
245                admin, DeleteTorrent
246                admin, GetTorrentInfo
247                admin, GenerateTorrentInfoListing
248                admin, GetCanonicalInfoHash
249                admin, ChangePassword
250                admin, BanUser
251                registered, GetAboutPage
252                registered, GetLicensePage
253                registered, GetCategories
254                registered, GetImageByUrl
255                registered, GetPublicSettings
256                registered, GetSiteName
257                registered, GetTags
258                registered, AddTorrent
259                registered, GetTorrent
260                registered, GetTorrentInfo
261                registered, GenerateTorrentInfoListing
262                registered, GetCanonicalInfoHash
263                registered, ChangePassword
264                guest, GetAboutPage
265                guest, GetLicensePage
266                guest, GetCategories
267                guest, GetPublicSettings
268                guest, GetSiteName
269                guest, GetTags
270                guest, GetTorrent
271                guest, GetTorrentInfo
272                guest, GenerateTorrentInfoListing
273                guest, GetCanonicalInfoHash
274                ",
275            ),
276        }
277    }
278}