1use crate::iam::{Auth, Level};
2use crate::rpc::Method;
3use ipnet::IpNet;
4use std::collections::HashSet;
5use std::fmt;
6use std::hash::Hash;
7use std::net::IpAddr;
8#[cfg(all(target_family = "wasm", feature = "http"))]
9use std::net::ToSocketAddrs;
10#[cfg(all(not(target_family = "wasm"), feature = "http"))]
11use tokio::net::lookup_host;
12use url::Url;
13
14pub trait Target<Item: ?Sized = Self> {
15 fn matches(&self, elem: &Item) -> bool;
16}
17
18#[derive(Debug, Clone, Hash, Eq, PartialEq)]
19#[non_exhaustive]
20pub struct FuncTarget(pub String, pub Option<String>);
21
22impl fmt::Display for FuncTarget {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 match &self.1 {
25 Some(name) => write!(f, "{}:{name}", self.0),
26 None => write!(f, "{}::*", self.0),
27 }
28 }
29}
30
31impl Target for FuncTarget {
32 fn matches(&self, elem: &FuncTarget) -> bool {
33 match self {
34 Self(family, Some(name)) => {
35 family == &elem.0 && (elem.1.as_ref().is_some_and(|n| n == name))
36 }
37 Self(family, None) => family == &elem.0,
38 }
39 }
40}
41
42impl Target<str> for FuncTarget {
43 fn matches(&self, elem: &str) -> bool {
44 if let Some(x) = self.1.as_ref() {
45 let Some((f, r)) = elem.split_once("::") else {
46 return false;
47 };
48
49 f == self.0 && r == x
50 } else {
51 let f = elem.split_once("::").map(|(f, _)| f).unwrap_or(elem);
52 f == self.0
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
58pub enum ParseFuncTargetError {
59 InvalidWildcardFamily,
60 InvalidName,
61}
62
63impl std::error::Error for ParseFuncTargetError {}
64impl fmt::Display for ParseFuncTargetError {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match *self {
67 ParseFuncTargetError::InvalidName => {
68 write!(f, "invalid function target name")
69 }
70 ParseFuncTargetError::InvalidWildcardFamily => {
71 write!(
72 f,
73 "invalid function target wildcard family, only first part of function can be wildcarded"
74 )
75 }
76 }
77 }
78}
79
80impl std::str::FromStr for FuncTarget {
81 type Err = ParseFuncTargetError;
82
83 fn from_str(s: &str) -> Result<Self, Self::Err> {
84 let s = s.trim();
85
86 if s.is_empty() {
87 return Err(ParseFuncTargetError::InvalidName);
88 }
89
90 if let Some(family) = s.strip_suffix("::*") {
91 if family.contains("::") {
92 return Err(ParseFuncTargetError::InvalidWildcardFamily);
93 }
94
95 if !family.bytes().all(|x| x.is_ascii_alphanumeric()) {
96 return Err(ParseFuncTargetError::InvalidName);
97 }
98
99 return Ok(FuncTarget(family.to_string(), None));
100 }
101
102 if !s.bytes().all(|x| x.is_ascii_alphanumeric() || x == b':') {
103 return Err(ParseFuncTargetError::InvalidName);
104 }
105
106 if let Some((first, rest)) = s.split_once("::") {
107 Ok(FuncTarget(first.to_string(), Some(rest.to_string())))
108 } else {
109 Ok(FuncTarget(s.to_string(), None))
110 }
111 }
112}
113
114#[derive(Debug, Clone, Hash, Eq, PartialEq)]
115#[non_exhaustive]
116pub enum ExperimentalTarget {
117 RecordReferences,
118 GraphQL,
119 BearerAccess,
120 DefineApi,
121}
122
123impl fmt::Display for ExperimentalTarget {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 match self {
126 Self::RecordReferences => write!(f, "record_references"),
127 Self::GraphQL => write!(f, "graphql"),
128 Self::BearerAccess => write!(f, "bearer_access"),
129 Self::DefineApi => write!(f, "define_api"),
130 }
131 }
132}
133
134impl Target for ExperimentalTarget {
135 fn matches(&self, elem: &ExperimentalTarget) -> bool {
136 self == elem
137 }
138}
139
140impl Target<str> for ExperimentalTarget {
141 fn matches(&self, elem: &str) -> bool {
142 match self {
143 Self::RecordReferences => elem.eq_ignore_ascii_case("record_references"),
144 Self::GraphQL => elem.eq_ignore_ascii_case("graphql"),
145 Self::BearerAccess => elem.eq_ignore_ascii_case("bearer_access"),
146 Self::DefineApi => elem.eq_ignore_ascii_case("define_api"),
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
152pub enum ParseExperimentalTargetError {
153 InvalidName,
154}
155
156impl std::error::Error for ParseExperimentalTargetError {}
157impl fmt::Display for ParseExperimentalTargetError {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 match *self {
160 ParseExperimentalTargetError::InvalidName => {
161 write!(f, "invalid experimental target name")
162 }
163 }
164 }
165}
166
167impl std::str::FromStr for ExperimentalTarget {
168 type Err = ParseExperimentalTargetError;
169
170 fn from_str(s: &str) -> Result<Self, Self::Err> {
171 match s.trim().to_lowercase().as_str() {
172 "record_references" => Ok(ExperimentalTarget::RecordReferences),
173 "graphql" => Ok(ExperimentalTarget::GraphQL),
174 "bearer_access" => Ok(ExperimentalTarget::BearerAccess),
175 "define_api" => Ok(ExperimentalTarget::DefineApi),
176 _ => Err(ParseExperimentalTargetError::InvalidName),
177 }
178 }
179}
180
181#[derive(Debug, Clone, Hash, Eq, PartialEq)]
182#[non_exhaustive]
183pub enum NetTarget {
184 Host(url::Host<String>, Option<u16>),
185 IPNet(IpNet),
186}
187
188#[cfg(feature = "http")]
189impl NetTarget {
190 #[cfg(not(target_family = "wasm"))]
217 pub(crate) async fn resolve(&self) -> Result<Vec<Self>, std::io::Error> {
218 match self {
219 NetTarget::Host(h, p) => {
220 let r = lookup_host((h.to_string(), p.unwrap_or(80)))
221 .await?
222 .map(|a| NetTarget::IPNet(a.ip().into()))
223 .collect();
224 Ok(r)
225 }
226 NetTarget::IPNet(_) => Ok(vec![]),
227 }
228 }
229
230 #[cfg(target_family = "wasm")]
231 pub(crate) fn resolve(&self) -> Result<Vec<Self>, std::io::Error> {
232 match self {
233 NetTarget::Host(h, p) => {
234 let r = (h.to_string(), p.unwrap_or(80))
235 .to_socket_addrs()?
236 .map(|a| NetTarget::IPNet(a.ip().into()))
237 .collect();
238 Ok(r)
239 }
240 NetTarget::IPNet(_) => Ok(vec![]),
241 }
242 }
243}
244
245impl fmt::Display for NetTarget {
247 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248 match self {
249 Self::Host(host, Some(port)) => write!(f, "{}:{}", host, port),
250 Self::Host(host, None) => write!(f, "{}", host),
251 Self::IPNet(ipnet) => write!(f, "{}", ipnet),
252 }
253 }
254}
255
256impl Target for NetTarget {
257 fn matches(&self, elem: &Self) -> bool {
258 match self {
259 Self::Host(host, Some(port)) => match elem {
261 Self::Host(_host, Some(_port)) => host == _host && port == _port,
262 _ => false,
263 },
264 Self::Host(host, None) => match elem {
266 Self::Host(_host, _) => host == _host,
267 _ => false,
268 },
269 Self::IPNet(ipnet) => match elem {
271 Self::IPNet(_ipnet) => ipnet.contains(_ipnet),
272 Self::Host(host, _) => match host {
273 url::Host::Ipv4(ip) => ipnet.contains(&IpAddr::from(ip.to_owned())),
274 url::Host::Ipv6(ip) => ipnet.contains(&IpAddr::from(ip.to_owned())),
275 _ => false,
276 },
277 },
278 }
279 }
280}
281
282#[derive(Debug)]
283pub struct ParseNetTargetError;
284
285impl std::error::Error for ParseNetTargetError {}
286impl fmt::Display for ParseNetTargetError {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 write!(f, "The provided network target is not a valid host name, IP address or CIDR block")
289 }
290}
291
292impl std::str::FromStr for NetTarget {
293 type Err = ParseNetTargetError;
294
295 fn from_str(s: &str) -> Result<Self, Self::Err> {
296 if let Ok(ipnet) = s.parse::<IpNet>() {
298 return Ok(NetTarget::IPNet(ipnet));
299 }
300
301 if let Ok(ipnet) = s.parse::<IpAddr>() {
303 return Ok(NetTarget::IPNet(IpNet::from(ipnet)));
304 }
305
306 if let Ok(url) = Url::parse(format!("http://{s}").as_str()) {
308 if let Some(host) = url.host() {
309 if let Some(Ok(port)) = s.split(':').next_back().map(|p| p.parse::<u16>()) {
311 return Ok(NetTarget::Host(host.to_owned(), Some(port)));
312 } else {
313 return Ok(NetTarget::Host(host.to_owned(), None));
314 }
315 }
316 }
317
318 Err(ParseNetTargetError)
319 }
320}
321
322#[derive(Debug, Clone, Hash, Eq, PartialEq)]
323pub struct MethodTarget {
324 pub method: Method,
325}
326
327impl fmt::Display for MethodTarget {
329 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330 write!(f, "{}", self.method.to_str())
331 }
332}
333
334impl Target for MethodTarget {
335 fn matches(&self, elem: &Self) -> bool {
336 self.method == elem.method
337 }
338}
339
340#[derive(Debug)]
341pub struct ParseMethodTargetError;
342
343impl std::error::Error for ParseMethodTargetError {}
344impl fmt::Display for ParseMethodTargetError {
345 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346 write!(f, "The provided method target is not a valid RPC method")
347 }
348}
349
350impl std::str::FromStr for MethodTarget {
351 type Err = ParseMethodTargetError;
352
353 fn from_str(s: &str) -> Result<Self, Self::Err> {
354 match Method::parse_case_insensitive(s) {
355 Method::Unknown => Err(ParseMethodTargetError),
356 method => Ok(MethodTarget {
357 method,
358 }),
359 }
360 }
361}
362
363#[derive(Debug, Clone, Hash, Eq, PartialEq)]
364#[non_exhaustive]
365pub enum RouteTarget {
366 Health,
367 Export,
368 Import,
369 Rpc,
370 Version,
371 Sync,
372 Sql,
373 Signin,
374 Signup,
375 Key,
376 Ml,
377 GraphQL,
378 Api,
379}
380
381impl fmt::Display for RouteTarget {
383 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384 match self {
385 RouteTarget::Health => write!(f, "health"),
386 RouteTarget::Export => write!(f, "export"),
387 RouteTarget::Import => write!(f, "import"),
388 RouteTarget::Rpc => write!(f, "rpc"),
389 RouteTarget::Version => write!(f, "version"),
390 RouteTarget::Sync => write!(f, "sync"),
391 RouteTarget::Sql => write!(f, "sql"),
392 RouteTarget::Signin => write!(f, "signin"),
393 RouteTarget::Signup => write!(f, "signup"),
394 RouteTarget::Key => write!(f, "key"),
395 RouteTarget::Ml => write!(f, "ml"),
396 RouteTarget::GraphQL => write!(f, "graphql"),
397 RouteTarget::Api => write!(f, "api"),
398 }
399 }
400}
401
402impl Target for RouteTarget {
403 fn matches(&self, elem: &Self) -> bool {
404 *self == *elem
405 }
406}
407
408#[derive(Debug)]
409pub struct ParseRouteTargetError;
410
411impl std::error::Error for ParseRouteTargetError {}
412impl fmt::Display for ParseRouteTargetError {
413 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
414 write!(f, "The provided route target is not a valid HTTP route")
415 }
416}
417
418impl std::str::FromStr for RouteTarget {
419 type Err = ParseRouteTargetError;
420
421 fn from_str(s: &str) -> Result<Self, Self::Err> {
422 match s {
423 "health" => Ok(RouteTarget::Health),
424 "export" => Ok(RouteTarget::Export),
425 "import" => Ok(RouteTarget::Import),
426 "rpc" => Ok(RouteTarget::Rpc),
427 "version" => Ok(RouteTarget::Version),
428 "sync" => Ok(RouteTarget::Sync),
429 "sql" => Ok(RouteTarget::Sql),
430 "signin" => Ok(RouteTarget::Signin),
431 "signup" => Ok(RouteTarget::Signup),
432 "key" => Ok(RouteTarget::Key),
433 "ml" => Ok(RouteTarget::Ml),
434 "graphql" => Ok(RouteTarget::GraphQL),
435 "api" => Ok(RouteTarget::Api),
436 _ => Err(ParseRouteTargetError),
437 }
438 }
439}
440
441#[derive(Debug, Clone, Hash, Eq, PartialEq)]
442#[non_exhaustive]
443pub enum ArbitraryQueryTarget {
444 Guest,
445 Record,
446 System,
447}
448
449impl fmt::Display for ArbitraryQueryTarget {
450 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451 match self {
452 Self::Guest => write!(f, "guest"),
453 Self::Record => write!(f, "record"),
454 Self::System => write!(f, "system"),
455 }
456 }
457}
458
459impl<'a> From<&'a Level> for ArbitraryQueryTarget {
460 fn from(level: &'a Level) -> Self {
461 match level {
462 Level::No => ArbitraryQueryTarget::Guest,
463 Level::Root => ArbitraryQueryTarget::System,
464 Level::Namespace(_) => ArbitraryQueryTarget::System,
465 Level::Database(_, _) => ArbitraryQueryTarget::System,
466 Level::Record(_, _, _) => ArbitraryQueryTarget::Record,
467 }
468 }
469}
470
471impl<'a> From<&'a Auth> for ArbitraryQueryTarget {
472 fn from(auth: &'a Auth) -> Self {
473 auth.level().into()
474 }
475}
476
477impl Target for ArbitraryQueryTarget {
478 fn matches(&self, elem: &ArbitraryQueryTarget) -> bool {
479 self == elem
480 }
481}
482
483impl Target<str> for ArbitraryQueryTarget {
484 fn matches(&self, elem: &str) -> bool {
485 match self {
486 Self::Guest => elem.eq_ignore_ascii_case("guest"),
487 Self::Record => elem.eq_ignore_ascii_case("record"),
488 Self::System => elem.eq_ignore_ascii_case("system"),
489 }
490 }
491}
492
493#[derive(Debug, Clone)]
494pub enum ParseArbitraryQueryTargetError {
495 InvalidName,
496}
497
498impl std::error::Error for ParseArbitraryQueryTargetError {}
499impl fmt::Display for ParseArbitraryQueryTargetError {
500 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
501 match *self {
502 ParseArbitraryQueryTargetError::InvalidName => {
503 write!(f, "invalid query target name")
504 }
505 }
506 }
507}
508
509impl std::str::FromStr for ArbitraryQueryTarget {
510 type Err = ParseArbitraryQueryTargetError;
511
512 fn from_str(s: &str) -> Result<Self, Self::Err> {
513 match s.trim().to_lowercase().as_str() {
514 "guest" => Ok(ArbitraryQueryTarget::Guest),
515 "record" => Ok(ArbitraryQueryTarget::Record),
516 "system" => Ok(ArbitraryQueryTarget::System),
517 _ => Err(ParseArbitraryQueryTargetError::InvalidName),
518 }
519 }
520}
521
522#[derive(Debug, Clone, Eq, PartialEq)]
523#[non_exhaustive]
524pub enum Targets<T: Hash + Eq + PartialEq> {
525 None,
526 Some(HashSet<T>),
527 All,
528}
529
530impl<T: Target + Hash + Eq + PartialEq> From<T> for Targets<T> {
531 fn from(t: T) -> Self {
532 let mut set = HashSet::new();
533 set.insert(t);
534 Self::Some(set)
535 }
536}
537
538impl<T: Hash + Eq + PartialEq + fmt::Debug + fmt::Display> Targets<T> {
539 pub(crate) fn matches<S>(&self, elem: &S) -> bool
540 where
541 S: ?Sized,
542 T: Target<S>,
543 {
544 match self {
545 Self::None => false,
546 Self::All => true,
547 Self::Some(targets) => targets.iter().any(|t| t.matches(elem)),
548 }
549 }
550}
551
552impl<T: Target + Hash + Eq + PartialEq + fmt::Display> fmt::Display for Targets<T> {
553 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
554 match self {
555 Self::None => write!(f, "none"),
556 Self::All => write!(f, "all"),
557 Self::Some(targets) => {
558 let targets =
559 targets.iter().map(|t| t.to_string()).collect::<Vec<String>>().join(", ");
560 write!(f, "{}", targets)
561 }
562 }
563 }
564}
565
566#[derive(Debug, Clone)]
567#[non_exhaustive]
568pub struct Capabilities {
569 scripting: bool,
570 guest_access: bool,
571 live_query_notifications: bool,
572
573 allow_funcs: Targets<FuncTarget>,
574 deny_funcs: Targets<FuncTarget>,
575 allow_net: Targets<NetTarget>,
576 deny_net: Targets<NetTarget>,
577 allow_rpc: Targets<MethodTarget>,
578 deny_rpc: Targets<MethodTarget>,
579 allow_http: Targets<RouteTarget>,
580 deny_http: Targets<RouteTarget>,
581 allow_experimental: Targets<ExperimentalTarget>,
582 deny_experimental: Targets<ExperimentalTarget>,
583 allow_arbitrary_query: Targets<ArbitraryQueryTarget>,
584 deny_arbitrary_query: Targets<ArbitraryQueryTarget>,
585}
586
587impl fmt::Display for Capabilities {
588 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
589 write!(
590 f,
591 "scripting={}, guest_access={}, live_query_notifications={}, allow_funcs={}, deny_funcs={}, allow_net={}, deny_net={}, allow_rpc={}, deny_rpc={}, allow_http={}, deny_http={}, allow_experimental={}, deny_experimental={}, allow_arbitrary_query={}, deny_arbitrary_query={}",
592 self.scripting, self.guest_access, self.live_query_notifications, self.allow_funcs, self.deny_funcs, self.allow_net, self.deny_net, self.allow_rpc, self.deny_rpc, self.allow_http, self.deny_http, self.allow_experimental, self.deny_experimental, self.allow_arbitrary_query, self.deny_arbitrary_query,
593 )
594 }
595}
596
597impl Default for Capabilities {
598 fn default() -> Self {
599 Self {
600 scripting: false,
601 guest_access: false,
602 live_query_notifications: true,
603
604 allow_funcs: Targets::All,
605 deny_funcs: Targets::None,
606 allow_net: Targets::None,
607 deny_net: Targets::None,
608 allow_rpc: Targets::All,
609 deny_rpc: Targets::None,
610 allow_http: Targets::All,
611 deny_http: Targets::None,
612 allow_experimental: Targets::None,
613 deny_experimental: Targets::None,
614 allow_arbitrary_query: Targets::All,
615 deny_arbitrary_query: Targets::None,
616 }
617 }
618}
619
620impl Capabilities {
621 pub fn all() -> Self {
622 Self {
623 scripting: true,
624 guest_access: true,
625 live_query_notifications: true,
626
627 allow_funcs: Targets::All,
628 deny_funcs: Targets::None,
629 allow_net: Targets::All,
630 deny_net: Targets::None,
631 allow_rpc: Targets::All,
632 deny_rpc: Targets::None,
633 allow_http: Targets::All,
634 deny_http: Targets::None,
635 allow_experimental: Targets::None,
636 deny_experimental: Targets::None,
637 allow_arbitrary_query: Targets::All,
638 deny_arbitrary_query: Targets::None,
639 }
640 }
641
642 pub fn none() -> Self {
643 Self {
644 scripting: false,
645 guest_access: false,
646 live_query_notifications: false,
647
648 allow_funcs: Targets::None,
649 deny_funcs: Targets::None,
650 allow_net: Targets::None,
651 deny_net: Targets::None,
652 allow_rpc: Targets::None,
653 deny_rpc: Targets::None,
654 allow_http: Targets::None,
655 deny_http: Targets::None,
656 allow_experimental: Targets::None,
657 deny_experimental: Targets::None,
658 allow_arbitrary_query: Targets::None,
659 deny_arbitrary_query: Targets::None,
660 }
661 }
662
663 pub fn with_scripting(mut self, scripting: bool) -> Self {
664 self.scripting = scripting;
665 self
666 }
667
668 pub fn with_guest_access(mut self, guest_access: bool) -> Self {
669 self.guest_access = guest_access;
670 self
671 }
672
673 pub fn with_live_query_notifications(mut self, live_query_notifications: bool) -> Self {
674 self.live_query_notifications = live_query_notifications;
675 self
676 }
677
678 pub fn with_functions(mut self, allow_funcs: Targets<FuncTarget>) -> Self {
679 self.allow_funcs = allow_funcs;
680 self
681 }
682
683 pub fn allowed_functions_mut(&mut self) -> &mut Targets<FuncTarget> {
684 &mut self.allow_funcs
685 }
686
687 pub fn without_functions(mut self, deny_funcs: Targets<FuncTarget>) -> Self {
688 self.deny_funcs = deny_funcs;
689 self
690 }
691
692 pub fn denied_functions_mut(&mut self) -> &mut Targets<FuncTarget> {
693 &mut self.deny_funcs
694 }
695
696 pub fn with_experimental(mut self, allow_experimental: Targets<ExperimentalTarget>) -> Self {
697 self.allow_experimental = allow_experimental;
698 self
699 }
700
701 pub fn allowed_experimental_features_mut(&mut self) -> &mut Targets<ExperimentalTarget> {
702 &mut self.allow_experimental
703 }
704
705 pub fn without_experimental(mut self, deny_experimental: Targets<ExperimentalTarget>) -> Self {
706 self.deny_experimental = deny_experimental;
707 self
708 }
709
710 pub fn denied_experimental_features_mut(&mut self) -> &mut Targets<ExperimentalTarget> {
711 &mut self.deny_experimental
712 }
713
714 pub fn with_arbitrary_query(
715 mut self,
716 allow_arbitrary_query: Targets<ArbitraryQueryTarget>,
717 ) -> Self {
718 self.allow_arbitrary_query = allow_arbitrary_query;
719 self
720 }
721
722 pub fn without_arbitrary_query(
723 mut self,
724 deny_arbitrary_query: Targets<ArbitraryQueryTarget>,
725 ) -> Self {
726 self.deny_arbitrary_query = deny_arbitrary_query;
727 self
728 }
729
730 pub fn with_network_targets(mut self, allow_net: Targets<NetTarget>) -> Self {
731 self.allow_net = allow_net;
732 self
733 }
734
735 pub fn allowed_network_targets_mut(&mut self) -> &mut Targets<NetTarget> {
736 &mut self.allow_net
737 }
738
739 pub fn without_network_targets(mut self, deny_net: Targets<NetTarget>) -> Self {
740 self.deny_net = deny_net;
741 self
742 }
743
744 pub fn denied_network_targets_mut(&mut self) -> &mut Targets<NetTarget> {
745 &mut self.deny_net
746 }
747
748 pub fn with_rpc_methods(mut self, allow_rpc: Targets<MethodTarget>) -> Self {
749 self.allow_rpc = allow_rpc;
750 self
751 }
752
753 pub fn without_rpc_methods(mut self, deny_rpc: Targets<MethodTarget>) -> Self {
754 self.deny_rpc = deny_rpc;
755 self
756 }
757
758 pub fn with_http_routes(mut self, allow_http: Targets<RouteTarget>) -> Self {
759 self.allow_http = allow_http;
760 self
761 }
762
763 pub fn without_http_routes(mut self, deny_http: Targets<RouteTarget>) -> Self {
764 self.deny_http = deny_http;
765 self
766 }
767
768 pub fn allows_scripting(&self) -> bool {
769 self.scripting
770 }
771
772 pub fn allows_guest_access(&self) -> bool {
773 self.guest_access
774 }
775
776 pub fn allows_live_query_notifications(&self) -> bool {
777 self.live_query_notifications
778 }
779
780 pub fn allows_function_name(&self, target: &str) -> bool {
781 self.allow_funcs.matches(target) && !self.deny_funcs.matches(target)
782 }
783
784 pub fn allows_experimental(&self, target: &ExperimentalTarget) -> bool {
785 self.allow_experimental.matches(target) && !self.deny_experimental.matches(target)
786 }
787
788 pub fn allows_experimental_name(&self, target: &str) -> bool {
789 self.allow_experimental.matches(target) && !self.deny_experimental.matches(target)
790 }
791
792 pub fn allows_query(&self, target: &ArbitraryQueryTarget) -> bool {
793 self.allow_arbitrary_query.matches(target) && !self.deny_arbitrary_query.matches(target)
794 }
795
796 pub fn allows_network_target(&self, target: &NetTarget) -> bool {
797 self.allow_net.matches(target) && !self.deny_net.matches(target)
798 }
799
800 #[cfg(feature = "http")]
801 pub(crate) fn matches_any_allow_net(&self, target: &NetTarget) -> bool {
802 self.allow_net.matches(target)
803 }
804
805 #[cfg(feature = "http")]
806 pub(crate) fn matches_any_deny_net(&self, target: &NetTarget) -> bool {
807 self.deny_net.matches(target)
808 }
809
810 pub fn allows_rpc_method(&self, target: &MethodTarget) -> bool {
811 self.allow_rpc.matches(target) && !self.deny_rpc.matches(target)
812 }
813
814 pub fn allows_http_route(&self, target: &RouteTarget) -> bool {
815 self.allow_http.matches(target) && !self.deny_http.matches(target)
816 }
817}
818
819#[cfg(test)]
820mod tests {
821 use std::str::FromStr;
822 use test_log::test;
823
824 use super::*;
825
826 #[test]
827 fn test_invalid_func_target() {
828 FuncTarget::from_str("te::*st").unwrap_err();
829 FuncTarget::from_str("\0::st").unwrap_err();
830 FuncTarget::from_str("").unwrap_err();
831 FuncTarget::from_str("❤️").unwrap_err();
832 }
833
834 #[test]
835 fn test_func_target() {
836 assert!(FuncTarget::from_str("test").unwrap().matches("test"));
837 assert!(!FuncTarget::from_str("test").unwrap().matches("test2"));
838
839 assert!(!FuncTarget::from_str("test::").unwrap().matches("test"));
840
841 assert!(FuncTarget::from_str("test::*").unwrap().matches("test::name"));
842 assert!(!FuncTarget::from_str("test::*").unwrap().matches("test2::name"));
843
844 assert!(FuncTarget::from_str("test::name").unwrap().matches("test::name"));
845 assert!(!FuncTarget::from_str("test::name").unwrap().matches("test::name2"));
846 }
847
848 #[test]
849 fn test_net_target() {
850 assert!(NetTarget::from_str("10.0.0.0/8")
852 .unwrap()
853 .matches(&NetTarget::from_str("10.0.1.0/24").unwrap()));
854 assert!(NetTarget::from_str("10.0.0.0/8")
855 .unwrap()
856 .matches(&NetTarget::from_str("10.0.1.2").unwrap()));
857 assert!(!NetTarget::from_str("10.0.0.0/8")
858 .unwrap()
859 .matches(&NetTarget::from_str("20.0.1.0/24").unwrap()));
860 assert!(!NetTarget::from_str("10.0.0.0/8")
861 .unwrap()
862 .matches(&NetTarget::from_str("20.0.1.0").unwrap()));
863
864 assert!(NetTarget::from_str("2001:db8::1")
866 .unwrap()
867 .matches(&NetTarget::from_str("2001:db8::1").unwrap()));
868 assert!(NetTarget::from_str("2001:db8::/32")
869 .unwrap()
870 .matches(&NetTarget::from_str("2001:db8::1").unwrap()));
871 assert!(NetTarget::from_str("2001:db8::/32")
872 .unwrap()
873 .matches(&NetTarget::from_str("2001:db8:abcd:12::/64").unwrap()));
874 assert!(!NetTarget::from_str("2001:db8::/32")
875 .unwrap()
876 .matches(&NetTarget::from_str("2001:db9::1").unwrap()));
877 assert!(!NetTarget::from_str("2001:db8::/32")
878 .unwrap()
879 .matches(&NetTarget::from_str("2001:db9:abcd:12::1/64").unwrap()));
880
881 assert!(NetTarget::from_str("example.com")
883 .unwrap()
884 .matches(&NetTarget::from_str("example.com").unwrap()));
885 assert!(NetTarget::from_str("example.com")
886 .unwrap()
887 .matches(&NetTarget::from_str("example.com:80").unwrap()));
888 assert!(!NetTarget::from_str("example.com")
889 .unwrap()
890 .matches(&NetTarget::from_str("www.example.com").unwrap()));
891 assert!(!NetTarget::from_str("example.com")
892 .unwrap()
893 .matches(&NetTarget::from_str("www.example.com:80").unwrap()));
894 assert!(NetTarget::from_str("example.com:80")
895 .unwrap()
896 .matches(&NetTarget::from_str("example.com:80").unwrap()));
897 assert!(!NetTarget::from_str("example.com:80")
898 .unwrap()
899 .matches(&NetTarget::from_str("example.com:443").unwrap()));
900 assert!(!NetTarget::from_str("example.com:80")
901 .unwrap()
902 .matches(&NetTarget::from_str("example.com").unwrap()));
903
904 assert!(
906 NetTarget::from_str("127.0.0.1")
907 .unwrap()
908 .matches(&NetTarget::from_str("127.0.0.1").unwrap()),
909 "Host IPv4 without port matches itself"
910 );
911 assert!(
912 NetTarget::from_str("127.0.0.1")
913 .unwrap()
914 .matches(&NetTarget::from_str("127.0.0.1:80").unwrap()),
915 "Host IPv4 without port matches Host IPv4 with port"
916 );
917 assert!(
918 NetTarget::from_str("10.0.0.0/8")
919 .unwrap()
920 .matches(&NetTarget::from_str("10.0.0.1:80").unwrap()),
921 "IPv4 network matches Host IPv4 with port"
922 );
923 assert!(
924 NetTarget::from_str("127.0.0.1:80")
925 .unwrap()
926 .matches(&NetTarget::from_str("127.0.0.1:80").unwrap()),
927 "Host IPv4 with port matches itself"
928 );
929 assert!(
930 !NetTarget::from_str("127.0.0.1:80")
931 .unwrap()
932 .matches(&NetTarget::from_str("127.0.0.1").unwrap()),
933 "Host IPv4 with port does not match Host IPv4 without port"
934 );
935 assert!(
936 !NetTarget::from_str("127.0.0.1:80")
937 .unwrap()
938 .matches(&NetTarget::from_str("127.0.0.1:443").unwrap()),
939 "Host IPv4 with port does not match Host IPv4 with different port"
940 );
941
942 assert!(
944 NetTarget::from_str("[2001:db8::1]")
945 .unwrap()
946 .matches(&NetTarget::from_str("[2001:db8::1]").unwrap()),
947 "Host IPv6 without port matches itself"
948 );
949 assert!(
950 NetTarget::from_str("[2001:db8::1]")
951 .unwrap()
952 .matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
953 "Host IPv6 without port matches Host IPv6 with port"
954 );
955 assert!(
956 NetTarget::from_str("2001:db8::1")
957 .unwrap()
958 .matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
959 "IPv6 addr matches Host IPv6 with port"
960 );
961 assert!(
962 NetTarget::from_str("2001:db8::/64")
963 .unwrap()
964 .matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
965 "IPv6 network matches Host IPv6 with port"
966 );
967 assert!(
968 NetTarget::from_str("[2001:db8::1]:80")
969 .unwrap()
970 .matches(&NetTarget::from_str("[2001:db8::1]:80").unwrap()),
971 "Host IPv6 with port matches itself"
972 );
973 assert!(
974 !NetTarget::from_str("[2001:db8::1]:80")
975 .unwrap()
976 .matches(&NetTarget::from_str("[2001:db8::1]").unwrap()),
977 "Host IPv6 with port does not match Host IPv6 without port"
978 );
979 assert!(
980 !NetTarget::from_str("[2001:db8::1]:80")
981 .unwrap()
982 .matches(&NetTarget::from_str("[2001:db8::1]:443").unwrap()),
983 "Host IPv6 with port does not match Host IPv6 with different port"
984 );
985
986 assert!(NetTarget::from_str("exam^ple.com").is_err());
988 assert!(NetTarget::from_str("example.com:80:80").is_err());
989 assert!(NetTarget::from_str("11111.3.4.5").is_err());
990 assert!(NetTarget::from_str("2001:db8::1/129").is_err());
991 assert!(NetTarget::from_str("[2001:db8::1").is_err());
992 }
993
994 #[tokio::test]
995 #[cfg(all(not(target_family = "wasm"), feature = "http"))]
996 async fn test_net_target_resolve_async() {
997 let r = NetTarget::from_str("localhost").unwrap().resolve().await.unwrap();
998 assert!(r.contains(&NetTarget::from_str("127.0.0.1").unwrap()));
999 assert!(r.contains(&NetTarget::from_str("::1/128").unwrap()));
1000 }
1001
1002 #[test]
1003 #[cfg(all(target_family = "wasm", feature = "http"))]
1004 fn test_net_target_resolve_sync() {
1005 let r = NetTarget::from_str("localhost").unwrap().resolve().unwrap();
1006 assert!(r.contains(&NetTarget::from_str("127.0.0.1").unwrap()));
1007 assert!(r.contains(&NetTarget::from_str("::1/128").unwrap()));
1008 }
1009
1010 #[test]
1011 fn test_method_target() {
1012 assert!(MethodTarget::from_str("query")
1013 .unwrap()
1014 .matches(&MethodTarget::from_str("query").unwrap()));
1015 assert!(MethodTarget::from_str("query")
1016 .unwrap()
1017 .matches(&MethodTarget::from_str("Query").unwrap()));
1018 assert!(MethodTarget::from_str("query")
1019 .unwrap()
1020 .matches(&MethodTarget::from_str("QUERY").unwrap()));
1021 assert!(!MethodTarget::from_str("query")
1022 .unwrap()
1023 .matches(&MethodTarget::from_str("ping").unwrap()));
1024 }
1025
1026 #[test]
1027 fn test_targets() {
1028 assert!(Targets::<NetTarget>::All.matches(&NetTarget::from_str("example.com").unwrap()));
1029 assert!(Targets::<FuncTarget>::All.matches("http::get"));
1030 assert!(!Targets::<NetTarget>::None.matches(&NetTarget::from_str("example.com").unwrap()));
1031 assert!(!Targets::<FuncTarget>::None.matches("http::get"));
1032 assert!(Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
1033 .matches(&NetTarget::from_str("example.com").unwrap()));
1034 assert!(!Targets::<NetTarget>::Some([NetTarget::from_str("example.com").unwrap()].into())
1035 .matches(&NetTarget::from_str("www.example.com").unwrap()));
1036 assert!(Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
1037 .matches("http::get"));
1038 assert!(!Targets::<FuncTarget>::Some([FuncTarget::from_str("http::get").unwrap()].into())
1039 .matches("http::post"));
1040 }
1041
1042 #[test]
1043 fn test_capabilities() {
1044 {
1046 let caps = Capabilities::default().with_scripting(true);
1047 assert!(caps.allows_scripting());
1048 }
1049
1050 {
1052 let caps = Capabilities::default().with_scripting(false);
1053 assert!(!caps.allows_scripting());
1054 }
1055
1056 {
1058 let caps = Capabilities::default().with_guest_access(true);
1059 assert!(caps.allows_guest_access());
1060 }
1061
1062 {
1064 let caps = Capabilities::default().with_guest_access(false);
1065 assert!(!caps.allows_guest_access());
1066 }
1067
1068 {
1070 let cap = Capabilities::default().with_live_query_notifications(true);
1071 assert!(cap.allows_live_query_notifications());
1072 }
1073
1074 {
1076 let cap = Capabilities::default().with_live_query_notifications(false);
1077 assert!(!cap.allows_live_query_notifications());
1078 }
1079
1080 {
1082 let caps = Capabilities::default()
1083 .with_network_targets(Targets::<NetTarget>::All)
1084 .without_network_targets(Targets::<NetTarget>::None);
1085 assert!(caps.allows_network_target(&NetTarget::from_str("example.com").unwrap()));
1086 assert!(caps.allows_network_target(&NetTarget::from_str("example.com:80").unwrap()));
1087 }
1088
1089 {
1091 let caps = Capabilities::default()
1092 .with_network_targets(Targets::<NetTarget>::All)
1093 .without_network_targets(Targets::<NetTarget>::All);
1094 assert!(!caps.allows_network_target(&NetTarget::from_str("example.com").unwrap()));
1095 assert!(!caps.allows_network_target(&NetTarget::from_str("example.com:80").unwrap()));
1096 }
1097
1098 {
1100 let caps = Capabilities::default()
1101 .with_network_targets(Targets::<NetTarget>::Some(
1102 [NetTarget::from_str("example.com").unwrap()].into(),
1103 ))
1104 .without_network_targets(Targets::<NetTarget>::Some(
1105 [NetTarget::from_str("example.com:80").unwrap()].into(),
1106 ));
1107 assert!(caps.allows_network_target(&NetTarget::from_str("example.com").unwrap()));
1108 assert!(caps.allows_network_target(&NetTarget::from_str("example.com:443").unwrap()));
1109 assert!(!caps.allows_network_target(&NetTarget::from_str("example.com:80").unwrap()));
1110 }
1111
1112 {
1114 let caps = Capabilities::default()
1115 .with_functions(Targets::<FuncTarget>::All)
1116 .without_functions(Targets::<FuncTarget>::None);
1117 assert!(caps.allows_function_name("http::get"));
1118 assert!(caps.allows_function_name("http::post"));
1119 }
1120
1121 {
1123 let caps = Capabilities::default()
1124 .with_functions(Targets::<FuncTarget>::All)
1125 .without_functions(Targets::<FuncTarget>::All);
1126 assert!(!caps.allows_function_name("http::get"));
1127 assert!(!caps.allows_function_name("http::post"));
1128 }
1129
1130 {
1132 let caps = Capabilities::default()
1133 .with_functions(Targets::<FuncTarget>::Some(
1134 [FuncTarget::from_str("http::*").unwrap()].into(),
1135 ))
1136 .without_functions(Targets::<FuncTarget>::Some(
1137 [FuncTarget::from_str("http::post").unwrap()].into(),
1138 ));
1139 assert!(caps.allows_function_name("http::get"));
1140 assert!(caps.allows_function_name("http::put"));
1141 assert!(!caps.allows_function_name("http::post"));
1142 }
1143
1144 {
1146 let caps = Capabilities::default()
1147 .with_rpc_methods(Targets::<MethodTarget>::All)
1148 .without_rpc_methods(Targets::<MethodTarget>::None);
1149 assert!(caps.allows_rpc_method(&MethodTarget::from_str("ping").unwrap()));
1150 assert!(caps.allows_rpc_method(&MethodTarget::from_str("select").unwrap()));
1151 assert!(caps.allows_rpc_method(&MethodTarget::from_str("query").unwrap()));
1152 }
1153
1154 {
1156 let caps = Capabilities::default()
1157 .with_rpc_methods(Targets::<MethodTarget>::All)
1158 .without_rpc_methods(Targets::<MethodTarget>::All);
1159 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("ping").unwrap()));
1160 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("select").unwrap()));
1161 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("query").unwrap()));
1162 }
1163
1164 {
1166 let caps = Capabilities::default().without_rpc_methods(Targets::<MethodTarget>::All);
1167 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("ping").unwrap()));
1168 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("select").unwrap()));
1169 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("query").unwrap()));
1170 }
1171
1172 {
1174 let caps = Capabilities::default()
1175 .with_rpc_methods(Targets::<MethodTarget>::Some(
1176 [
1177 MethodTarget::from_str("select").unwrap(),
1178 MethodTarget::from_str("create").unwrap(),
1179 MethodTarget::from_str("insert").unwrap(),
1180 MethodTarget::from_str("update").unwrap(),
1181 MethodTarget::from_str("query").unwrap(),
1182 MethodTarget::from_str("run").unwrap(),
1183 ]
1184 .into(),
1185 ))
1186 .without_rpc_methods(Targets::<MethodTarget>::Some(
1187 [
1188 MethodTarget::from_str("query").unwrap(),
1189 MethodTarget::from_str("run").unwrap(),
1190 ]
1191 .into(),
1192 ));
1193
1194 assert!(caps.allows_rpc_method(&MethodTarget::from_str("select").unwrap()));
1195 assert!(caps.allows_rpc_method(&MethodTarget::from_str("create").unwrap()));
1196 assert!(caps.allows_rpc_method(&MethodTarget::from_str("insert").unwrap()));
1197 assert!(caps.allows_rpc_method(&MethodTarget::from_str("update").unwrap()));
1198 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("query").unwrap()));
1199 assert!(!caps.allows_rpc_method(&MethodTarget::from_str("run").unwrap()));
1200 }
1201
1202 {
1204 let caps = Capabilities::default()
1205 .with_http_routes(Targets::<RouteTarget>::All)
1206 .without_http_routes(Targets::<RouteTarget>::None);
1207 assert!(caps.allows_http_route(&RouteTarget::from_str("version").unwrap()));
1208 assert!(caps.allows_http_route(&RouteTarget::from_str("export").unwrap()));
1209 assert!(caps.allows_http_route(&RouteTarget::from_str("sql").unwrap()));
1210 }
1211
1212 {
1214 let caps = Capabilities::default()
1215 .with_http_routes(Targets::<RouteTarget>::All)
1216 .without_http_routes(Targets::<RouteTarget>::All);
1217 assert!(!caps.allows_http_route(&RouteTarget::from_str("version").unwrap()));
1218 assert!(!caps.allows_http_route(&RouteTarget::from_str("export").unwrap()));
1219 assert!(!caps.allows_http_route(&RouteTarget::from_str("sql").unwrap()));
1220 }
1221
1222 {
1224 let caps = Capabilities::default().without_http_routes(Targets::<RouteTarget>::All);
1225 assert!(!caps.allows_http_route(&RouteTarget::from_str("version").unwrap()));
1226 assert!(!caps.allows_http_route(&RouteTarget::from_str("export").unwrap()));
1227 assert!(!caps.allows_http_route(&RouteTarget::from_str("sql").unwrap()));
1228 }
1229
1230 {
1232 let caps = Capabilities::default()
1233 .with_http_routes(Targets::<RouteTarget>::Some(
1234 [
1235 RouteTarget::from_str("version").unwrap(),
1236 RouteTarget::from_str("import").unwrap(),
1237 RouteTarget::from_str("export").unwrap(),
1238 RouteTarget::from_str("key").unwrap(),
1239 RouteTarget::from_str("sql").unwrap(),
1240 RouteTarget::from_str("rpc").unwrap(),
1241 ]
1242 .into(),
1243 ))
1244 .without_http_routes(Targets::<RouteTarget>::Some(
1245 [RouteTarget::from_str("sql").unwrap(), RouteTarget::from_str("rpc").unwrap()]
1246 .into(),
1247 ));
1248
1249 assert!(caps.allows_http_route(&RouteTarget::from_str("version").unwrap()));
1250 assert!(caps.allows_http_route(&RouteTarget::from_str("import").unwrap()));
1251 assert!(caps.allows_http_route(&RouteTarget::from_str("export").unwrap()));
1252 assert!(caps.allows_http_route(&RouteTarget::from_str("key").unwrap()));
1253 assert!(!caps.allows_http_route(&RouteTarget::from_str("sql").unwrap()));
1254 assert!(!caps.allows_http_route(&RouteTarget::from_str("rpc").unwrap()));
1255 }
1256
1257 {
1259 let caps = Capabilities::default()
1260 .with_arbitrary_query(Targets::<ArbitraryQueryTarget>::All)
1261 .without_arbitrary_query(Targets::<ArbitraryQueryTarget>::None);
1262 assert!(caps.allows_query(&ArbitraryQueryTarget::from_str("guest").unwrap()));
1263 assert!(caps.allows_query(&ArbitraryQueryTarget::from_str("record").unwrap()));
1264 assert!(caps.allows_query(&ArbitraryQueryTarget::from_str("system").unwrap()));
1265 }
1266
1267 {
1269 let caps = Capabilities::default()
1270 .with_arbitrary_query(Targets::<ArbitraryQueryTarget>::All)
1271 .without_arbitrary_query(Targets::<ArbitraryQueryTarget>::All);
1272 assert!(!caps.allows_query(&ArbitraryQueryTarget::from_str("guest").unwrap()));
1273 assert!(!caps.allows_query(&ArbitraryQueryTarget::from_str("record").unwrap()));
1274 assert!(!caps.allows_query(&ArbitraryQueryTarget::from_str("system").unwrap()));
1275 }
1276
1277 {
1279 let caps = Capabilities::default()
1280 .without_arbitrary_query(Targets::<ArbitraryQueryTarget>::All);
1281 assert!(!caps.allows_query(&ArbitraryQueryTarget::from_str("guest").unwrap()));
1282 assert!(!caps.allows_query(&ArbitraryQueryTarget::from_str("record").unwrap()));
1283 assert!(!caps.allows_query(&ArbitraryQueryTarget::from_str("system").unwrap()));
1284 }
1285
1286 {
1288 let caps = Capabilities::default()
1289 .with_arbitrary_query(Targets::<ArbitraryQueryTarget>::Some(
1290 [
1291 ArbitraryQueryTarget::from_str("guest").unwrap(),
1292 ArbitraryQueryTarget::from_str("record").unwrap(),
1293 ]
1294 .into(),
1295 ))
1296 .without_arbitrary_query(Targets::<ArbitraryQueryTarget>::Some(
1297 [
1298 ArbitraryQueryTarget::from_str("record").unwrap(),
1299 ArbitraryQueryTarget::from_str("system").unwrap(),
1300 ]
1301 .into(),
1302 ));
1303
1304 assert!(caps.allows_query(&ArbitraryQueryTarget::from_str("guest").unwrap()));
1305 assert!(!caps.allows_query(&ArbitraryQueryTarget::from_str("record").unwrap()));
1306 assert!(!caps.allows_query(&ArbitraryQueryTarget::from_str("system").unwrap()));
1307 }
1308 }
1309}