warg_client/
api.rs

1//! A module for Warg registry API clients.
2
3use anyhow::{anyhow, Result};
4use bytes::Bytes;
5use futures_util::{future::ready, stream::once, Stream, StreamExt, TryStreamExt};
6use indexmap::IndexMap;
7use reqwest::{
8    header::{HeaderMap, HeaderValue},
9    Body, IntoUrl, Method, RequestBuilder, Response, StatusCode,
10};
11use secrecy::{ExposeSecret, Secret};
12use serde::de::DeserializeOwned;
13use std::borrow::Cow;
14use thiserror::Error;
15use warg_api::{
16    v1::{
17        content::{ContentError, ContentSourcesResponse},
18        fetch::{
19            FetchError, FetchLogsRequest, FetchLogsResponse, FetchPackageNamesRequest,
20            FetchPackageNamesResponse,
21        },
22        ledger::{LedgerError, LedgerSourcesResponse},
23        monitor::{CheckpointVerificationResponse, MonitorError},
24        package::{ContentSource, PackageError, PackageRecord, PublishRecordRequest},
25        paths,
26        proof::{
27            ConsistencyRequest, ConsistencyResponse, InclusionRequest, InclusionResponse,
28            ProofError,
29        },
30        REGISTRY_HEADER_NAME, REGISTRY_HINT_HEADER_NAME,
31    },
32    WellKnownConfig, WELL_KNOWN_PATH,
33};
34use warg_crypto::hash::{AnyHash, HashError, Sha256};
35use warg_protocol::{
36    registry::{Checkpoint, LogId, LogLeaf, MapLeaf, RecordId, TimestampedCheckpoint},
37    SerdeEnvelope,
38};
39use warg_transparency::{
40    log::{ConsistencyProofError, InclusionProofError, LogProofBundle, ProofBundle},
41    map::MapProofBundle,
42};
43
44use crate::{registry_url::RegistryUrl, storage::RegistryDomain};
45/// Represents an error that occurred while communicating with the registry.
46#[derive(Debug, Error)]
47pub enum ClientError {
48    /// An error was returned from the fetch API.
49    #[error(transparent)]
50    Fetch(#[from] FetchError),
51    /// An error was returned from the package API.
52    #[error(transparent)]
53    Package(#[from] PackageError),
54    /// An error was returned from the content API.
55    #[error(transparent)]
56    Content(#[from] ContentError),
57    /// An error was returned from the proof API.
58    #[error(transparent)]
59    Proof(#[from] ProofError),
60    /// An error was returned from the monitor API.
61    #[error(transparent)]
62    Monitor(#[from] MonitorError),
63    /// An error was returned from the ledger API.
64    #[error(transparent)]
65    Ledger(#[from] LedgerError),
66    /// An error occurred while communicating with the registry.
67    #[error("failed to send request to registry server: {0}")]
68    Communication(#[from] reqwest::Error),
69    /// An unexpected response was received from the server.
70    #[error("{message} (status code: {status})")]
71    UnexpectedResponse {
72        /// The response from the server.
73        status: StatusCode,
74        /// The error message.
75        message: String,
76    },
77    /// The provided root for a consistency proof was incorrect.
78    #[error(
79        "the client failed to prove consistency: found root `{found}` but was given root `{root}`"
80    )]
81    IncorrectConsistencyProof {
82        /// The provided root.
83        root: AnyHash,
84        /// The found root.
85        found: AnyHash,
86    },
87    /// A hash returned from the server was incorrect.
88    #[error("the server returned an invalid hash: {0}")]
89    Hash(#[from] HashError),
90    /// The client failed a consistency proof.
91    #[error("the client failed a consistency proof: {0}")]
92    ConsistencyProof(#[from] ConsistencyProofError),
93    /// The client failed an inclusion proof.
94    #[error("the client failed an inclusion proof: {0}")]
95    InclusionProof(#[from] InclusionProofError),
96    /// The record was not published.
97    #[error("record `{0}` has not been published")]
98    RecordNotPublished(RecordId),
99    /// Could not find a source for the given content digest.
100    #[error("no download location could be found for content digest `{0}`")]
101    NoSourceForContent(AnyHash),
102    /// All sources for the given content digest returned an error response.
103    #[error("all sources for content digest `{0}` returned an error response")]
104    AllSourcesFailed(AnyHash),
105    /// Invalid upload HTTP method.
106    #[error("server returned an invalid HTTP method `{0}`")]
107    InvalidHttpMethod(String),
108    /// Invalid upload HTTP method.
109    #[error("server returned an invalid HTTP header `{0}: {1}`")]
110    InvalidHttpHeader(String, String),
111    /// The provided log was not found with hint header.
112    #[error("log `{0}` was not found in this registry, but the registry provided the hint header: `{1:?}`")]
113    LogNotFoundWithHint(LogId, HeaderValue),
114    /// Invalid well-known config.
115    #[error("registry `{0}` returned an invalid well-known config")]
116    InvalidWellKnownConfig(String),
117    /// An other error occurred during the requested operation.
118    #[error(transparent)]
119    Other(#[from] anyhow::Error),
120}
121
122async fn deserialize<T: DeserializeOwned>(response: Response) -> Result<T, ClientError> {
123    let status = response.status();
124    match response.headers().get("content-type") {
125        Some(content_type) if content_type == "application/json" => {
126            let bytes = response
127                .bytes()
128                .await
129                .map_err(|e| ClientError::UnexpectedResponse {
130                    status,
131                    message: format!("failed to read response: {e}"),
132                })?;
133            serde_json::from_slice(&bytes).map_err(|e| {
134                tracing::debug!(
135                    "Unexpected response body: {}",
136                    String::from_utf8_lossy(&bytes)
137                );
138                ClientError::UnexpectedResponse {
139                    status,
140                    message: format!("failed to deserialize JSON response: {e}"),
141                }
142            })
143        }
144        Some(ty) => Err(ClientError::UnexpectedResponse {
145            status,
146            message: format!(
147                "the server returned an unsupported content type of `{ty}`",
148                ty = ty.to_str().unwrap_or("")
149            ),
150        }),
151        None => Err(ClientError::UnexpectedResponse {
152            status,
153            message: "the server response did not include a content type header".into(),
154        }),
155    }
156}
157
158async fn into_result<T: DeserializeOwned, E: DeserializeOwned + Into<ClientError>>(
159    response: Response,
160) -> Result<T, ClientError> {
161    if response.status().is_success() {
162        deserialize::<T>(response).await
163    } else {
164        Err(deserialize::<E>(response).await?.into())
165    }
166}
167
168trait WithWargHeader {
169    fn warg_header(self, registry_header: Option<&RegistryDomain>) -> Result<RequestBuilder>;
170}
171
172impl WithWargHeader for RequestBuilder {
173    fn warg_header(self, registry_header: Option<&RegistryDomain>) -> Result<RequestBuilder> {
174        if let Some(reg) = registry_header {
175            Ok(self.header(REGISTRY_HEADER_NAME, HeaderValue::try_from(reg.clone())?))
176        } else {
177            Ok(self)
178        }
179    }
180}
181
182trait WithAuth {
183    fn auth(self, auth_token: &Option<Secret<String>>) -> RequestBuilder;
184}
185
186impl WithAuth for RequestBuilder {
187    fn auth(self, auth_token: &Option<Secret<String>>) -> reqwest::RequestBuilder {
188        if let Some(tok) = auth_token {
189            self.bearer_auth(tok.expose_secret())
190        } else {
191            self
192        }
193    }
194}
195
196/// Represents a Warg API client for communicating with
197/// a Warg registry server.
198pub struct Client {
199    url: RegistryUrl,
200    client: reqwest::Client,
201    warg_registry_header: Option<RegistryDomain>,
202    auth_token: Option<Secret<String>>,
203}
204
205impl Client {
206    /// Creates a new API client with the given URL.
207    pub fn new(url: impl IntoUrl, auth_token: Option<Secret<String>>) -> Result<Self> {
208        let url = RegistryUrl::new(url)?;
209        Ok(Self {
210            url,
211            client: reqwest::Client::new(),
212            warg_registry_header: None,
213            auth_token,
214        })
215    }
216
217    /// Gets auth token
218    pub fn auth_token(&self) -> &Option<Secret<String>> {
219        &self.auth_token
220    }
221
222    /// Gets the URL of the API client.
223    pub fn url(&self) -> &RegistryUrl {
224        &self.url
225    }
226    /// Gets the `.well-known` configuration registry URL.
227    pub async fn well_known_config(&self) -> Result<Option<RegistryUrl>, ClientError> {
228        let url = self.url.join(WELL_KNOWN_PATH);
229        tracing::debug!(url, "getting `.well-known` config",);
230
231        let res = self.client.get(url).send().await?;
232
233        if !res.status().is_success() {
234            tracing::debug!(
235                "the `.well-known` config request returned HTTP status `{status}`",
236                status = res.status()
237            );
238            return Ok(None);
239        }
240
241        if let Some(warg_url) = res
242            .json::<WellKnownConfig>()
243            .await
244            .map_err(|e| {
245                tracing::debug!("parsing `.well-known` config failed: {e}");
246                ClientError::InvalidWellKnownConfig(self.url.registry_domain().to_string())
247            })?
248            .warg_url
249        {
250            Ok(Some(RegistryUrl::new(warg_url)?))
251        } else {
252            tracing::debug!("the `.well-known` config did not have a `wargUrl` set");
253            Ok(None)
254        }
255    }
256
257    /// Gets the latest checkpoint from the registry.
258    pub async fn latest_checkpoint(
259        &self,
260        registry_domain: Option<&RegistryDomain>,
261    ) -> Result<SerdeEnvelope<TimestampedCheckpoint>, ClientError> {
262        let url = self.url.join(paths::fetch_checkpoint());
263        tracing::debug!(
264            url,
265            registry_header = ?registry_domain,
266            "getting latest checkpoint",
267        );
268        into_result::<_, FetchError>(
269            self.client
270                .get(url)
271                .warg_header(registry_domain)?
272                .auth(self.auth_token())
273                .send()
274                .await?,
275        )
276        .await
277    }
278
279    /// Verify checkpoint of the registry.
280    pub async fn verify_checkpoint(
281        &self,
282        registry_domain: Option<&RegistryDomain>,
283        request: SerdeEnvelope<TimestampedCheckpoint>,
284    ) -> Result<CheckpointVerificationResponse, ClientError> {
285        let url = self.url.join(paths::verify_checkpoint());
286        tracing::debug!(
287            url,
288            registry_header = ?registry_domain,
289            "verifying checkpoint",
290        );
291
292        let response = self
293            .client
294            .post(url)
295            .json(&request)
296            .warg_header(registry_domain)?
297            .auth(self.auth_token())
298            .send()
299            .await?;
300        into_result::<_, MonitorError>(response).await
301    }
302
303    /// Fetches package log entries from the registry.
304    pub async fn fetch_logs(
305        &self,
306        registry_domain: Option<&RegistryDomain>,
307        request: FetchLogsRequest<'_>,
308    ) -> Result<FetchLogsResponse, ClientError> {
309        let url = self.url.join(paths::fetch_logs());
310        tracing::debug!(
311            url,
312            registry_header = ?registry_domain,
313            "fetching logs",
314        );
315        let response = self
316            .client
317            .post(&url)
318            .json(&request)
319            .warg_header(registry_domain)?
320            .auth(self.auth_token())
321            .send()
322            .await?;
323
324        let header = response.headers().get(REGISTRY_HINT_HEADER_NAME).cloned();
325        into_result::<_, FetchError>(response)
326            .await
327            .map_err(|err| match err {
328                ClientError::Fetch(FetchError::LogNotFound(log_id)) if header.is_some() => {
329                    ClientError::LogNotFoundWithHint(log_id, header.unwrap())
330                }
331                _ => err,
332            })
333    }
334
335    /// Fetches package names from the registry.
336    pub async fn fetch_package_names(
337        &self,
338        registry_domain: Option<&RegistryDomain>,
339        request: FetchPackageNamesRequest<'_>,
340    ) -> Result<FetchPackageNamesResponse, ClientError> {
341        let url = self.url.join(paths::fetch_package_names());
342        tracing::debug!(
343            url,
344            registry_header = ?registry_domain,
345            "fetching package names",
346        );
347        let response = self
348            .client
349            .post(url)
350            .warg_header(registry_domain)?
351            .auth(self.auth_token())
352            .json(&request)
353            .send()
354            .await?;
355        into_result::<_, FetchError>(response).await
356    }
357
358    /// Gets ledger sources from the registry.
359    pub async fn ledger_sources(
360        &self,
361        registry_domain: Option<&RegistryDomain>,
362    ) -> Result<LedgerSourcesResponse, ClientError> {
363        let url = self.url.join(paths::ledger_sources());
364        tracing::debug!(
365            url,
366            registry_header = ?registry_domain,
367            "getting ledger sources",
368        );
369        into_result::<_, LedgerError>(
370            self.client
371                .get(url)
372                .warg_header(registry_domain)?
373                .auth(self.auth_token())
374                .send()
375                .await?,
376        )
377        .await
378    }
379
380    /// Publish a new record to a package log.
381    pub async fn publish_package_record(
382        &self,
383        registry_domain: Option<&RegistryDomain>,
384        log_id: &LogId,
385        request: PublishRecordRequest<'_>,
386    ) -> Result<PackageRecord, ClientError> {
387        let url = self.url.join(&paths::publish_package_record(log_id));
388        tracing::debug!(
389            log_id = log_id.to_string(),
390            url,
391            registry_header = ?registry_domain,
392            "publishing to package",
393        );
394        let response = self
395            .client
396            .post(url)
397            .json(&request)
398            .warg_header(registry_domain)?
399            .auth(self.auth_token())
400            .send()
401            .await?;
402        into_result::<_, PackageError>(response).await
403    }
404
405    /// Gets a package record from the registry.
406    pub async fn get_package_record(
407        &self,
408        registry_domain: Option<&RegistryDomain>,
409        log_id: &LogId,
410        record_id: &RecordId,
411    ) -> Result<PackageRecord, ClientError> {
412        let url = self.url.join(&paths::package_record(log_id, record_id));
413        tracing::debug!(
414            log_id = log_id.to_string(),
415            record_id = record_id.to_string(),
416            url,
417            registry_header = ?registry_domain,
418            "getting package record",
419        );
420        into_result::<_, PackageError>(
421            self.client
422                .get(url)
423                .warg_header(registry_domain)?
424                .auth(self.auth_token())
425                .send()
426                .await?,
427        )
428        .await
429    }
430
431    /// Gets a content sources from the registry.
432    pub async fn content_sources(
433        &self,
434        registry_domain: Option<&RegistryDomain>,
435        digest: &AnyHash,
436    ) -> Result<ContentSourcesResponse, ClientError> {
437        let url = self.url.join(&paths::content_sources(digest));
438        tracing::debug!(
439            digest = digest.to_string(),
440            url,
441            registry_header = ?registry_domain,
442            "getting content sources for digest",
443        );
444        into_result::<_, ContentError>(
445            self.client
446                .get(url)
447                .warg_header(registry_domain)?
448                .auth(self.auth_token())
449                .send()
450                .await?,
451        )
452        .await
453    }
454
455    /// Downloads the content associated with a given record.
456    pub async fn download_content(
457        &self,
458        registry_domain: Option<&RegistryDomain>,
459        digest: &AnyHash,
460    ) -> Result<impl Stream<Item = Result<Bytes>>, ClientError> {
461        let ContentSourcesResponse { content_sources } =
462            self.content_sources(registry_domain, digest).await?;
463
464        let sources = content_sources
465            .get(digest)
466            .ok_or(ClientError::AllSourcesFailed(digest.clone()))?;
467
468        for source in sources {
469            let ContentSource::HttpGet { url, .. } = source;
470
471            tracing::debug!("downloading content `{digest}` from `{url}`");
472
473            let response = self.client.get(url).send().await?;
474            if !response.status().is_success() {
475                tracing::debug!(
476                    "failed to download content `{digest}` from `{url}`: {status}",
477                    status = response.status()
478                );
479                continue;
480            }
481
482            return Ok(validate_stream(
483                digest,
484                response.bytes_stream().map_err(|e| anyhow!(e)),
485            ));
486        }
487
488        Err(ClientError::AllSourcesFailed(digest.clone()))
489    }
490
491    /// Set warg-registry header value
492    pub fn set_warg_registry(&mut self, registry: Option<RegistryDomain>) {
493        self.warg_registry_header = registry;
494    }
495
496    /// Proves the inclusion of the given package log heads in the registry.
497    pub async fn prove_inclusion(
498        &self,
499        registry_domain: Option<&RegistryDomain>,
500        request: InclusionRequest,
501        checkpoint: &Checkpoint,
502        leafs: &[LogLeaf],
503    ) -> Result<(), ClientError> {
504        let url = self.url.join(paths::prove_inclusion());
505        tracing::debug!(
506            url,
507            registry_header = ?registry_domain,
508            "proving checkpoint inclusion",
509        );
510        let response = into_result::<InclusionResponse, ProofError>(
511            self.client
512                .post(url)
513                .json(&request)
514                .warg_header(registry_domain)?
515                .auth(self.auth_token())
516                .send()
517                .await?,
518        )
519        .await?;
520
521        Self::validate_inclusion_response(response, checkpoint, leafs)
522    }
523
524    /// Proves consistency between two log roots.
525    pub async fn prove_log_consistency(
526        &self,
527        registry_domain: Option<&RegistryDomain>,
528        request: ConsistencyRequest,
529        from_log_root: Cow<'_, AnyHash>,
530        to_log_root: Cow<'_, AnyHash>,
531    ) -> Result<(), ClientError> {
532        let url = self.url.join(paths::prove_consistency());
533        let response = into_result::<ConsistencyResponse, ProofError>(
534            self.client
535                .post(url)
536                .json(&request)
537                .warg_header(registry_domain)?
538                .auth(self.auth_token())
539                .send()
540                .await?,
541        )
542        .await?;
543
544        let proof = ProofBundle::<Sha256, LogLeaf>::decode(&response.proof).unwrap();
545        let (log_data, consistencies, inclusions) = proof.unbundle();
546        if !inclusions.is_empty() {
547            return Err(ClientError::Proof(ProofError::BundleFailure(
548                "expected no inclusion proofs".into(),
549            )));
550        }
551
552        if consistencies.len() != 1 {
553            return Err(ClientError::Proof(ProofError::BundleFailure(
554                "expected exactly one consistency proof".into(),
555            )));
556        }
557
558        let (from, to) = consistencies
559            .first()
560            .unwrap()
561            .evaluate(&log_data)
562            .map(|(from, to)| (AnyHash::from(from), AnyHash::from(to)))?;
563
564        if from_log_root.as_ref() != &from {
565            return Err(ClientError::IncorrectConsistencyProof {
566                root: from_log_root.into_owned(),
567                found: from,
568            });
569        }
570
571        if to_log_root.as_ref() != &to {
572            return Err(ClientError::IncorrectConsistencyProof {
573                root: to_log_root.into_owned(),
574                found: to,
575            });
576        }
577
578        Ok(())
579    }
580
581    /// Uploads package content to the registry.
582    pub async fn upload_content(
583        &self,
584        method: &str,
585        url: &str,
586        headers: &IndexMap<String, String>,
587        content: impl Into<Body>,
588    ) -> Result<(), ClientError> {
589        // Upload URLs may be relative to the registry URL.
590        let url = self.url.join(url);
591
592        let method = match method {
593            "POST" => Method::POST,
594            "PUT" => Method::PUT,
595            method => return Err(ClientError::InvalidHttpMethod(method.to_string())),
596        };
597
598        let headers = headers
599            .iter()
600            .map(|(k, v)| {
601                let name = match k.as_str() {
602                    "authorization" => reqwest::header::AUTHORIZATION,
603                    "content-type" => reqwest::header::CONTENT_TYPE,
604                    _ => return Err(ClientError::InvalidHttpHeader(k.to_string(), v.to_string())),
605                };
606                let value = HeaderValue::try_from(k)
607                    .map_err(|_| ClientError::InvalidHttpHeader(k.to_string(), v.to_string()))?;
608                Ok((name, value))
609            })
610            .collect::<Result<HeaderMap, ClientError>>()?;
611
612        tracing::debug!("uploading content to `{url}`");
613
614        let response = self
615            .client
616            .request(method, url)
617            .headers(headers)
618            .body(content)
619            .send()
620            .await?;
621        if !response.status().is_success() {
622            return Err(ClientError::Package(
623                deserialize::<PackageError>(response).await?,
624            ));
625        }
626
627        Ok(())
628    }
629
630    fn validate_inclusion_response(
631        response: InclusionResponse,
632        checkpoint: &Checkpoint,
633        leafs: &[LogLeaf],
634    ) -> Result<(), ClientError> {
635        let log_proof_bundle: LogProofBundle<Sha256, LogLeaf> =
636            LogProofBundle::decode(response.log.as_slice())?;
637        let (log_data, _, log_inclusions) = log_proof_bundle.unbundle();
638        for (leaf, proof) in leafs.iter().zip(log_inclusions.iter()) {
639            let found = proof.evaluate_value(&log_data, leaf)?;
640            let root = checkpoint.log_root.clone().try_into()?;
641            if found != root {
642                return Err(ClientError::Proof(ProofError::IncorrectProof {
643                    root: checkpoint.log_root.clone(),
644                    found: found.into(),
645                }));
646            }
647        }
648
649        let map_proof_bundle: MapProofBundle<Sha256, LogId, MapLeaf> =
650            MapProofBundle::decode(response.map.as_slice())?;
651        let map_inclusions = map_proof_bundle.unbundle();
652        for (leaf, proof) in leafs.iter().zip(map_inclusions.iter()) {
653            let found = proof.evaluate(
654                &leaf.log_id,
655                &MapLeaf {
656                    record_id: leaf.record_id.clone(),
657                },
658            );
659            let root = checkpoint.map_root.clone().try_into()?;
660            if found != root {
661                return Err(ClientError::Proof(ProofError::IncorrectProof {
662                    root: checkpoint.map_root.clone(),
663                    found: found.into(),
664                }));
665            }
666        }
667
668        Ok(())
669    }
670}
671
672fn validate_stream(
673    digest: &AnyHash,
674    stream: impl Stream<Item = Result<Bytes>>,
675) -> impl Stream<Item = Result<Bytes>> {
676    let hasher = Some(digest.algorithm().hasher());
677    let expected = digest.clone();
678    stream
679        .map_ok(Some)
680        .chain(once(async { Ok(None) }))
681        .scan(hasher, move |hasher, res| {
682            ready(match res {
683                Ok(Some(bytes)) => {
684                    hasher.as_mut().unwrap().update(&bytes);
685                    Some(Ok(bytes))
686                }
687                Ok(None) => {
688                    let hasher = std::mem::take(hasher).unwrap();
689                    let computed = hasher.finalize();
690                    if expected == computed {
691                        None
692                    } else {
693                        Some(Err(anyhow!(
694                            "expected digest `{expected}` but computed digest `{computed}`"
695                        )))
696                    }
697                }
698                Err(err) => Some(Err(err)),
699            })
700        })
701}