1use anyhow::Context;
2use base64::{
3 engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig},
4 Engine,
5};
6use oci_client::client::ClientConfig;
7use secrecy::{ExposeSecret, SecretString};
8use serde::{Deserialize, Serialize, Serializer};
9use wasm_pkg_common::{config::RegistryConfig, Error};
10
11#[derive(Default, Serialize)]
15#[serde(into = "OciRegistryConfigToml")]
16pub struct OciRegistryConfig {
17 pub client_config: ClientConfig,
18 pub credentials: Option<BasicCredentials>,
19}
20
21impl Clone for OciRegistryConfig {
22 fn clone(&self) -> Self {
23 let client_config = ClientConfig {
24 protocol: self.client_config.protocol.clone(),
25 extra_root_certificates: self.client_config.extra_root_certificates.clone(),
26 platform_resolver: None,
27 https_proxy: self.client_config.https_proxy.clone(),
28 no_proxy: self.client_config.no_proxy.clone(),
29 ..self.client_config
30 };
31 Self {
32 client_config,
33 credentials: self.credentials.clone(),
34 }
35 }
36}
37
38impl std::fmt::Debug for OciRegistryConfig {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("OciConfig")
41 .field("client_config", &"...")
42 .field("credentials", &self.credentials)
43 .finish()
44 }
45}
46
47impl TryFrom<&RegistryConfig> for OciRegistryConfig {
48 type Error = Error;
49
50 fn try_from(registry_config: &RegistryConfig) -> Result<Self, Self::Error> {
51 let OciRegistryConfigToml { auth, protocol } =
52 registry_config.backend_config("oci")?.unwrap_or_default();
53 let mut client_config = ClientConfig::default();
54 if let Some(protocol) = protocol {
55 client_config.protocol = oci_client_protocol(&protocol)?;
56 };
57 let credentials = auth
58 .map(TryInto::try_into)
59 .transpose()
60 .map_err(Error::InvalidConfig)?;
61 Ok(Self {
62 client_config,
63 credentials,
64 })
65 }
66}
67
68#[derive(Default, Deserialize, Serialize)]
69struct OciRegistryConfigToml {
70 auth: Option<TomlAuth>,
71 protocol: Option<String>,
72}
73
74impl From<OciRegistryConfig> for OciRegistryConfigToml {
75 fn from(value: OciRegistryConfig) -> Self {
76 OciRegistryConfigToml {
77 auth: value.credentials.map(|c| TomlAuth::UsernamePassword {
78 username: c.username,
79 password: c.password,
80 }),
81 protocol: Some(oci_protocol_string(&value.client_config.protocol)),
82 }
83 }
84}
85
86#[derive(Deserialize, Serialize)]
87#[serde(untagged)]
88#[serde(deny_unknown_fields)]
89enum TomlAuth {
90 #[serde(serialize_with = "serialize_secret")]
91 Base64(SecretString),
92 UsernamePassword {
93 username: String,
94 #[serde(serialize_with = "serialize_secret")]
95 password: SecretString,
96 },
97}
98
99#[derive(Clone, Debug)]
100pub struct BasicCredentials {
101 pub username: String,
102 pub password: SecretString,
103}
104
105const OCI_AUTH_BASE64: GeneralPurpose = GeneralPurpose::new(
106 &base64::alphabet::STANDARD,
107 GeneralPurposeConfig::new().with_decode_padding_mode(DecodePaddingMode::Indifferent),
108);
109
110impl TryFrom<TomlAuth> for BasicCredentials {
111 type Error = anyhow::Error;
112
113 fn try_from(value: TomlAuth) -> Result<Self, Self::Error> {
114 match value {
115 TomlAuth::Base64(b64) => {
116 fn decode_b64_creds(b64: &str) -> anyhow::Result<BasicCredentials> {
117 let bs = OCI_AUTH_BASE64.decode(b64)?;
118 let s = String::from_utf8(bs)?;
119 let (username, password) = s
120 .split_once(':')
121 .context("expected <username>:<password> but no ':' found")?;
122 Ok(BasicCredentials {
123 username: username.into(),
124 password: password.to_string().into(),
125 })
126 }
127 decode_b64_creds(b64.expose_secret()).context("invalid base64-encoded creds")
128 }
129 TomlAuth::UsernamePassword { username, password } => {
130 Ok(BasicCredentials { username, password })
131 }
132 }
133 }
134}
135
136fn oci_client_protocol(text: &str) -> Result<oci_client::client::ClientProtocol, Error> {
137 match text {
138 "http" => Ok(oci_client::client::ClientProtocol::Http),
139 "https" => Ok(oci_client::client::ClientProtocol::Https),
140 _ => Err(Error::InvalidConfig(anyhow::anyhow!(
141 "Unknown OCI protocol {text:?}"
142 ))),
143 }
144}
145
146fn oci_protocol_string(protocol: &oci_client::client::ClientProtocol) -> String {
147 match protocol {
148 oci_client::client::ClientProtocol::Http => "http".into(),
149 oci_client::client::ClientProtocol::Https => "https".into(),
150 _ => "https".into(),
152 }
153}
154
155fn serialize_secret<S: Serializer>(
156 secret: &SecretString,
157 serializer: S,
158) -> Result<S::Ok, S::Error> {
159 secret.expose_secret().serialize(serializer)
160}
161
162#[cfg(test)]
163mod tests {
164 use wasm_pkg_common::config::RegistryMapping;
165
166 use crate::oci::OciRegistryMetadata;
167
168 use super::*;
169
170 #[test]
171 fn smoke_test() {
172 let toml_config = r#"
173 [registry."example.com"]
174 type = "oci"
175 [registry."example.com".oci]
176 auth = { username = "open", password = "sesame" }
177 protocol = "http"
178
179 [registry."wasi.dev"]
180 type = "oci"
181 [registry."wasi.dev".oci]
182 auth = "cGluZzpwb25n"
183 "#;
184 let cfg = wasm_pkg_common::config::Config::from_toml(toml_config).unwrap();
185
186 let oci_config: OciRegistryConfig = cfg
187 .registry_config(&"example.com".parse().unwrap())
188 .unwrap()
189 .try_into()
190 .unwrap();
191 let BasicCredentials { username, password } = oci_config.credentials.as_ref().unwrap();
192 assert_eq!(username, "open");
193 assert_eq!(password.expose_secret(), "sesame");
194 assert_eq!(
195 oci_client::client::ClientProtocol::Http,
196 oci_config.client_config.protocol
197 );
198
199 let oci_config: OciRegistryConfig = cfg
200 .registry_config(&"wasi.dev".parse().unwrap())
201 .unwrap()
202 .try_into()
203 .unwrap();
204 let BasicCredentials { username, password } = oci_config.credentials.as_ref().unwrap();
205 assert_eq!(username, "ping");
206 assert_eq!(password.expose_secret(), "pong");
207 }
208
209 #[test]
210 fn test_roundtrip() {
211 let config = OciRegistryConfig {
212 client_config: oci_client::client::ClientConfig {
213 protocol: oci_client::client::ClientProtocol::Http,
214 ..Default::default()
215 },
216 credentials: Some(BasicCredentials {
217 username: "open".into(),
218 password: SecretString::new("sesame".into()),
219 }),
220 };
221
222 let mut conf = crate::Config::empty();
224
225 let registry: crate::Registry = "example.com:8080".parse().unwrap();
226 let reg_conf = conf.get_or_insert_registry_config_mut(®istry);
227 reg_conf
228 .set_backend_config("oci", &config)
229 .expect("Unable to set config");
230
231 let reg_conf = conf.registry_config(®istry).unwrap();
232
233 let roundtripped = OciRegistryConfig::try_from(reg_conf).expect("Unable to load config");
234 assert_eq!(
235 roundtripped.client_config.protocol, config.client_config.protocol,
236 "Home url should be set to the right value"
237 );
238 let creds = config.credentials.unwrap();
239 let roundtripped_creds = roundtripped.credentials.expect("Should have creds");
240 assert_eq!(
241 creds.username, roundtripped_creds.username,
242 "Username should be set to the right value"
243 );
244 assert_eq!(
245 creds.password.expose_secret(),
246 roundtripped_creds.password.expose_secret(),
247 "Password should be set to the right value"
248 );
249 }
250
251 #[test]
252 fn test_custom_namespace_config() {
253 let toml_config = toml::toml! {
254 [namespace_registries]
255 test = { registry = "localhost:1234", metadata = { preferredProtocol = "oci", "oci" = { registry = "ghcr.io", namespacePrefix = "webassembly/" } } }
256 };
257
258 let cfg = wasm_pkg_common::config::Config::from_toml(&toml_config.to_string())
259 .expect("Should be able to load config");
260
261 let ns_config = cfg
262 .namespace_registry(&"test".parse().unwrap())
263 .expect("Should have a namespace config");
264 let custom = match ns_config {
265 RegistryMapping::Custom(c) => c,
266 _ => panic!("Should have a custom namespace config"),
267 };
268 let map: OciRegistryMetadata = custom
269 .metadata
270 .protocol_config("oci")
271 .expect("Should be able to deserialize config")
272 .expect("protocol config should be present");
273 assert_eq!(map.namespace_prefix, Some("webassembly/".to_string()));
274 assert_eq!(map.registry, Some("ghcr.io".to_string()));
275 }
276}