1use super::DockerConfigAuth;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::PathBuf;
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum AuthSource {
15 #[default]
17 Anonymous,
18
19 Basic { username: String, password: String },
21
22 DockerConfig,
24
25 EnvVar {
27 username_var: String,
28 password_var: String,
29 },
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34pub struct RegistryAuthConfig {
35 pub registry: String,
37
38 pub source: AuthSource,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
44pub struct AuthConfig {
45 #[serde(default)]
47 pub registries: Vec<RegistryAuthConfig>,
48
49 #[serde(default)]
51 pub default: AuthSource,
52
53 pub docker_config_path: Option<PathBuf>,
55}
56
57impl Default for AuthConfig {
58 fn default() -> Self {
59 Self {
60 registries: Vec::new(),
61 default: AuthSource::DockerConfig,
62 docker_config_path: None,
63 }
64 }
65}
66
67pub struct AuthResolver {
69 config: AuthConfig,
70 docker_config: Option<DockerConfigAuth>,
71 registry_map: HashMap<String, AuthSource>,
72}
73
74impl AuthResolver {
75 pub fn new(config: AuthConfig) -> Self {
77 let registry_map: HashMap<String, AuthSource> = config
79 .registries
80 .iter()
81 .map(|r| (r.registry.clone(), r.source.clone()))
82 .collect();
83
84 let needs_docker_config = config.default == AuthSource::DockerConfig
86 || registry_map
87 .values()
88 .any(|s| matches!(s, AuthSource::DockerConfig));
89
90 let docker_config = if needs_docker_config {
91 Self::load_docker_config(&config.docker_config_path)
92 } else {
93 None
94 };
95
96 Self {
97 config,
98 docker_config,
99 registry_map,
100 }
101 }
102
103 pub fn resolve(&self, image: &str) -> oci_client::secrets::RegistryAuth {
108 let registry = Self::extract_registry(image);
109 let source = self
110 .registry_map
111 .get(®istry)
112 .unwrap_or(&self.config.default);
113
114 self.resolve_source(source, ®istry)
115 }
116
117 fn resolve_source(
119 &self,
120 source: &AuthSource,
121 registry: &str,
122 ) -> oci_client::secrets::RegistryAuth {
123 match source {
124 AuthSource::Anonymous => oci_client::secrets::RegistryAuth::Anonymous,
125
126 AuthSource::Basic { username, password } => {
127 oci_client::secrets::RegistryAuth::Basic(username.clone(), password.clone())
128 }
129
130 AuthSource::DockerConfig => {
131 if let Some(ref docker_config) = self.docker_config {
132 if let Some((username, password)) = docker_config.get_credentials(registry) {
133 return oci_client::secrets::RegistryAuth::Basic(username, password);
134 }
135 }
136 oci_client::secrets::RegistryAuth::Anonymous
138 }
139
140 AuthSource::EnvVar {
141 username_var,
142 password_var,
143 } => {
144 let username = std::env::var(username_var).unwrap_or_default();
145 let password = std::env::var(password_var).unwrap_or_default();
146
147 if !username.is_empty() && !password.is_empty() {
148 oci_client::secrets::RegistryAuth::Basic(username, password)
149 } else {
150 oci_client::secrets::RegistryAuth::Anonymous
151 }
152 }
153 }
154 }
155
156 fn extract_registry(image: &str) -> String {
163 let image_without_digest = image.split('@').next().unwrap_or(image);
165
166 let parts: Vec<&str> = image_without_digest.split('/').collect();
168
169 if parts.len() == 1 {
171 return "docker.io".to_string();
172 }
173
174 let first_part = parts[0];
176 if first_part.contains('.') || first_part.contains(':') || first_part == "localhost" {
177 first_part.to_string()
178 } else {
179 "docker.io".to_string()
181 }
182 }
183
184 fn load_docker_config(path: &Option<PathBuf>) -> Option<DockerConfigAuth> {
186 let config = if let Some(path) = path {
187 DockerConfigAuth::load_from_path(path).ok()
188 } else {
189 DockerConfigAuth::load().ok()
190 };
191
192 if config.is_none() {
193 tracing::debug!("Failed to load Docker config, using anonymous auth as fallback");
194 }
195
196 config
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_extract_registry() {
206 assert_eq!(AuthResolver::extract_registry("ubuntu"), "docker.io");
207 assert_eq!(AuthResolver::extract_registry("ubuntu:latest"), "docker.io");
208 assert_eq!(
209 AuthResolver::extract_registry("library/ubuntu"),
210 "docker.io"
211 );
212 assert_eq!(
213 AuthResolver::extract_registry("ghcr.io/owner/repo"),
214 "ghcr.io"
215 );
216 assert_eq!(
217 AuthResolver::extract_registry("ghcr.io/owner/repo:tag"),
218 "ghcr.io"
219 );
220 assert_eq!(
221 AuthResolver::extract_registry("localhost:5000/image"),
222 "localhost:5000"
223 );
224 assert_eq!(
225 AuthResolver::extract_registry("myregistry.com/path/to/image:v1.0"),
226 "myregistry.com"
227 );
228 }
229
230 #[test]
231 fn test_anonymous_auth() {
232 let config = AuthConfig {
233 default: AuthSource::Anonymous,
234 ..Default::default()
235 };
236
237 let resolver = AuthResolver::new(config);
238 let auth = resolver.resolve("ubuntu:latest");
239
240 assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
241 }
242
243 #[test]
244 fn test_basic_auth() {
245 let config = AuthConfig {
246 default: AuthSource::Basic {
247 username: "user".to_string(),
248 password: "pass".to_string(),
249 },
250 ..Default::default()
251 };
252
253 let resolver = AuthResolver::new(config);
254 let auth = resolver.resolve("ubuntu:latest");
255
256 match auth {
257 oci_client::secrets::RegistryAuth::Basic(username, password) => {
258 assert_eq!(username, "user");
259 assert_eq!(password, "pass");
260 }
261 _ => panic!("Expected Basic auth"),
262 }
263 }
264
265 #[test]
266 fn test_per_registry_auth() {
267 let config = AuthConfig {
268 registries: vec![RegistryAuthConfig {
269 registry: "ghcr.io".to_string(),
270 source: AuthSource::Basic {
271 username: "ghcr_user".to_string(),
272 password: "ghcr_pass".to_string(),
273 },
274 }],
275 default: AuthSource::Anonymous,
276 ..Default::default()
277 };
278
279 let resolver = AuthResolver::new(config);
280
281 let auth = resolver.resolve("ghcr.io/owner/repo:tag");
283 match auth {
284 oci_client::secrets::RegistryAuth::Basic(username, password) => {
285 assert_eq!(username, "ghcr_user");
286 assert_eq!(password, "ghcr_pass");
287 }
288 _ => panic!("Expected Basic auth for ghcr.io"),
289 }
290
291 let auth = resolver.resolve("ubuntu:latest");
293 assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
294 }
295
296 #[test]
297 fn test_env_var_auth() {
298 std::env::set_var("TEST_USERNAME", "env_user");
299 std::env::set_var("TEST_PASSWORD", "env_pass");
300
301 let config = AuthConfig {
302 default: AuthSource::EnvVar {
303 username_var: "TEST_USERNAME".to_string(),
304 password_var: "TEST_PASSWORD".to_string(),
305 },
306 ..Default::default()
307 };
308
309 let resolver = AuthResolver::new(config);
310 let auth = resolver.resolve("ubuntu:latest");
311
312 match auth {
313 oci_client::secrets::RegistryAuth::Basic(username, password) => {
314 assert_eq!(username, "env_user");
315 assert_eq!(password, "env_pass");
316 }
317 _ => panic!("Expected Basic auth from env vars"),
318 }
319
320 std::env::remove_var("TEST_USERNAME");
321 std::env::remove_var("TEST_PASSWORD");
322 }
323
324 #[test]
325 fn test_env_var_auth_fallback() {
326 let config = AuthConfig {
328 default: AuthSource::EnvVar {
329 username_var: "NONEXISTENT_USER".to_string(),
330 password_var: "NONEXISTENT_PASS".to_string(),
331 },
332 ..Default::default()
333 };
334
335 let resolver = AuthResolver::new(config);
336 let auth = resolver.resolve("ubuntu:latest");
337
338 assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
339 }
340}