Skip to main content

wasm_pkg_client/oci/
config.rs

1use anyhow::Context;
2use base64::{
3    engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig},
4    Engine,
5};
6use oci_client::client::ClientConfig;
7use secrecy::{ExposeSecret, SecretString};
8use serde::{Deserialize, Serialize, Serializer};
9use wasm_pkg_common::{config::RegistryConfig, Error};
10
11/// Registry configuration for OCI backends.
12///
13/// See: [`RegistryConfig::backend_config`]
14#[derive(Default, Serialize)]
15#[serde(into = "OciRegistryConfigToml")]
16pub struct OciRegistryConfig {
17    pub client_config: ClientConfig,
18    pub credentials: Option<BasicCredentials>,
19}
20
21impl Clone for OciRegistryConfig {
22    fn clone(&self) -> Self {
23        let client_config = ClientConfig {
24            protocol: self.client_config.protocol.clone(),
25            extra_root_certificates: self.client_config.extra_root_certificates.clone(),
26            tls_certs_only: self.client_config.tls_certs_only.clone(),
27            platform_resolver: None,
28            http_proxy: self.client_config.http_proxy.clone(),
29            https_proxy: self.client_config.https_proxy.clone(),
30            no_proxy: self.client_config.no_proxy.clone(),
31            ..self.client_config
32        };
33        Self {
34            client_config,
35            credentials: self.credentials.clone(),
36        }
37    }
38}
39
40impl std::fmt::Debug for OciRegistryConfig {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("OciConfig")
43            .field("client_config", &"...")
44            .field("credentials", &self.credentials)
45            .finish()
46    }
47}
48
49impl TryFrom<&RegistryConfig> for OciRegistryConfig {
50    type Error = Error;
51
52    fn try_from(registry_config: &RegistryConfig) -> Result<Self, Self::Error> {
53        let OciRegistryConfigToml { auth, protocol } =
54            registry_config.backend_config("oci")?.unwrap_or_default();
55        let mut client_config = ClientConfig::default();
56        if let Some(protocol) = protocol {
57            client_config.protocol = oci_client_protocol(&protocol)?;
58        };
59        let credentials = auth
60            .map(TryInto::try_into)
61            .transpose()
62            .map_err(Error::InvalidConfig)?;
63        Ok(Self {
64            client_config,
65            credentials,
66        })
67    }
68}
69
70#[derive(Default, Deserialize, Serialize)]
71struct OciRegistryConfigToml {
72    auth: Option<TomlAuth>,
73    protocol: Option<String>,
74}
75
76impl From<OciRegistryConfig> for OciRegistryConfigToml {
77    fn from(value: OciRegistryConfig) -> Self {
78        OciRegistryConfigToml {
79            auth: value.credentials.map(|c| TomlAuth::UsernamePassword {
80                username: c.username,
81                password: c.password,
82            }),
83            protocol: Some(oci_protocol_string(&value.client_config.protocol)),
84        }
85    }
86}
87
88#[derive(Deserialize, Serialize)]
89#[serde(untagged)]
90#[serde(deny_unknown_fields)]
91enum TomlAuth {
92    #[serde(serialize_with = "serialize_secret")]
93    Base64(SecretString),
94    UsernamePassword {
95        username: String,
96        #[serde(serialize_with = "serialize_secret")]
97        password: SecretString,
98    },
99}
100
101#[derive(Clone, Debug)]
102pub struct BasicCredentials {
103    pub username: String,
104    pub password: SecretString,
105}
106
107const OCI_AUTH_BASE64: GeneralPurpose = GeneralPurpose::new(
108    &base64::alphabet::STANDARD,
109    GeneralPurposeConfig::new().with_decode_padding_mode(DecodePaddingMode::Indifferent),
110);
111
112impl TryFrom<TomlAuth> for BasicCredentials {
113    type Error = anyhow::Error;
114
115    fn try_from(value: TomlAuth) -> Result<Self, Self::Error> {
116        match value {
117            TomlAuth::Base64(b64) => {
118                fn decode_b64_creds(b64: &str) -> anyhow::Result<BasicCredentials> {
119                    let bs = OCI_AUTH_BASE64.decode(b64)?;
120                    let s = String::from_utf8(bs)?;
121                    let (username, password) = s
122                        .split_once(':')
123                        .context("expected <username>:<password> but no ':' found")?;
124                    Ok(BasicCredentials {
125                        username: username.into(),
126                        password: password.to_string().into(),
127                    })
128                }
129                decode_b64_creds(b64.expose_secret()).context("invalid base64-encoded creds")
130            }
131            TomlAuth::UsernamePassword { username, password } => {
132                Ok(BasicCredentials { username, password })
133            }
134        }
135    }
136}
137
138fn oci_client_protocol(text: &str) -> Result<oci_client::client::ClientProtocol, Error> {
139    match text {
140        "http" => Ok(oci_client::client::ClientProtocol::Http),
141        "https" => Ok(oci_client::client::ClientProtocol::Https),
142        _ => Err(Error::InvalidConfig(anyhow::anyhow!(
143            "Unknown OCI protocol {text:?}"
144        ))),
145    }
146}
147
148fn oci_protocol_string(protocol: &oci_client::client::ClientProtocol) -> String {
149    match protocol {
150        oci_client::client::ClientProtocol::Http => "http".into(),
151        oci_client::client::ClientProtocol::Https => "https".into(),
152        // Default to https if not specified
153        _ => "https".into(),
154    }
155}
156
157fn serialize_secret<S: Serializer>(
158    secret: &SecretString,
159    serializer: S,
160) -> Result<S::Ok, S::Error> {
161    secret.expose_secret().serialize(serializer)
162}
163
164#[cfg(test)]
165mod tests {
166    use wasm_pkg_common::config::RegistryMapping;
167
168    use crate::oci::OciRegistryMetadata;
169
170    use super::*;
171
172    #[test]
173    fn smoke_test() {
174        let toml_config = r#"
175            [registry."example.com"]
176            type = "oci"
177            [registry."example.com".oci]
178            auth = { username = "open", password = "sesame" }
179            protocol = "http"
180
181            [registry."wasi.dev"]
182            type = "oci"
183            [registry."wasi.dev".oci]
184            auth = "cGluZzpwb25n"
185        "#;
186        let cfg = wasm_pkg_common::config::Config::from_toml(toml_config).unwrap();
187
188        let oci_config: OciRegistryConfig = cfg
189            .registry_config(&"example.com".parse().unwrap())
190            .unwrap()
191            .try_into()
192            .unwrap();
193        let BasicCredentials { username, password } = oci_config.credentials.as_ref().unwrap();
194        assert_eq!(username, "open");
195        assert_eq!(password.expose_secret(), "sesame");
196        assert_eq!(
197            oci_client::client::ClientProtocol::Http,
198            oci_config.client_config.protocol
199        );
200
201        let oci_config: OciRegistryConfig = cfg
202            .registry_config(&"wasi.dev".parse().unwrap())
203            .unwrap()
204            .try_into()
205            .unwrap();
206        let BasicCredentials { username, password } = oci_config.credentials.as_ref().unwrap();
207        assert_eq!(username, "ping");
208        assert_eq!(password.expose_secret(), "pong");
209    }
210
211    #[test]
212    fn test_roundtrip() {
213        let config = OciRegistryConfig {
214            client_config: oci_client::client::ClientConfig {
215                protocol: oci_client::client::ClientProtocol::Http,
216                ..Default::default()
217            },
218            credentials: Some(BasicCredentials {
219                username: "open".into(),
220                password: SecretString::new("sesame".into()),
221            }),
222        };
223
224        // Set the data and then try to load it back
225        let mut conf = crate::Config::empty();
226
227        let registry: crate::Registry = "example.com:8080".parse().unwrap();
228        let reg_conf = conf.get_or_insert_registry_config_mut(&registry);
229        reg_conf
230            .set_backend_config("oci", &config)
231            .expect("Unable to set config");
232
233        let reg_conf = conf.registry_config(&registry).unwrap();
234
235        let roundtripped = OciRegistryConfig::try_from(reg_conf).expect("Unable to load config");
236        assert_eq!(
237            roundtripped.client_config.protocol, config.client_config.protocol,
238            "Home url should be set to the right value"
239        );
240        let creds = config.credentials.unwrap();
241        let roundtripped_creds = roundtripped.credentials.expect("Should have creds");
242        assert_eq!(
243            creds.username, roundtripped_creds.username,
244            "Username should be set to the right value"
245        );
246        assert_eq!(
247            creds.password.expose_secret(),
248            roundtripped_creds.password.expose_secret(),
249            "Password should be set to the right value"
250        );
251    }
252
253    #[test]
254    fn test_custom_namespace_config() {
255        let toml_config = toml::toml! {
256            [namespace_registries]
257            test = { registry = "localhost:1234", metadata = { preferredProtocol = "oci", "oci" = { registry = "ghcr.io", namespacePrefix = "webassembly/" } } }
258        };
259
260        let cfg = wasm_pkg_common::config::Config::from_toml(&toml_config.to_string())
261            .expect("Should be able to load config");
262
263        let ns_config = cfg
264            .namespace_registry(&"test".parse().unwrap())
265            .expect("Should have a namespace config");
266        let custom = match ns_config {
267            RegistryMapping::Custom(c) => c,
268            _ => panic!("Should have a custom namespace config"),
269        };
270        let map: OciRegistryMetadata = custom
271            .metadata
272            .protocol_config("oci")
273            .expect("Should be able to deserialize config")
274            .expect("protocol config should be present");
275        assert_eq!(map.namespace_prefix, Some("webassembly/".to_string()));
276        assert_eq!(map.registry, Some("ghcr.io".to_string()));
277    }
278}