1use async_trait::async_trait;
2use casbin::{error::AdapterError, Adapter, Error as CasbinError, Filter, Model, Result};
3use dotenvy::dotenv;
4use std::sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7};
8
9use crate::{error::*, models::*};
10
11use crate::actions as adapter;
12
13#[cfg(feature = "mysql")]
14use sqlx::mysql::MySqlPoolOptions;
15#[cfg(feature = "postgres")]
16use sqlx::postgres::PgPoolOptions;
17#[cfg(feature = "sqlite")]
18use sqlx::sqlite::SqlitePoolOptions;
19
20#[derive(Clone)]
21pub struct SqlxAdapter {
22 pool: adapter::ConnectionPool,
23 is_filtered: Arc<AtomicBool>,
24}
25
26impl<'a> SqlxAdapter {
29 pub async fn new<U: Into<String>>(url: U, pool_size: u32) -> Result<Self> {
30 dotenv().ok();
31
32 #[cfg(feature = "postgres")]
33 let pool = PgPoolOptions::new()
34 .max_connections(pool_size)
35 .connect(&url.into())
36 .await
37 .map_err(|err| CasbinError::from(AdapterError(Box::new(Error::SqlxError(err)))))?;
38
39 #[cfg(feature = "mysql")]
40 let pool = MySqlPoolOptions::new()
41 .max_connections(pool_size)
42 .connect(&url.into())
43 .await
44 .map_err(|err| CasbinError::from(AdapterError(Box::new(Error::SqlxError(err)))))?;
45
46 #[cfg(feature = "sqlite")]
47 let pool = SqlitePoolOptions::new()
48 .max_connections(pool_size)
49 .connect(&url.into())
50 .await
51 .map_err(|err| CasbinError::from(AdapterError(Box::new(Error::SqlxError(err)))))?;
52
53 adapter::new(&pool).await.map(|_| Self {
54 pool,
55 is_filtered: Arc::new(AtomicBool::new(false)),
56 })
57 }
58
59 pub async fn new_with_pool(pool: adapter::ConnectionPool) -> Result<Self> {
60 adapter::new(&pool).await.map(|_| Self {
61 pool,
62 is_filtered: Arc::new(AtomicBool::new(false)),
63 })
64 }
65
66 pub(crate) fn save_policy_line(
67 &self,
68 ptype: &'a str,
69 rule: &'a [String],
70 ) -> Option<NewCasbinRule<'a>> {
71 if ptype.trim().is_empty() || rule.is_empty() {
72 return None;
73 }
74
75 let mut new_rule = NewCasbinRule {
76 ptype,
77 v0: "",
78 v1: "",
79 v2: "",
80 v3: "",
81 v4: "",
82 v5: "",
83 };
84
85 new_rule.v0 = &rule[0];
86
87 if rule.len() > 1 {
88 new_rule.v1 = &rule[1];
89 }
90
91 if rule.len() > 2 {
92 new_rule.v2 = &rule[2];
93 }
94
95 if rule.len() > 3 {
96 new_rule.v3 = &rule[3];
97 }
98
99 if rule.len() > 4 {
100 new_rule.v4 = &rule[4];
101 }
102
103 if rule.len() > 5 {
104 new_rule.v5 = &rule[5];
105 }
106
107 Some(new_rule)
108 }
109
110 pub(crate) fn load_policy_line(&self, casbin_rule: &CasbinRule) -> Option<Vec<String>> {
111 if casbin_rule.ptype.chars().next().is_some() {
112 return self.normalize_policy(casbin_rule);
113 }
114
115 None
116 }
117
118 fn normalize_policy(&self, casbin_rule: &CasbinRule) -> Option<Vec<String>> {
119 let mut result = vec![
120 &casbin_rule.v0,
121 &casbin_rule.v1,
122 &casbin_rule.v2,
123 &casbin_rule.v3,
124 &casbin_rule.v4,
125 &casbin_rule.v5,
126 ];
127
128 while let Some(last) = result.last() {
129 if last.is_empty() {
130 result.pop();
131 } else {
132 break;
133 }
134 }
135
136 if !result.is_empty() {
137 return Some(result.iter().map(|&x| x.to_owned()).collect());
138 }
139
140 None
141 }
142}
143
144#[async_trait]
145impl Adapter for SqlxAdapter {
146 async fn load_policy(&mut self, m: &mut dyn Model) -> Result<()> {
147 let rules = adapter::load_policy(&self.pool).await?;
148
149 for casbin_rule in &rules {
150 let rule = self.load_policy_line(casbin_rule);
151
152 if let Some(ref sec) = casbin_rule.ptype.chars().next().map(|x| x.to_string()) {
153 if let Some(t1) = m.get_mut_model().get_mut(sec) {
154 if let Some(t2) = t1.get_mut(&casbin_rule.ptype) {
155 if let Some(rule) = rule {
156 t2.get_mut_policy().insert(rule);
157 }
158 }
159 }
160 }
161 }
162
163 Ok(())
164 }
165
166 async fn load_filtered_policy<'a>(&mut self, m: &mut dyn Model, f: Filter<'a>) -> Result<()> {
167 let rules = adapter::load_filtered_policy(&self.pool, &f).await?;
168 self.is_filtered.store(true, Ordering::SeqCst);
169
170 for casbin_rule in &rules {
171 if let Some(policy) = self.normalize_policy(casbin_rule) {
172 if let Some(ref sec) = casbin_rule.ptype.chars().next().map(|x| x.to_string()) {
173 if let Some(t1) = m.get_mut_model().get_mut(sec) {
174 if let Some(t2) = t1.get_mut(&casbin_rule.ptype) {
175 t2.get_mut_policy().insert(policy);
176 }
177 }
178 }
179 }
180 }
181
182 Ok(())
183 }
184
185 async fn save_policy(&mut self, m: &mut dyn Model) -> Result<()> {
186 let mut rules = vec![];
187
188 if let Some(ast_map) = m.get_model().get("p") {
189 for (ptype, ast) in ast_map {
190 let new_rules = ast
191 .get_policy()
192 .into_iter()
193 .filter_map(|x| self.save_policy_line(ptype, x));
194
195 rules.extend(new_rules);
196 }
197 }
198
199 if let Some(ast_map) = m.get_model().get("g") {
200 for (ptype, ast) in ast_map {
201 let new_rules = ast
202 .get_policy()
203 .into_iter()
204 .filter_map(|x| self.save_policy_line(ptype, x));
205
206 rules.extend(new_rules);
207 }
208 }
209 adapter::save_policy(&self.pool, rules).await
210 }
211
212 async fn add_policy(&mut self, _sec: &str, ptype: &str, rule: Vec<String>) -> Result<bool> {
213 if let Some(new_rule) = self.save_policy_line(ptype, rule.as_slice()) {
214 return adapter::add_policy(&self.pool, new_rule).await;
215 }
216
217 Ok(false)
218 }
219
220 async fn add_policies(
221 &mut self,
222 _sec: &str,
223 ptype: &str,
224 rules: Vec<Vec<String>>,
225 ) -> Result<bool> {
226 let new_rules = rules
227 .iter()
228 .filter_map(|x| self.save_policy_line(ptype, x))
229 .collect::<Vec<NewCasbinRule>>();
230
231 adapter::add_policies(&self.pool, new_rules).await
232 }
233
234 async fn remove_policy(&mut self, _sec: &str, pt: &str, rule: Vec<String>) -> Result<bool> {
235 adapter::remove_policy(&self.pool, pt, rule).await
236 }
237
238 async fn remove_policies(
239 &mut self,
240 _sec: &str,
241 pt: &str,
242 rules: Vec<Vec<String>>,
243 ) -> Result<bool> {
244 adapter::remove_policies(&self.pool, pt, rules).await
245 }
246
247 async fn remove_filtered_policy(
248 &mut self,
249 _sec: &str,
250 pt: &str,
251 field_index: usize,
252 field_values: Vec<String>,
253 ) -> Result<bool> {
254 if field_index <= 5 && !field_values.is_empty() && field_values.len() + field_index <= 6 {
255 adapter::remove_filtered_policy(&self.pool, pt, field_index, field_values).await
256 } else {
257 Ok(false)
258 }
259 }
260
261 async fn clear_policy(&mut self) -> Result<()> {
262 adapter::clear_policy(&self.pool).await
263 }
264
265 fn is_filtered(&self) -> bool {
266 self.is_filtered.load(Ordering::SeqCst)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 fn to_owned(v: Vec<&str>) -> Vec<String> {
275 v.into_iter().map(|x| x.to_owned()).collect()
276 }
277
278 #[cfg_attr(
279 any(
280 feature = "runtime-async-std-native-tls",
281 feature = "runtime-async-std-rustls"
282 ),
283 async_std::test
284 )]
285 #[cfg_attr(
286 any(feature = "runtime-tokio-native-tls", feature = "runtime-tokio-rustls"),
287 tokio::test(flavor = "multi_thread")
288 )]
289 async fn test_create() {
290 use casbin::prelude::*;
291
292 let m = DefaultModel::from_file("examples/rbac_model.conf")
293 .await
294 .unwrap();
295
296 let adapter = {
297 #[cfg(feature = "postgres")]
298 {
299 SqlxAdapter::new("postgres://casbin_rs:casbin_rs@localhost:5432/casbin", 8)
300 .await
301 .unwrap()
302 }
303
304 #[cfg(feature = "mysql")]
305 {
306 SqlxAdapter::new("mysql://casbin_rs:casbin_rs@localhost:3306/casbin", 8)
307 .await
308 .unwrap()
309 }
310
311 #[cfg(feature = "sqlite")]
312 {
313 SqlxAdapter::new("sqlite:casbin.db", 8).await.unwrap()
314 }
315 };
316
317 assert!(Enforcer::new(m, adapter).await.is_ok());
318 }
319
320 #[cfg_attr(
321 any(
322 feature = "runtime-async-std-native-tls",
323 feature = "runtime-async-std-rustls"
324 ),
325 async_std::test
326 )]
327 #[cfg_attr(
328 any(feature = "runtime-tokio-native-tls", feature = "runtime-tokio-rustls"),
329 tokio::test(flavor = "multi_thread")
330 )]
331 async fn test_create_with_pool() {
332 use casbin::prelude::*;
333
334 let m = DefaultModel::from_file("examples/rbac_model.conf")
335 .await
336 .unwrap();
337 let pool = {
338 #[cfg(feature = "postgres")]
339 {
340 PgPoolOptions::new()
341 .max_connections(8)
342 .connect("postgres://casbin_rs:casbin_rs@localhost:5432/casbin")
343 .await
344 .unwrap()
345 }
346
347 #[cfg(feature = "mysql")]
348 {
349 MySqlPoolOptions::new()
350 .max_connections(8)
351 .connect("mysql://casbin_rs:casbin_rs@localhost:3306/casbin")
352 .await
353 .unwrap()
354 }
355
356 #[cfg(feature = "sqlite")]
357 {
358 SqlitePoolOptions::new()
359 .max_connections(8)
360 .connect("sqlite:casbin.db")
361 .await
362 .unwrap()
363 }
364 };
365
366 let adapter = SqlxAdapter::new_with_pool(pool).await.unwrap();
367
368 assert!(Enforcer::new(m, adapter).await.is_ok());
369 }
370
371 #[cfg_attr(
372 any(
373 feature = "runtime-async-std-native-tls",
374 feature = "runtime-async-std-rustls"
375 ),
376 async_std::test
377 )]
378 #[cfg_attr(
379 any(feature = "runtime-tokio-native-tls", feature = "runtime-tokio-rustls"),
380 tokio::test(flavor = "multi_thread")
381 )]
382 async fn test_adapter() {
383 use casbin::prelude::*;
384
385 let file_adapter = FileAdapter::new("examples/rbac_policy.csv");
386
387 let m = DefaultModel::from_file("examples/rbac_model.conf")
388 .await
389 .unwrap();
390
391 let mut e = Enforcer::new(m, file_adapter).await.unwrap();
392 let mut adapter = {
393 #[cfg(feature = "postgres")]
394 {
395 SqlxAdapter::new("postgres://casbin_rs:casbin_rs@localhost:5432/casbin", 8)
396 .await
397 .unwrap()
398 }
399
400 #[cfg(feature = "mysql")]
401 {
402 SqlxAdapter::new("mysql://casbin_rs:casbin_rs@localhost:3306/casbin", 8)
403 .await
404 .unwrap()
405 }
406
407 #[cfg(feature = "sqlite")]
408 {
409 SqlxAdapter::new("sqlite:casbin.db", 8).await.unwrap()
410 }
411 };
412
413 assert!(adapter.save_policy(e.get_mut_model()).await.is_ok());
414
415 assert!(adapter
416 .remove_policy("", "p", to_owned(vec!["alice", "data1", "read"]))
417 .await
418 .unwrap());
419 assert!(adapter
420 .remove_policy("", "p", to_owned(vec!["bob", "data2", "write"]))
421 .await
422 .is_ok());
423 assert!(adapter
424 .remove_policy("", "p", to_owned(vec!["data2_admin", "data2", "read"]))
425 .await
426 .is_ok());
427 assert!(adapter
428 .remove_policy("", "p", to_owned(vec!["data2_admin", "data2", "write"]))
429 .await
430 .is_ok());
431 assert!(adapter
432 .remove_policy("", "g", to_owned(vec!["alice", "data2_admin"]))
433 .await
434 .is_ok());
435
436 assert!(adapter
437 .add_policy("", "p", to_owned(vec!["alice", "data1", "read"]))
438 .await
439 .is_ok());
440 assert!(adapter
441 .add_policy("", "p", to_owned(vec!["bob", "data2", "write"]))
442 .await
443 .is_ok());
444 assert!(adapter
445 .add_policy("", "p", to_owned(vec!["data2_admin", "data2", "read"]))
446 .await
447 .is_ok());
448 assert!(adapter
449 .add_policy("", "p", to_owned(vec!["data2_admin", "data2", "write"]))
450 .await
451 .is_ok());
452
453 assert!(adapter
454 .remove_policies(
455 "",
456 "p",
457 vec![
458 to_owned(vec!["alice", "data1", "read"]),
459 to_owned(vec!["bob", "data2", "write"]),
460 to_owned(vec!["data2_admin", "data2", "read"]),
461 to_owned(vec!["data2_admin", "data2", "write"]),
462 ]
463 )
464 .await
465 .is_ok());
466
467 assert!(adapter
468 .add_policies(
469 "",
470 "p",
471 vec![
472 to_owned(vec!["alice", "data1", "read"]),
473 to_owned(vec!["bob", "data2", "write"]),
474 to_owned(vec!["data2_admin", "data2", "read"]),
475 to_owned(vec!["data2_admin", "data2", "write"]),
476 ]
477 )
478 .await
479 .is_ok());
480
481 assert!(adapter
482 .add_policy("", "g", to_owned(vec!["alice", "data2_admin"]))
483 .await
484 .is_ok());
485
486 assert!(adapter
487 .remove_policy("", "p", to_owned(vec!["alice", "data1", "read"]))
488 .await
489 .is_ok());
490 assert!(adapter
491 .remove_policy("", "p", to_owned(vec!["bob", "data2", "write"]))
492 .await
493 .is_ok());
494 assert!(adapter
495 .remove_policy("", "p", to_owned(vec!["data2_admin", "data2", "read"]))
496 .await
497 .is_ok());
498 assert!(adapter
499 .remove_policy("", "p", to_owned(vec!["data2_admin", "data2", "write"]))
500 .await
501 .is_ok());
502 assert!(adapter
503 .remove_policy("", "g", to_owned(vec!["alice", "data2_admin"]))
504 .await
505 .is_ok());
506
507 assert!(!adapter
508 .remove_policy(
509 "",
510 "g",
511 to_owned(vec!["alice", "data2_admin", "not_exists"])
512 )
513 .await
514 .unwrap());
515
516 assert!(adapter
517 .add_policy("", "g", to_owned(vec!["alice", "data2_admin"]))
518 .await
519 .is_ok());
520 assert!(adapter
521 .add_policy("", "g", to_owned(vec!["alice", "data2_admin"]))
522 .await
523 .is_err());
524
525 assert!(!adapter
526 .remove_filtered_policy(
527 "",
528 "g",
529 0,
530 to_owned(vec!["alice", "data2_admin", "not_exists"]),
531 )
532 .await
533 .unwrap());
534
535 assert!(adapter
536 .remove_filtered_policy("", "g", 0, to_owned(vec!["alice", "data2_admin"]))
537 .await
538 .unwrap());
539
540 assert!(adapter
541 .add_policy(
542 "",
543 "g",
544 to_owned(vec!["alice", "data2_admin", "domain1", "domain2"]),
545 )
546 .await
547 .is_ok());
548 assert!(adapter
549 .remove_filtered_policy(
550 "",
551 "g",
552 1,
553 to_owned(vec!["data2_admin", "domain1", "domain2"]),
554 )
555 .await
556 .unwrap());
557
558 assert!(adapter
560 .add_policy("", "g", to_owned(vec!["carol", "data1_admin"]),)
561 .await
562 .is_ok());
563 assert!(adapter
564 .remove_filtered_policy("", "g", 0, to_owned(vec!["carol"]),)
565 .await
566 .unwrap());
567 assert_eq!(vec![String::new(); 0], e.get_roles_for_user("carol", None));
568
569 assert!(adapter
578 .add_policy("", "p", to_owned(vec!["alice_rfp", "book_rfp", "read_rfp"]),)
579 .await
580 .is_ok());
581 assert!(adapter
582 .add_policy("", "p", to_owned(vec!["bob_rfp", "book_rfp", "read_rfp"]),)
583 .await
584 .is_ok());
585 assert!(adapter
586 .add_policy("", "p", to_owned(vec!["bob_rfp", "book_rfp", "write_rfp"]),)
587 .await
588 .is_ok());
589 assert!(adapter
590 .add_policy("", "p", to_owned(vec!["alice_rfp", "pen_rfp", "get_rfp"]),)
591 .await
592 .is_ok());
593 assert!(adapter
594 .add_policy("", "p", to_owned(vec!["bob_rfp", "pen_rfp", "get_rfp"]),)
595 .await
596 .is_ok());
597 assert!(adapter
598 .add_policy(
599 "",
600 "p",
601 to_owned(vec!["alice_rfp", "pencil_rfp", "get_rfp"]),
602 )
603 .await
604 .is_ok());
605
606 assert!(adapter
608 .remove_filtered_policy("", "p", 1, to_owned(vec!["book_rfp"]),)
609 .await
610 .unwrap());
611
612 assert!(adapter
615 .remove_filtered_policy("", "p", 0, to_owned(vec!["alice_rfp", "", "get_rfp"]),)
616 .await
617 .unwrap());
618
619 let mut e = Enforcer::new(
621 "examples/rbac_with_domains_model.conf",
622 "examples/rbac_with_domains_policy.csv",
623 )
624 .await
625 .unwrap();
626
627 assert!(adapter.save_policy(e.get_mut_model()).await.is_ok());
628 e.set_adapter(adapter).await.unwrap();
629
630 let filter = Filter {
631 p: vec!["", "domain1"],
632 g: vec!["", "", "domain1"],
633 };
634
635 e.load_filtered_policy(filter).await.unwrap();
636 assert!(e.enforce(("alice", "domain1", "data1", "read")).unwrap());
637 assert!(e.enforce(("alice", "domain1", "data1", "write")).unwrap());
638 assert!(!e.enforce(("alice", "domain1", "data2", "read")).unwrap());
639 assert!(!e.enforce(("alice", "domain1", "data2", "write")).unwrap());
640 assert!(!e.enforce(("bob", "domain2", "data2", "read")).unwrap());
641 assert!(!e.enforce(("bob", "domain2", "data2", "write")).unwrap());
642 }
643}