spiffe_rs/workloadapi/
watcher.rs

1use crate::workloadapi::{JWTBundleWatcher, Result, X509Context, X509ContextWatcher};
2use crate::workloadapi::{option::WatcherConfig, Client, Context};
3use std::sync::{Arc, Mutex};
4use tokio::sync::{oneshot, watch};
5use tokio::task::JoinHandle;
6use tokio_util::sync::CancellationToken;
7
8pub struct Watcher {
9    updated_tx: watch::Sender<u64>,
10    updated_rx: watch::Receiver<u64>,
11    pub(crate) client: Arc<Client>,
12    owns_client: bool,
13    cancel: CancellationToken,
14    tasks: Mutex<Vec<JoinHandle<()>>>,
15}
16
17impl Watcher {
18    pub async fn new(
19        ctx: &Context,
20        config: WatcherConfig,
21        x509_context_fn: Option<Arc<dyn Fn(X509Context) + Send + Sync>>,
22        jwt_bundles_fn: Option<Arc<dyn Fn(crate::bundle::jwtbundle::Set) + Send + Sync>>,
23    ) -> Result<Watcher> {
24        let owns_client = config.client.is_none();
25        let client = match config.client {
26            Some(client) => client,
27            None => Arc::new(Client::new(config.client_options).await?),
28        };
29        let cancel = CancellationToken::new();
30        let (updated_tx, updated_rx) = watch::channel(0u64);
31        let watcher = Watcher {
32            updated_tx,
33            updated_rx,
34            client,
35            owns_client,
36            cancel,
37            tasks: Mutex::new(Vec::new()),
38        };
39
40        watcher
41            .spawn_watchers(ctx, x509_context_fn, jwt_bundles_fn)
42            .await?;
43        Ok(watcher)
44    }
45
46    pub async fn close(&self) -> Result<()> {
47        self.cancel.cancel();
48        if let Ok(mut tasks) = self.tasks.lock() {
49            for task in tasks.drain(..) {
50                let _ = task.await;
51            }
52        }
53        if self.owns_client {
54            self.client.close().await?;
55        }
56        Ok(())
57    }
58
59    pub async fn wait_until_updated(&self, ctx: &Context) -> Result<()> {
60        let mut rx = self.updated_rx.clone();
61        tokio::select! {
62            _ = rx.changed() => Ok(()),
63            _ = ctx.cancelled() => Err(crate::workloadapi::wrap_error("context canceled")),
64        }
65    }
66
67    pub fn updated(&self) -> watch::Receiver<u64> {
68        self.updated_rx.clone()
69    }
70
71
72    async fn spawn_watchers(
73        &self,
74        ctx: &Context,
75        x509_context_fn: Option<Arc<dyn Fn(X509Context) + Send + Sync>>,
76        jwt_bundles_fn: Option<Arc<dyn Fn(crate::bundle::jwtbundle::Set) + Send + Sync>>,
77    ) -> Result<()> {
78        let mut tasks = self.tasks.lock().expect("watcher task lock");
79        let (err_tx, mut err_rx) = tokio::sync::mpsc::channel(2);
80
81        if let Some(handler) = x509_context_fn.clone() {
82            let (ready_tx, ready_rx) = oneshot::channel();
83            let watcher = Arc::new(InternalX509Watcher {
84                handler,
85                ready: Mutex::new(Some(ready_tx)),
86                updated: self.updated_tx.clone(),
87            });
88            let client = self.client.clone();
89            let cancel = self.cancel.clone();
90            let err_tx = err_tx.clone();
91            tasks.push(tokio::spawn(async move {
92                if let Err(err) = client.watch_x509_context(&cancel, watcher).await {
93                    let _ = err_tx.send(err).await;
94                }
95            }));
96            wait_for_ready(&mut err_rx, ctx, ready_rx).await?;
97        }
98
99        if let Some(handler) = jwt_bundles_fn.clone() {
100            let (ready_tx, ready_rx) = oneshot::channel();
101            let watcher = Arc::new(InternalJWTWatcher {
102                handler,
103                ready: Mutex::new(Some(ready_tx)),
104                updated: self.updated_tx.clone(),
105            });
106            let client = self.client.clone();
107            let cancel = self.cancel.clone();
108            let err_tx = err_tx.clone();
109            tasks.push(tokio::spawn(async move {
110                if let Err(err) = client.watch_jwt_bundles(&cancel, watcher).await {
111                    let _ = err_tx.send(err).await;
112                }
113            }));
114            wait_for_ready(&mut err_rx, ctx, ready_rx).await?;
115        }
116
117        Ok(())
118    }
119}
120
121struct InternalX509Watcher {
122    handler: Arc<dyn Fn(X509Context) + Send + Sync>,
123    ready: Mutex<Option<oneshot::Sender<()>>>,
124    updated: watch::Sender<u64>,
125}
126
127impl X509ContextWatcher for InternalX509Watcher {
128    fn on_x509_context_update(&self, context: X509Context) {
129        (self.handler)(context);
130        let _ = self.updated.send(*self.updated.borrow() + 1);
131        if let Some(tx) = self.ready.lock().ok().and_then(|mut lock| lock.take()) {
132            let _ = tx.send(());
133        }
134    }
135
136    fn on_x509_context_watch_error(&self, _err: crate::workloadapi::Error) {}
137}
138
139struct InternalJWTWatcher {
140    handler: Arc<dyn Fn(crate::bundle::jwtbundle::Set) + Send + Sync>,
141    ready: Mutex<Option<oneshot::Sender<()>>>,
142    updated: watch::Sender<u64>,
143}
144
145async fn wait_for_ready(
146    err_rx: &mut tokio::sync::mpsc::Receiver<crate::workloadapi::Error>,
147    ctx: &Context,
148    mut ready_rx: oneshot::Receiver<()>,
149) -> Result<()> {
150    tokio::select! {
151        _ = &mut ready_rx => Ok(()),
152        err = err_rx.recv() => Err(err.unwrap_or_else(|| crate::workloadapi::wrap_error("watcher failed"))),
153        _ = ctx.cancelled() => Err(crate::workloadapi::wrap_error("context canceled")),
154    }
155}
156
157impl JWTBundleWatcher for InternalJWTWatcher {
158    fn on_jwt_bundles_update(&self, bundles: crate::bundle::jwtbundle::Set) {
159        (self.handler)(bundles);
160        let _ = self.updated.send(*self.updated.borrow() + 1);
161        if let Some(tx) = self.ready.lock().ok().and_then(|mut lock| lock.take()) {
162            let _ = tx.send(());
163        }
164    }
165
166    fn on_jwt_bundles_watch_error(&self, _err: crate::workloadapi::Error) {}
167}