1mod cidr;
2mod config;
3mod error;
4
5use std::{
6 collections::HashSet,
7 fmt,
8 net::IpAddr,
9 str::FromStr,
10 sync::{Arc, RwLock},
11};
12
13use cidr::{ParsedCidr, is_sensitive_ip_literal};
14pub use config::{
15 AllowedPropagationTarget, BearerPropagationPolicy, PropagatedTokenValidationConfig,
16 PropagationDestinationPolicy, PropagationScheme, TokenPropagatorConfig,
17};
18pub use error::{TokenPropagatorError, TokenPropagatorResult};
19use http::header::{AUTHORIZATION, HeaderMap, HeaderValue};
20use securitydept_oauth_resource_server::ResourceTokenPrincipal;
21use url::Url;
22
23pub const DEFAULT_PROPAGATION_HEADER_NAME: &str = "x-securitydept-propagation";
24
25pub trait PropagationNodeTargetResolver: fmt::Debug + Send + Sync {
26 fn resolve_url(&self, node_id: &str) -> Option<Url>;
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct PropagationRequestTarget {
32 pub node_id: Option<String>,
34 pub scheme: Option<PropagationScheme>,
36 pub hostname: Option<String>,
38 pub port: Option<u16>,
40}
41
42impl PropagationRequestTarget {
43 pub fn new(
44 node_id: Option<String>,
45 scheme: PropagationScheme,
46 hostname: impl Into<String>,
47 port: impl Into<Option<u16>>,
48 ) -> Self {
49 Self {
50 node_id,
51 scheme: Some(scheme),
52 hostname: Some(normalize_host(&hostname.into())),
53 port: port.into(),
54 }
55 }
56
57 pub fn for_node(node_id: impl Into<String>) -> Self {
58 Self {
59 node_id: Some(node_id.into()),
60 scheme: None,
61 hostname: None,
62 port: None,
63 }
64 }
65
66 pub fn from_url(node_id: Option<String>, url: &Url) -> TokenPropagatorResult<Self> {
67 let scheme = parse_scheme(url.scheme())?;
68 let hostname = url
69 .host_str()
70 .ok_or_else(|| TokenPropagatorError::InvalidTargetHost {
71 host: String::new(),
72 })?;
73 let port = url.port();
74
75 Ok(Self::new(node_id, scheme, hostname, port))
76 }
77
78 fn display(&self) -> String {
79 match (&self.scheme, &self.hostname, self.port) {
80 (Some(scheme), Some(hostname), Some(port)) => {
81 format!("{}://{}:{port}", scheme.as_str(), hostname)
82 }
83 (Some(scheme), Some(hostname), None) => {
84 format!(
85 "{}://{}:{}",
86 scheme.as_str(),
87 hostname,
88 scheme.default_port()
89 )
90 }
91 (None, None, None) => self
92 .node_id
93 .as_ref()
94 .map(|node_id| format!("node:{node_id}"))
95 .unwrap_or_else(|| "incomplete-target".to_string()),
96 _ => "incomplete-target".to_string(),
97 }
98 }
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
102struct ResolvedPropagationTarget {
103 node_id: Option<String>,
104 scheme: PropagationScheme,
105 hostname: String,
106 port: u16,
107}
108
109impl ResolvedPropagationTarget {
110 fn display(&self) -> String {
111 format!("{}://{}:{}", self.scheme.as_str(), self.hostname, self.port)
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
122pub struct PropagationDirective {
123 pub by: Option<String>,
124 pub r#for: Option<String>,
125 pub hostname: String,
126 pub port: Option<u16>,
127 pub proto: PropagationScheme,
128}
129
130impl PropagationDirective {
131 pub fn parse(value: &str) -> TokenPropagatorResult<Self> {
132 let mut by = None;
133 let mut for_identifier = None;
134 let mut hostname = None;
135 let mut port = None;
136 let mut proto = None;
137
138 for part in value.split(';') {
139 let part = part.trim();
140 if part.is_empty() {
141 continue;
142 }
143
144 let Some((raw_key, raw_value)) = part.split_once('=') else {
145 return Err(TokenPropagatorError::InvalidPropagationDirective {
146 message: format!("invalid propagation directive segment `{part}`"),
147 });
148 };
149 let key = raw_key.trim().to_ascii_lowercase();
150 let value = trim_quoted_value(raw_value.trim());
151
152 match key.as_str() {
153 "by" => by = Some(parse_directive_identifier("by", value)?),
154 "for" => for_identifier = Some(parse_directive_identifier("for", value)?),
155 "host" => {
156 let (parsed_hostname, parsed_port) = parse_directive_host(value)?;
157 hostname = Some(parsed_hostname);
158 port = parsed_port;
159 }
160 "proto" => proto = Some(parse_scheme(value)?),
161 _ => {
162 return Err(TokenPropagatorError::InvalidPropagationDirective {
163 message: format!("unsupported propagation directive field `{key}`"),
164 });
165 }
166 }
167 }
168
169 let hostname =
170 hostname.ok_or_else(|| TokenPropagatorError::InvalidPropagationDirective {
171 message: "propagation directive requires `host`".to_string(),
172 })?;
173 let proto = proto.ok_or_else(|| TokenPropagatorError::InvalidPropagationDirective {
174 message: "propagation directive requires `proto`".to_string(),
175 })?;
176
177 Ok(Self {
178 by,
179 r#for: for_identifier,
180 hostname,
181 port,
182 proto,
183 })
184 }
185
186 pub fn from_header_value(value: &HeaderValue) -> TokenPropagatorResult<Self> {
187 let value =
188 value
189 .to_str()
190 .map_err(|_| TokenPropagatorError::InvalidPropagationDirective {
191 message: "propagation header value must be valid ASCII".to_string(),
192 })?;
193
194 Self::parse(value)
195 }
196
197 pub fn to_header_value(&self) -> TokenPropagatorResult<HeaderValue> {
198 let mut segments = Vec::new();
199
200 if let Some(by) = &self.by {
201 segments.push(format!("by={by}"));
202 }
203 if let Some(for_identifier) = &self.r#for {
204 segments.push(format!("for={for_identifier}"));
205 }
206 let host = match self.port {
207 Some(port) => format!("{}:{port}", self.hostname),
208 None => self.hostname.clone(),
209 };
210 segments.push(format!("host={host}"));
211 segments.push(format!("proto={}", self.proto.as_str()));
212
213 HeaderValue::from_str(&segments.join(";"))
214 .map_err(|source| TokenPropagatorError::InvalidHeaderValue { source })
215 }
216
217 pub fn to_request_target(&self) -> PropagationRequestTarget {
218 PropagationRequestTarget::new(
219 self.r#for.clone(),
220 self.proto.clone(),
221 self.hostname.clone(),
222 self.port,
223 )
224 }
225}
226
227#[derive(Debug, Clone, Copy)]
229pub struct PropagatedBearer<'a> {
230 pub access_token: &'a str,
231 pub resource_token_principal: Option<&'a ResourceTokenPrincipal>,
232}
233
234impl<'a> PropagatedBearer<'a> {
235 pub fn authorization_value(&self) -> String {
236 format!("Bearer {}", self.access_token)
237 }
238}
239
240#[derive(Debug, Clone)]
241pub struct TokenPropagator {
242 default_policy: BearerPropagationPolicy,
243 destination_policy: PropagationDestinationPolicy,
244 token_validation: PropagatedTokenValidationConfig,
245 node_target_resolver: Arc<RwLock<Option<Arc<dyn PropagationNodeTargetResolver>>>>,
246}
247
248impl TokenPropagatorConfig {
249 pub fn validate(&self) -> TokenPropagatorResult<()> {
250 for target in &self.destination_policy.allowed_targets {
251 match target {
252 AllowedPropagationTarget::ExactOrigin { hostname, port, .. } => {
253 validate_host(hostname)?;
254 validate_port(*port)?;
255 }
256 AllowedPropagationTarget::DomainSuffix {
257 domain_suffix,
258 port,
259 ..
260 } => {
261 let normalized = normalize_host(domain_suffix);
262 if normalized.is_empty() || normalized.parse::<IpAddr>().is_ok() {
263 return Err(TokenPropagatorError::PropagatorConfig {
264 message: format!(
265 "domain propagation target `{domain_suffix}` must be a non-IP \
266 domain suffix"
267 ),
268 });
269 }
270 if normalized.contains('*') {
271 return Err(TokenPropagatorError::PropagatorConfig {
272 message: format!(
273 "domain propagation target `{domain_suffix}` must not contain \
274 wildcards"
275 ),
276 });
277 }
278 validate_port(*port)?;
279 }
280 AllowedPropagationTarget::DomainRegex {
281 domain_regex, port, ..
282 } => validate_domain_regex_target(domain_regex, *port)?,
283 AllowedPropagationTarget::Cidr { cidr, port, .. } => {
284 if ParsedCidr::parse(cidr).is_none() {
285 return Err(TokenPropagatorError::InvalidCidr { cidr: cidr.clone() });
286 }
287 validate_port(*port)?;
288 }
289 }
290 }
291
292 Ok(())
293 }
294}
295
296impl PropagationScheme {
297 pub fn default_port(&self) -> u16 {
298 match self {
299 Self::Https => 443,
300 Self::Http => 80,
301 }
302 }
303}
304
305impl TokenPropagator {
306 pub fn from_config(config: &TokenPropagatorConfig) -> TokenPropagatorResult<Self> {
307 Self::from_config_with_node_target_resolver(config, None)
308 }
309
310 pub fn from_config_with_node_target_resolver(
311 config: &TokenPropagatorConfig,
312 node_target_resolver: Option<Arc<dyn PropagationNodeTargetResolver>>,
313 ) -> TokenPropagatorResult<Self> {
314 config.validate()?;
315
316 Ok(Self {
317 default_policy: config.default_policy.clone(),
318 destination_policy: config.destination_policy.clone(),
319 token_validation: config.token_validation.clone(),
320 node_target_resolver: Arc::new(RwLock::new(node_target_resolver)),
321 })
322 }
323
324 pub fn policy(&self) -> &BearerPropagationPolicy {
325 &self.default_policy
326 }
327
328 pub fn resolve_policy(&self) -> BearerPropagationPolicy {
329 self.default_policy.clone()
330 }
331
332 pub fn set_node_target_resolver(
333 &self,
334 node_target_resolver: Option<Arc<dyn PropagationNodeTargetResolver>>,
335 ) {
336 let mut guard = self
337 .node_target_resolver
338 .write()
339 .expect("node target resolver lock poisoned");
340 *guard = node_target_resolver;
341 }
342
343 pub fn validate_target(
344 &self,
345 bearer: &PropagatedBearer<'_>,
346 target: &PropagationRequestTarget,
347 ) -> TokenPropagatorResult<()> {
348 self.validate_destination(target)?;
349 self.validate_token(bearer)?;
350 Ok(())
351 }
352
353 pub fn authorization_value(
354 &self,
355 bearer: &PropagatedBearer<'_>,
356 target: &PropagationRequestTarget,
357 ) -> TokenPropagatorResult<String> {
358 match self.resolve_policy() {
359 BearerPropagationPolicy::ValidateThenForward => {
360 self.validate_target(bearer, target)?;
361 Ok(bearer.authorization_value())
362 }
363 BearerPropagationPolicy::ExchangeForDownstreamToken => {
364 Err(TokenPropagatorError::UnsupportedDirectAuthorization {
365 policy: BearerPropagationPolicy::ExchangeForDownstreamToken,
366 })
367 }
368 }
369 }
370
371 pub fn authorization_header_value(
372 &self,
373 bearer: &PropagatedBearer<'_>,
374 target: &PropagationRequestTarget,
375 ) -> TokenPropagatorResult<HeaderValue> {
376 let authorization_value = self.authorization_value(bearer, target)?;
377
378 HeaderValue::from_str(&authorization_value)
379 .map_err(|source| TokenPropagatorError::InvalidHeaderValue { source })
380 }
381
382 pub fn resolve_target_origin(
383 &self,
384 target: &PropagationRequestTarget,
385 ) -> TokenPropagatorResult<String> {
386 Ok(self.resolve_target(target)?.origin())
387 }
388
389 pub fn apply_authorization_header(
390 &self,
391 bearer: &PropagatedBearer<'_>,
392 target: &PropagationRequestTarget,
393 headers: &mut HeaderMap,
394 ) -> TokenPropagatorResult<()> {
395 headers.insert(
396 AUTHORIZATION,
397 self.authorization_header_value(bearer, target)?,
398 );
399 Ok(())
400 }
401
402 fn validate_destination(&self, target: &PropagationRequestTarget) -> TokenPropagatorResult<()> {
403 let target = self.resolve_target(target)?;
404 validate_host(&target.hostname)?;
405 validate_port(target.port)?;
406
407 let matched_by_node = target.node_id.as_ref().is_some_and(|node_id| {
408 self.destination_policy
409 .allowed_node_ids
410 .iter()
411 .any(|allowed| allowed == node_id)
412 });
413
414 let host_ip = IpAddr::from_str(&target.hostname).ok();
415 let matched_by_target = self
416 .destination_policy
417 .allowed_targets
418 .iter()
419 .any(|allowed_target| match_allowed_target(allowed_target, &target, host_ip));
420
421 if !matched_by_node && !matched_by_target {
422 return Err(TokenPropagatorError::DestinationNotAllowed {
423 target: target.display(),
424 });
425 }
426
427 if self.destination_policy.deny_sensitive_ip_literals
428 && host_ip.is_some_and(is_sensitive_ip_literal)
429 && !self
430 .destination_policy
431 .allowed_targets
432 .iter()
433 .any(|allowed_target| {
434 matches!(
435 allowed_target,
436 AllowedPropagationTarget::Cidr {
437 scheme,
438 port,
439 cidr,
440 } if scheme == &target.scheme
441 && port == &target.port
442 && ParsedCidr::parse(cidr)
443 .is_some_and(|parsed| parsed.contains(host_ip.expect("checked above")))
444 )
445 })
446 {
447 return Err(TokenPropagatorError::SensitiveIpLiteralDenied {
448 host: target.hostname.clone(),
449 });
450 }
451
452 Ok(())
453 }
454
455 fn resolve_target(
456 &self,
457 target: &PropagationRequestTarget,
458 ) -> TokenPropagatorResult<ResolvedPropagationTarget> {
459 match (&target.node_id, &target.scheme, &target.hostname) {
460 (node_id, Some(scheme), Some(hostname)) => Ok(ResolvedPropagationTarget {
461 node_id: node_id.clone(),
462 scheme: scheme.clone(),
463 hostname: hostname.clone(),
464 port: target.port.unwrap_or_else(|| scheme.default_port()),
465 }),
466 (Some(node_id), None, None) => {
467 let resolver = self
468 .node_target_resolver
469 .read()
470 .expect("node target resolver lock poisoned")
471 .clone()
472 .ok_or_else(|| TokenPropagatorError::NodeTargetResolverRequired {
473 node_id: node_id.clone(),
474 })?;
475 let url = resolver.resolve_url(node_id).ok_or_else(|| {
476 TokenPropagatorError::NodeTargetUnresolved {
477 node_id: node_id.clone(),
478 }
479 })?;
480 ResolvedPropagationTarget::from_url(Some(node_id.clone()), url)
481 }
482 _ => Err(TokenPropagatorError::IncompleteTarget {
483 target: target.display(),
484 }),
485 }
486 }
487
488 fn validate_token(&self, bearer: &PropagatedBearer<'_>) -> TokenPropagatorResult<()> {
489 let requires_token_facts = !self.token_validation.required_issuers.is_empty()
490 || !self.token_validation.allowed_audiences.is_empty()
491 || !self.token_validation.required_scopes.is_empty()
492 || !self.token_validation.allowed_azp.is_empty();
493 let resource_token_principal = bearer.resource_token_principal;
494
495 if requires_token_facts && resource_token_principal.is_none() {
496 return Err(TokenPropagatorError::TokenFactsUnavailable);
497 }
498
499 if !self.token_validation.required_issuers.is_empty() {
500 let issuer = resource_token_principal
501 .and_then(|principal| principal.issuer.clone())
502 .unwrap_or_default();
503
504 if !self
505 .token_validation
506 .required_issuers
507 .iter()
508 .any(|allowed| allowed == &issuer)
509 {
510 return Err(TokenPropagatorError::TokenIssuerNotAllowed { issuer });
511 }
512 }
513
514 if !self.token_validation.allowed_audiences.is_empty() {
515 let audiences: HashSet<String> = resource_token_principal
516 .map(|principal| principal.audiences.iter().cloned().collect())
517 .unwrap_or_default();
518 let allowed_audiences: HashSet<String> = self
519 .token_validation
520 .allowed_audiences
521 .iter()
522 .cloned()
523 .collect();
524
525 if audiences.is_disjoint(&allowed_audiences) {
526 return Err(TokenPropagatorError::TokenAudienceNotAllowed);
527 }
528 }
529
530 if !self.token_validation.required_scopes.is_empty() {
531 let scopes: HashSet<String> = resource_token_principal
532 .map(|principal| principal.scopes.iter().cloned().collect())
533 .unwrap_or_default();
534
535 for required_scope in &self.token_validation.required_scopes {
536 if !scopes.contains(required_scope) {
537 return Err(TokenPropagatorError::TokenScopeMissing {
538 scope: required_scope.clone(),
539 });
540 }
541 }
542 }
543
544 if !self.token_validation.allowed_azp.is_empty() {
545 let azp = resource_token_principal
546 .and_then(|principal| principal.authorized_party.as_deref())
547 .unwrap_or_default()
548 .to_string();
549
550 if !self
551 .token_validation
552 .allowed_azp
553 .iter()
554 .any(|allowed| allowed == &azp)
555 {
556 return Err(TokenPropagatorError::TokenAzpNotAllowed { azp });
557 }
558 }
559
560 Ok(())
561 }
562}
563
564fn parse_scheme(scheme: &str) -> TokenPropagatorResult<PropagationScheme> {
565 match scheme {
566 "https" => Ok(PropagationScheme::Https),
567 "http" => Ok(PropagationScheme::Http),
568 _ => Err(TokenPropagatorError::UnsupportedTargetScheme {
569 scheme: scheme.to_string(),
570 }),
571 }
572}
573
574fn validate_port(port: u16) -> TokenPropagatorResult<()> {
575 if port == 0 {
576 return Err(TokenPropagatorError::PropagatorConfig {
577 message: "propagation targets must use a non-zero port".to_string(),
578 });
579 }
580
581 Ok(())
582}
583
584fn validate_host(host: &str) -> TokenPropagatorResult<()> {
585 let normalized = normalize_host(host);
586 if normalized.is_empty() {
587 return Err(TokenPropagatorError::InvalidTargetHost {
588 host: host.to_string(),
589 });
590 }
591
592 if normalized.parse::<IpAddr>().is_ok() {
593 return Ok(());
594 }
595
596 if normalized
597 .chars()
598 .any(|ch| !(ch.is_ascii_alphanumeric() || ch == '.' || ch == '-'))
599 {
600 return Err(TokenPropagatorError::InvalidTargetHost {
601 host: host.to_string(),
602 });
603 }
604
605 Ok(())
606}
607
608fn normalize_host(host: &str) -> String {
609 host.trim().trim_end_matches('.').to_ascii_lowercase()
610}
611
612fn match_allowed_target(
613 allowed_target: &AllowedPropagationTarget,
614 target: &ResolvedPropagationTarget,
615 host_ip: Option<IpAddr>,
616) -> bool {
617 match allowed_target {
618 AllowedPropagationTarget::ExactOrigin {
619 scheme,
620 hostname,
621 port,
622 } => {
623 scheme == &target.scheme
624 && port == &target.port
625 && normalize_host(hostname) == target.hostname
626 }
627 AllowedPropagationTarget::DomainSuffix {
628 scheme,
629 domain_suffix,
630 port,
631 } => {
632 let suffix = normalize_host(domain_suffix);
633 scheme == &target.scheme
634 && port == &target.port
635 && host_ip.is_none()
636 && domain_suffix_matches(&target.hostname, &suffix)
637 }
638 AllowedPropagationTarget::DomainRegex {
639 scheme,
640 domain_regex,
641 port,
642 } => {
643 scheme == &target.scheme
644 && port == &target.port
645 && host_ip.is_none()
646 && domain_regex.is_match(&target.hostname)
647 }
648 AllowedPropagationTarget::Cidr { scheme, cidr, port } => {
649 scheme == &target.scheme
650 && port == &target.port
651 && host_ip.is_some_and(|ip| {
652 ParsedCidr::parse(cidr).is_some_and(|parsed| parsed.contains(ip))
653 })
654 }
655 }
656}
657
658fn domain_suffix_matches(host: &str, suffix: &str) -> bool {
659 host == suffix || host.ends_with(&format!(".{suffix}"))
660}
661
662fn validate_domain_regex_target(
663 domain_regex: ®ex::Regex,
664 port: u16,
665) -> TokenPropagatorResult<()> {
666 if domain_regex.as_str().is_empty() {
667 return Err(TokenPropagatorError::PropagatorConfig {
668 message: "domain regex propagation target must not be empty".to_string(),
669 });
670 }
671
672 validate_port(port)
673}
674
675fn trim_quoted_value(value: &str) -> &str {
676 value
677 .strip_prefix('"')
678 .and_then(|value| value.strip_suffix('"'))
679 .unwrap_or(value)
680}
681
682fn parse_directive_identifier(field: &str, value: &str) -> TokenPropagatorResult<String> {
683 if value.is_empty() {
684 return Err(TokenPropagatorError::InvalidPropagationDirective {
685 message: format!("propagation directive `{field}` must not be empty"),
686 });
687 }
688
689 if value
690 .chars()
691 .any(|ch| !(ch.is_ascii_alphanumeric() || ch == '.' || ch == '_' || ch == '-' || ch == ':'))
692 {
693 return Err(TokenPropagatorError::InvalidPropagationDirective {
694 message: format!("propagation directive `{field}` contains unsupported characters"),
695 });
696 }
697
698 Ok(value.to_string())
699}
700
701fn parse_directive_host(value: &str) -> TokenPropagatorResult<(String, Option<u16>)> {
702 if value.is_empty() {
703 return Err(TokenPropagatorError::InvalidPropagationDirective {
704 message: "propagation directive `host` must not be empty".to_string(),
705 });
706 }
707
708 if let Some(host) = value
709 .strip_prefix('[')
710 .and_then(|value| value.strip_suffix(']'))
711 {
712 validate_host(host)?;
713 return Ok((normalize_host(host), None));
714 }
715
716 if let Some((host, port)) = value.rsplit_once(':')
717 && !host.is_empty()
718 && let Ok(port) = port.parse::<u16>()
719 {
720 validate_host(host)?;
721 validate_port(port)?;
722 return Ok((normalize_host(host), Some(port)));
723 }
724
725 validate_host(value)?;
726 Ok((normalize_host(value), None))
727}
728
729impl ResolvedPropagationTarget {
730 fn origin(&self) -> String {
731 format!("{}://{}:{}", self.scheme.as_str(), self.hostname, self.port)
732 }
733
734 fn from_url(node_id: Option<String>, url: Url) -> TokenPropagatorResult<Self> {
735 let target = PropagationRequestTarget::from_url(node_id, &url)?;
736 let scheme =
737 target
738 .scheme
739 .clone()
740 .ok_or_else(|| TokenPropagatorError::IncompleteTarget {
741 target: target.display(),
742 })?;
743 let hostname =
744 target
745 .hostname
746 .clone()
747 .ok_or_else(|| TokenPropagatorError::IncompleteTarget {
748 target: target.display(),
749 })?;
750
751 Ok(Self {
752 node_id: target.node_id,
753 port: target.port.unwrap_or_else(|| scheme.default_port()),
754 scheme,
755 hostname,
756 })
757 }
758}