wasm_pkg_loader/source/
oci.rs

1mod config;
2
3use async_trait::async_trait;
4use bytes::Bytes;
5use config::{BasicCredentials, OciConfig};
6use docker_credential::{CredentialRetrievalError, DockerCredential};
7use futures_util::{stream::BoxStream, StreamExt, TryStreamExt};
8use oci_distribution::{
9    errors::OciDistributionError, manifest::OciDescriptor, secrets::RegistryAuth, Reference,
10};
11use secrecy::ExposeSecret;
12use serde::Deserialize;
13use wasm_pkg_common::{
14    config::RegistryConfig,
15    metadata::RegistryMetadata,
16    package::{PackageRef, Version},
17    registry::Registry,
18    Error,
19};
20
21use crate::{
22    source::{PackageSource, VersionInfo},
23    Release,
24};
25
26#[derive(Default, Deserialize)]
27#[serde(rename_all = "camelCase")]
28struct OciRegistryMetadata {
29    registry: Option<String>,
30    namespace_prefix: Option<String>,
31}
32
33pub struct OciSource {
34    client: oci_wasm::WasmClient,
35    oci_registry: String,
36    namespace_prefix: Option<String>,
37    credentials: Option<BasicCredentials>,
38    registry_auth: Option<RegistryAuth>,
39}
40
41impl OciSource {
42    pub fn new(
43        registry: &Registry,
44        registry_config: &RegistryConfig,
45        registry_meta: &RegistryMetadata,
46    ) -> Result<Self, Error> {
47        let OciConfig {
48            client_config,
49            credentials,
50        } = registry_config.try_into()?;
51        let client = oci_distribution::Client::new(client_config);
52        let client = oci_wasm::WasmClient::new(client);
53
54        let oci_meta = registry_meta
55            .protocol_config::<OciRegistryMetadata>("oci")?
56            .unwrap_or_default();
57        let oci_registry = oci_meta.registry.unwrap_or_else(|| registry.to_string());
58
59        Ok(Self {
60            client,
61            oci_registry,
62            namespace_prefix: oci_meta.namespace_prefix,
63            credentials,
64            registry_auth: None,
65        })
66    }
67
68    async fn auth(&mut self, reference: &Reference) -> Result<RegistryAuth, Error> {
69        if self.registry_auth.is_none() {
70            let mut auth = self.get_credentials()?;
71            // Preflight auth to check for validity; this isn't wasted
72            // effort because the oci_distribution::Client caches it
73            use oci_distribution::errors::OciDistributionError::AuthenticationFailure;
74            use oci_distribution::RegistryOperation::Pull;
75            match self.client.auth(reference, &auth, Pull).await {
76                Ok(_) => (),
77                Err(err @ AuthenticationFailure(_)) if auth != RegistryAuth::Anonymous => {
78                    // The failed credentials might not even be required for this image; retry anonymously
79                    if self
80                        .client
81                        .auth(reference, &RegistryAuth::Anonymous, Pull)
82                        .await
83                        .is_ok()
84                    {
85                        auth = RegistryAuth::Anonymous;
86                    } else {
87                        return Err(oci_registry_error(err));
88                    }
89                }
90                Err(err) => return Err(oci_registry_error(err)),
91            }
92            self.registry_auth = Some(auth);
93        }
94        Ok(self.registry_auth.clone().unwrap())
95    }
96
97    fn get_credentials(&self) -> Result<RegistryAuth, Error> {
98        if let Some(BasicCredentials { username, password }) = &self.credentials {
99            return Ok(RegistryAuth::Basic(
100                username.clone(),
101                password.expose_secret().clone(),
102            ));
103        }
104
105        let server_url = format!("https://{}", self.oci_registry);
106        match docker_credential::get_credential(&server_url) {
107            Ok(DockerCredential::UsernamePassword(username, password)) => {
108                return Ok(RegistryAuth::Basic(username, password));
109            }
110            Ok(DockerCredential::IdentityToken(_)) => {
111                return Err(Error::CredentialError(anyhow::anyhow!(
112                    "identity tokens not supported"
113                )));
114            }
115            Err(err) => {
116                if matches!(
117                    err,
118                    CredentialRetrievalError::ConfigNotFound
119                        | CredentialRetrievalError::ConfigReadError
120                        | CredentialRetrievalError::NoCredentialConfigured
121                ) {
122                    tracing::debug!("Failed to look up OCI credentials: {err}");
123                } else {
124                    tracing::warn!("Failed to look up OCI credentials: {err}");
125                };
126            }
127        }
128
129        Ok(RegistryAuth::Anonymous)
130    }
131
132    fn make_reference(&self, package: &PackageRef, version: Option<&Version>) -> Reference {
133        let repository = format!(
134            "{}{}/{}",
135            self.namespace_prefix.as_deref().unwrap_or_default(),
136            package.namespace(),
137            package.name()
138        );
139        let tag = version
140            .map(|ver| ver.to_string())
141            .unwrap_or_else(|| "latest".into());
142        Reference::with_tag(self.oci_registry.clone(), repository, tag)
143    }
144}
145
146#[async_trait]
147impl PackageSource for OciSource {
148    async fn list_all_versions(&mut self, package: &PackageRef) -> Result<Vec<VersionInfo>, Error> {
149        let reference = self.make_reference(package, None);
150
151        tracing::debug!(?reference, "Listing tags for OCI reference");
152        let auth = self.auth(&reference).await?;
153        let resp = self
154            .client
155            .list_tags(&reference, &auth, None, None)
156            .await
157            .map_err(oci_registry_error)?;
158        tracing::trace!(response = ?resp, "List tags response");
159
160        // Return only tags that parse as valid semver versions.
161        let versions = resp
162            .tags
163            .iter()
164            .flat_map(|tag| match Version::parse(tag) {
165                Ok(version) => Some(VersionInfo {
166                    version,
167                    yanked: false,
168                }),
169                Err(err) => {
170                    tracing::warn!(?tag, error = ?err, "Ignoring invalid version tag");
171                    None
172                }
173            })
174            .collect();
175        Ok(versions)
176    }
177
178    async fn get_release(
179        &mut self,
180        package: &PackageRef,
181        version: &Version,
182    ) -> Result<Release, Error> {
183        let reference = self.make_reference(package, Some(version));
184
185        tracing::debug!(?reference, "Fetching image manifest for OCI reference");
186        let auth = self.auth(&reference).await?;
187        let (manifest, _config, _digest) = self
188            .client
189            .pull_manifest_and_config(&reference, &auth)
190            .await
191            .map_err(Error::RegistryError)?;
192        tracing::trace!(?manifest, "Got manifest");
193
194        let version = version.to_owned();
195        let content_digest = manifest
196            .layers
197            .into_iter()
198            .next()
199            .ok_or_else(|| {
200                Error::InvalidPackageManifest("Returned manifest had no layers".to_string())
201            })?
202            .digest
203            .parse()?;
204        Ok(Release {
205            version,
206            content_digest,
207        })
208    }
209
210    async fn stream_content_unvalidated(
211        &mut self,
212        package: &PackageRef,
213        release: &Release,
214    ) -> Result<BoxStream<Result<Bytes, Error>>, Error> {
215        let reference = self.make_reference(package, None);
216        let descriptor = OciDescriptor {
217            digest: release.content_digest.to_string(),
218            ..Default::default()
219        };
220        self.auth(&reference).await?;
221        let stream = self
222            .client
223            .pull_blob_stream(&reference, &descriptor)
224            .await
225            .map_err(oci_registry_error)?;
226        Ok(stream.map_err(Into::into).boxed())
227    }
228}
229
230fn oci_registry_error(err: OciDistributionError) -> Error {
231    match err {
232        // Technically this could be a missing version too, but there really isn't a way to find out
233        OciDistributionError::ImageManifestNotFoundError(_) => Error::PackageNotFound,
234        _ => Error::RegistryError(err.into()),
235    }
236}