sqlx_adapter/
adapter.rs

1use async_trait::async_trait;
2use casbin::{error::AdapterError, Adapter, Error as CasbinError, Filter, Model, Result};
3use dotenvy::dotenv;
4use std::sync::{
5    atomic::{AtomicBool, Ordering},
6    Arc,
7};
8
9use crate::{error::*, models::*};
10
11use crate::actions as adapter;
12
13#[cfg(feature = "mysql")]
14use sqlx::mysql::MySqlPoolOptions;
15#[cfg(feature = "postgres")]
16use sqlx::postgres::PgPoolOptions;
17#[cfg(feature = "sqlite")]
18use sqlx::sqlite::SqlitePoolOptions;
19
20#[derive(Clone)]
21pub struct SqlxAdapter {
22    pool: adapter::ConnectionPool,
23    is_filtered: Arc<AtomicBool>,
24}
25
26//pub const TABLE_NAME: &str = "casbin_rule";
27
28impl<'a> SqlxAdapter {
29    pub async fn new<U: Into<String>>(url: U, pool_size: u32) -> Result<Self> {
30        dotenv().ok();
31
32        #[cfg(feature = "postgres")]
33        let pool = PgPoolOptions::new()
34            .max_connections(pool_size)
35            .connect(&url.into())
36            .await
37            .map_err(|err| CasbinError::from(AdapterError(Box::new(Error::SqlxError(err)))))?;
38
39        #[cfg(feature = "mysql")]
40        let pool = MySqlPoolOptions::new()
41            .max_connections(pool_size)
42            .connect(&url.into())
43            .await
44            .map_err(|err| CasbinError::from(AdapterError(Box::new(Error::SqlxError(err)))))?;
45
46        #[cfg(feature = "sqlite")]
47        let pool = SqlitePoolOptions::new()
48            .max_connections(pool_size)
49            .connect(&url.into())
50            .await
51            .map_err(|err| CasbinError::from(AdapterError(Box::new(Error::SqlxError(err)))))?;
52
53        adapter::new(&pool).await.map(|_| Self {
54            pool,
55            is_filtered: Arc::new(AtomicBool::new(false)),
56        })
57    }
58
59    pub async fn new_with_pool(pool: adapter::ConnectionPool) -> Result<Self> {
60        adapter::new(&pool).await.map(|_| Self {
61            pool,
62            is_filtered: Arc::new(AtomicBool::new(false)),
63        })
64    }
65
66    pub(crate) fn save_policy_line(
67        &self,
68        ptype: &'a str,
69        rule: &'a [String],
70    ) -> Option<NewCasbinRule<'a>> {
71        if ptype.trim().is_empty() || rule.is_empty() {
72            return None;
73        }
74
75        let mut new_rule = NewCasbinRule {
76            ptype,
77            v0: "",
78            v1: "",
79            v2: "",
80            v3: "",
81            v4: "",
82            v5: "",
83        };
84
85        new_rule.v0 = &rule[0];
86
87        if rule.len() > 1 {
88            new_rule.v1 = &rule[1];
89        }
90
91        if rule.len() > 2 {
92            new_rule.v2 = &rule[2];
93        }
94
95        if rule.len() > 3 {
96            new_rule.v3 = &rule[3];
97        }
98
99        if rule.len() > 4 {
100            new_rule.v4 = &rule[4];
101        }
102
103        if rule.len() > 5 {
104            new_rule.v5 = &rule[5];
105        }
106
107        Some(new_rule)
108    }
109
110    pub(crate) fn load_policy_line(&self, casbin_rule: &CasbinRule) -> Option<Vec<String>> {
111        if casbin_rule.ptype.chars().next().is_some() {
112            return self.normalize_policy(casbin_rule);
113        }
114
115        None
116    }
117
118    fn normalize_policy(&self, casbin_rule: &CasbinRule) -> Option<Vec<String>> {
119        let mut result = vec![
120            &casbin_rule.v0,
121            &casbin_rule.v1,
122            &casbin_rule.v2,
123            &casbin_rule.v3,
124            &casbin_rule.v4,
125            &casbin_rule.v5,
126        ];
127
128        while let Some(last) = result.last() {
129            if last.is_empty() {
130                result.pop();
131            } else {
132                break;
133            }
134        }
135
136        if !result.is_empty() {
137            return Some(result.iter().map(|&x| x.to_owned()).collect());
138        }
139
140        None
141    }
142}
143
144#[async_trait]
145impl Adapter for SqlxAdapter {
146    async fn load_policy(&mut self, m: &mut dyn Model) -> Result<()> {
147        let rules = adapter::load_policy(&self.pool).await?;
148
149        for casbin_rule in &rules {
150            let rule = self.load_policy_line(casbin_rule);
151
152            if let Some(ref sec) = casbin_rule.ptype.chars().next().map(|x| x.to_string()) {
153                if let Some(t1) = m.get_mut_model().get_mut(sec) {
154                    if let Some(t2) = t1.get_mut(&casbin_rule.ptype) {
155                        if let Some(rule) = rule {
156                            t2.get_mut_policy().insert(rule);
157                        }
158                    }
159                }
160            }
161        }
162
163        Ok(())
164    }
165
166    async fn load_filtered_policy<'a>(&mut self, m: &mut dyn Model, f: Filter<'a>) -> Result<()> {
167        let rules = adapter::load_filtered_policy(&self.pool, &f).await?;
168        self.is_filtered.store(true, Ordering::SeqCst);
169
170        for casbin_rule in &rules {
171            if let Some(policy) = self.normalize_policy(casbin_rule) {
172                if let Some(ref sec) = casbin_rule.ptype.chars().next().map(|x| x.to_string()) {
173                    if let Some(t1) = m.get_mut_model().get_mut(sec) {
174                        if let Some(t2) = t1.get_mut(&casbin_rule.ptype) {
175                            t2.get_mut_policy().insert(policy);
176                        }
177                    }
178                }
179            }
180        }
181
182        Ok(())
183    }
184
185    async fn save_policy(&mut self, m: &mut dyn Model) -> Result<()> {
186        let mut rules = vec![];
187
188        if let Some(ast_map) = m.get_model().get("p") {
189            for (ptype, ast) in ast_map {
190                let new_rules = ast
191                    .get_policy()
192                    .into_iter()
193                    .filter_map(|x| self.save_policy_line(ptype, x));
194
195                rules.extend(new_rules);
196            }
197        }
198
199        if let Some(ast_map) = m.get_model().get("g") {
200            for (ptype, ast) in ast_map {
201                let new_rules = ast
202                    .get_policy()
203                    .into_iter()
204                    .filter_map(|x| self.save_policy_line(ptype, x));
205
206                rules.extend(new_rules);
207            }
208        }
209        adapter::save_policy(&self.pool, rules).await
210    }
211
212    async fn add_policy(&mut self, _sec: &str, ptype: &str, rule: Vec<String>) -> Result<bool> {
213        if let Some(new_rule) = self.save_policy_line(ptype, rule.as_slice()) {
214            return adapter::add_policy(&self.pool, new_rule).await;
215        }
216
217        Ok(false)
218    }
219
220    async fn add_policies(
221        &mut self,
222        _sec: &str,
223        ptype: &str,
224        rules: Vec<Vec<String>>,
225    ) -> Result<bool> {
226        let new_rules = rules
227            .iter()
228            .filter_map(|x| self.save_policy_line(ptype, x))
229            .collect::<Vec<NewCasbinRule>>();
230
231        adapter::add_policies(&self.pool, new_rules).await
232    }
233
234    async fn remove_policy(&mut self, _sec: &str, pt: &str, rule: Vec<String>) -> Result<bool> {
235        adapter::remove_policy(&self.pool, pt, rule).await
236    }
237
238    async fn remove_policies(
239        &mut self,
240        _sec: &str,
241        pt: &str,
242        rules: Vec<Vec<String>>,
243    ) -> Result<bool> {
244        adapter::remove_policies(&self.pool, pt, rules).await
245    }
246
247    async fn remove_filtered_policy(
248        &mut self,
249        _sec: &str,
250        pt: &str,
251        field_index: usize,
252        field_values: Vec<String>,
253    ) -> Result<bool> {
254        if field_index <= 5 && !field_values.is_empty() && field_values.len() + field_index <= 6 {
255            adapter::remove_filtered_policy(&self.pool, pt, field_index, field_values).await
256        } else {
257            Ok(false)
258        }
259    }
260
261    async fn clear_policy(&mut self) -> Result<()> {
262        adapter::clear_policy(&self.pool).await
263    }
264
265    fn is_filtered(&self) -> bool {
266        self.is_filtered.load(Ordering::SeqCst)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    fn to_owned(v: Vec<&str>) -> Vec<String> {
275        v.into_iter().map(|x| x.to_owned()).collect()
276    }
277
278    #[cfg_attr(
279        any(
280            feature = "runtime-async-std-native-tls",
281            feature = "runtime-async-std-rustls"
282        ),
283        async_std::test
284    )]
285    #[cfg_attr(
286        any(feature = "runtime-tokio-native-tls", feature = "runtime-tokio-rustls"),
287        tokio::test(flavor = "multi_thread")
288    )]
289    async fn test_create() {
290        use casbin::prelude::*;
291
292        let m = DefaultModel::from_file("examples/rbac_model.conf")
293            .await
294            .unwrap();
295
296        let adapter = {
297            #[cfg(feature = "postgres")]
298            {
299                SqlxAdapter::new("postgres://casbin_rs:casbin_rs@localhost:5432/casbin", 8)
300                    .await
301                    .unwrap()
302            }
303
304            #[cfg(feature = "mysql")]
305            {
306                SqlxAdapter::new("mysql://casbin_rs:casbin_rs@localhost:3306/casbin", 8)
307                    .await
308                    .unwrap()
309            }
310
311            #[cfg(feature = "sqlite")]
312            {
313                SqlxAdapter::new("sqlite:casbin.db", 8).await.unwrap()
314            }
315        };
316
317        assert!(Enforcer::new(m, adapter).await.is_ok());
318    }
319
320    #[cfg_attr(
321        any(
322            feature = "runtime-async-std-native-tls",
323            feature = "runtime-async-std-rustls"
324        ),
325        async_std::test
326    )]
327    #[cfg_attr(
328        any(feature = "runtime-tokio-native-tls", feature = "runtime-tokio-rustls"),
329        tokio::test(flavor = "multi_thread")
330    )]
331    async fn test_create_with_pool() {
332        use casbin::prelude::*;
333
334        let m = DefaultModel::from_file("examples/rbac_model.conf")
335            .await
336            .unwrap();
337        let pool = {
338            #[cfg(feature = "postgres")]
339            {
340                PgPoolOptions::new()
341                    .max_connections(8)
342                    .connect("postgres://casbin_rs:casbin_rs@localhost:5432/casbin")
343                    .await
344                    .unwrap()
345            }
346
347            #[cfg(feature = "mysql")]
348            {
349                MySqlPoolOptions::new()
350                    .max_connections(8)
351                    .connect("mysql://casbin_rs:casbin_rs@localhost:3306/casbin")
352                    .await
353                    .unwrap()
354            }
355
356            #[cfg(feature = "sqlite")]
357            {
358                SqlitePoolOptions::new()
359                    .max_connections(8)
360                    .connect("sqlite:casbin.db")
361                    .await
362                    .unwrap()
363            }
364        };
365
366        let adapter = SqlxAdapter::new_with_pool(pool).await.unwrap();
367
368        assert!(Enforcer::new(m, adapter).await.is_ok());
369    }
370
371    #[cfg_attr(
372        any(
373            feature = "runtime-async-std-native-tls",
374            feature = "runtime-async-std-rustls"
375        ),
376        async_std::test
377    )]
378    #[cfg_attr(
379        any(feature = "runtime-tokio-native-tls", feature = "runtime-tokio-rustls"),
380        tokio::test(flavor = "multi_thread")
381    )]
382    async fn test_adapter() {
383        use casbin::prelude::*;
384
385        let file_adapter = FileAdapter::new("examples/rbac_policy.csv");
386
387        let m = DefaultModel::from_file("examples/rbac_model.conf")
388            .await
389            .unwrap();
390
391        let mut e = Enforcer::new(m, file_adapter).await.unwrap();
392        let mut adapter = {
393            #[cfg(feature = "postgres")]
394            {
395                SqlxAdapter::new("postgres://casbin_rs:casbin_rs@localhost:5432/casbin", 8)
396                    .await
397                    .unwrap()
398            }
399
400            #[cfg(feature = "mysql")]
401            {
402                SqlxAdapter::new("mysql://casbin_rs:casbin_rs@localhost:3306/casbin", 8)
403                    .await
404                    .unwrap()
405            }
406
407            #[cfg(feature = "sqlite")]
408            {
409                SqlxAdapter::new("sqlite:casbin.db", 8).await.unwrap()
410            }
411        };
412
413        assert!(adapter.save_policy(e.get_mut_model()).await.is_ok());
414
415        assert!(adapter
416            .remove_policy("", "p", to_owned(vec!["alice", "data1", "read"]))
417            .await
418            .unwrap());
419        assert!(adapter
420            .remove_policy("", "p", to_owned(vec!["bob", "data2", "write"]))
421            .await
422            .is_ok());
423        assert!(adapter
424            .remove_policy("", "p", to_owned(vec!["data2_admin", "data2", "read"]))
425            .await
426            .is_ok());
427        assert!(adapter
428            .remove_policy("", "p", to_owned(vec!["data2_admin", "data2", "write"]))
429            .await
430            .is_ok());
431        assert!(adapter
432            .remove_policy("", "g", to_owned(vec!["alice", "data2_admin"]))
433            .await
434            .is_ok());
435
436        assert!(adapter
437            .add_policy("", "p", to_owned(vec!["alice", "data1", "read"]))
438            .await
439            .is_ok());
440        assert!(adapter
441            .add_policy("", "p", to_owned(vec!["bob", "data2", "write"]))
442            .await
443            .is_ok());
444        assert!(adapter
445            .add_policy("", "p", to_owned(vec!["data2_admin", "data2", "read"]))
446            .await
447            .is_ok());
448        assert!(adapter
449            .add_policy("", "p", to_owned(vec!["data2_admin", "data2", "write"]))
450            .await
451            .is_ok());
452
453        assert!(adapter
454            .remove_policies(
455                "",
456                "p",
457                vec![
458                    to_owned(vec!["alice", "data1", "read"]),
459                    to_owned(vec!["bob", "data2", "write"]),
460                    to_owned(vec!["data2_admin", "data2", "read"]),
461                    to_owned(vec!["data2_admin", "data2", "write"]),
462                ]
463            )
464            .await
465            .is_ok());
466
467        assert!(adapter
468            .add_policies(
469                "",
470                "p",
471                vec![
472                    to_owned(vec!["alice", "data1", "read"]),
473                    to_owned(vec!["bob", "data2", "write"]),
474                    to_owned(vec!["data2_admin", "data2", "read"]),
475                    to_owned(vec!["data2_admin", "data2", "write"]),
476                ]
477            )
478            .await
479            .is_ok());
480
481        assert!(adapter
482            .add_policy("", "g", to_owned(vec!["alice", "data2_admin"]))
483            .await
484            .is_ok());
485
486        assert!(adapter
487            .remove_policy("", "p", to_owned(vec!["alice", "data1", "read"]))
488            .await
489            .is_ok());
490        assert!(adapter
491            .remove_policy("", "p", to_owned(vec!["bob", "data2", "write"]))
492            .await
493            .is_ok());
494        assert!(adapter
495            .remove_policy("", "p", to_owned(vec!["data2_admin", "data2", "read"]))
496            .await
497            .is_ok());
498        assert!(adapter
499            .remove_policy("", "p", to_owned(vec!["data2_admin", "data2", "write"]))
500            .await
501            .is_ok());
502        assert!(adapter
503            .remove_policy("", "g", to_owned(vec!["alice", "data2_admin"]))
504            .await
505            .is_ok());
506
507        assert!(!adapter
508            .remove_policy(
509                "",
510                "g",
511                to_owned(vec!["alice", "data2_admin", "not_exists"])
512            )
513            .await
514            .unwrap());
515
516        assert!(adapter
517            .add_policy("", "g", to_owned(vec!["alice", "data2_admin"]))
518            .await
519            .is_ok());
520        assert!(adapter
521            .add_policy("", "g", to_owned(vec!["alice", "data2_admin"]))
522            .await
523            .is_err());
524
525        assert!(!adapter
526            .remove_filtered_policy(
527                "",
528                "g",
529                0,
530                to_owned(vec!["alice", "data2_admin", "not_exists"]),
531            )
532            .await
533            .unwrap());
534
535        assert!(adapter
536            .remove_filtered_policy("", "g", 0, to_owned(vec!["alice", "data2_admin"]))
537            .await
538            .unwrap());
539
540        assert!(adapter
541            .add_policy(
542                "",
543                "g",
544                to_owned(vec!["alice", "data2_admin", "domain1", "domain2"]),
545            )
546            .await
547            .is_ok());
548        assert!(adapter
549            .remove_filtered_policy(
550                "",
551                "g",
552                1,
553                to_owned(vec!["data2_admin", "domain1", "domain2"]),
554            )
555            .await
556            .unwrap());
557
558        // GitHub issue: https://github.com/casbin-rs/sqlx-adapter/issues/64
559        assert!(adapter
560            .add_policy("", "g", to_owned(vec!["carol", "data1_admin"]),)
561            .await
562            .is_ok());
563        assert!(adapter
564            .remove_filtered_policy("", "g", 0, to_owned(vec!["carol"]),)
565            .await
566            .unwrap());
567        assert_eq!(vec![String::new(); 0], e.get_roles_for_user("carol", None));
568
569        // GitHub issue: https://github.com/casbin-rs/sqlx-adapter/pull/90
570        // add policies:
571        // p, alice_rfp, book_rfp, read_rfp
572        // p, bob_rfp, book_rfp, read_rfp
573        // p, bob_rfp, book_rfp, write_rfp
574        // p, alice_rfp, pen_rfp, get_rfp
575        // p, bob_rfp, pen_rfp, get_rfp
576        // p, alice_rfp, pencil_rfp, get_rfp
577        assert!(adapter
578            .add_policy("", "p", to_owned(vec!["alice_rfp", "book_rfp", "read_rfp"]),)
579            .await
580            .is_ok());
581        assert!(adapter
582            .add_policy("", "p", to_owned(vec!["bob_rfp", "book_rfp", "read_rfp"]),)
583            .await
584            .is_ok());
585        assert!(adapter
586            .add_policy("", "p", to_owned(vec!["bob_rfp", "book_rfp", "write_rfp"]),)
587            .await
588            .is_ok());
589        assert!(adapter
590            .add_policy("", "p", to_owned(vec!["alice_rfp", "pen_rfp", "get_rfp"]),)
591            .await
592            .is_ok());
593        assert!(adapter
594            .add_policy("", "p", to_owned(vec!["bob_rfp", "pen_rfp", "get_rfp"]),)
595            .await
596            .is_ok());
597        assert!(adapter
598            .add_policy(
599                "",
600                "p",
601                to_owned(vec!["alice_rfp", "pencil_rfp", "get_rfp"]),
602            )
603            .await
604            .is_ok());
605
606        // should remove (return true) all policies where "book_rfp" is in the second position
607        assert!(adapter
608            .remove_filtered_policy("", "p", 1, to_owned(vec!["book_rfp"]),)
609            .await
610            .unwrap());
611
612        // should remove (return true) all policies which match "alice_rfp" on first position
613        // and "get_rfp" on third position
614        assert!(adapter
615            .remove_filtered_policy("", "p", 0, to_owned(vec!["alice_rfp", "", "get_rfp"]),)
616            .await
617            .unwrap());
618
619        // shadow the previous enforcer
620        let mut e = Enforcer::new(
621            "examples/rbac_with_domains_model.conf",
622            "examples/rbac_with_domains_policy.csv",
623        )
624        .await
625        .unwrap();
626
627        assert!(adapter.save_policy(e.get_mut_model()).await.is_ok());
628        e.set_adapter(adapter).await.unwrap();
629
630        let filter = Filter {
631            p: vec!["", "domain1"],
632            g: vec!["", "", "domain1"],
633        };
634
635        e.load_filtered_policy(filter).await.unwrap();
636        assert!(e.enforce(("alice", "domain1", "data1", "read")).unwrap());
637        assert!(e.enforce(("alice", "domain1", "data1", "write")).unwrap());
638        assert!(!e.enforce(("alice", "domain1", "data2", "read")).unwrap());
639        assert!(!e.enforce(("alice", "domain1", "data2", "write")).unwrap());
640        assert!(!e.enforce(("bob", "domain2", "data2", "read")).unwrap());
641        assert!(!e.enforce(("bob", "domain2", "data2", "write")).unwrap());
642    }
643}