spiffe_rs/workloadapi/
jwtsource.rs

1use crate::bundle::jwtbundle;
2use crate::svid::jwtsvid;
3use crate::workloadapi::option::{JWTSourceConfig, JWTSourceOption};
4use crate::workloadapi::{Context, Result, Watcher};
5use std::sync::{Arc, RwLock};
6
7/// A source of JWT SVIDs and bundles that is kept up-to-date by watching the
8/// Workload API.
9pub struct JWTSource {
10    watcher: Watcher,
11    picker: Option<Arc<dyn Fn(&[jwtsvid::SVID]) -> jwtsvid::SVID + Send + Sync>>,
12    bundles: Arc<RwLock<Option<jwtbundle::Set>>>,
13    closed: std::sync::atomic::AtomicBool,
14}
15
16impl JWTSource {
17    /// Creates a new `JWTSource` with the given options.
18    ///
19    /// It starts watching the Workload API for updates.
20    pub async fn new<I>(ctx: &Context, options: I) -> Result<JWTSource>
21    where
22        I: IntoIterator<Item = Arc<dyn JWTSourceOption>>,
23    {
24        let mut config = JWTSourceConfig::default();
25        for opt in options {
26            opt.configure_jwt_source(&mut config);
27        }
28
29        let bundles_slot = Arc::new(RwLock::new(None));
30        let bundles_slot_clone = bundles_slot.clone();
31        let handler = Arc::new(move |bundles: jwtbundle::Set| {
32            if let Ok(mut guard) = bundles_slot_clone.write() {
33                *guard = Some(bundles);
34            }
35        });
36
37        let watcher = Watcher::new(ctx, config.watcher, None, Some(handler)).await?;
38        Ok(JWTSource {
39            watcher,
40            picker: config.picker.clone(),
41            bundles: bundles_slot,
42            closed: std::sync::atomic::AtomicBool::new(false),
43        })
44    }
45
46    /// Closes the source.
47    pub async fn close(&self) -> Result<()> {
48        self.closed.store(true, std::sync::atomic::Ordering::SeqCst);
49        self.watcher.close().await
50    }
51
52    /// Fetches a JWT SVID with the given parameters.
53    pub async fn fetch_jwt_svid(
54        &self,
55        ctx: &Context,
56        params: jwtsvid::Params,
57    ) -> Result<jwtsvid::SVID> {
58        self.check_closed()?;
59        if let Some(picker) = &self.picker {
60            let svids = self.watcher.client.fetch_jwt_svids(ctx, params).await?;
61            return Ok(picker(&svids));
62        }
63        self.watcher.client.fetch_jwt_svid(ctx, params).await
64    }
65
66    /// Fetches multiple JWT SVIDs with the given parameters.
67    pub async fn fetch_jwt_svids(
68        &self,
69        ctx: &Context,
70        params: jwtsvid::Params,
71    ) -> Result<Vec<jwtsvid::SVID>> {
72        self.check_closed()?;
73        self.watcher.client.fetch_jwt_svids(ctx, params).await
74    }
75
76    /// Returns the JWT bundle for the given trust domain.
77    pub fn get_jwt_bundle_for_trust_domain(
78        &self,
79        trust_domain: crate::spiffeid::TrustDomain,
80    ) -> Result<jwtbundle::Bundle> {
81        self.check_closed()?;
82        self.bundles
83            .read()
84            .ok()
85            .and_then(|guard| guard.as_ref().and_then(|b| b.get_jwt_bundle_for_trust_domain(trust_domain).ok()))
86            .ok_or_else(|| crate::workloadapi::Error::new("jwtsource: no JWT bundle found"))
87    }
88
89    /// Waits until the source has been updated for the first time.
90    pub async fn wait_until_updated(&self, ctx: &Context) -> Result<()> {
91        self.watcher.wait_until_updated(ctx).await
92    }
93
94    /// Returns a receiver that can be used to watch for updates to the source.
95    pub fn updated(&self) -> tokio::sync::watch::Receiver<u64> {
96        self.watcher.updated()
97    }
98
99    fn check_closed(&self) -> Result<()> {
100        if self.closed.load(std::sync::atomic::Ordering::SeqCst) {
101            return Err(crate::workloadapi::Error::new("jwtsource: source is closed"));
102        }
103        Ok(())
104    }
105}
106
107impl jwtbundle::Source for JWTSource {
108    fn get_jwt_bundle_for_trust_domain(
109        &self,
110        trust_domain: crate::spiffeid::TrustDomain,
111    ) -> jwtbundle::Result<jwtbundle::Bundle> {
112        self.get_jwt_bundle_for_trust_domain(trust_domain)
113            .map_err(|err| jwtbundle::Error::new(err.to_string()))
114    }
115}