warg_loader/config/
toml.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4};
5
6use anyhow::Context;
7use base64::{
8    engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig},
9    Engine,
10};
11use secrecy::{ExposeSecret, SecretString};
12use serde::Deserialize;
13
14use crate::{
15    source::{local::LocalConfig, oci::OciConfig, warg::WargConfig},
16    Error,
17};
18
19use super::BasicCredentials;
20
21impl super::ClientConfig {
22    pub fn from_toml(s: &str) -> Result<Self, Error> {
23        let toml_cfg: TomlConfig = toml::from_str(s)
24            .context("error parsing TOML")
25            .map_err(Error::InvalidConfig)?;
26        toml_cfg.try_into().map_err(Error::InvalidConfig)
27    }
28
29    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, Error> {
30        tracing::debug!("Reading config file from {:?}", path.as_ref());
31        Self::from_toml(std::fs::read_to_string(path)?.as_str())
32    }
33
34    pub fn from_default_file() -> Result<Option<Self>, Error> {
35        let Some(config_dir) = dirs::config_dir() else {
36            return Ok(None);
37        };
38        let path = config_dir.join("warg").join("config.toml");
39        if !path.exists() {
40            return Ok(None);
41        }
42        Ok(Some(Self::from_file(path)?))
43    }
44}
45
46#[derive(Deserialize)]
47#[serde(deny_unknown_fields)]
48struct TomlConfig {
49    default_registry: Option<String>,
50    #[serde(default)]
51    namespace: HashMap<String, TomlNamespaceConfig>,
52    #[serde(default)]
53    registry: HashMap<String, TomlRegistryConfig>,
54}
55
56impl TryFrom<TomlConfig> for super::ClientConfig {
57    type Error = anyhow::Error;
58
59    fn try_from(value: TomlConfig) -> Result<Self, Self::Error> {
60        let TomlConfig {
61            default_registry,
62            namespace,
63            registry,
64        } = value;
65        let namespace_registries = namespace
66            .into_iter()
67            .map(|(name, config)| (name, config.registry))
68            .collect();
69        let registry_configs = registry
70            .into_iter()
71            .map(|(k, v)| Ok((k, v.try_into()?)))
72            .collect::<Result<_, Self::Error>>()?;
73        Ok(Self {
74            default_registry,
75            namespace_registries,
76            registry_configs,
77        })
78    }
79}
80
81#[derive(Deserialize)]
82#[serde(deny_unknown_fields)]
83struct TomlNamespaceConfig {
84    registry: String,
85}
86
87#[derive(Deserialize)]
88#[serde(tag = "type", rename_all = "snake_case")]
89#[serde(deny_unknown_fields)]
90enum TomlRegistryConfig {
91    Local {
92        root: PathBuf,
93    },
94    Oci {
95        auth: Option<TomlAuth>,
96    },
97    Warg {
98        auth_token: Option<SecretString>,
99        config_file: Option<PathBuf>,
100    },
101}
102
103impl TryFrom<TomlRegistryConfig> for super::RegistryConfig {
104    type Error = anyhow::Error;
105
106    fn try_from(value: TomlRegistryConfig) -> Result<Self, Self::Error> {
107        Ok(match value {
108            TomlRegistryConfig::Local { root } => Self::Local(LocalConfig { root }),
109            TomlRegistryConfig::Oci { auth } => {
110                let credentials = auth.map(TryInto::try_into).transpose()?;
111                Self::Oci(OciConfig {
112                    client_config: None,
113                    credentials,
114                })
115            }
116            TomlRegistryConfig::Warg {
117                auth_token,
118                config_file,
119            } => {
120                let client_config = match config_file {
121                    Some(path) => warg_client::Config::from_file(path)?,
122                    None => warg_client::Config::from_default_file()?.unwrap_or_default(),
123                };
124                Self::Warg(WargConfig {
125                    auth_token,
126                    client_config,
127                })
128            }
129        })
130    }
131}
132
133#[derive(Deserialize)]
134#[serde(untagged)]
135#[serde(deny_unknown_fields)]
136enum TomlAuth {
137    Base64(SecretString),
138    UsernamePassword {
139        username: String,
140        password: SecretString,
141    },
142}
143
144const OCI_AUTH_BASE64: GeneralPurpose = GeneralPurpose::new(
145    &base64::alphabet::STANDARD,
146    GeneralPurposeConfig::new().with_decode_padding_mode(DecodePaddingMode::Indifferent),
147);
148
149impl TryFrom<TomlAuth> for BasicCredentials {
150    type Error = anyhow::Error;
151
152    fn try_from(value: TomlAuth) -> Result<Self, Self::Error> {
153        match value {
154            TomlAuth::Base64(b64) => {
155                fn decode_b64_creds(b64: &str) -> anyhow::Result<BasicCredentials> {
156                    let bs = OCI_AUTH_BASE64.decode(b64)?;
157                    let s = String::from_utf8(bs)?;
158                    let (username, password) = s
159                        .split_once(':')
160                        .context("expected <username>:<password> but no ':' found")?;
161                    Ok(BasicCredentials {
162                        username: username.into(),
163                        password: password.to_string().into(),
164                    })
165                }
166                decode_b64_creds(b64.expose_secret()).context("invalid base64-encoded creds")
167            }
168            TomlAuth::UsernamePassword { username, password } => {
169                Ok(BasicCredentials { username, password })
170            }
171        }
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use crate::config::{ClientConfig, RegistryConfig};
178
179    use super::*;
180
181    #[test]
182    fn smoke_test() {
183        let toml_config = r#"
184            default_registry = "example.com"
185
186            [namespace.wasi]
187            registry = "wasi.dev"
188
189            [registry."example.com"]
190            type = "oci"
191            auth = { username = "open", password = "sesame" }
192
193            [registry."wasi.dev"]
194            type = "oci"
195            auth = "cGluZzpwb25n"
196        "#;
197        let cfg = ClientConfig::from_toml(toml_config).unwrap();
198
199        assert_eq!(cfg.default_registry.as_deref(), Some("example.com"));
200        assert_eq!(cfg.namespace_registries["wasi"], "wasi.dev");
201
202        let RegistryConfig::Oci(oci_config) = &cfg.registry_configs["example.com"] else {
203            panic!("not an oci config");
204        };
205        let BasicCredentials { username, password } = oci_config.credentials.as_ref().unwrap();
206        assert_eq!(username, "open");
207        assert_eq!(password.expose_secret(), "sesame");
208
209        let RegistryConfig::Oci(oci_config) = &cfg.registry_configs["wasi.dev"] else {
210            panic!("not an oci config");
211        };
212        let BasicCredentials { username, password } = oci_config.credentials.as_ref().unwrap();
213        assert_eq!(username, "ping");
214        assert_eq!(password.expose_secret(), "pong");
215    }
216}