1use anyhow::{anyhow, Result};
4use bytes::Bytes;
5use futures_util::{future::ready, stream::once, Stream, StreamExt, TryStreamExt};
6use indexmap::IndexMap;
7use reqwest::{
8 header::{HeaderMap, HeaderValue},
9 Body, IntoUrl, Method, RequestBuilder, Response, StatusCode,
10};
11use secrecy::{ExposeSecret, Secret};
12use serde::de::DeserializeOwned;
13use std::borrow::Cow;
14use thiserror::Error;
15use warg_api::{
16 v1::{
17 content::{ContentError, ContentSourcesResponse},
18 fetch::{
19 FetchError, FetchLogsRequest, FetchLogsResponse, FetchPackageNamesRequest,
20 FetchPackageNamesResponse,
21 },
22 ledger::{LedgerError, LedgerSourcesResponse},
23 monitor::{CheckpointVerificationResponse, MonitorError},
24 package::{ContentSource, PackageError, PackageRecord, PublishRecordRequest},
25 paths,
26 proof::{
27 ConsistencyRequest, ConsistencyResponse, InclusionRequest, InclusionResponse,
28 ProofError,
29 },
30 REGISTRY_HEADER_NAME, REGISTRY_HINT_HEADER_NAME,
31 },
32 WellKnownConfig, WELL_KNOWN_PATH,
33};
34use warg_crypto::hash::{AnyHash, HashError, Sha256};
35use warg_protocol::{
36 registry::{Checkpoint, LogId, LogLeaf, MapLeaf, RecordId, TimestampedCheckpoint},
37 SerdeEnvelope,
38};
39use warg_transparency::{
40 log::{ConsistencyProofError, InclusionProofError, LogProofBundle, ProofBundle},
41 map::MapProofBundle,
42};
43
44use crate::{registry_url::RegistryUrl, storage::RegistryDomain};
45#[derive(Debug, Error)]
47pub enum ClientError {
48 #[error(transparent)]
50 Fetch(#[from] FetchError),
51 #[error(transparent)]
53 Package(#[from] PackageError),
54 #[error(transparent)]
56 Content(#[from] ContentError),
57 #[error(transparent)]
59 Proof(#[from] ProofError),
60 #[error(transparent)]
62 Monitor(#[from] MonitorError),
63 #[error(transparent)]
65 Ledger(#[from] LedgerError),
66 #[error("failed to send request to registry server: {0}")]
68 Communication(#[from] reqwest::Error),
69 #[error("{message} (status code: {status})")]
71 UnexpectedResponse {
72 status: StatusCode,
74 message: String,
76 },
77 #[error(
79 "the client failed to prove consistency: found root `{found}` but was given root `{root}`"
80 )]
81 IncorrectConsistencyProof {
82 root: AnyHash,
84 found: AnyHash,
86 },
87 #[error("the server returned an invalid hash: {0}")]
89 Hash(#[from] HashError),
90 #[error("the client failed a consistency proof: {0}")]
92 ConsistencyProof(#[from] ConsistencyProofError),
93 #[error("the client failed an inclusion proof: {0}")]
95 InclusionProof(#[from] InclusionProofError),
96 #[error("record `{0}` has not been published")]
98 RecordNotPublished(RecordId),
99 #[error("no download location could be found for content digest `{0}`")]
101 NoSourceForContent(AnyHash),
102 #[error("all sources for content digest `{0}` returned an error response")]
104 AllSourcesFailed(AnyHash),
105 #[error("server returned an invalid HTTP method `{0}`")]
107 InvalidHttpMethod(String),
108 #[error("server returned an invalid HTTP header `{0}: {1}`")]
110 InvalidHttpHeader(String, String),
111 #[error("log `{0}` was not found in this registry, but the registry provided the hint header: `{1:?}`")]
113 LogNotFoundWithHint(LogId, HeaderValue),
114 #[error("registry `{0}` returned an invalid well-known config")]
116 InvalidWellKnownConfig(String),
117 #[error(transparent)]
119 Other(#[from] anyhow::Error),
120}
121
122async fn deserialize<T: DeserializeOwned>(response: Response) -> Result<T, ClientError> {
123 let status = response.status();
124 match response.headers().get("content-type") {
125 Some(content_type) if content_type == "application/json" => {
126 let bytes = response
127 .bytes()
128 .await
129 .map_err(|e| ClientError::UnexpectedResponse {
130 status,
131 message: format!("failed to read response: {e}"),
132 })?;
133 serde_json::from_slice(&bytes).map_err(|e| {
134 tracing::debug!(
135 "Unexpected response body: {}",
136 String::from_utf8_lossy(&bytes)
137 );
138 ClientError::UnexpectedResponse {
139 status,
140 message: format!("failed to deserialize JSON response: {e}"),
141 }
142 })
143 }
144 Some(ty) => Err(ClientError::UnexpectedResponse {
145 status,
146 message: format!(
147 "the server returned an unsupported content type of `{ty}`",
148 ty = ty.to_str().unwrap_or("")
149 ),
150 }),
151 None => Err(ClientError::UnexpectedResponse {
152 status,
153 message: "the server response did not include a content type header".into(),
154 }),
155 }
156}
157
158async fn into_result<T: DeserializeOwned, E: DeserializeOwned + Into<ClientError>>(
159 response: Response,
160) -> Result<T, ClientError> {
161 if response.status().is_success() {
162 deserialize::<T>(response).await
163 } else {
164 Err(deserialize::<E>(response).await?.into())
165 }
166}
167
168trait WithWargHeader {
169 fn warg_header(self, registry_header: Option<&RegistryDomain>) -> Result<RequestBuilder>;
170}
171
172impl WithWargHeader for RequestBuilder {
173 fn warg_header(self, registry_header: Option<&RegistryDomain>) -> Result<RequestBuilder> {
174 if let Some(reg) = registry_header {
175 Ok(self.header(REGISTRY_HEADER_NAME, HeaderValue::try_from(reg.clone())?))
176 } else {
177 Ok(self)
178 }
179 }
180}
181
182trait WithAuth {
183 fn auth(self, auth_token: &Option<Secret<String>>) -> RequestBuilder;
184}
185
186impl WithAuth for RequestBuilder {
187 fn auth(self, auth_token: &Option<Secret<String>>) -> reqwest::RequestBuilder {
188 if let Some(tok) = auth_token {
189 self.bearer_auth(tok.expose_secret())
190 } else {
191 self
192 }
193 }
194}
195
196pub struct Client {
199 url: RegistryUrl,
200 client: reqwest::Client,
201 warg_registry_header: Option<RegistryDomain>,
202 auth_token: Option<Secret<String>>,
203}
204
205impl Client {
206 pub fn new(url: impl IntoUrl, auth_token: Option<Secret<String>>) -> Result<Self> {
208 let url = RegistryUrl::new(url)?;
209 Ok(Self {
210 url,
211 client: reqwest::Client::new(),
212 warg_registry_header: None,
213 auth_token,
214 })
215 }
216
217 pub fn auth_token(&self) -> &Option<Secret<String>> {
219 &self.auth_token
220 }
221
222 pub fn url(&self) -> &RegistryUrl {
224 &self.url
225 }
226 pub async fn well_known_config(&self) -> Result<Option<RegistryUrl>, ClientError> {
228 let url = self.url.join(WELL_KNOWN_PATH);
229 tracing::debug!(url, "getting `.well-known` config",);
230
231 let res = self.client.get(url).send().await?;
232
233 if !res.status().is_success() {
234 tracing::debug!(
235 "the `.well-known` config request returned HTTP status `{status}`",
236 status = res.status()
237 );
238 return Ok(None);
239 }
240
241 if let Some(warg_url) = res
242 .json::<WellKnownConfig>()
243 .await
244 .map_err(|e| {
245 tracing::debug!("parsing `.well-known` config failed: {e}");
246 ClientError::InvalidWellKnownConfig(self.url.registry_domain().to_string())
247 })?
248 .warg_url
249 {
250 Ok(Some(RegistryUrl::new(warg_url)?))
251 } else {
252 tracing::debug!("the `.well-known` config did not have a `wargUrl` set");
253 Ok(None)
254 }
255 }
256
257 pub async fn latest_checkpoint(
259 &self,
260 registry_domain: Option<&RegistryDomain>,
261 ) -> Result<SerdeEnvelope<TimestampedCheckpoint>, ClientError> {
262 let url = self.url.join(paths::fetch_checkpoint());
263 tracing::debug!(
264 url,
265 registry_header = ?registry_domain,
266 "getting latest checkpoint",
267 );
268 into_result::<_, FetchError>(
269 self.client
270 .get(url)
271 .warg_header(registry_domain)?
272 .auth(self.auth_token())
273 .send()
274 .await?,
275 )
276 .await
277 }
278
279 pub async fn verify_checkpoint(
281 &self,
282 registry_domain: Option<&RegistryDomain>,
283 request: SerdeEnvelope<TimestampedCheckpoint>,
284 ) -> Result<CheckpointVerificationResponse, ClientError> {
285 let url = self.url.join(paths::verify_checkpoint());
286 tracing::debug!(
287 url,
288 registry_header = ?registry_domain,
289 "verifying checkpoint",
290 );
291
292 let response = self
293 .client
294 .post(url)
295 .json(&request)
296 .warg_header(registry_domain)?
297 .auth(self.auth_token())
298 .send()
299 .await?;
300 into_result::<_, MonitorError>(response).await
301 }
302
303 pub async fn fetch_logs(
305 &self,
306 registry_domain: Option<&RegistryDomain>,
307 request: FetchLogsRequest<'_>,
308 ) -> Result<FetchLogsResponse, ClientError> {
309 let url = self.url.join(paths::fetch_logs());
310 tracing::debug!(
311 url,
312 registry_header = ?registry_domain,
313 "fetching logs",
314 );
315 let response = self
316 .client
317 .post(&url)
318 .json(&request)
319 .warg_header(registry_domain)?
320 .auth(self.auth_token())
321 .send()
322 .await?;
323
324 let header = response.headers().get(REGISTRY_HINT_HEADER_NAME).cloned();
325 into_result::<_, FetchError>(response)
326 .await
327 .map_err(|err| match err {
328 ClientError::Fetch(FetchError::LogNotFound(log_id)) if header.is_some() => {
329 ClientError::LogNotFoundWithHint(log_id, header.unwrap())
330 }
331 _ => err,
332 })
333 }
334
335 pub async fn fetch_package_names(
337 &self,
338 registry_domain: Option<&RegistryDomain>,
339 request: FetchPackageNamesRequest<'_>,
340 ) -> Result<FetchPackageNamesResponse, ClientError> {
341 let url = self.url.join(paths::fetch_package_names());
342 tracing::debug!(
343 url,
344 registry_header = ?registry_domain,
345 "fetching package names",
346 );
347 let response = self
348 .client
349 .post(url)
350 .warg_header(registry_domain)?
351 .auth(self.auth_token())
352 .json(&request)
353 .send()
354 .await?;
355 into_result::<_, FetchError>(response).await
356 }
357
358 pub async fn ledger_sources(
360 &self,
361 registry_domain: Option<&RegistryDomain>,
362 ) -> Result<LedgerSourcesResponse, ClientError> {
363 let url = self.url.join(paths::ledger_sources());
364 tracing::debug!(
365 url,
366 registry_header = ?registry_domain,
367 "getting ledger sources",
368 );
369 into_result::<_, LedgerError>(
370 self.client
371 .get(url)
372 .warg_header(registry_domain)?
373 .auth(self.auth_token())
374 .send()
375 .await?,
376 )
377 .await
378 }
379
380 pub async fn publish_package_record(
382 &self,
383 registry_domain: Option<&RegistryDomain>,
384 log_id: &LogId,
385 request: PublishRecordRequest<'_>,
386 ) -> Result<PackageRecord, ClientError> {
387 let url = self.url.join(&paths::publish_package_record(log_id));
388 tracing::debug!(
389 log_id = log_id.to_string(),
390 url,
391 registry_header = ?registry_domain,
392 "publishing to package",
393 );
394 let response = self
395 .client
396 .post(url)
397 .json(&request)
398 .warg_header(registry_domain)?
399 .auth(self.auth_token())
400 .send()
401 .await?;
402 into_result::<_, PackageError>(response).await
403 }
404
405 pub async fn get_package_record(
407 &self,
408 registry_domain: Option<&RegistryDomain>,
409 log_id: &LogId,
410 record_id: &RecordId,
411 ) -> Result<PackageRecord, ClientError> {
412 let url = self.url.join(&paths::package_record(log_id, record_id));
413 tracing::debug!(
414 log_id = log_id.to_string(),
415 record_id = record_id.to_string(),
416 url,
417 registry_header = ?registry_domain,
418 "getting package record",
419 );
420 into_result::<_, PackageError>(
421 self.client
422 .get(url)
423 .warg_header(registry_domain)?
424 .auth(self.auth_token())
425 .send()
426 .await?,
427 )
428 .await
429 }
430
431 pub async fn content_sources(
433 &self,
434 registry_domain: Option<&RegistryDomain>,
435 digest: &AnyHash,
436 ) -> Result<ContentSourcesResponse, ClientError> {
437 let url = self.url.join(&paths::content_sources(digest));
438 tracing::debug!(
439 digest = digest.to_string(),
440 url,
441 registry_header = ?registry_domain,
442 "getting content sources for digest",
443 );
444 into_result::<_, ContentError>(
445 self.client
446 .get(url)
447 .warg_header(registry_domain)?
448 .auth(self.auth_token())
449 .send()
450 .await?,
451 )
452 .await
453 }
454
455 pub async fn download_content(
457 &self,
458 registry_domain: Option<&RegistryDomain>,
459 digest: &AnyHash,
460 ) -> Result<impl Stream<Item = Result<Bytes>>, ClientError> {
461 let ContentSourcesResponse { content_sources } =
462 self.content_sources(registry_domain, digest).await?;
463
464 let sources = content_sources
465 .get(digest)
466 .ok_or(ClientError::AllSourcesFailed(digest.clone()))?;
467
468 for source in sources {
469 let ContentSource::HttpGet { url, .. } = source;
470
471 tracing::debug!("downloading content `{digest}` from `{url}`");
472
473 let response = self.client.get(url).send().await?;
474 if !response.status().is_success() {
475 tracing::debug!(
476 "failed to download content `{digest}` from `{url}`: {status}",
477 status = response.status()
478 );
479 continue;
480 }
481
482 return Ok(validate_stream(
483 digest,
484 response.bytes_stream().map_err(|e| anyhow!(e)),
485 ));
486 }
487
488 Err(ClientError::AllSourcesFailed(digest.clone()))
489 }
490
491 pub fn set_warg_registry(&mut self, registry: Option<RegistryDomain>) {
493 self.warg_registry_header = registry;
494 }
495
496 pub async fn prove_inclusion(
498 &self,
499 registry_domain: Option<&RegistryDomain>,
500 request: InclusionRequest,
501 checkpoint: &Checkpoint,
502 leafs: &[LogLeaf],
503 ) -> Result<(), ClientError> {
504 let url = self.url.join(paths::prove_inclusion());
505 tracing::debug!(
506 url,
507 registry_header = ?registry_domain,
508 "proving checkpoint inclusion",
509 );
510 let response = into_result::<InclusionResponse, ProofError>(
511 self.client
512 .post(url)
513 .json(&request)
514 .warg_header(registry_domain)?
515 .auth(self.auth_token())
516 .send()
517 .await?,
518 )
519 .await?;
520
521 Self::validate_inclusion_response(response, checkpoint, leafs)
522 }
523
524 pub async fn prove_log_consistency(
526 &self,
527 registry_domain: Option<&RegistryDomain>,
528 request: ConsistencyRequest,
529 from_log_root: Cow<'_, AnyHash>,
530 to_log_root: Cow<'_, AnyHash>,
531 ) -> Result<(), ClientError> {
532 let url = self.url.join(paths::prove_consistency());
533 let response = into_result::<ConsistencyResponse, ProofError>(
534 self.client
535 .post(url)
536 .json(&request)
537 .warg_header(registry_domain)?
538 .auth(self.auth_token())
539 .send()
540 .await?,
541 )
542 .await?;
543
544 let proof = ProofBundle::<Sha256, LogLeaf>::decode(&response.proof).unwrap();
545 let (log_data, consistencies, inclusions) = proof.unbundle();
546 if !inclusions.is_empty() {
547 return Err(ClientError::Proof(ProofError::BundleFailure(
548 "expected no inclusion proofs".into(),
549 )));
550 }
551
552 if consistencies.len() != 1 {
553 return Err(ClientError::Proof(ProofError::BundleFailure(
554 "expected exactly one consistency proof".into(),
555 )));
556 }
557
558 let (from, to) = consistencies
559 .first()
560 .unwrap()
561 .evaluate(&log_data)
562 .map(|(from, to)| (AnyHash::from(from), AnyHash::from(to)))?;
563
564 if from_log_root.as_ref() != &from {
565 return Err(ClientError::IncorrectConsistencyProof {
566 root: from_log_root.into_owned(),
567 found: from,
568 });
569 }
570
571 if to_log_root.as_ref() != &to {
572 return Err(ClientError::IncorrectConsistencyProof {
573 root: to_log_root.into_owned(),
574 found: to,
575 });
576 }
577
578 Ok(())
579 }
580
581 pub async fn upload_content(
583 &self,
584 method: &str,
585 url: &str,
586 headers: &IndexMap<String, String>,
587 content: impl Into<Body>,
588 ) -> Result<(), ClientError> {
589 let url = self.url.join(url);
591
592 let method = match method {
593 "POST" => Method::POST,
594 "PUT" => Method::PUT,
595 method => return Err(ClientError::InvalidHttpMethod(method.to_string())),
596 };
597
598 let headers = headers
599 .iter()
600 .map(|(k, v)| {
601 let name = match k.as_str() {
602 "authorization" => reqwest::header::AUTHORIZATION,
603 "content-type" => reqwest::header::CONTENT_TYPE,
604 _ => return Err(ClientError::InvalidHttpHeader(k.to_string(), v.to_string())),
605 };
606 let value = HeaderValue::try_from(k)
607 .map_err(|_| ClientError::InvalidHttpHeader(k.to_string(), v.to_string()))?;
608 Ok((name, value))
609 })
610 .collect::<Result<HeaderMap, ClientError>>()?;
611
612 tracing::debug!("uploading content to `{url}`");
613
614 let response = self
615 .client
616 .request(method, url)
617 .headers(headers)
618 .body(content)
619 .send()
620 .await?;
621 if !response.status().is_success() {
622 return Err(ClientError::Package(
623 deserialize::<PackageError>(response).await?,
624 ));
625 }
626
627 Ok(())
628 }
629
630 fn validate_inclusion_response(
631 response: InclusionResponse,
632 checkpoint: &Checkpoint,
633 leafs: &[LogLeaf],
634 ) -> Result<(), ClientError> {
635 let log_proof_bundle: LogProofBundle<Sha256, LogLeaf> =
636 LogProofBundle::decode(response.log.as_slice())?;
637 let (log_data, _, log_inclusions) = log_proof_bundle.unbundle();
638 for (leaf, proof) in leafs.iter().zip(log_inclusions.iter()) {
639 let found = proof.evaluate_value(&log_data, leaf)?;
640 let root = checkpoint.log_root.clone().try_into()?;
641 if found != root {
642 return Err(ClientError::Proof(ProofError::IncorrectProof {
643 root: checkpoint.log_root.clone(),
644 found: found.into(),
645 }));
646 }
647 }
648
649 let map_proof_bundle: MapProofBundle<Sha256, LogId, MapLeaf> =
650 MapProofBundle::decode(response.map.as_slice())?;
651 let map_inclusions = map_proof_bundle.unbundle();
652 for (leaf, proof) in leafs.iter().zip(map_inclusions.iter()) {
653 let found = proof.evaluate(
654 &leaf.log_id,
655 &MapLeaf {
656 record_id: leaf.record_id.clone(),
657 },
658 );
659 let root = checkpoint.map_root.clone().try_into()?;
660 if found != root {
661 return Err(ClientError::Proof(ProofError::IncorrectProof {
662 root: checkpoint.map_root.clone(),
663 found: found.into(),
664 }));
665 }
666 }
667
668 Ok(())
669 }
670}
671
672fn validate_stream(
673 digest: &AnyHash,
674 stream: impl Stream<Item = Result<Bytes>>,
675) -> impl Stream<Item = Result<Bytes>> {
676 let hasher = Some(digest.algorithm().hasher());
677 let expected = digest.clone();
678 stream
679 .map_ok(Some)
680 .chain(once(async { Ok(None) }))
681 .scan(hasher, move |hasher, res| {
682 ready(match res {
683 Ok(Some(bytes)) => {
684 hasher.as_mut().unwrap().update(&bytes);
685 Some(Ok(bytes))
686 }
687 Ok(None) => {
688 let hasher = std::mem::take(hasher).unwrap();
689 let computed = hasher.finalize();
690 if expected == computed {
691 None
692 } else {
693 Some(Err(anyhow!(
694 "expected digest `{expected}` but computed digest `{computed}`"
695 )))
696 }
697 }
698 Err(err) => Some(Err(err)),
699 })
700 })
701}