1use crate::bundle::jwtbundle;
2use crate::bundle::x509bundle;
3use crate::spiffeid::{self, ID};
4use crate::svid::{jwtsvid, x509svid};
5use crate::workloadapi::proto::spiffe_workload_api_client::SpiffeWorkloadApiClient;
6use crate::workloadapi::proto::{
7 JwtBundlesRequest, JwtBundlesResponse, JwtsvidRequest, JwtsvidResponse, ValidateJwtsvidRequest,
8 X509BundlesRequest, X509BundlesResponse, X509svidRequest, X509svidResponse,
9};
10use crate::workloadapi::{target_from_address, wrap_error, Backoff, Error, Result};
11use crate::workloadapi::{option::ClientConfig, Context};
12use tower::service_fn;
13use std::collections::HashSet;
14use std::sync::Arc;
15use tokio::net::UnixStream;
16use tonic::metadata::MetadataValue;
17use tonic::transport::{Channel, Endpoint};
18use tonic::{Code, Request, Status};
19
20pub struct Client {
25 inner: SpiffeWorkloadApiClient<Channel>,
26 config: ClientConfig,
27}
28
29impl Client {
30 pub async fn new<I>(options: I) -> Result<Client>
32 where
33 I: IntoIterator<Item = Arc<dyn crate::workloadapi::ClientOption>>,
34 {
35 let mut config = ClientConfig::default();
36 for opt in options {
37 opt.configure_client(&mut config);
38 }
39
40 let address = match config.address.clone() {
41 Some(addr) => addr,
42 None => crate::workloadapi::get_default_address().ok_or_else(|| {
43 wrap_error("workload endpoint socket address is not configured")
44 })?,
45 };
46 let target = target_from_address(&address)?;
47 let channel = connect_channel(&target, &config.dial_options).await?;
48 let inner = SpiffeWorkloadApiClient::new(channel);
49 Ok(Client { inner, config })
50 }
51
52 pub async fn close(&self) -> Result<()> {
54 Ok(())
55 }
56
57 pub async fn fetch_x509_svid(&self, ctx: &Context) -> Result<x509svid::SVID> {
59 let mut client = self.inner.clone();
60 let request = with_header(Request::new(X509svidRequest {}));
61 let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner();
62 let response = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
63 let svids = parse_x509_svids(response, true)?;
64 Ok(svids
65 .into_iter()
66 .next()
67 .ok_or_else(|| wrap_error("no SVIDs in response"))?)
68 }
69
70 pub async fn fetch_x509_svids(&self, ctx: &Context) -> Result<Vec<x509svid::SVID>> {
72 let mut client = self.inner.clone();
73 let request = with_header(Request::new(X509svidRequest {}));
74 let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner();
75 let response = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
76 parse_x509_svids(response, false)
77 }
78
79 pub async fn fetch_x509_bundles(&self, ctx: &Context) -> Result<x509bundle::Set> {
81 let mut client = self.inner.clone();
82 let request = with_header(Request::new(X509BundlesRequest {}));
83 let mut stream = cancelable(ctx, client.fetch_x509_bundles(request)).await?.into_inner();
84 let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
85 parse_x509_bundles_response(resp)
86 }
87
88 pub async fn watch_x509_bundles(&self, ctx: &Context, watcher: Arc<dyn X509BundleWatcher>) -> Result<()> {
90 let mut backoff = self.config.backoff_strategy.new_backoff();
91 loop {
92 if let Err(err) = self.watch_x509_bundles_once(ctx, watcher.clone(), &mut *backoff).await {
93 watcher.on_x509_bundles_watch_error(err.clone());
94 if let Some(err) = self.handle_watch_error(ctx, err, &mut *backoff).await {
95 return Err(err);
96 }
97 }
98 }
99 }
100
101 pub async fn fetch_x509_context(&self, ctx: &Context) -> Result<crate::workloadapi::X509Context> {
103 let mut client = self.inner.clone();
104 let request = with_header(Request::new(X509svidRequest {}));
105 let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner();
106 let response = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
107 parse_x509_context(response)
108 }
109
110 pub async fn watch_x509_context(
112 &self,
113 ctx: &Context,
114 watcher: Arc<dyn X509ContextWatcher>,
115 ) -> Result<()> {
116 let mut backoff = self.config.backoff_strategy.new_backoff();
117 loop {
118 if let Err(err) = self.watch_x509_context_once(ctx, watcher.clone(), &mut *backoff).await {
119 watcher.on_x509_context_watch_error(err.clone());
120 if let Some(err) = self.handle_watch_error(ctx, err, &mut *backoff).await {
121 return Err(err);
122 }
123 }
124 }
125 }
126
127 pub async fn fetch_jwt_svid(&self, ctx: &Context, params: jwtsvid::Params) -> Result<jwtsvid::SVID> {
129 let mut client = self.inner.clone();
130 let audience = params.audience_list();
131 let request = with_header(Request::new(JwtsvidRequest {
132 spiffe_id: params.subject.to_string(),
133 audience: audience.clone(),
134 }));
135 let response = cancelable(ctx, client.fetch_jwtsvid(request)).await?;
136 let svids = parse_jwt_svids(response.into_inner(), &audience, true)?;
137 Ok(svids
138 .into_iter()
139 .next()
140 .ok_or_else(|| wrap_error("there were no SVIDs in the response"))?)
141 }
142
143 pub async fn fetch_jwt_svids(&self, ctx: &Context, params: jwtsvid::Params) -> Result<Vec<jwtsvid::SVID>> {
145 let mut client = self.inner.clone();
146 let audience = params.audience_list();
147 let request = with_header(Request::new(JwtsvidRequest {
148 spiffe_id: params.subject.to_string(),
149 audience: audience.clone(),
150 }));
151 let response = cancelable(ctx, client.fetch_jwtsvid(request)).await?;
152 parse_jwt_svids(response.into_inner(), &audience, false)
153 }
154
155 pub async fn fetch_jwt_bundles(&self, ctx: &Context) -> Result<jwtbundle::Set> {
157 let mut client = self.inner.clone();
158 let request = with_header(Request::new(JwtBundlesRequest {}));
159 let mut stream = cancelable(ctx, client.fetch_jwt_bundles(request)).await?.into_inner();
160 let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
161 parse_jwt_bundles(resp)
162 }
163
164 pub async fn watch_jwt_bundles(&self, ctx: &Context, watcher: Arc<dyn JWTBundleWatcher>) -> Result<()> {
166 let mut backoff = self.config.backoff_strategy.new_backoff();
167 loop {
168 if let Err(err) = self.watch_jwt_bundles_once(ctx, watcher.clone(), &mut *backoff).await {
169 watcher.on_jwt_bundles_watch_error(err.clone());
170 if let Some(err) = self.handle_watch_error(ctx, err, &mut *backoff).await {
171 return Err(err);
172 }
173 }
174 }
175 }
176
177 pub async fn validate_jwt_svid(&self, ctx: &Context, token: &str, audience: &str) -> Result<jwtsvid::SVID> {
179 let mut client = self.inner.clone();
180 let request = with_header(Request::new(ValidateJwtsvidRequest {
181 svid: token.to_string(),
182 audience: audience.to_string(),
183 }));
184 cancelable(ctx, client.validate_jwtsvid(request)).await?;
185 jwtsvid::parse_insecure(token, &[audience.to_string()]).map_err(|err| wrap_error(err))
186 }
187
188 async fn handle_watch_error(
189 &self,
190 ctx: &Context,
191 err: Error,
192 backoff: &mut dyn Backoff,
193 ) -> Option<Error> {
194 let status = err.status().cloned().unwrap_or_else(|| Status::unknown(err.to_string()));
195 match status.code() {
196 Code::Cancelled => return Some(err),
197 Code::InvalidArgument => {
198 self.config
199 .log
200 .errorf(format_args!("Canceling watch: {}", status));
201 return Some(err);
202 }
203 _ => {
204 self.config
205 .log
206 .errorf(format_args!("Failed to watch the Workload API: {}", status));
207 }
208 }
209
210 let retry_after = backoff.next();
211 self.config
212 .log
213 .debugf(format_args!("Retrying watch in {:?}", retry_after));
214 tokio::select! {
215 _ = tokio::time::sleep(retry_after) => None,
216 _ = ctx.cancelled() => Some(wrap_error("context canceled")),
217 }
218 }
219
220 async fn watch_x509_context_once(
221 &self,
222 ctx: &Context,
223 watcher: Arc<dyn X509ContextWatcher>,
224 backoff: &mut dyn Backoff,
225 ) -> Result<()> {
226 let mut client = self.inner.clone();
227 let request = with_header(Request::new(X509svidRequest {}));
228 let mut stream = cancelable(ctx, client.fetch_x509svid(request)).await?.into_inner();
229 self.config.log.debugf(format_args!("Watching X.509 contexts"));
230 loop {
231 let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
232 backoff.reset();
233 match parse_x509_context(resp) {
234 Ok(context) => watcher.on_x509_context_update(context),
235 Err(err) => {
236 self.config
237 .log
238 .errorf(format_args!("Failed to parse X509-SVID response: {}", err));
239 watcher.on_x509_context_watch_error(err);
240 }
241 }
242 }
243 }
244
245 async fn watch_jwt_bundles_once(
246 &self,
247 ctx: &Context,
248 watcher: Arc<dyn JWTBundleWatcher>,
249 backoff: &mut dyn Backoff,
250 ) -> Result<()> {
251 let mut client = self.inner.clone();
252 let request = with_header(Request::new(JwtBundlesRequest {}));
253 let mut stream = cancelable(ctx, client.fetch_jwt_bundles(request)).await?.into_inner();
254 self.config.log.debugf(format_args!("Watching JWT bundles"));
255 loop {
256 let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
257 backoff.reset();
258 match parse_jwt_bundles(resp) {
259 Ok(bundles) => watcher.on_jwt_bundles_update(bundles),
260 Err(err) => {
261 self.config
262 .log
263 .errorf(format_args!("Failed to parse JWT bundle response: {}", err));
264 watcher.on_jwt_bundles_watch_error(err);
265 }
266 }
267 }
268 }
269
270 async fn watch_x509_bundles_once(
271 &self,
272 ctx: &Context,
273 watcher: Arc<dyn X509BundleWatcher>,
274 backoff: &mut dyn Backoff,
275 ) -> Result<()> {
276 let mut client = self.inner.clone();
277 let request = with_header(Request::new(X509BundlesRequest {}));
278 let mut stream = cancelable(ctx, client.fetch_x509_bundles(request)).await?.into_inner();
279 self.config.log.debugf(format_args!("Watching X.509 bundles"));
280 loop {
281 let resp = cancelable(ctx, stream.message()).await?.ok_or_else(|| wrap_error("stream closed"))?;
282 backoff.reset();
283 match parse_x509_bundles_response(resp) {
284 Ok(bundles) => watcher.on_x509_bundles_update(bundles),
285 Err(err) => {
286 self.config
287 .log
288 .errorf(format_args!("Failed to parse X.509 bundle response: {}", err));
289 watcher.on_x509_bundles_watch_error(err);
290 }
291 }
292 }
293 }
294}
295
296fn with_header<T>(mut request: Request<T>) -> Request<T> {
297 request
298 .metadata_mut()
299 .insert("workload.spiffe.io", MetadataValue::from_static("true"));
300 request
301}
302
303async fn connect_channel(target: &str, options: &[Arc<dyn crate::workloadapi::DialOption>]) -> Result<Channel> {
304 if let Ok(url) = url::Url::parse(target) {
305 if url.scheme() == "unix" {
306 let path = unix_path_from_url(&url)?;
307 let mut endpoint = Endpoint::try_from("http://[::]:0")
308 .map_err(|err| wrap_error(format!("invalid endpoint: {}", err)))?;
309 for opt in options {
310 endpoint = opt.apply(endpoint);
311 }
312 let connector = service_fn(move |_uri| UnixStream::connect(path.clone()));
313 let channel = endpoint
314 .connect_with_connector(connector)
315 .await
316 .map_err(|err| wrap_error(format!("unable to connect: {}", err)))?;
317 return Ok(channel);
318 }
319 }
320
321 let mut endpoint = Endpoint::from_shared(format!("http://{}", target))
322 .map_err(|err| wrap_error(format!("invalid endpoint: {}", err)))?;
323 for opt in options {
324 endpoint = opt.apply(endpoint);
325 }
326 endpoint
327 .connect()
328 .await
329 .map_err(|err| wrap_error(format!("unable to connect: {}", err)))
330}
331
332fn unix_path_from_url(url: &url::Url) -> Result<std::path::PathBuf> {
333 if url.cannot_be_a_base() {
334 return Err(wrap_error("workload endpoint unix socket URI must not be opaque"));
335 }
336 let host = url.host_str().unwrap_or("");
337 let raw_path = if host.is_empty() {
338 url.path().to_string()
339 } else if url.path().is_empty() {
340 format!("/{host}")
341 } else {
342 format!("/{host}{}", url.path())
343 };
344 if raw_path.is_empty() || raw_path == "/" {
345 return Err(wrap_error("workload endpoint unix socket URI must include a path"));
346 }
347 Ok(std::path::PathBuf::from(raw_path))
348}
349
350async fn cancelable<T, F>(ctx: &Context, fut: F) -> Result<T>
351where
352 F: std::future::Future<Output = std::result::Result<T, Status>>,
353{
354 tokio::select! {
355 result = fut => result.map_err(Error::from),
356 _ = ctx.cancelled() => Err(wrap_error("context canceled")),
357 }
358}
359
360fn parse_x509_context(resp: X509svidResponse) -> Result<crate::workloadapi::X509Context> {
361 let svids = parse_x509_svids(resp.clone(), false)?;
362 let bundles = parse_x509_bundles(resp)?;
363 Ok(crate::workloadapi::X509Context { svids, bundles })
364}
365
366fn parse_x509_svids(resp: X509svidResponse, first_only: bool) -> Result<Vec<x509svid::SVID>> {
367 let mut svids = resp.svids;
368 if svids.is_empty() {
369 return Err(wrap_error("no SVIDs in response"));
370 }
371 if first_only {
372 svids.truncate(1);
373 }
374
375 let mut seen = HashSet::new();
376 let mut out = Vec::new();
377 for svid in svids {
378 if !svid.hint.is_empty() && !seen.insert(svid.hint.clone()) {
379 continue;
380 }
381 let mut parsed = x509svid::SVID::parse_raw(&svid.x509_svid, &svid.x509_svid_key)
382 .map_err(|err| wrap_error(err))?;
383 parsed.hint = svid.hint;
384 out.push(parsed);
385 }
386 Ok(out)
387}
388
389fn parse_x509_bundles(resp: X509svidResponse) -> Result<x509bundle::Set> {
390 let mut bundles = Vec::new();
391 for svid in resp.svids {
392 let td = ID::from_string(&svid.spiffe_id)
393 .map_err(|err| wrap_error(err))?
394 .trust_domain();
395 bundles.push(x509bundle::Bundle::parse_raw(td, &svid.bundle).map_err(|err| wrap_error(err))?);
396 }
397 for (td_id, bundle) in resp.federated_bundles {
398 let td = spiffeid::trust_domain_from_string(&td_id).map_err(|err| wrap_error(err))?;
399 bundles.push(x509bundle::Bundle::parse_raw(td, &bundle).map_err(|err| wrap_error(err))?);
400 }
401 Ok(x509bundle::Set::new(&bundles))
402}
403
404fn parse_x509_bundles_response(resp: X509BundlesResponse) -> Result<x509bundle::Set> {
405 let mut bundles = Vec::new();
406 for (td_id, bundle) in resp.bundles {
407 let td = spiffeid::trust_domain_from_string(&td_id).map_err(|err| wrap_error(err))?;
408 bundles.push(x509bundle::Bundle::parse_raw(td, &bundle).map_err(|err| wrap_error(err))?);
409 }
410 Ok(x509bundle::Set::new(&bundles))
411}
412
413fn parse_jwt_svids(resp: JwtsvidResponse, audience: &[String], first_only: bool) -> Result<Vec<jwtsvid::SVID>> {
414 let mut svids = resp.svids;
415 if svids.is_empty() {
416 return Err(wrap_error("there were no SVIDs in the response"));
417 }
418 if first_only {
419 svids.truncate(1);
420 }
421
422 let mut seen = HashSet::new();
423 let mut out = Vec::new();
424 for svid in svids {
425 if !svid.hint.is_empty() && !seen.insert(svid.hint.clone()) {
426 continue;
427 }
428 let mut parsed = jwtsvid::parse_insecure(&svid.svid, audience).map_err(|err| wrap_error(err))?;
429 parsed.hint = svid.hint;
430 out.push(parsed);
431 }
432 Ok(out)
433}
434
435fn parse_jwt_bundles(resp: JwtBundlesResponse) -> Result<jwtbundle::Set> {
436 let mut bundles = Vec::new();
437 for (td_id, bundle) in resp.bundles {
438 let td = spiffeid::trust_domain_from_string(&td_id).map_err(|err| wrap_error(err))?;
439 bundles.push(jwtbundle::Bundle::parse(td, &bundle).map_err(|err| wrap_error(err))?);
440 }
441 Ok(jwtbundle::Set::new(&bundles))
442}
443
444pub trait X509ContextWatcher: Send + Sync {
445 fn on_x509_context_update(&self, context: crate::workloadapi::X509Context);
446 fn on_x509_context_watch_error(&self, err: Error);
447}
448
449pub trait JWTBundleWatcher: Send + Sync {
450 fn on_jwt_bundles_update(&self, bundles: jwtbundle::Set);
451 fn on_jwt_bundles_watch_error(&self, err: Error);
452}
453
454pub trait X509BundleWatcher: Send + Sync {
455 fn on_x509_bundles_update(&self, bundles: x509bundle::Set);
456 fn on_x509_bundles_watch_error(&self, err: Error);
457}