1use async_trait::async_trait;
2use casbin::{Adapter, Filter, Model, Result as CasbinResult};
3use surrealdb::{Surreal, engine::any::Any};
4use surrealdb_types::{RecordId, SurrealValue};
5
6pub const TABLE: &str = "casbin_rule";
7
8#[derive(Debug, Clone, SurrealValue)]
11struct CasbinRule {
12 id: Option<RecordId>,
13 sec: String,
14 ptype: String,
15 v0: Option<String>,
16 v1: Option<String>,
17 v2: Option<String>,
18 v3: Option<String>,
19 v4: Option<String>,
20 v5: Option<String>,
21}
22
23impl CasbinRule {
24 fn new(sec: &str, ptype: &str, rule: &[String]) -> Self {
25 let get = |i: usize| rule.get(i).cloned();
26 Self {
27 id: None,
28 sec: sec.to_owned(),
29 ptype: ptype.to_owned(),
30 v0: get(0),
31 v1: get(1),
32 v2: get(2),
33 v3: get(3),
34 v4: get(4),
35 v5: get(5),
36 }
37 }
38
39 fn to_rule(&self) -> Vec<String> {
40 [&self.v0, &self.v1, &self.v2, &self.v3, &self.v4, &self.v5]
41 .iter()
42 .filter_map(|v| v.as_deref().map(str::to_owned))
43 .collect()
44 }
45
46 fn bind_values<'a>(
47 &self,
48 q: surrealdb::method::Query<'a, Any>,
49 ) -> surrealdb::method::Query<'a, Any> {
50 q.bind(("v0", self.v0.clone()))
51 .bind(("v1", self.v1.clone()))
52 .bind(("v2", self.v2.clone()))
53 .bind(("v3", self.v3.clone()))
54 .bind(("v4", self.v4.clone()))
55 .bind(("v5", self.v5.clone()))
56 }
57}
58
59fn load_policy_line(m: &mut dyn Model, rule: &CasbinRule) {
62 let values = rule.to_rule();
63 if values.is_empty() {
64 return;
65 }
66 if let Some(sec_map) = m.get_mut_model().get_mut(&rule.sec)
67 && let Some(assertion) = sec_map.get_mut(&rule.ptype)
68 {
69 assertion.get_mut_policy().insert(values);
70 }
71}
72
73pub struct SurrealAdapter {
76 db: Surreal<Any>,
77 table: String,
78 is_filtered: bool,
79}
80
81impl SurrealAdapter {
82 pub fn new(db: Surreal<Any>) -> Self {
83 Self {
84 db,
85 table: TABLE.to_owned(),
86 is_filtered: false,
87 }
88 }
89
90 pub fn with_table(db: Surreal<Any>, table: impl Into<String>) -> Self {
91 Self {
92 db,
93 table: table.into(),
94 is_filtered: false,
95 }
96 }
97
98 pub async fn create_table(&self) -> Result<(), surrealdb::Error> {
99 self.db
100 .query("DEFINE TABLE IF NOT EXISTS $table SCHEMALESS;")
101 .bind(("table", self.table.clone()))
102 .await?
103 .check()?;
104 Ok(())
105 }
106}
107
108#[async_trait]
109impl Adapter for SurrealAdapter {
110 async fn load_policy(&mut self, m: &mut dyn Model) -> CasbinResult<()> {
111 for rule in self.get_all_rules().await? {
112 load_policy_line(m, &rule);
113 }
114 self.is_filtered = false;
115 Ok(())
116 }
117
118 async fn load_filtered_policy<'a>(
119 &mut self,
120 m: &mut dyn Model,
121 f: Filter<'a>,
122 ) -> CasbinResult<()> {
123 for (sec, filter) in [("p", &f.p), ("g", &f.g)] {
124 let has_filter = filter.iter().any(|fv| !fv.is_empty());
125 let rules = if has_filter {
126 self.get_filtered_rules(sec, filter).await?
127 } else {
128 self.get_rules_by_sec(sec).await?
129 };
130 for rule in &rules {
131 load_policy_line(m, rule);
132 }
133 }
134 self.is_filtered = true;
135 Ok(())
136 }
137
138 async fn save_policy(&mut self, m: &mut dyn Model) -> CasbinResult<()> {
139 self.clear_policy().await?;
140
141 let mut all_rules: Vec<CasbinRule> = Vec::new();
142 for sec in ["p", "g"] {
143 if let Some(sec_map) = m.get_model().get(sec) {
144 for (ptype, assertion) in sec_map {
145 for policy in assertion.get_policy() {
146 all_rules.push(CasbinRule::new(sec, ptype, policy));
147 }
148 }
149 }
150 }
151
152 if !all_rules.is_empty() {
153 self.insert_entries(all_rules).await?;
154 }
155 Ok(())
156 }
157
158 async fn clear_policy(&mut self) -> CasbinResult<()> {
159 self.db
160 .query("DELETE type::table($table);")
161 .bind(("table", self.table.clone()))
162 .await
163 .map_err(io_err)?
164 .check()
165 .map_err(io_err)?;
166 Ok(())
167 }
168
169 fn is_filtered(&self) -> bool {
170 self.is_filtered
171 }
172
173 async fn add_policy(
174 &mut self,
175 sec: &str,
176 ptype: &str,
177 rule: Vec<String>,
178 ) -> CasbinResult<bool> {
179 if self.rule_exists(sec, ptype, &rule).await? {
180 return Ok(false);
181 }
182 let entry = CasbinRule::new(sec, ptype, &rule);
183 let _: Option<CasbinRule> = self
184 .db
185 .create(&*self.table)
186 .content(entry)
187 .await
188 .map_err(io_err)?;
189 Ok(true)
190 }
191
192 async fn add_policies(
193 &mut self,
194 sec: &str,
195 ptype: &str,
196 rules: Vec<Vec<String>>,
197 ) -> CasbinResult<bool> {
198 if self.any_rules_exist(sec, ptype, &rules).await? {
199 return Ok(false);
200 }
201 let entries: Vec<CasbinRule> = rules
202 .iter()
203 .map(|r| CasbinRule::new(sec, ptype, r))
204 .collect();
205 self.insert_entries(entries).await?;
206 Ok(true)
207 }
208
209 async fn remove_policy(
210 &mut self,
211 sec: &str,
212 ptype: &str,
213 rule: Vec<String>,
214 ) -> CasbinResult<bool> {
215 self.delete_exact(sec, ptype, &rule).await
216 }
217
218 async fn remove_policies(
219 &mut self,
220 sec: &str,
221 ptype: &str,
222 rules: Vec<Vec<String>>,
223 ) -> CasbinResult<bool> {
224 if rules.is_empty() {
225 return Ok(false);
226 }
227 self.delete_exact_batch(sec, ptype, &rules).await
228 }
229
230 async fn remove_filtered_policy(
231 &mut self,
232 sec: &str,
233 ptype: &str,
234 field_index: usize,
235 field_values: Vec<String>,
236 ) -> CasbinResult<bool> {
237 self.delete_filtered(sec, ptype, field_index, &field_values)
238 .await
239 }
240}
241
242impl SurrealAdapter {
245 async fn insert_entries(&self, entries: Vec<CasbinRule>) -> CasbinResult<bool> {
246 let _: Vec<CasbinRule> = self
247 .db
248 .insert(&*self.table)
249 .content(entries)
250 .await
251 .map_err(io_err)?;
252 Ok(true)
253 }
254
255 async fn get_all_rules(&self) -> CasbinResult<Vec<CasbinRule>> {
256 self.db.select(&*self.table).await.map_err(io_err)
257 }
258
259 async fn get_rules_by_sec(&self, sec: &str) -> CasbinResult<Vec<CasbinRule>> {
260 let rules: Vec<CasbinRule> = self
261 .db
262 .query("SELECT * FROM type::table($table) WHERE sec = $sec")
263 .bind(("table", self.table.clone()))
264 .bind(("sec", sec.to_owned()))
265 .await
266 .map_err(io_err)?
267 .check()
268 .map_err(io_err)?
269 .take(0)
270 .map_err(io_err)?;
271 Ok(rules)
272 }
273
274 async fn get_filtered_rules(
275 &self,
276 sec: &str,
277 filter: &[&str],
278 ) -> CasbinResult<Vec<CasbinRule>> {
279 let mut conditions = vec!["sec = $sec".to_owned()];
280 let mut binds: Vec<(String, String)> = Vec::new();
281
282 for (i, fv) in filter.iter().enumerate() {
283 if !fv.is_empty() {
284 let param = format!("fv{i}");
285 conditions.push(format!("v{i} = ${param}"));
286 binds.push((param, (*fv).to_owned()));
287 }
288 }
289
290 let query = format!(
291 "SELECT * FROM type::table($table) WHERE {}",
292 conditions.join(" AND ")
293 );
294 let mut q = self
295 .db
296 .query(&query)
297 .bind(("table", self.table.clone()))
298 .bind(("sec", sec.to_owned()));
299 for (k, v) in binds {
300 q = q.bind((k, v));
301 }
302
303 let rules: Vec<CasbinRule> = q
304 .await
305 .map_err(io_err)?
306 .check()
307 .map_err(io_err)?
308 .take(0)
309 .map_err(io_err)?;
310 Ok(rules)
311 }
312
313 async fn rule_exists(&self, sec: &str, ptype: &str, rule: &[String]) -> CasbinResult<bool> {
314 let entry = CasbinRule::new(sec, ptype, rule);
315 let q = self
316 .db
317 .query(
318 "SELECT * FROM type::table($table)
319 WHERE sec = $sec AND ptype = $ptype
320 AND v0 = $v0 AND v1 = $v1 AND v2 = $v2
321 AND v3 = $v3 AND v4 = $v4 AND v5 = $v5
322 LIMIT 1",
323 )
324 .bind(("table", self.table.clone()))
325 .bind(("sec", sec.to_owned()))
326 .bind(("ptype", ptype.to_owned()));
327
328 let found: Vec<CasbinRule> = entry
329 .bind_values(q)
330 .await
331 .map_err(io_err)?
332 .check()
333 .map_err(io_err)?
334 .take(0)
335 .map_err(io_err)?;
336
337 Ok(!found.is_empty())
338 }
339
340 async fn any_rules_exist(
341 &self,
342 sec: &str,
343 ptype: &str,
344 rules: &[Vec<String>],
345 ) -> CasbinResult<bool> {
346 if rules.is_empty() {
347 return Ok(false);
348 }
349
350 let mut or_clauses = Vec::new();
351 let mut binds: Vec<(String, Option<String>)> = Vec::new();
352
353 for (ri, rule) in rules.iter().enumerate() {
354 let entry = CasbinRule::new(sec, ptype, rule);
355 let fields = [
356 &entry.v0, &entry.v1, &entry.v2, &entry.v3, &entry.v4, &entry.v5,
357 ];
358 let mut field_conditions = Vec::new();
359 for (fi, val) in fields.iter().enumerate() {
360 let param = format!("r{ri}v{fi}");
361 field_conditions.push(format!("v{fi} = ${param}"));
362 binds.push((param, (*val).clone()));
363 }
364 or_clauses.push(format!("({})", field_conditions.join(" AND ")));
365 }
366
367 let query = format!(
368 "SELECT * FROM type::table($table) WHERE sec = $sec AND ptype = $ptype AND ({}) LIMIT 1",
369 or_clauses.join(" OR ")
370 );
371
372 let mut q = self
373 .db
374 .query(&query)
375 .bind(("table", self.table.clone()))
376 .bind(("sec", sec.to_owned()))
377 .bind(("ptype", ptype.to_owned()));
378
379 for (k, v) in binds {
380 q = q.bind((k, v));
381 }
382
383 let found: Vec<CasbinRule> = q
384 .await
385 .map_err(io_err)?
386 .check()
387 .map_err(io_err)?
388 .take(0)
389 .map_err(io_err)?;
390
391 Ok(!found.is_empty())
392 }
393
394 async fn delete_exact(&self, sec: &str, ptype: &str, rule: &[String]) -> CasbinResult<bool> {
395 let entry = CasbinRule::new(sec, ptype, rule);
396 let q = self
397 .db
398 .query(
399 "DELETE type::table($table)
400 WHERE sec = $sec AND ptype = $ptype
401 AND v0 = $v0 AND v1 = $v1 AND v2 = $v2
402 AND v3 = $v3 AND v4 = $v4 AND v5 = $v5
403 RETURN BEFORE",
404 )
405 .bind(("table", self.table.clone()))
406 .bind(("sec", sec.to_owned()))
407 .bind(("ptype", ptype.to_owned()));
408
409 let deleted: Vec<CasbinRule> = entry
410 .bind_values(q)
411 .await
412 .map_err(io_err)?
413 .check()
414 .map_err(io_err)?
415 .take(0)
416 .map_err(io_err)?;
417
418 Ok(!deleted.is_empty())
419 }
420
421 async fn delete_exact_batch(
422 &self,
423 sec: &str,
424 ptype: &str,
425 rules: &[Vec<String>],
426 ) -> CasbinResult<bool> {
427 let mut or_clauses = Vec::new();
428 let mut binds: Vec<(String, Option<String>)> = Vec::new();
429
430 for (ri, rule) in rules.iter().enumerate() {
431 let entry = CasbinRule::new(sec, ptype, rule);
432 let fields = [
433 &entry.v0, &entry.v1, &entry.v2, &entry.v3, &entry.v4, &entry.v5,
434 ];
435 let mut field_conditions = Vec::new();
436 for (fi, val) in fields.iter().enumerate() {
437 let param = format!("r{ri}v{fi}");
438 field_conditions.push(format!("v{fi} = ${param}"));
439 binds.push((param, (*val).clone()));
440 }
441 or_clauses.push(format!("({})", field_conditions.join(" AND ")));
442 }
443
444 let query = format!(
445 "DELETE type::table($table) WHERE sec = $sec AND ptype = $ptype AND ({}) RETURN BEFORE",
446 or_clauses.join(" OR ")
447 );
448
449 let mut q = self
450 .db
451 .query(&query)
452 .bind(("table", self.table.clone()))
453 .bind(("sec", sec.to_owned()))
454 .bind(("ptype", ptype.to_owned()));
455
456 for (k, v) in binds {
457 q = q.bind((k, v));
458 }
459
460 let deleted: Vec<CasbinRule> = q
461 .await
462 .map_err(io_err)?
463 .check()
464 .map_err(io_err)?
465 .take(0)
466 .map_err(io_err)?;
467
468 Ok(!deleted.is_empty())
469 }
470
471 async fn delete_filtered(
472 &self,
473 sec: &str,
474 ptype: &str,
475 field_index: usize,
476 field_values: &[String],
477 ) -> CasbinResult<bool> {
478 let mut col_conditions = Vec::new();
479 let mut binds: Vec<(String, String)> = Vec::new();
480
481 for (offset, v) in field_values.iter().enumerate() {
482 if !v.is_empty() {
483 let col = field_index + offset;
484 let param = format!("fv{offset}");
485 col_conditions.push(format!("v{col} = ${param}"));
486 binds.push((param, v.clone()));
487 }
488 }
489
490 let where_clause = if col_conditions.is_empty() {
491 "sec = $sec AND ptype = $ptype".to_owned()
492 } else {
493 format!(
494 "sec = $sec AND ptype = $ptype AND {}",
495 col_conditions.join(" AND ")
496 )
497 };
498
499 let query = format!("DELETE type::table($table) WHERE {where_clause} RETURN BEFORE");
500
501 let mut q = self
502 .db
503 .query(&query)
504 .bind(("table", self.table.clone()))
505 .bind(("sec", sec.to_owned()))
506 .bind(("ptype", ptype.to_owned()));
507
508 for (k, v) in binds {
509 q = q.bind((k, v));
510 }
511
512 let deleted: Vec<CasbinRule> = q
513 .await
514 .map_err(io_err)?
515 .check()
516 .map_err(io_err)?
517 .take(0)
518 .map_err(io_err)?;
519
520 Ok(!deleted.is_empty())
521 }
522}
523
524fn io_err(e: impl std::fmt::Display) -> casbin::Error {
527 casbin::Error::IoError(std::io::Error::other(e.to_string()))
528}