spiffe_rs/workloadapi/
client.rs

1use crate::bundle::jwtbundle;
2use crate::bundle::x509bundle;
3use crate::spiffeid::{self, ID};
4use crate::svid::{jwtsvid, x509svid};
5use crate::workloadapi::proto::spiffe_workload_api_client::SpiffeWorkloadApiClient;
6use crate::workloadapi::proto::{
7    JwtBundlesRequest, JwtBundlesResponse, JwtsvidRequest, JwtsvidResponse, ValidateJwtsvidRequest,
8    X509BundlesRequest, X509BundlesResponse, X509svidRequest, X509svidResponse,
9};
10use crate::workloadapi::{target_from_address, wrap_error, Backoff, Error, Result};
11use crate::workloadapi::{option::ClientConfig, Context};
12use tower::service_fn;
13use std::collections::HashSet;
14use std::sync::Arc;
15use tokio::net::UnixStream;
16use tonic::metadata::MetadataValue;
17use tonic::transport::{Channel, Endpoint};
18use tonic::{Code, Request, Status};
19
20/// A client for the SPIFFE Workload API.
21///
22/// This client can be used to fetch X.509 and JWT SVIDs and bundles from a
23/// Workload API endpoint.
24pub struct Client {
25    inner: SpiffeWorkloadApiClient<Channel>,
26    config: ClientConfig,
27}
28
29impl Client {
30    /// Creates a new `Client` with the given options.
31    pub async fn new<I>(options: I) -> Result<Client>
32    where
33        I: IntoIterator<Item = Arc<dyn crate::workloadapi::ClientOption>>,
34    {
35        let mut config = ClientConfig::default();
36        for opt in options {
37            opt.configure_client(&mut config);
38        }
39
40        let address = match config.address.clone() {
41            Some(addr) => addr,
42            None => crate::workloadapi::get_default_address().ok_or_else(|| {
43                wrap_error("workload endpoint socket address is not configured")
44            })?,
45        };
46        let target = target_from_address(&address)?;
47        let channel = connect_channel(&target, &config.dial_options).await?;
48        let inner = SpiffeWorkloadApiClient::new(channel);
49        Ok(Client { inner, config })
50    }
51
52    /// Closes the client.
53    pub async fn close(&self) -> Result<()> {
54        Ok(())
55    }
56
57    /// Fetches a single X.509 SVID from the Workload API.
58    pub async fn fetch_x509_svid(&self, ctx: &Context) -> Result<x509svid::SVID> {
59        let mut client = self.inner.clone();
60        let request = with_header(Request::new(X509svidRequest {}));
61        let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner();
62        let response = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
63        let svids = parse_x509_svids(response, true)?;
64        Ok(svids
65            .into_iter()
66            .next()
67            .ok_or_else(|| wrap_error("no SVIDs in response"))?)
68    }
69
70    /// Fetches all X.509 SVIDs from the Workload API.
71    pub async fn fetch_x509_svids(&self, ctx: &Context) -> Result<Vec<x509svid::SVID>> {
72        let mut client = self.inner.clone();
73        let request = with_header(Request::new(X509svidRequest {}));
74        let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner();
75        let response = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
76        parse_x509_svids(response, false)
77    }
78
79    /// Fetches X.509 bundles from the Workload API.
80    pub async fn fetch_x509_bundles(&self, ctx: &Context) -> Result<x509bundle::Set> {
81        let mut client = self.inner.clone();
82        let request = with_header(Request::new(X509BundlesRequest {}));
83        let mut stream = cancelable(ctx, client.fetch_x509_bundles(request)).await?.into_inner();
84        let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
85        parse_x509_bundles_response(resp)
86    }
87
88    /// Watches for X.509 bundle updates from the Workload API.
89    pub async fn watch_x509_bundles(&self, ctx: &Context, watcher: Arc<dyn X509BundleWatcher>) -> Result<()> {
90        let mut backoff = self.config.backoff_strategy.new_backoff();
91        loop {
92            if let Err(err) = self.watch_x509_bundles_once(ctx, watcher.clone(), &mut *backoff).await {
93                watcher.on_x509_bundles_watch_error(err.clone());
94                if let Some(err) = self.handle_watch_error(ctx, err, &mut *backoff).await {
95                    return Err(err);
96                }
97            }
98        }
99    }
100
101    /// Fetches the X.509 context (SVIDs and bundles) from the Workload API.
102    pub async fn fetch_x509_context(&self, ctx: &Context) -> Result<crate::workloadapi::X509Context> {
103        let mut client = self.inner.clone();
104        let request = with_header(Request::new(X509svidRequest {}));
105        let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner();
106        let response = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
107        parse_x509_context(response)
108    }
109
110    /// Watches for X.509 context updates from the Workload API.
111    pub async fn watch_x509_context(
112        &self,
113        ctx: &Context,
114        watcher: Arc<dyn X509ContextWatcher>,
115    ) -> Result<()> {
116        let mut backoff = self.config.backoff_strategy.new_backoff();
117        loop {
118            if let Err(err) = self.watch_x509_context_once(ctx, watcher.clone(), &mut *backoff).await {
119                watcher.on_x509_context_watch_error(err.clone());
120                if let Some(err) = self.handle_watch_error(ctx, err, &mut *backoff).await {
121                    return Err(err);
122                }
123            }
124        }
125    }
126
127    /// Fetches a single JWT SVID from the Workload API.
128    pub async fn fetch_jwt_svid(&self, ctx: &Context, params: jwtsvid::Params) -> Result<jwtsvid::SVID> {
129        let mut client = self.inner.clone();
130        let audience = params.audience_list();
131        let request = with_header(Request::new(JwtsvidRequest {
132            spiffe_id: params.subject.to_string(),
133            audience: audience.clone(),
134        }));
135        let response = cancelable(ctx, client.fetch_jwtsvid(request)).await?;
136        let svids = parse_jwt_svids(response.into_inner(), &audience, true)?;
137        Ok(svids
138            .into_iter()
139            .next()
140            .ok_or_else(|| wrap_error("there were no SVIDs in the response"))?)
141    }
142
143    /// Fetches multiple JWT SVIDs from the Workload API.
144    pub async fn fetch_jwt_svids(&self, ctx: &Context, params: jwtsvid::Params) -> Result<Vec<jwtsvid::SVID>> {
145        let mut client = self.inner.clone();
146        let audience = params.audience_list();
147        let request = with_header(Request::new(JwtsvidRequest {
148            spiffe_id: params.subject.to_string(),
149            audience: audience.clone(),
150        }));
151        let response = cancelable(ctx, client.fetch_jwtsvid(request)).await?;
152        parse_jwt_svids(response.into_inner(), &audience, false)
153    }
154
155    /// Fetches JWT bundles from the Workload API.
156    pub async fn fetch_jwt_bundles(&self, ctx: &Context) -> Result<jwtbundle::Set> {
157        let mut client = self.inner.clone();
158        let request = with_header(Request::new(JwtBundlesRequest {}));
159        let mut stream = cancelable(ctx, client.fetch_jwt_bundles(request)).await?.into_inner();
160        let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
161        parse_jwt_bundles(resp)
162    }
163
164    /// Watches for JWT bundle updates from the Workload API.
165    pub async fn watch_jwt_bundles(&self, ctx: &Context, watcher: Arc<dyn JWTBundleWatcher>) -> Result<()> {
166        let mut backoff = self.config.backoff_strategy.new_backoff();
167        loop {
168            if let Err(err) = self.watch_jwt_bundles_once(ctx, watcher.clone(), &mut *backoff).await {
169                watcher.on_jwt_bundles_watch_error(err.clone());
170                if let Some(err) = self.handle_watch_error(ctx, err, &mut *backoff).await {
171                    return Err(err);
172                }
173            }
174        }
175    }
176
177    /// Validates a JWT SVID token using the Workload API.
178    pub async fn validate_jwt_svid(&self, ctx: &Context, token: &str, audience: &str) -> Result<jwtsvid::SVID> {
179        let mut client = self.inner.clone();
180        let request = with_header(Request::new(ValidateJwtsvidRequest {
181            svid: token.to_string(),
182            audience: audience.to_string(),
183        }));
184        cancelable(ctx, client.validate_jwtsvid(request)).await?;
185        jwtsvid::parse_insecure(token, &[audience.to_string()]).map_err(|err| wrap_error(err))
186    }
187
188    async fn handle_watch_error(
189        &self,
190        ctx: &Context,
191        err: Error,
192        backoff: &mut dyn Backoff,
193    ) -> Option<Error> {
194        let status = err.status().cloned().unwrap_or_else(|| Status::unknown(err.to_string()));
195        match status.code() {
196            Code::Cancelled => return Some(err),
197            Code::InvalidArgument => {
198                self.config
199                    .log
200                    .errorf(format_args!("Canceling watch: {}", status));
201                return Some(err);
202            }
203            _ => {
204                self.config
205                    .log
206                    .errorf(format_args!("Failed to watch the Workload API: {}", status));
207            }
208        }
209
210        let retry_after = backoff.next();
211        self.config
212            .log
213            .debugf(format_args!("Retrying watch in {:?}", retry_after));
214        tokio::select! {
215            _ = tokio::time::sleep(retry_after) => None,
216            _ = ctx.cancelled() => Some(wrap_error("context canceled")),
217        }
218    }
219
220    async fn watch_x509_context_once(
221        &self,
222        ctx: &Context,
223        watcher: Arc<dyn X509ContextWatcher>,
224        backoff: &mut dyn Backoff,
225    ) -> Result<()> {
226        let mut client = self.inner.clone();
227        let request = with_header(Request::new(X509svidRequest {}));
228        let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner();
229        self.config.log.debugf(format_args!("Watching X.509 contexts"));
230        loop {
231            let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
232            backoff.reset();
233            match parse_x509_context(resp) {
234                Ok(context) => watcher.on_x509_context_update(context),
235                Err(err) => {
236                    self.config
237                        .log
238                        .errorf(format_args!("Failed to parse X509-SVID response: {}", err));
239                    watcher.on_x509_context_watch_error(err);
240                }
241            }
242        }
243    }
244
245    async fn watch_jwt_bundles_once(
246        &self,
247        ctx: &Context,
248        watcher: Arc<dyn JWTBundleWatcher>,
249        backoff: &mut dyn Backoff,
250    ) -> Result<()> {
251        let mut client = self.inner.clone();
252        let request = with_header(Request::new(JwtBundlesRequest {}));
253        let mut stream = cancelable(ctx, client.fetch_jwt_bundles(request)).await?.into_inner();
254        self.config.log.debugf(format_args!("Watching JWT bundles"));
255        loop {
256            let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
257            backoff.reset();
258            match parse_jwt_bundles(resp) {
259                Ok(bundles) => watcher.on_jwt_bundles_update(bundles),
260                Err(err) => {
261                    self.config
262                        .log
263                        .errorf(format_args!("Failed to parse JWT bundle response: {}", err));
264                    watcher.on_jwt_bundles_watch_error(err);
265                }
266            }
267        }
268    }
269
270    async fn watch_x509_bundles_once(
271        &self,
272        ctx: &Context,
273        watcher: Arc<dyn X509BundleWatcher>,
274        backoff: &mut dyn Backoff,
275    ) -> Result<()> {
276        let mut client = self.inner.clone();
277        let request = with_header(Request::new(X509BundlesRequest {}));
278        let mut stream = cancelable(ctx, client.fetch_x509_bundles(request)).await?.into_inner();
279        self.config.log.debugf(format_args!("Watching X.509 bundles"));
280        loop {
281            let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
282            backoff.reset();
283            match parse_x509_bundles_response(resp) {
284                Ok(bundles) => watcher.on_x509_bundles_update(bundles),
285                Err(err) => {
286                    self.config
287                        .log
288                        .errorf(format_args!("Failed to parse X.509 bundle response: {}", err));
289                    watcher.on_x509_bundles_watch_error(err);
290                }
291            }
292        }
293    }
294}
295
296fn with_header<T>(mut request: Request<T>) -> Request<T> {
297    request
298        .metadata_mut()
299        .insert("workload.spiffe.io", MetadataValue::from_static("true"));
300    request
301}
302
303async fn connect_channel(target: &str, options: &[Arc<dyn crate::workloadapi::DialOption>]) -> Result<Channel> {
304    if let Ok(url) = url::Url::parse(target) {
305        if url.scheme() == "unix" {
306            let path = unix_path_from_url(&url)?;
307            let mut endpoint = Endpoint::try_from("http://[::]:0")
308                .map_err(|err| wrap_error(format!("invalid endpoint: {}", err)))?;
309            for opt in options {
310                endpoint = opt.apply(endpoint);
311            }
312            let connector = service_fn(move |_uri| UnixStream::connect(path.clone()));
313            let channel = endpoint
314                .connect_with_connector(connector)
315                .await
316                .map_err(|err| wrap_error(format!("unable to connect: {}", err)))?;
317            return Ok(channel);
318        }
319    }
320
321    let mut endpoint = Endpoint::from_shared(format!("http://{}", target))
322        .map_err(|err| wrap_error(format!("invalid endpoint: {}", err)))?;
323    for opt in options {
324        endpoint = opt.apply(endpoint);
325    }
326    endpoint
327        .connect()
328        .await
329        .map_err(|err| wrap_error(format!("unable to connect: {}", err)))
330}
331
332fn unix_path_from_url(url: &url::Url) -> Result<std::path::PathBuf> {
333    if url.cannot_be_a_base() {
334        return Err(wrap_error("workload endpoint unix socket URI must not be opaque"));
335    }
336    let host = url.host_str().unwrap_or("");
337    let raw_path = if host.is_empty() {
338        url.path().to_string()
339    } else if url.path().is_empty() {
340        format!("/{host}")
341    } else {
342        format!("/{host}{}", url.path())
343    };
344    if raw_path.is_empty() || raw_path == "/" {
345        return Err(wrap_error("workload endpoint unix socket URI must include a path"));
346    }
347    Ok(std::path::PathBuf::from(raw_path))
348}
349
350async fn cancelable<T, F>(ctx: &Context, fut: F) -> Result<T>
351where
352    F: std::future::Future<Output = std::result::Result<T, Status>>,
353{
354    tokio::select! {
355        result = fut => result.map_err(Error::from),
356        _ = ctx.cancelled() => Err(wrap_error("context canceled")),
357    }
358}
359
360fn parse_x509_context(resp: X509svidResponse) -> Result<crate::workloadapi::X509Context> {
361    let svids = parse_x509_svids(resp.clone(), false)?;
362    let bundles = parse_x509_bundles(resp)?;
363    Ok(crate::workloadapi::X509Context { svids, bundles })
364}
365
366fn parse_x509_svids(resp: X509svidResponse, first_only: bool) -> Result<Vec<x509svid::SVID>> {
367    let mut svids = resp.svids;
368    if svids.is_empty() {
369        return Err(wrap_error("no SVIDs in response"));
370    }
371    if first_only {
372        svids.truncate(1);
373    }
374
375    let mut seen = HashSet::new();
376    let mut out = Vec::new();
377    for svid in svids {
378        if !svid.hint.is_empty() && !seen.insert(svid.hint.clone()) {
379            continue;
380        }
381        let mut parsed = x509svid::SVID::parse_raw(&svid.x509_svid, &svid.x509_svid_key)
382            .map_err(|err| wrap_error(err))?;
383        parsed.hint = svid.hint;
384        out.push(parsed);
385    }
386    Ok(out)
387}
388
389fn parse_x509_bundles(resp: X509svidResponse) -> Result<x509bundle::Set> {
390    let mut bundles = Vec::new();
391    for svid in resp.svids {
392        let td = ID::from_string(&svid.spiffe_id)
393            .map_err(|err| wrap_error(err))?
394            .trust_domain();
395        bundles.push(x509bundle::Bundle::parse_raw(td, &svid.bundle).map_err(|err| wrap_error(err))?);
396    }
397    for (td_id, bundle) in resp.federated_bundles {
398        let td = spiffeid::trust_domain_from_string(&td_id).map_err(|err| wrap_error(err))?;
399        bundles.push(x509bundle::Bundle::parse_raw(td, &bundle).map_err(|err| wrap_error(err))?);
400    }
401    Ok(x509bundle::Set::new(&bundles))
402}
403
404fn parse_x509_bundles_response(resp: X509BundlesResponse) -> Result<x509bundle::Set> {
405    let mut bundles = Vec::new();
406    for (td_id, bundle) in resp.bundles {
407        let td = spiffeid::trust_domain_from_string(&td_id).map_err(|err| wrap_error(err))?;
408        bundles.push(x509bundle::Bundle::parse_raw(td, &bundle).map_err(|err| wrap_error(err))?);
409    }
410    Ok(x509bundle::Set::new(&bundles))
411}
412
413fn parse_jwt_svids(resp: JwtsvidResponse, audience: &[String], first_only: bool) -> Result<Vec<jwtsvid::SVID>> {
414    let mut svids = resp.svids;
415    if svids.is_empty() {
416        return Err(wrap_error("there were no SVIDs in the response"));
417    }
418    if first_only {
419        svids.truncate(1);
420    }
421
422    let mut seen = HashSet::new();
423    let mut out = Vec::new();
424    for svid in svids {
425        if !svid.hint.is_empty() && !seen.insert(svid.hint.clone()) {
426            continue;
427        }
428        let mut parsed = jwtsvid::parse_insecure(&svid.svid, audience).map_err(|err| wrap_error(err))?;
429        parsed.hint = svid.hint;
430        out.push(parsed);
431    }
432    Ok(out)
433}
434
435fn parse_jwt_bundles(resp: JwtBundlesResponse) -> Result<jwtbundle::Set> {
436    let mut bundles = Vec::new();
437    for (td_id, bundle) in resp.bundles {
438        let td = spiffeid::trust_domain_from_string(&td_id).map_err(|err| wrap_error(err))?;
439        bundles.push(jwtbundle::Bundle::parse(td, &bundle).map_err(|err| wrap_error(err))?);
440    }
441    Ok(jwtbundle::Set::new(&bundles))
442}
443
444pub trait X509ContextWatcher: Send + Sync {
445    fn on_x509_context_update(&self, context: crate::workloadapi::X509Context);
446    fn on_x509_context_watch_error(&self, err: Error);
447}
448
449pub trait JWTBundleWatcher: Send + Sync {
450    fn on_jwt_bundles_update(&self, bundles: jwtbundle::Set);
451    fn on_jwt_bundles_watch_error(&self, err: Error);
452}
453
454pub trait X509BundleWatcher: Send + Sync {
455    fn on_x509_bundles_update(&self, bundles: x509bundle::Set);
456    fn on_x509_bundles_watch_error(&self, err: Error);
457}