Skip to main content

surreal_casbin_adapter/
lib.rs

1use async_trait::async_trait;
2use casbin::{Adapter, Filter, Model, Result as CasbinResult};
3use surrealdb::{Surreal, engine::any::Any};
4use surrealdb_types::{RecordId, SurrealValue};
5
6pub const TABLE: &str = "casbin_rule";
7
8// ─── CasbinRule ──────────────────────────────────────────────────────────────
9
10#[derive(Debug, Clone, SurrealValue)]
11struct CasbinRule {
12    id: Option<RecordId>,
13    sec: String,
14    ptype: String,
15    v0: Option<String>,
16    v1: Option<String>,
17    v2: Option<String>,
18    v3: Option<String>,
19    v4: Option<String>,
20    v5: Option<String>,
21}
22
23impl CasbinRule {
24    fn new(sec: &str, ptype: &str, rule: &[String]) -> Self {
25        let get = |i: usize| rule.get(i).cloned();
26        Self {
27            id: None,
28            sec: sec.to_owned(),
29            ptype: ptype.to_owned(),
30            v0: get(0),
31            v1: get(1),
32            v2: get(2),
33            v3: get(3),
34            v4: get(4),
35            v5: get(5),
36        }
37    }
38
39    fn to_rule(&self) -> Vec<String> {
40        [&self.v0, &self.v1, &self.v2, &self.v3, &self.v4, &self.v5]
41            .iter()
42            .filter_map(|v| v.as_deref().map(str::to_owned))
43            .collect()
44    }
45
46    fn bind_values<'a>(
47        &self,
48        q: surrealdb::method::Query<'a, Any>,
49    ) -> surrealdb::method::Query<'a, Any> {
50        q.bind(("v0", self.v0.clone()))
51            .bind(("v1", self.v1.clone()))
52            .bind(("v2", self.v2.clone()))
53            .bind(("v3", self.v3.clone()))
54            .bind(("v4", self.v4.clone()))
55            .bind(("v5", self.v5.clone()))
56    }
57}
58
59// ─── load helper ─────────────────────────────────────────────────────────────
60
61fn load_policy_line(m: &mut dyn Model, rule: &CasbinRule) {
62    let values = rule.to_rule();
63    if values.is_empty() {
64        return;
65    }
66    if let Some(sec_map) = m.get_mut_model().get_mut(&rule.sec)
67        && let Some(assertion) = sec_map.get_mut(&rule.ptype)
68    {
69        assertion.get_mut_policy().insert(values);
70    }
71}
72
73// ─── Adapter ─────────────────────────────────────────────────────────────────
74
75pub struct SurrealAdapter {
76    db: Surreal<Any>,
77    table: String,
78    is_filtered: bool,
79}
80
81impl SurrealAdapter {
82    pub fn new(db: Surreal<Any>) -> Self {
83        Self {
84            db,
85            table: TABLE.to_owned(),
86            is_filtered: false,
87        }
88    }
89
90    pub fn with_table(db: Surreal<Any>, table: impl Into<String>) -> Self {
91        Self {
92            db,
93            table: table.into(),
94            is_filtered: false,
95        }
96    }
97
98    pub async fn create_table(&self) -> Result<(), surrealdb::Error> {
99        self.db
100            .query("DEFINE TABLE IF NOT EXISTS $table SCHEMALESS;")
101            .bind(("table", self.table.clone()))
102            .await?
103            .check()?;
104        Ok(())
105    }
106}
107
108#[async_trait]
109impl Adapter for SurrealAdapter {
110    async fn load_policy(&mut self, m: &mut dyn Model) -> CasbinResult<()> {
111        for rule in self.get_all_rules().await? {
112            load_policy_line(m, &rule);
113        }
114        self.is_filtered = false;
115        Ok(())
116    }
117
118    async fn load_filtered_policy<'a>(
119        &mut self,
120        m: &mut dyn Model,
121        f: Filter<'a>,
122    ) -> CasbinResult<()> {
123        for (sec, filter) in [("p", &f.p), ("g", &f.g)] {
124            let has_filter = filter.iter().any(|fv| !fv.is_empty());
125            let rules = if has_filter {
126                self.get_filtered_rules(sec, filter).await?
127            } else {
128                self.get_rules_by_sec(sec).await?
129            };
130            for rule in &rules {
131                load_policy_line(m, rule);
132            }
133        }
134        self.is_filtered = true;
135        Ok(())
136    }
137
138    async fn save_policy(&mut self, m: &mut dyn Model) -> CasbinResult<()> {
139        self.clear_policy().await?;
140
141        let mut all_rules: Vec<CasbinRule> = Vec::new();
142        for sec in ["p", "g"] {
143            if let Some(sec_map) = m.get_model().get(sec) {
144                for (ptype, assertion) in sec_map {
145                    for policy in assertion.get_policy() {
146                        all_rules.push(CasbinRule::new(sec, ptype, policy));
147                    }
148                }
149            }
150        }
151
152        if !all_rules.is_empty() {
153            self.insert_entries(all_rules).await?;
154        }
155        Ok(())
156    }
157
158    async fn clear_policy(&mut self) -> CasbinResult<()> {
159        self.db
160            .query("DELETE type::table($table);")
161            .bind(("table", self.table.clone()))
162            .await
163            .map_err(io_err)?
164            .check()
165            .map_err(io_err)?;
166        Ok(())
167    }
168
169    fn is_filtered(&self) -> bool {
170        self.is_filtered
171    }
172
173    async fn add_policy(
174        &mut self,
175        sec: &str,
176        ptype: &str,
177        rule: Vec<String>,
178    ) -> CasbinResult<bool> {
179        if self.rule_exists(sec, ptype, &rule).await? {
180            return Ok(false);
181        }
182        let entry = CasbinRule::new(sec, ptype, &rule);
183        let _: Option<CasbinRule> = self
184            .db
185            .create(&*self.table)
186            .content(entry)
187            .await
188            .map_err(io_err)?;
189        Ok(true)
190    }
191
192    async fn add_policies(
193        &mut self,
194        sec: &str,
195        ptype: &str,
196        rules: Vec<Vec<String>>,
197    ) -> CasbinResult<bool> {
198        if self.any_rules_exist(sec, ptype, &rules).await? {
199            return Ok(false);
200        }
201        let entries: Vec<CasbinRule> = rules
202            .iter()
203            .map(|r| CasbinRule::new(sec, ptype, r))
204            .collect();
205        self.insert_entries(entries).await?;
206        Ok(true)
207    }
208
209    async fn remove_policy(
210        &mut self,
211        sec: &str,
212        ptype: &str,
213        rule: Vec<String>,
214    ) -> CasbinResult<bool> {
215        self.delete_exact(sec, ptype, &rule).await
216    }
217
218    async fn remove_policies(
219        &mut self,
220        sec: &str,
221        ptype: &str,
222        rules: Vec<Vec<String>>,
223    ) -> CasbinResult<bool> {
224        if rules.is_empty() {
225            return Ok(false);
226        }
227        self.delete_exact_batch(sec, ptype, &rules).await
228    }
229
230    async fn remove_filtered_policy(
231        &mut self,
232        sec: &str,
233        ptype: &str,
234        field_index: usize,
235        field_values: Vec<String>,
236    ) -> CasbinResult<bool> {
237        self.delete_filtered(sec, ptype, field_index, &field_values)
238            .await
239    }
240}
241
242// ─── Private helpers ─────────────────────────────────────────────────────────
243
244impl SurrealAdapter {
245    async fn insert_entries(&self, entries: Vec<CasbinRule>) -> CasbinResult<bool> {
246        let _: Vec<CasbinRule> = self
247            .db
248            .insert(&*self.table)
249            .content(entries)
250            .await
251            .map_err(io_err)?;
252        Ok(true)
253    }
254
255    async fn get_all_rules(&self) -> CasbinResult<Vec<CasbinRule>> {
256        self.db.select(&*self.table).await.map_err(io_err)
257    }
258
259    async fn get_rules_by_sec(&self, sec: &str) -> CasbinResult<Vec<CasbinRule>> {
260        let rules: Vec<CasbinRule> = self
261            .db
262            .query("SELECT * FROM type::table($table) WHERE sec = $sec")
263            .bind(("table", self.table.clone()))
264            .bind(("sec", sec.to_owned()))
265            .await
266            .map_err(io_err)?
267            .check()
268            .map_err(io_err)?
269            .take(0)
270            .map_err(io_err)?;
271        Ok(rules)
272    }
273
274    async fn get_filtered_rules(
275        &self,
276        sec: &str,
277        filter: &[&str],
278    ) -> CasbinResult<Vec<CasbinRule>> {
279        let mut conditions = vec!["sec = $sec".to_owned()];
280        let mut binds: Vec<(String, String)> = Vec::new();
281
282        for (i, fv) in filter.iter().enumerate() {
283            if !fv.is_empty() {
284                let param = format!("fv{i}");
285                conditions.push(format!("v{i} = ${param}"));
286                binds.push((param, (*fv).to_owned()));
287            }
288        }
289
290        let query = format!(
291            "SELECT * FROM type::table($table) WHERE {}",
292            conditions.join(" AND ")
293        );
294        let mut q = self
295            .db
296            .query(&query)
297            .bind(("table", self.table.clone()))
298            .bind(("sec", sec.to_owned()));
299        for (k, v) in binds {
300            q = q.bind((k, v));
301        }
302
303        let rules: Vec<CasbinRule> = q
304            .await
305            .map_err(io_err)?
306            .check()
307            .map_err(io_err)?
308            .take(0)
309            .map_err(io_err)?;
310        Ok(rules)
311    }
312
313    async fn rule_exists(&self, sec: &str, ptype: &str, rule: &[String]) -> CasbinResult<bool> {
314        let entry = CasbinRule::new(sec, ptype, rule);
315        let q = self
316            .db
317            .query(
318                "SELECT * FROM type::table($table)
319                 WHERE sec = $sec AND ptype = $ptype
320                   AND v0 = $v0 AND v1 = $v1 AND v2 = $v2
321                   AND v3 = $v3 AND v4 = $v4 AND v5 = $v5
322                 LIMIT 1",
323            )
324            .bind(("table", self.table.clone()))
325            .bind(("sec", sec.to_owned()))
326            .bind(("ptype", ptype.to_owned()));
327
328        let found: Vec<CasbinRule> = entry
329            .bind_values(q)
330            .await
331            .map_err(io_err)?
332            .check()
333            .map_err(io_err)?
334            .take(0)
335            .map_err(io_err)?;
336
337        Ok(!found.is_empty())
338    }
339
340    async fn any_rules_exist(
341        &self,
342        sec: &str,
343        ptype: &str,
344        rules: &[Vec<String>],
345    ) -> CasbinResult<bool> {
346        if rules.is_empty() {
347            return Ok(false);
348        }
349
350        let mut or_clauses = Vec::new();
351        let mut binds: Vec<(String, Option<String>)> = Vec::new();
352
353        for (ri, rule) in rules.iter().enumerate() {
354            let entry = CasbinRule::new(sec, ptype, rule);
355            let fields = [
356                &entry.v0, &entry.v1, &entry.v2, &entry.v3, &entry.v4, &entry.v5,
357            ];
358            let mut field_conditions = Vec::new();
359            for (fi, val) in fields.iter().enumerate() {
360                let param = format!("r{ri}v{fi}");
361                field_conditions.push(format!("v{fi} = ${param}"));
362                binds.push((param, (*val).clone()));
363            }
364            or_clauses.push(format!("({})", field_conditions.join(" AND ")));
365        }
366
367        let query = format!(
368            "SELECT * FROM type::table($table) WHERE sec = $sec AND ptype = $ptype AND ({}) LIMIT 1",
369            or_clauses.join(" OR ")
370        );
371
372        let mut q = self
373            .db
374            .query(&query)
375            .bind(("table", self.table.clone()))
376            .bind(("sec", sec.to_owned()))
377            .bind(("ptype", ptype.to_owned()));
378
379        for (k, v) in binds {
380            q = q.bind((k, v));
381        }
382
383        let found: Vec<CasbinRule> = q
384            .await
385            .map_err(io_err)?
386            .check()
387            .map_err(io_err)?
388            .take(0)
389            .map_err(io_err)?;
390
391        Ok(!found.is_empty())
392    }
393
394    async fn delete_exact(&self, sec: &str, ptype: &str, rule: &[String]) -> CasbinResult<bool> {
395        let entry = CasbinRule::new(sec, ptype, rule);
396        let q = self
397            .db
398            .query(
399                "DELETE type::table($table)
400                 WHERE sec = $sec AND ptype = $ptype
401                   AND v0 = $v0 AND v1 = $v1 AND v2 = $v2
402                   AND v3 = $v3 AND v4 = $v4 AND v5 = $v5
403                 RETURN BEFORE",
404            )
405            .bind(("table", self.table.clone()))
406            .bind(("sec", sec.to_owned()))
407            .bind(("ptype", ptype.to_owned()));
408
409        let deleted: Vec<CasbinRule> = entry
410            .bind_values(q)
411            .await
412            .map_err(io_err)?
413            .check()
414            .map_err(io_err)?
415            .take(0)
416            .map_err(io_err)?;
417
418        Ok(!deleted.is_empty())
419    }
420
421    async fn delete_exact_batch(
422        &self,
423        sec: &str,
424        ptype: &str,
425        rules: &[Vec<String>],
426    ) -> CasbinResult<bool> {
427        let mut or_clauses = Vec::new();
428        let mut binds: Vec<(String, Option<String>)> = Vec::new();
429
430        for (ri, rule) in rules.iter().enumerate() {
431            let entry = CasbinRule::new(sec, ptype, rule);
432            let fields = [
433                &entry.v0, &entry.v1, &entry.v2, &entry.v3, &entry.v4, &entry.v5,
434            ];
435            let mut field_conditions = Vec::new();
436            for (fi, val) in fields.iter().enumerate() {
437                let param = format!("r{ri}v{fi}");
438                field_conditions.push(format!("v{fi} = ${param}"));
439                binds.push((param, (*val).clone()));
440            }
441            or_clauses.push(format!("({})", field_conditions.join(" AND ")));
442        }
443
444        let query = format!(
445            "DELETE type::table($table) WHERE sec = $sec AND ptype = $ptype AND ({}) RETURN BEFORE",
446            or_clauses.join(" OR ")
447        );
448
449        let mut q = self
450            .db
451            .query(&query)
452            .bind(("table", self.table.clone()))
453            .bind(("sec", sec.to_owned()))
454            .bind(("ptype", ptype.to_owned()));
455
456        for (k, v) in binds {
457            q = q.bind((k, v));
458        }
459
460        let deleted: Vec<CasbinRule> = q
461            .await
462            .map_err(io_err)?
463            .check()
464            .map_err(io_err)?
465            .take(0)
466            .map_err(io_err)?;
467
468        Ok(!deleted.is_empty())
469    }
470
471    async fn delete_filtered(
472        &self,
473        sec: &str,
474        ptype: &str,
475        field_index: usize,
476        field_values: &[String],
477    ) -> CasbinResult<bool> {
478        let mut col_conditions = Vec::new();
479        let mut binds: Vec<(String, String)> = Vec::new();
480
481        for (offset, v) in field_values.iter().enumerate() {
482            if !v.is_empty() {
483                let col = field_index + offset;
484                let param = format!("fv{offset}");
485                col_conditions.push(format!("v{col} = ${param}"));
486                binds.push((param, v.clone()));
487            }
488        }
489
490        let where_clause = if col_conditions.is_empty() {
491            "sec = $sec AND ptype = $ptype".to_owned()
492        } else {
493            format!(
494                "sec = $sec AND ptype = $ptype AND {}",
495                col_conditions.join(" AND ")
496            )
497        };
498
499        let query = format!("DELETE type::table($table) WHERE {where_clause} RETURN BEFORE");
500
501        let mut q = self
502            .db
503            .query(&query)
504            .bind(("table", self.table.clone()))
505            .bind(("sec", sec.to_owned()))
506            .bind(("ptype", ptype.to_owned()));
507
508        for (k, v) in binds {
509            q = q.bind((k, v));
510        }
511
512        let deleted: Vec<CasbinRule> = q
513            .await
514            .map_err(io_err)?
515            .check()
516            .map_err(io_err)?
517            .take(0)
518            .map_err(io_err)?;
519
520        Ok(!deleted.is_empty())
521    }
522}
523
524// ─── Error helper ────────────────────────────────────────────────────────────
525
526fn io_err(e: impl std::fmt::Display) -> casbin::Error {
527    casbin::Error::IoError(std::io::Error::other(e.to_string()))
528}