1#![cfg_attr(docsrs, feature(doc_cfg))]
166
167use std::collections::HashMap;
168use std::fmt::Debug;
169use std::result::Result as StdResult;
170use std::sync::Arc;
171
172use ipnet::IpNet;
173use serde_json::{Map, Value as JsonValue, json};
174
175#[cfg(feature = "validation")]
177#[cfg_attr(docsrs, doc(cfg(feature = "validation")))]
178pub use jsonschema::ValidationError;
179
180pub use error::Error;
182pub use matcher::{
183 BoolMatcher, DefaultMatcher, IpMatcher, Matcher, NumberMatcher, Operator, RegexMatcher,
184 StringMatcher,
185};
186pub use types::{AsyncCheckFn, BoxFuture, CheckFn, MaybeSend, MaybeSync, ToOperator};
187pub use value::Value;
188
189use crate::types::{AsyncEvalFn, AsyncFetcherFn, DynError, EvalFn, FetcherFn};
190
191pub(crate) type Result<T> = StdResult<T, error::Error>;
192
193pub enum Rule<Ctx: ?Sized + 'static> {
201 Any(Vec<Self>),
202 All(Vec<Self>),
203 Not(Box<Self>),
204 Leaf(Condition<Ctx>),
205}
206
207impl<Ctx: ?Sized> Debug for Rule<Ctx> {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 match self {
210 Rule::Any(rules) => f.debug_tuple("Any").field(rules).finish(),
211 Rule::All(rules) => f.debug_tuple("All").field(rules).finish(),
212 Rule::Not(rule) => f.debug_tuple("Not").field(rule).finish(),
213 Rule::Leaf(_) => f.debug_tuple("Leaf").finish(),
214 }
215 }
216}
217
218impl<Ctx: ?Sized> Clone for Rule<Ctx> {
219 fn clone(&self) -> Self {
220 match self {
221 Rule::Any(rules) => Rule::Any(rules.clone()),
222 Rule::All(rules) => Rule::All(rules.clone()),
223 Rule::Not(rule) => Rule::Not(rule.clone()),
224 Rule::Leaf(condition) => Rule::Leaf(condition.clone()),
225 }
226 }
227}
228
229#[doc(hidden)]
233pub struct Condition<Ctx: ?Sized>(AnyEvalFn<Ctx>);
234
235impl<Ctx: ?Sized> Clone for Condition<Ctx> {
236 fn clone(&self) -> Self {
237 Condition(self.0.clone())
238 }
239}
240
241impl<Ctx: ?Sized> Rule<Ctx> {
242 #[inline(always)]
243 fn any(mut rules: Vec<Rule<Ctx>>) -> Self {
244 if rules.len() == 1 {
245 return rules.pop().unwrap();
246 }
247 Rule::Any(rules)
248 }
249
250 #[inline(always)]
251 fn all(mut rules: Vec<Rule<Ctx>>) -> Self {
252 if rules.len() == 1 {
253 return rules.pop().unwrap();
254 }
255 Rule::All(rules)
256 }
257
258 #[inline(always)]
259 fn not(mut rules: Vec<Rule<Ctx>>) -> Self {
260 if rules.len() == 1 {
261 return Rule::Not(Box::new(rules.pop().unwrap()));
262 }
263 Rule::Not(Box::new(Rule::All(rules)))
264 }
265
266 #[inline(always)]
267 fn leaf(eval_fn: AnyEvalFn<Ctx>) -> Self {
268 Rule::Leaf(Condition(eval_fn))
269 }
270}
271
272#[derive(Debug)]
274pub(crate) struct FetcherKey {
275 name: String,
276 args: Vec<String>,
277}
278
279enum AnyFetcherFn<Ctx: ?Sized> {
280 Sync(Arc<FetcherFn<Ctx>>),
281 Async(Arc<AsyncFetcherFn<Ctx>>),
282}
283
284impl<Ctx: ?Sized> Clone for AnyFetcherFn<Ctx> {
285 fn clone(&self) -> Self {
286 match self {
287 AnyFetcherFn::Sync(func) => AnyFetcherFn::Sync(func.clone()),
288 AnyFetcherFn::Async(func) => AnyFetcherFn::Async(func.clone()),
289 }
290 }
291}
292
293enum AnyEvalFn<Ctx: ?Sized> {
294 Sync(EvalFn<Ctx>),
295 Async(AsyncEvalFn<Ctx>),
296}
297
298impl<Ctx: ?Sized> Clone for AnyEvalFn<Ctx> {
299 fn clone(&self) -> Self {
300 match self {
301 AnyEvalFn::Sync(func) => AnyEvalFn::Sync(func.clone()),
302 AnyEvalFn::Async(func) => AnyEvalFn::Async(func.clone()),
303 }
304 }
305}
306
307pub struct Fetcher<Ctx: ?Sized> {
314 matcher: Arc<dyn Matcher<Ctx>>,
315 func: AnyFetcherFn<Ctx>,
316 raw_args: bool,
317}
318
319impl<Ctx: ?Sized> Clone for Fetcher<Ctx> {
320 fn clone(&self) -> Self {
321 Fetcher {
322 matcher: self.matcher.clone(),
323 func: self.func.clone(),
324 raw_args: self.raw_args,
325 }
326 }
327}
328
329impl<Ctx: ?Sized> Fetcher<Ctx> {
330 pub fn with_matcher<M>(&mut self, matcher: M) -> &mut Self
332 where
333 M: Matcher<Ctx> + 'static,
334 {
335 self.matcher = Arc::new(matcher);
336 self
337 }
338
339 pub fn with_raw_args(&mut self, raw_args: bool) -> &mut Self {
341 self.raw_args = raw_args;
342 self
343 }
344}
345
346pub struct Engine<Ctx: MaybeSync + ?Sized + 'static> {
385 fetchers: HashMap<String, Fetcher<Ctx>>,
386 operators: HashMap<String, Arc<dyn ToOperator<Ctx>>>,
387}
388
389impl<Ctx: MaybeSync + ?Sized> Default for Engine<Ctx> {
390 fn default() -> Self {
391 Self::new()
392 }
393}
394
395impl<Ctx: MaybeSync + ?Sized> Clone for Engine<Ctx> {
396 fn clone(&self) -> Self {
397 Engine {
398 fetchers: self.fetchers.clone(),
399 operators: self.operators.clone(),
400 }
401 }
402}
403
404impl<Ctx: MaybeSync + ?Sized> Engine<Ctx> {
405 pub fn new() -> Self {
407 Engine {
408 fetchers: HashMap::new(),
409 operators: HashMap::new(),
410 }
411 }
412
413 pub fn register_fetcher<F>(&mut self, name: &str, func: F) -> &mut Fetcher<Ctx>
423 where
424 F: for<'a> Fn(&'a Ctx, &[String]) -> StdResult<Value<'a>, DynError>
425 + MaybeSend
426 + MaybeSync
427 + 'static,
428 {
429 let fetcher = Fetcher {
430 matcher: Arc::new(DefaultMatcher),
431 func: AnyFetcherFn::Sync(Arc::new(func)),
432 raw_args: false,
433 };
434 self.fetchers
435 .entry(name.to_string())
436 .insert_entry(fetcher)
437 .into_mut()
438 }
439
440 pub fn register_async_fetcher<F>(&mut self, name: &str, func: F) -> &mut Fetcher<Ctx>
444 where
445 F: for<'a> Fn(&'a Ctx, Arc<[String]>) -> BoxFuture<'a, StdResult<Value<'a>, DynError>>
446 + MaybeSend
447 + MaybeSync
448 + 'static,
449 {
450 let fetcher = Fetcher {
451 matcher: Arc::new(DefaultMatcher),
452 func: AnyFetcherFn::Async(Arc::new(func)),
453 raw_args: false,
454 };
455 self.fetchers
456 .entry(name.to_string())
457 .insert_entry(fetcher)
458 .into_mut()
459 }
460
461 pub fn register_operator<O>(&mut self, name: &str, op: O)
463 where
464 O: ToOperator<Ctx> + 'static,
465 {
466 self.operators.insert(name.to_string(), Arc::new(op));
467 }
468
469 pub fn compile_rule(&self, value: &JsonValue) -> Result<Rule<Ctx>> {
471 self.compile_rule_inner(value).map(Rule::all)
472 }
473
474 fn compile_rule_inner(&self, value: &JsonValue) -> Result<Vec<Rule<Ctx>>> {
476 match value {
477 JsonValue::Object(map) => {
478 let mut subrules = Vec::with_capacity(map.len());
479 for (key, value) in map {
480 match key.as_str() {
481 "any" => subrules.push(Rule::any(self.compile_rule_inner(value)?)),
482 "all" => subrules.push(Rule::all(self.compile_rule_inner(value)?)),
483 "not" => subrules.push(Rule::not(self.compile_rule_inner(value)?)),
484 _ => {
485 let FetcherKey { name, args } = Self::parse_fetcher_key(key)?;
486 let fetcher = (self.fetchers.get(&name)).ok_or_else(|| {
487 Error::fetcher(&name, "fetcher is not registered")
488 })?;
489 let args = Self::parse_fetcher_args(args.clone(), fetcher.raw_args);
490
491 let mut operator = fetcher.matcher.compile(value);
492 if let Err(Error::UnknownOperator(ref op)) = operator
494 && let Some(op_builder) = self.operators.get(op)
495 {
496 operator = op_builder
497 .to_operator(&value[op])
498 .map_err(|err| Error::operator(op, err));
499 }
500 let operator = operator.map_err(|err| Error::matcher(&name, err))?;
501 let fetcher_fn = fetcher.func.clone();
502 let eval_fn =
503 Self::compile_condition(fetcher_fn, args.into(), operator);
504
505 subrules.push(Rule::leaf(eval_fn));
506 }
507 }
508 }
509 Ok(subrules)
510 }
511 JsonValue::Array(seq) => {
512 (seq.iter()).try_fold(Vec::with_capacity(seq.len()), |mut subrules, inner| {
513 subrules.push(self.compile_rule_inner(inner).map(Rule::all)?);
514 Result::Ok(subrules)
515 })
516 }
517 _ => Err(Error::json("rule must be a JSON object or array")),
518 }
519 }
520
521 #[cfg(feature = "validation")]
523 #[cfg_attr(docsrs, doc(cfg(feature = "validation")))]
524 #[allow(clippy::result_large_err)]
525 pub fn validate_rule<'a>(&self, value: &'a JsonValue) -> StdResult<(), ValidationError<'a>> {
526 let schema = self.json_schema();
528 let validator = jsonschema::options()
529 .with_pattern_options(jsonschema::PatternOptions::regex())
530 .build(&schema)?;
531 validator.validate(value)
532 }
533
534 pub fn json_schema(&self) -> JsonValue {
536 let mut pattern_props = Map::new();
537
538 let custom_ops: Vec<(&str, JsonValue)> = (self.operators.iter())
540 .map(|(k, v)| (k.as_str(), v.json_schema()))
541 .collect();
542
543 for (name, fetcher) in &self.fetchers {
545 let pattern = format!(r"^{}(:?\(([^)]*)\))?$", regex::escape(name));
546 let schema = fetcher.matcher.json_schema(&custom_ops);
547 pattern_props.insert(pattern, schema);
548 }
549
550 json!({
551 "$schema": "http://json-schema.org/draft-07/schema",
552 "$ref": "#/definitions/rule_object",
553 "definitions": {
554 "rule_object": {
555 "type": "object",
556 "properties": {
557 "any": { "$ref": "#/definitions/rule" },
558 "all": { "$ref": "#/definitions/rule" },
559 "not": { "$ref": "#/definitions/rule" }
560 },
561 "patternProperties": pattern_props,
562 "additionalProperties": false,
563 },
564 "rule_array": {
565 "type": "array",
566 "minItems": 1,
567 "items": { "$ref": "#/definitions/rule_object" },
568 },
569 "rule": {
570 "if": { "type": "array" },
571 "then": { "$ref": "#/definitions/rule_array" },
572 "else": { "$ref": "#/definitions/rule_object" }
573 },
574 }
575 })
576 }
577
578 fn compile_condition(
579 fetcher_fn: AnyFetcherFn<Ctx>,
580 fetcher_args: Arc<[String]>,
581 operator: Operator<Ctx>,
582 ) -> AnyEvalFn<Ctx> {
583 match (fetcher_fn, operator) {
584 (AnyFetcherFn::Sync(fetcher_fn), Operator::Equal(right)) => {
586 AnyEvalFn::Sync(Arc::new(move |ctx| {
587 Ok(fetcher_fn(ctx, &fetcher_args)? == right)
588 }))
589 }
590 (AnyFetcherFn::Async(fetcher_fn), Operator::Equal(right)) => {
591 let right = Arc::new(right);
592 AnyEvalFn::Async(Arc::new(move |ctx| {
593 let right = right.clone();
594 let value = fetcher_fn(ctx, fetcher_args.clone());
595 Box::pin(async move { Ok(value.await? == *right) })
596 }))
597 }
598
599 (AnyFetcherFn::Sync(fetcher_fn), Operator::LessThan(right)) => {
601 AnyEvalFn::Sync(Arc::new(move |ctx| {
602 Ok(fetcher_fn(ctx, &fetcher_args)? < right)
603 }))
604 }
605 (AnyFetcherFn::Async(fetcher_fn), Operator::LessThan(right)) => {
606 let right = Arc::new(right);
607 AnyEvalFn::Async(Arc::new(move |ctx| {
608 let right = right.clone();
609 let value = fetcher_fn(ctx, fetcher_args.clone());
610 Box::pin(async move { Ok(value.await? < *right) })
611 }))
612 }
613
614 (AnyFetcherFn::Sync(fetcher_fn), Operator::LessThanOrEqual(right)) => {
616 AnyEvalFn::Sync(Arc::new(move |ctx| {
617 Ok(fetcher_fn(ctx, &fetcher_args)? <= right)
618 }))
619 }
620 (AnyFetcherFn::Async(fetcher_fn), Operator::LessThanOrEqual(right)) => {
621 let right = Arc::new(right);
622 AnyEvalFn::Async(Arc::new(move |ctx| {
623 let right = right.clone();
624 let value = fetcher_fn(ctx, fetcher_args.clone());
625 Box::pin(async move { Ok(value.await? <= *right) })
626 }))
627 }
628
629 (AnyFetcherFn::Sync(fetcher_fn), Operator::GreaterThan(right)) => {
631 AnyEvalFn::Sync(Arc::new(move |ctx| {
632 Ok(fetcher_fn(ctx, &fetcher_args)? > right)
633 }))
634 }
635 (AnyFetcherFn::Async(fetcher_fn), Operator::GreaterThan(right)) => {
636 let right = Arc::new(right);
637 AnyEvalFn::Async(Arc::new(move |ctx| {
638 let right = right.clone();
639 let value = fetcher_fn(ctx, fetcher_args.clone());
640 Box::pin(async move { Ok(value.await? > *right) })
641 }))
642 }
643
644 (AnyFetcherFn::Sync(fetcher_fn), Operator::GreaterThanOrEqual(right)) => {
646 AnyEvalFn::Sync(Arc::new(move |ctx| {
647 Ok(fetcher_fn(ctx, &fetcher_args)? >= right)
648 }))
649 }
650 (AnyFetcherFn::Async(fetcher_fn), Operator::GreaterThanOrEqual(right)) => {
651 let right = Arc::new(right);
652 AnyEvalFn::Async(Arc::new(move |ctx| {
653 let right = right.clone();
654 let value = fetcher_fn(ctx, fetcher_args.clone());
655 Box::pin(async move { Ok(value.await? >= *right) })
656 }))
657 }
658
659 (AnyFetcherFn::Sync(fetcher_fn), Operator::InSet(set)) => {
661 AnyEvalFn::Sync(Arc::new(move |ctx| {
662 fetcher_fn(ctx, &fetcher_args).map(|val| set.contains(&val))
663 }))
664 }
665 (AnyFetcherFn::Async(fetcher_fn), Operator::InSet(set)) => {
666 let set = Arc::new(set);
667 AnyEvalFn::Async(Arc::new(move |ctx| {
668 let set = set.clone();
669 let value = fetcher_fn(ctx, fetcher_args.clone());
670 Box::pin(async move { value.await.map(|val| set.contains(&val)) })
671 }))
672 }
673
674 (AnyFetcherFn::Sync(fetcher_fn), Operator::Regex(regex)) => {
676 AnyEvalFn::Sync(Arc::new(move |ctx| {
677 fetcher_fn(ctx, &fetcher_args)
678 .map(|val| val.as_str().map(|s| regex.is_match(s)).unwrap_or(false))
679 }))
680 }
681 (AnyFetcherFn::Async(fetcher_fn), Operator::Regex(regex)) => {
682 let regex = Arc::new(regex);
683 AnyEvalFn::Async(Arc::new(move |ctx| {
684 let regex = regex.clone();
685 let value = fetcher_fn(ctx, fetcher_args.clone());
686 Box::pin(async move {
687 (value.await)
688 .map(|val| val.as_str().map(|s| regex.is_match(s)).unwrap_or(false))
689 })
690 }))
691 }
692
693 (AnyFetcherFn::Sync(fetcher_fn), Operator::RegexSet(regex_set)) => {
695 AnyEvalFn::Sync(Arc::new(move |ctx| {
696 fetcher_fn(ctx, &fetcher_args)
697 .map(|val| val.as_str().map(|s| regex_set.is_match(s)).unwrap_or(false))
698 }))
699 }
700 (AnyFetcherFn::Async(fetcher_fn), Operator::RegexSet(regex_set)) => {
701 let regex_set = Arc::new(regex_set);
702 AnyEvalFn::Async(Arc::new(move |ctx| {
703 let regex_set = regex_set.clone();
704 let value = fetcher_fn(ctx, fetcher_args.clone());
705 Box::pin(async move {
706 (value.await)
707 .map(|val| val.as_str().map(|s| regex_set.is_match(s)).unwrap_or(false))
708 })
709 }))
710 }
711
712 (AnyFetcherFn::Sync(fetcher_fn), Operator::IpSet(set)) => {
714 AnyEvalFn::Sync(Arc::new(move |ctx| {
715 Ok((fetcher_fn(ctx, &fetcher_args)?.to_ip())
716 .map(|ip| set.longest_match(&IpNet::from(ip)).is_some())
717 .unwrap_or(false))
718 }))
719 }
720 (AnyFetcherFn::Async(fetcher_fn), Operator::IpSet(set)) => {
721 let set = Arc::new(set);
722 AnyEvalFn::Async(Arc::new(move |ctx| {
723 let set = set.clone();
724 let value = fetcher_fn(ctx, fetcher_args.clone());
725 Box::pin(async move {
726 Ok((value.await?.to_ip())
727 .map(|ip| set.longest_match(&IpNet::from(ip)).is_some())
728 .unwrap_or(false))
729 })
730 }))
731 }
732
733 (AnyFetcherFn::Sync(fetcher_fn), Operator::Custom(op_fn)) => {
735 AnyEvalFn::Sync(Arc::new(move |ctx| {
736 let value = fetcher_fn(ctx, &fetcher_args)?;
737 op_fn(ctx, value)
738 }))
739 }
740 (AnyFetcherFn::Async(fetcher_fn), Operator::Custom(op_fn)) => {
741 let op_fn: Arc<CheckFn<Ctx>> = op_fn.into();
742 AnyEvalFn::Async(Arc::new(move |ctx| {
743 let op_fn = op_fn.clone();
744 let value = fetcher_fn(ctx, fetcher_args.clone());
745 Box::pin(async move { op_fn(ctx, value.await?) })
746 }))
747 }
748
749 (AnyFetcherFn::Sync(fetcher_fn), Operator::CustomAsync(op_fn)) => {
751 let op_fn: Arc<AsyncCheckFn<Ctx>> = op_fn.into();
752 AnyEvalFn::Async(Arc::new(move |ctx| {
753 let op_fn = op_fn.clone();
754 let value = fetcher_fn(ctx, &fetcher_args);
755 Box::pin(async move { op_fn(ctx, value?).await })
756 }))
757 }
758 (AnyFetcherFn::Async(fetcher_fn), Operator::CustomAsync(op_fn)) => {
759 let op_fn: Arc<AsyncCheckFn<Ctx>> = op_fn.into();
760 AnyEvalFn::Async(Arc::new(move |ctx| {
761 let op_fn = op_fn.clone();
762 let value = fetcher_fn(ctx, fetcher_args.clone());
763 Box::pin(async move { op_fn(ctx, value.await?).await })
764 }))
765 }
766 }
767 }
768
769 fn parse_fetcher_key(key: &str) -> Result<FetcherKey> {
771 if let Some((name, args_str)) = key.split_once('(') {
772 if !args_str.ends_with(')') {
773 return Err(Error::fetcher(name, "missing closing parenthesis"));
774 }
775 let args_str = &args_str[..args_str.len() - 1];
776 let args = if args_str.is_empty() {
777 vec![]
778 } else {
779 vec![args_str.to_string()]
780 };
781 Ok(FetcherKey {
782 name: name.to_string(),
783 args,
784 })
785 } else {
786 Ok(FetcherKey {
787 name: key.to_string(),
788 args: Vec::new(),
789 })
790 }
791 }
792
793 fn parse_fetcher_args(mut args: Vec<String>, raw: bool) -> Vec<String> {
795 if raw || args.is_empty() {
796 args
797 } else {
798 let arg = args.pop().unwrap_or_default();
799 arg.split(',').map(|s| s.trim().to_string()).collect()
800 }
801 }
802}
803
804impl<Ctx: ?Sized> Rule<Ctx> {
805 pub fn evaluate(&self, context: &Ctx) -> StdResult<bool, DynError> {
811 match self {
812 Rule::Leaf(Condition(AnyEvalFn::Sync(eval_fn))) => eval_fn(context),
813 Rule::Leaf(Condition(AnyEvalFn::Async(_))) => {
814 Err("async operations are not supported in sync context".into())
815 }
816 Rule::Any(subrules) => {
817 for rule in subrules {
818 if rule.evaluate(context)? {
819 return Ok(true);
820 }
821 }
822 Ok(false)
823 }
824 Rule::All(subrules) => {
825 for rule in subrules {
826 if !rule.evaluate(context)? {
827 return Ok(false);
828 }
829 }
830 Ok(true)
831 }
832 Rule::Not(subrule) => Ok(!subrule.evaluate(context)?),
833 }
834 }
835
836 pub async fn evaluate_async(&self, context: &Ctx) -> StdResult<bool, DynError> {
842 match self {
843 Rule::Leaf(Condition(AnyEvalFn::Sync(eval_fn))) => eval_fn(context),
844 Rule::Leaf(Condition(AnyEvalFn::Async(eval_fn))) => eval_fn(context).await,
845 Rule::Any(subrules) => {
846 for rule in subrules {
847 if Box::pin(rule.evaluate_async(context)).await? {
848 return Ok(true);
849 }
850 }
851 Ok(false)
852 }
853 Rule::All(subrules) => {
854 for rule in subrules {
855 if !Box::pin(rule.evaluate_async(context)).await? {
856 return Ok(false);
857 }
858 }
859 Ok(true)
860 }
861 Rule::Not(subrule) => Ok(!Box::pin(subrule.evaluate_async(context)).await?),
862 }
863 }
864}
865
866pub(crate) trait JsonValueExt {
867 fn type_name(&self) -> &'static str;
868}
869
870impl JsonValueExt for JsonValue {
871 fn type_name(&self) -> &'static str {
872 match self {
873 JsonValue::String(_) => "string",
874 JsonValue::Number(_) => "number",
875 JsonValue::Bool(_) => "boolean",
876 JsonValue::Array(_) => "array",
877 JsonValue::Object(_) => "object",
878 JsonValue::Null => "null",
879 }
880 }
881}
882
883mod error;
884mod matcher;
885mod types;
886mod value;
887
888#[cfg(test)]
889mod tests {
890 #[cfg(feature = "send")]
891 static_assertions::assert_impl_all!(super::Engine<()>: Send, Sync);
892 #[cfg(feature = "send")]
893 static_assertions::assert_impl_all!(super::Rule<()>: Send, Sync);
894}